Skip to main content

hirn_exec/operators/
topic_loom.rs

1//! `TopicLoomExec` — Topic-scoped timeline aggregation with Leiden-like clustering.
2//!
3//! Groups memories by topic similarity and produces per-topic narratives.
4//! Uses greedy modularity optimization (no external dep) for topic clustering.
5
6use std::any::Any;
7use std::fmt;
8use std::sync::Arc;
9
10use arrow_array::{Array, Float32Array, RecordBatch, StringArray, UInt32Array};
11use arrow_schema::{DataType, Field, Schema, SchemaRef};
12use datafusion_common::Result;
13use datafusion_execution::{SendableRecordBatchStream, TaskContext};
14use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
15use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
16use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
17
18/// Configuration for topic loom.
19#[derive(Debug, Clone)]
20pub struct TopicLoomConfig {
21    /// Minimum similarity threshold for same-topic membership.
22    pub similarity_threshold: f32,
23    /// Maximum cluster count.
24    pub max_clusters: usize,
25}
26
27impl Default for TopicLoomConfig {
28    fn default() -> Self {
29        Self {
30            similarity_threshold: 0.6,
31            max_clusters: 20,
32        }
33    }
34}
35
36/// Topic loom operator for DataFusion.
37///
38/// Input: memory records with content and optional topic columns.
39/// Output: topic_id, topic_label, memory_id, relevance_score per record.
40///
41/// When explicit `topic` column exists, groups by that.
42/// Otherwise, uses greedy modularity on content similarity.
43#[derive(Debug)]
44pub struct TopicLoomExec {
45    input: Arc<dyn ExecutionPlan>,
46    schema: SchemaRef,
47    properties: PlanProperties,
48    config: TopicLoomConfig,
49}
50
51impl TopicLoomExec {
52    pub fn new(input: Arc<dyn ExecutionPlan>, config: TopicLoomConfig) -> Self {
53        let schema = Self::output_schema();
54        let properties = PlanProperties::new(
55            datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
56            datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
57            EmissionType::Final,
58            Boundedness::Bounded,
59        );
60        Self {
61            input,
62            schema,
63            properties,
64            config,
65        }
66    }
67
68    pub fn output_schema() -> SchemaRef {
69        Arc::new(Schema::new(vec![
70            Field::new("topic_id", DataType::UInt32, false),
71            Field::new("topic_label", DataType::Utf8, false),
72            Field::new("memory_id", DataType::Utf8, false),
73            Field::new("relevance_score", DataType::Float32, false),
74        ]))
75    }
76}
77
78impl DisplayAs for TopicLoomExec {
79    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
80        write!(
81            f,
82            "TopicLoomExec: threshold={}, max_clusters={}",
83            self.config.similarity_threshold, self.config.max_clusters
84        )
85    }
86}
87
88impl ExecutionPlan for TopicLoomExec {
89    fn name(&self) -> &str {
90        "TopicLoomExec"
91    }
92
93    fn as_any(&self) -> &dyn Any {
94        self
95    }
96
97    fn schema(&self) -> SchemaRef {
98        self.schema.clone()
99    }
100
101    fn properties(&self) -> &PlanProperties {
102        &self.properties
103    }
104
105    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
106        vec![&self.input]
107    }
108
109    fn with_new_children(
110        self: Arc<Self>,
111        children: Vec<Arc<dyn ExecutionPlan>>,
112    ) -> Result<Arc<dyn ExecutionPlan>> {
113        Ok(Arc::new(Self::new(
114            children[0].clone(),
115            self.config.clone(),
116        )))
117    }
118
119    fn execute(
120        &self,
121        partition: usize,
122        context: Arc<TaskContext>,
123    ) -> Result<SendableRecordBatchStream> {
124        let input = self.input.execute(partition, context)?;
125        let schema = self.schema.clone();
126        let stream_schema = schema.clone();
127        let config = self.config.clone();
128
129        let fut = async move {
130            use futures::StreamExt;
131
132            // Collect all records first.
133            let mut records: Vec<(String, String, Option<String>)> = Vec::new(); // (id, content, topic)
134
135            let mut stream = input;
136            while let Some(batch) = stream.next().await {
137                let batch = batch?;
138
139                let id_col = batch.column_by_name("id");
140                let content_col = batch.column_by_name("content");
141                let topic_col = batch.column_by_name("topic");
142
143                if let (Some(ids), Some(contents)) = (id_col, content_col) {
144                    if let (Some(id_arr), Some(content_arr)) = (
145                        ids.as_any().downcast_ref::<StringArray>(),
146                        contents.as_any().downcast_ref::<StringArray>(),
147                    ) {
148                        let topics = topic_col
149                            .and_then(|c| c.as_any().downcast_ref::<StringArray>().cloned());
150
151                        for i in 0..id_arr.len() {
152                            if id_arr.is_null(i) || content_arr.is_null(i) {
153                                continue;
154                            }
155                            let topic = topics.as_ref().and_then(|t| {
156                                if t.is_null(i) {
157                                    None
158                                } else {
159                                    Some(t.value(i).to_string())
160                                }
161                            });
162                            records.push((
163                                id_arr.value(i).to_string(),
164                                content_arr.value(i).to_string(),
165                                topic,
166                            ));
167                        }
168                    }
169                }
170            }
171
172            // Cluster records by topic.
173            let clusters = if records.iter().any(|(_, _, t)| t.is_some()) {
174                // Use explicit topics.
175                cluster_by_explicit_topic(&records)
176            } else {
177                // Greedy modularity clustering by word overlap.
178                cluster_by_word_overlap(&records, config.similarity_threshold, config.max_clusters)
179            };
180
181            // Build output.
182            let mut topic_ids = Vec::new();
183            let mut topic_labels = Vec::new();
184            let mut memory_ids = Vec::new();
185            let mut relevance_scores = Vec::new();
186
187            for (cluster_id, label, members) in &clusters {
188                for (mem_id, score) in members {
189                    topic_ids.push(*cluster_id);
190                    topic_labels.push(label.clone());
191                    memory_ids.push(mem_id.clone());
192                    relevance_scores.push(*score);
193                }
194            }
195
196            let batch = RecordBatch::try_new(
197                schema,
198                vec![
199                    Arc::new(UInt32Array::from(topic_ids)),
200                    Arc::new(StringArray::from(topic_labels)),
201                    Arc::new(StringArray::from(memory_ids)),
202                    Arc::new(Float32Array::from(relevance_scores)),
203                ],
204            )?;
205
206            Ok(batch)
207        };
208
209        let stream = futures::stream::once(fut);
210        Ok(Box::pin(RecordBatchStreamAdapter::new(
211            stream_schema,
212            stream,
213        )))
214    }
215}
216
217// ── Clustering helpers ─────────────────────────────────────────────────
218
219/// Cluster by explicit topic column.
220fn cluster_by_explicit_topic(
221    records: &[(String, String, Option<String>)],
222) -> Vec<(u32, String, Vec<(String, f32)>)> {
223    use std::collections::HashMap;
224
225    let mut topic_map: HashMap<String, Vec<String>> = HashMap::new();
226    for (id, _content, topic) in records {
227        let t = topic.as_deref().unwrap_or("unknown").to_string();
228        topic_map.entry(t).or_default().push(id.clone());
229    }
230
231    topic_map
232        .into_iter()
233        .enumerate()
234        .map(|(idx, (label, members))| {
235            let scored: Vec<(String, f32)> = members.into_iter().map(|m| (m, 1.0)).collect();
236            (idx as u32, label, scored)
237        })
238        .collect()
239}
240
241/// Greedy modularity clustering by word overlap similarity.
242///
243/// Simple O(n²) approach suitable for consolidation batch sizes (<1000 records).
244fn cluster_by_word_overlap(
245    records: &[(String, String, Option<String>)],
246    threshold: f32,
247    max_clusters: usize,
248) -> Vec<(u32, String, Vec<(String, f32)>)> {
249    if records.is_empty() {
250        return Vec::new();
251    }
252
253    // Tokenize content into word sets.
254    let word_sets: Vec<std::collections::HashSet<&str>> = records
255        .iter()
256        .map(|(_, content, _)| {
257            content
258                .split_whitespace()
259                .map(|w| w.trim_matches(|c: char| !c.is_alphanumeric()))
260                .filter(|w| w.len() > 2)
261                .collect()
262        })
263        .collect();
264
265    // Greedy clustering: assign each record to the first cluster above threshold.
266    let mut clusters: Vec<Vec<usize>> = Vec::new();
267
268    for i in 0..records.len() {
269        let mut best_cluster = None;
270        let mut best_sim = 0.0_f32;
271
272        for (c_idx, cluster) in clusters.iter().enumerate() {
273            // Average similarity to cluster centroid (first member).
274            let centroid = cluster[0];
275            let sim = jaccard_similarity(&word_sets[i], &word_sets[centroid]);
276            if sim > threshold && sim > best_sim {
277                best_sim = sim;
278                best_cluster = Some(c_idx);
279            }
280        }
281
282        if let Some(c_idx) = best_cluster {
283            clusters[c_idx].push(i);
284        } else if clusters.len() < max_clusters {
285            clusters.push(vec![i]);
286        } else {
287            // Assign to most similar cluster even if below threshold.
288            let closest = clusters
289                .iter()
290                .enumerate()
291                .map(|(idx, c)| {
292                    let sim = jaccard_similarity(&word_sets[i], &word_sets[c[0]]);
293                    (idx, sim)
294                })
295                .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
296                .map(|(idx, _)| idx)
297                .unwrap_or(0);
298            clusters[closest].push(i);
299        }
300    }
301
302    // Build output with labels derived from first record's content prefix.
303    clusters
304        .into_iter()
305        .enumerate()
306        .map(|(idx, member_indices)| {
307            let label = records
308                .get(member_indices[0])
309                .map(|(_, content, _)| content.chars().take(40).collect::<String>())
310                .unwrap_or_else(|| format!("cluster_{idx}"));
311
312            let members: Vec<(String, f32)> = member_indices
313                .iter()
314                .map(|&mi| {
315                    let sim = if mi == member_indices[0] {
316                        1.0
317                    } else {
318                        jaccard_similarity(&word_sets[member_indices[0]], &word_sets[mi])
319                    };
320                    (records[mi].0.clone(), sim)
321                })
322                .collect();
323
324            (idx as u32, label, members)
325        })
326        .collect()
327}
328
329fn jaccard_similarity(
330    a: &std::collections::HashSet<&str>,
331    b: &std::collections::HashSet<&str>,
332) -> f32 {
333    if a.is_empty() && b.is_empty() {
334        return 0.0;
335    }
336    let intersection = a.intersection(b).count();
337    let union = a.union(b).count();
338    if union == 0 {
339        0.0
340    } else {
341        intersection as f32 / union as f32
342    }
343}