1use std::any::Any;
7use std::fmt;
8use std::sync::Arc;
9
10use arrow_array::{Array, Float32Array, RecordBatch, StringArray, UInt32Array};
11use arrow_schema::{DataType, Field, Schema, SchemaRef};
12use datafusion_common::Result;
13use datafusion_execution::{SendableRecordBatchStream, TaskContext};
14use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
15use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
16use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
17
18#[derive(Debug, Clone)]
20pub struct TopicLoomConfig {
21 pub similarity_threshold: f32,
23 pub max_clusters: usize,
25}
26
27impl Default for TopicLoomConfig {
28 fn default() -> Self {
29 Self {
30 similarity_threshold: 0.6,
31 max_clusters: 20,
32 }
33 }
34}
35
36#[derive(Debug)]
44pub struct TopicLoomExec {
45 input: Arc<dyn ExecutionPlan>,
46 schema: SchemaRef,
47 properties: PlanProperties,
48 config: TopicLoomConfig,
49}
50
51impl TopicLoomExec {
52 pub fn new(input: Arc<dyn ExecutionPlan>, config: TopicLoomConfig) -> Self {
53 let schema = Self::output_schema();
54 let properties = PlanProperties::new(
55 datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
56 datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
57 EmissionType::Final,
58 Boundedness::Bounded,
59 );
60 Self {
61 input,
62 schema,
63 properties,
64 config,
65 }
66 }
67
68 pub fn output_schema() -> SchemaRef {
69 Arc::new(Schema::new(vec![
70 Field::new("topic_id", DataType::UInt32, false),
71 Field::new("topic_label", DataType::Utf8, false),
72 Field::new("memory_id", DataType::Utf8, false),
73 Field::new("relevance_score", DataType::Float32, false),
74 ]))
75 }
76}
77
78impl DisplayAs for TopicLoomExec {
79 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80 write!(
81 f,
82 "TopicLoomExec: threshold={}, max_clusters={}",
83 self.config.similarity_threshold, self.config.max_clusters
84 )
85 }
86}
87
88impl ExecutionPlan for TopicLoomExec {
89 fn name(&self) -> &str {
90 "TopicLoomExec"
91 }
92
93 fn as_any(&self) -> &dyn Any {
94 self
95 }
96
97 fn schema(&self) -> SchemaRef {
98 self.schema.clone()
99 }
100
101 fn properties(&self) -> &PlanProperties {
102 &self.properties
103 }
104
105 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
106 vec![&self.input]
107 }
108
109 fn with_new_children(
110 self: Arc<Self>,
111 children: Vec<Arc<dyn ExecutionPlan>>,
112 ) -> Result<Arc<dyn ExecutionPlan>> {
113 Ok(Arc::new(Self::new(
114 children[0].clone(),
115 self.config.clone(),
116 )))
117 }
118
119 fn execute(
120 &self,
121 partition: usize,
122 context: Arc<TaskContext>,
123 ) -> Result<SendableRecordBatchStream> {
124 let input = self.input.execute(partition, context)?;
125 let schema = self.schema.clone();
126 let stream_schema = schema.clone();
127 let config = self.config.clone();
128
129 let fut = async move {
130 use futures::StreamExt;
131
132 let mut records: Vec<(String, String, Option<String>)> = Vec::new(); let mut stream = input;
136 while let Some(batch) = stream.next().await {
137 let batch = batch?;
138
139 let id_col = batch.column_by_name("id");
140 let content_col = batch.column_by_name("content");
141 let topic_col = batch.column_by_name("topic");
142
143 if let (Some(ids), Some(contents)) = (id_col, content_col) {
144 if let (Some(id_arr), Some(content_arr)) = (
145 ids.as_any().downcast_ref::<StringArray>(),
146 contents.as_any().downcast_ref::<StringArray>(),
147 ) {
148 let topics = topic_col
149 .and_then(|c| c.as_any().downcast_ref::<StringArray>().cloned());
150
151 for i in 0..id_arr.len() {
152 if id_arr.is_null(i) || content_arr.is_null(i) {
153 continue;
154 }
155 let topic = topics.as_ref().and_then(|t| {
156 if t.is_null(i) {
157 None
158 } else {
159 Some(t.value(i).to_string())
160 }
161 });
162 records.push((
163 id_arr.value(i).to_string(),
164 content_arr.value(i).to_string(),
165 topic,
166 ));
167 }
168 }
169 }
170 }
171
172 let clusters = if records.iter().any(|(_, _, t)| t.is_some()) {
174 cluster_by_explicit_topic(&records)
176 } else {
177 cluster_by_word_overlap(&records, config.similarity_threshold, config.max_clusters)
179 };
180
181 let mut topic_ids = Vec::new();
183 let mut topic_labels = Vec::new();
184 let mut memory_ids = Vec::new();
185 let mut relevance_scores = Vec::new();
186
187 for (cluster_id, label, members) in &clusters {
188 for (mem_id, score) in members {
189 topic_ids.push(*cluster_id);
190 topic_labels.push(label.clone());
191 memory_ids.push(mem_id.clone());
192 relevance_scores.push(*score);
193 }
194 }
195
196 let batch = RecordBatch::try_new(
197 schema,
198 vec![
199 Arc::new(UInt32Array::from(topic_ids)),
200 Arc::new(StringArray::from(topic_labels)),
201 Arc::new(StringArray::from(memory_ids)),
202 Arc::new(Float32Array::from(relevance_scores)),
203 ],
204 )?;
205
206 Ok(batch)
207 };
208
209 let stream = futures::stream::once(fut);
210 Ok(Box::pin(RecordBatchStreamAdapter::new(
211 stream_schema,
212 stream,
213 )))
214 }
215}
216
217fn cluster_by_explicit_topic(
221 records: &[(String, String, Option<String>)],
222) -> Vec<(u32, String, Vec<(String, f32)>)> {
223 use std::collections::HashMap;
224
225 let mut topic_map: HashMap<String, Vec<String>> = HashMap::new();
226 for (id, _content, topic) in records {
227 let t = topic.as_deref().unwrap_or("unknown").to_string();
228 topic_map.entry(t).or_default().push(id.clone());
229 }
230
231 topic_map
232 .into_iter()
233 .enumerate()
234 .map(|(idx, (label, members))| {
235 let scored: Vec<(String, f32)> = members.into_iter().map(|m| (m, 1.0)).collect();
236 (idx as u32, label, scored)
237 })
238 .collect()
239}
240
241fn cluster_by_word_overlap(
245 records: &[(String, String, Option<String>)],
246 threshold: f32,
247 max_clusters: usize,
248) -> Vec<(u32, String, Vec<(String, f32)>)> {
249 if records.is_empty() {
250 return Vec::new();
251 }
252
253 let word_sets: Vec<std::collections::HashSet<&str>> = records
255 .iter()
256 .map(|(_, content, _)| {
257 content
258 .split_whitespace()
259 .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()))
260 .filter(|w| w.len() > 2)
261 .collect()
262 })
263 .collect();
264
265 let mut clusters: Vec<Vec<usize>> = Vec::new();
267
268 for i in 0..records.len() {
269 let mut best_cluster = None;
270 let mut best_sim = 0.0_f32;
271
272 for (c_idx, cluster) in clusters.iter().enumerate() {
273 let centroid = cluster[0];
275 let sim = jaccard_similarity(&word_sets[i], &word_sets[centroid]);
276 if sim > threshold && sim > best_sim {
277 best_sim = sim;
278 best_cluster = Some(c_idx);
279 }
280 }
281
282 if let Some(c_idx) = best_cluster {
283 clusters[c_idx].push(i);
284 } else if clusters.len() < max_clusters {
285 clusters.push(vec![i]);
286 } else {
287 let closest = clusters
289 .iter()
290 .enumerate()
291 .map(|(idx, c)| {
292 let sim = jaccard_similarity(&word_sets[i], &word_sets[c[0]]);
293 (idx, sim)
294 })
295 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
296 .map(|(idx, _)| idx)
297 .unwrap_or(0);
298 clusters[closest].push(i);
299 }
300 }
301
302 clusters
304 .into_iter()
305 .enumerate()
306 .map(|(idx, member_indices)| {
307 let label = records
308 .get(member_indices[0])
309 .map(|(_, content, _)| content.chars().take(40).collect::<String>())
310 .unwrap_or_else(|| format!("cluster_{idx}"));
311
312 let members: Vec<(String, f32)> = member_indices
313 .iter()
314 .map(|&mi| {
315 let sim = if mi == member_indices[0] {
316 1.0
317 } else {
318 jaccard_similarity(&word_sets[member_indices[0]], &word_sets[mi])
319 };
320 (records[mi].0.clone(), sim)
321 })
322 .collect();
323
324 (idx as u32, label, members)
325 })
326 .collect()
327}
328
329fn jaccard_similarity(
330 a: &std::collections::HashSet<&str>,
331 b: &std::collections::HashSet<&str>,
332) -> f32 {
333 if a.is_empty() && b.is_empty() {
334 return 0.0;
335 }
336 let intersection = a.intersection(b).count();
337 let union = a.union(b).count();
338 if union == 0 {
339 0.0
340 } else {
341 intersection as f32 / union as f32
342 }
343}