Skip to main content

hirn_exec/operators/
hebbian_buffer.rs

1//! `HebbianBufferExec` — pass-through operator that records co-retrieval pairs.
2
3use std::any::Any;
4use std::fmt;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9use arrow_array::Array;
10use arrow_array::{RecordBatch, StringArray};
11use arrow_schema::SchemaRef;
12use crossbeam_queue::SegQueue;
13use datafusion_common::Result;
14use datafusion_execution::{SendableRecordBatchStream, TaskContext};
15use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
16use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
17use futures::Stream;
18
19/// Queue for co-retrieval pairs (memory_id_a, memory_id_b).
20pub type CoRetrievalQueue = Arc<SegQueue<(String, String)>>;
21
22/// Maximum number of IDs per batch to consider for pair generation.
23/// C(MAX, 2) = ~5000 pairs, which is reasonable for a single batch.
24const MAX_IDS_FOR_PAIRS: usize = 100;
25
26/// Pass-through operator recording co-retrieval pairs for Hebbian learning.
27///
28/// Input flows through unchanged; side effect: pairs of memory IDs from the
29/// same batch are pushed into a [`SegQueue`].
30#[derive(Debug)]
31pub struct HebbianBufferExec {
32    input: Arc<dyn ExecutionPlan>,
33    properties: PlanProperties,
34    queue: CoRetrievalQueue,
35}
36
37impl HebbianBufferExec {
38    pub fn new(input: Arc<dyn ExecutionPlan>, queue: CoRetrievalQueue) -> Self {
39        let schema = input.schema();
40        let properties = PlanProperties::new(
41            datafusion_physical_expr::EquivalenceProperties::new(schema),
42            datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
43            EmissionType::Incremental,
44            Boundedness::Bounded,
45        );
46
47        Self {
48            input,
49            properties,
50            queue,
51        }
52    }
53}
54
55impl DisplayAs for HebbianBufferExec {
56    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        write!(f, "HebbianBufferExec")
58    }
59}
60
61impl ExecutionPlan for HebbianBufferExec {
62    fn name(&self) -> &str {
63        "HebbianBufferExec"
64    }
65
66    fn as_any(&self) -> &dyn Any {
67        self
68    }
69
70    fn schema(&self) -> SchemaRef {
71        self.input.schema()
72    }
73
74    fn properties(&self) -> &PlanProperties {
75        &self.properties
76    }
77
78    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
79        vec![&self.input]
80    }
81
82    fn with_new_children(
83        self: Arc<Self>,
84        children: Vec<Arc<dyn ExecutionPlan>>,
85    ) -> Result<Arc<dyn ExecutionPlan>> {
86        Ok(Arc::new(Self::new(children[0].clone(), self.queue.clone())))
87    }
88
89    fn execute(
90        &self,
91        partition: usize,
92        context: Arc<TaskContext>,
93    ) -> Result<SendableRecordBatchStream> {
94        let input = self.input.execute(partition, context)?;
95        let schema = self.input.schema();
96        let queue = self.queue.clone();
97
98        Ok(Box::pin(HebbianBufferStream {
99            input,
100            schema,
101            queue,
102        }))
103    }
104}
105
106struct HebbianBufferStream {
107    input: SendableRecordBatchStream,
108    schema: SchemaRef,
109    queue: CoRetrievalQueue,
110}
111
112impl Stream for HebbianBufferStream {
113    type Item = Result<RecordBatch>;
114
115    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
116        let queue = self.queue.clone();
117        match Pin::new(&mut self.input).poll_next(cx) {
118            Poll::Ready(Some(Ok(batch))) => {
119                // Extract memory IDs and record all pairs
120                let id_col = batch
121                    .column_by_name("id")
122                    .or_else(|| batch.column_by_name("node_id"))
123                    .or_else(|| batch.column_by_name("memory_id"));
124
125                if let Some(col) = id_col {
126                    if let Some(strings) = col.as_any().downcast_ref::<StringArray>() {
127                        let total_non_null =
128                            (0..strings.len()).filter(|&i| !strings.is_null(i)).count();
129                        let ids: Vec<&str> = (0..strings.len())
130                            .filter(|&i| !strings.is_null(i))
131                            .map(|i| strings.value(i))
132                            .take(MAX_IDS_FOR_PAIRS)
133                            .collect();
134
135                        if total_non_null > MAX_IDS_FOR_PAIRS {
136                            tracing::debug!(
137                                total = total_non_null,
138                                limit = MAX_IDS_FOR_PAIRS,
139                                "HebbianBufferExec: truncating co-retrieval IDs to limit"
140                            );
141                        }
142
143                        // Record all pairs (combinatorial)
144                        for i in 0..ids.len() {
145                            for j in (i + 1)..ids.len() {
146                                queue.push((ids[i].to_string(), ids[j].to_string()));
147                            }
148                        }
149                    }
150                }
151
152                Poll::Ready(Some(Ok(batch)))
153            }
154            other => other,
155        }
156    }
157}
158
159impl datafusion_execution::RecordBatchStream for HebbianBufferStream {
160    fn schema(&self) -> SchemaRef {
161        self.schema.clone()
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use arrow_array::Float32Array;
169    use arrow_schema::{DataType, Field, Schema};
170    use datafusion::prelude::SessionContext;
171    use datafusion_datasource::memory::MemorySourceConfig;
172    use futures::StreamExt;
173
174    fn test_batch(ids: &[&str]) -> RecordBatch {
175        let schema = Arc::new(Schema::new(vec![
176            Field::new("id", DataType::Utf8, false),
177            Field::new("score", DataType::Float32, false),
178        ]));
179        RecordBatch::try_new(
180            schema,
181            vec![
182                Arc::new(StringArray::from(ids.to_vec())),
183                Arc::new(Float32Array::from(vec![1.0; ids.len()])),
184            ],
185        )
186        .unwrap()
187    }
188
189    #[tokio::test]
190    async fn passthrough_rows() {
191        let batch = test_batch(&["m1", "m2", "m3", "m4"]);
192        let schema = batch.schema();
193        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
194        let queue = Arc::new(SegQueue::new());
195
196        let exec = HebbianBufferExec::new(input, queue.clone());
197        let ctx = SessionContext::new();
198        let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
199
200        let mut total = 0;
201        while let Some(result) = stream.next().await {
202            total += result.unwrap().num_rows();
203        }
204        assert_eq!(total, 4, "all rows should pass through");
205    }
206
207    #[tokio::test]
208    async fn pairs_recorded() {
209        let batch = test_batch(&["m1", "m2", "m3"]);
210        let schema = batch.schema();
211        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
212        let queue = Arc::new(SegQueue::new());
213
214        let exec = HebbianBufferExec::new(input, queue.clone());
215        let ctx = SessionContext::new();
216        let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
217
218        while stream.next().await.is_some() {}
219
220        // 3 IDs → C(3,2) = 3 pairs
221        assert_eq!(queue.len(), 3, "should record 3 co-retrieval pairs");
222    }
223
224    #[tokio::test]
225    async fn empty_input_no_pairs() {
226        let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)]));
227        let input = MemorySourceConfig::try_new_exec(&[vec![]], schema, None).unwrap();
228        let queue = Arc::new(SegQueue::new());
229
230        let exec = HebbianBufferExec::new(input, queue.clone());
231        let ctx = SessionContext::new();
232        let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
233
234        while stream.next().await.is_some() {}
235        assert_eq!(queue.len(), 0);
236    }
237}