hirn_exec/operators/
hebbian_buffer.rs1use std::any::Any;
4use std::fmt;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9use arrow_array::Array;
10use arrow_array::{RecordBatch, StringArray};
11use arrow_schema::SchemaRef;
12use crossbeam_queue::SegQueue;
13use datafusion_common::Result;
14use datafusion_execution::{SendableRecordBatchStream, TaskContext};
15use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
16use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
17use futures::Stream;
18
19pub type CoRetrievalQueue = Arc<SegQueue<(String, String)>>;
21
22const MAX_IDS_FOR_PAIRS: usize = 100;
25
26#[derive(Debug)]
31pub struct HebbianBufferExec {
32 input: Arc<dyn ExecutionPlan>,
33 properties: PlanProperties,
34 queue: CoRetrievalQueue,
35}
36
37impl HebbianBufferExec {
38 pub fn new(input: Arc<dyn ExecutionPlan>, queue: CoRetrievalQueue) -> Self {
39 let schema = input.schema();
40 let properties = PlanProperties::new(
41 datafusion_physical_expr::EquivalenceProperties::new(schema),
42 datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
43 EmissionType::Incremental,
44 Boundedness::Bounded,
45 );
46
47 Self {
48 input,
49 properties,
50 queue,
51 }
52 }
53}
54
55impl DisplayAs for HebbianBufferExec {
56 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 write!(f, "HebbianBufferExec")
58 }
59}
60
61impl ExecutionPlan for HebbianBufferExec {
62 fn name(&self) -> &str {
63 "HebbianBufferExec"
64 }
65
66 fn as_any(&self) -> &dyn Any {
67 self
68 }
69
70 fn schema(&self) -> SchemaRef {
71 self.input.schema()
72 }
73
74 fn properties(&self) -> &PlanProperties {
75 &self.properties
76 }
77
78 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
79 vec![&self.input]
80 }
81
82 fn with_new_children(
83 self: Arc<Self>,
84 children: Vec<Arc<dyn ExecutionPlan>>,
85 ) -> Result<Arc<dyn ExecutionPlan>> {
86 Ok(Arc::new(Self::new(children[0].clone(), self.queue.clone())))
87 }
88
89 fn execute(
90 &self,
91 partition: usize,
92 context: Arc<TaskContext>,
93 ) -> Result<SendableRecordBatchStream> {
94 let input = self.input.execute(partition, context)?;
95 let schema = self.input.schema();
96 let queue = self.queue.clone();
97
98 Ok(Box::pin(HebbianBufferStream {
99 input,
100 schema,
101 queue,
102 }))
103 }
104}
105
106struct HebbianBufferStream {
107 input: SendableRecordBatchStream,
108 schema: SchemaRef,
109 queue: CoRetrievalQueue,
110}
111
112impl Stream for HebbianBufferStream {
113 type Item = Result<RecordBatch>;
114
115 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
116 let queue = self.queue.clone();
117 match Pin::new(&mut self.input).poll_next(cx) {
118 Poll::Ready(Some(Ok(batch))) => {
119 let id_col = batch
121 .column_by_name("id")
122 .or_else(|| batch.column_by_name("node_id"))
123 .or_else(|| batch.column_by_name("memory_id"));
124
125 if let Some(col) = id_col {
126 if let Some(strings) = col.as_any().downcast_ref::<StringArray>() {
127 let total_non_null =
128 (0..strings.len()).filter(|&i| !strings.is_null(i)).count();
129 let ids: Vec<&str> = (0..strings.len())
130 .filter(|&i| !strings.is_null(i))
131 .map(|i| strings.value(i))
132 .take(MAX_IDS_FOR_PAIRS)
133 .collect();
134
135 if total_non_null > MAX_IDS_FOR_PAIRS {
136 tracing::debug!(
137 total = total_non_null,
138 limit = MAX_IDS_FOR_PAIRS,
139 "HebbianBufferExec: truncating co-retrieval IDs to limit"
140 );
141 }
142
143 for i in 0..ids.len() {
145 for j in (i + 1)..ids.len() {
146 queue.push((ids[i].to_string(), ids[j].to_string()));
147 }
148 }
149 }
150 }
151
152 Poll::Ready(Some(Ok(batch)))
153 }
154 other => other,
155 }
156 }
157}
158
159impl datafusion_execution::RecordBatchStream for HebbianBufferStream {
160 fn schema(&self) -> SchemaRef {
161 self.schema.clone()
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use arrow_array::Float32Array;
169 use arrow_schema::{DataType, Field, Schema};
170 use datafusion::prelude::SessionContext;
171 use datafusion_datasource::memory::MemorySourceConfig;
172 use futures::StreamExt;
173
174 fn test_batch(ids: &[&str]) -> RecordBatch {
175 let schema = Arc::new(Schema::new(vec![
176 Field::new("id", DataType::Utf8, false),
177 Field::new("score", DataType::Float32, false),
178 ]));
179 RecordBatch::try_new(
180 schema,
181 vec![
182 Arc::new(StringArray::from(ids.to_vec())),
183 Arc::new(Float32Array::from(vec![1.0; ids.len()])),
184 ],
185 )
186 .unwrap()
187 }
188
189 #[tokio::test]
190 async fn passthrough_rows() {
191 let batch = test_batch(&["m1", "m2", "m3", "m4"]);
192 let schema = batch.schema();
193 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
194 let queue = Arc::new(SegQueue::new());
195
196 let exec = HebbianBufferExec::new(input, queue.clone());
197 let ctx = SessionContext::new();
198 let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
199
200 let mut total = 0;
201 while let Some(result) = stream.next().await {
202 total += result.unwrap().num_rows();
203 }
204 assert_eq!(total, 4, "all rows should pass through");
205 }
206
207 #[tokio::test]
208 async fn pairs_recorded() {
209 let batch = test_batch(&["m1", "m2", "m3"]);
210 let schema = batch.schema();
211 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
212 let queue = Arc::new(SegQueue::new());
213
214 let exec = HebbianBufferExec::new(input, queue.clone());
215 let ctx = SessionContext::new();
216 let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
217
218 while stream.next().await.is_some() {}
219
220 assert_eq!(queue.len(), 3, "should record 3 co-retrieval pairs");
222 }
223
224 #[tokio::test]
225 async fn empty_input_no_pairs() {
226 let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Utf8, false)]));
227 let input = MemorySourceConfig::try_new_exec(&[vec![]], schema, None).unwrap();
228 let queue = Arc::new(SegQueue::new());
229
230 let exec = HebbianBufferExec::new(input, queue.clone());
231 let ctx = SessionContext::new();
232 let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
233
234 while stream.next().await.is_some() {}
235 assert_eq!(queue.len(), 0);
236 }
237}