Skip to main content

hirn_exec/operators/
graph_activation.rs

1//! `GraphActivationExec` — graph activation as a DataFusion operator.
2//!
3//! Runs real static activation, spreading activation, or PPR through the
4//! authoritative graph runtime. When the child produces standardized recall
5//! rows, this operator preserves that row shape and hydrates graph-expanded
6//! neighbors back into recall rows while appending `activation_score` and
7//! `depth`.
8
9use std::any::Any;
10use std::collections::HashMap;
11use std::fmt;
12use std::sync::Arc;
13
14use arrow_array::{
15    Array, ArrayRef, Float32Array, Int64Array, RecordBatch, StringArray, UInt32Array, UInt64Array,
16};
17use arrow_schema::{DataType, Field, Schema, SchemaRef};
18use datafusion_common::Result;
19use datafusion_execution::{SendableRecordBatchStream, TaskContext};
20use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
21use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
22use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23
24use hirn_core::id::MemoryId;
25use hirn_core::types::Namespace;
26use hirn_graph::ActivationConfig;
27#[cfg(test)]
28use hirn_graph::PropertyGraph;
29#[cfg(test)]
30use parking_lot::RwLock;
31
32use crate::extensions::HirnSessionExt;
33use crate::operators::lance_hybrid_search::{RecallRow, fetch_recall_rows_by_ids};
34
35/// Activation mode for the graph traversal.
36#[derive(Debug, Clone, Copy)]
37pub enum ActivationMode {
38    /// One-hop static neighborhood expansion.
39    Static,
40    /// Full spreading activation with lateral inhibition.
41    Spreading,
42    /// Personalized PageRank — random-walk-based retrieval.
43    Ppr,
44}
45
46/// DataFusion physical operator that runs graph activation through the runtime.
47///
48/// Input: child plan providing seed node IDs (column `node_id: Utf8` or `id: Utf8`).
49/// Output: `node_id (Utf8)`, `activation_score (Float32)`, `depth (UInt32)`.
50///
51/// Retrieves the graph-read runtime from `HirnSessionExt` via the `TaskContext`
52/// config extensions and fails if that runtime is not registered.
53#[derive(Debug)]
54pub struct GraphActivationExec {
55    input: Arc<dyn ExecutionPlan>,
56    schema: SchemaRef,
57    properties: PlanProperties,
58    seed_limit: usize,
59    mode: ActivationMode,
60    max_depth: u32,
61    epsilon: f32,
62    inhibition_mu: f32,
63    preserve_recall_rows: bool,
64}
65
66impl GraphActivationExec {
67    pub fn new(
68        input: Arc<dyn ExecutionPlan>,
69        seed_limit: usize,
70        mode: ActivationMode,
71        max_depth: u32,
72        epsilon: f32,
73        inhibition_mu: f32,
74    ) -> Result<Self> {
75        let seed_limit = seed_limit.max(1);
76        let config = ActivationConfig {
77            max_depth: max_depth as usize,
78            epsilon: f64::from(epsilon),
79            inhibition_strength: f64::from(inhibition_mu),
80            ..Default::default()
81        };
82        config.validate().map_err(|error| {
83            datafusion_common::DataFusionError::Execution(format!(
84                "invalid graph activation config: {error}"
85            ))
86        })?;
87
88        let preserve_recall_rows = supports_recall_row_passthrough(input.schema().as_ref());
89        let schema = if preserve_recall_rows {
90            recall_activation_schema(input.schema())
91        } else {
92            Self::output_schema()
93        };
94        let properties = PlanProperties::new(
95            datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
96            datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
97            // N-M18: operator collects all results into a single batch before emitting;
98            // declare Final not Incremental to match actual emission semantics.
99            EmissionType::Final,
100            Boundedness::Bounded,
101        );
102        Ok(Self {
103            input,
104            schema,
105            properties,
106            seed_limit,
107            mode,
108            max_depth,
109            epsilon,
110            inhibition_mu,
111            preserve_recall_rows,
112        })
113    }
114
115    /// Output schema: `(node_id, activation_score, depth)`.
116    pub fn output_schema() -> SchemaRef {
117        Arc::new(Schema::new(vec![
118            Field::new("node_id", DataType::Utf8, false),
119            Field::new("activation_score", DataType::Float32, false),
120            Field::new("depth", DataType::UInt32, false),
121        ]))
122    }
123
124    pub fn mode(&self) -> ActivationMode {
125        self.mode
126    }
127
128    pub fn seed_limit(&self) -> usize {
129        self.seed_limit
130    }
131
132    pub fn max_depth(&self) -> u32 {
133        self.max_depth
134    }
135
136    pub fn epsilon(&self) -> f32 {
137        self.epsilon
138    }
139
140    pub fn inhibition_mu(&self) -> f32 {
141        self.inhibition_mu
142    }
143
144    pub fn preserves_recall_rows(&self) -> bool {
145        self.preserve_recall_rows
146    }
147}
148
149impl DisplayAs for GraphActivationExec {
150    fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        write!(
152            f,
153            "GraphActivationExec: seed_limit={}, mode={:?}, depth={}, ε={}, µ={}",
154            self.seed_limit, self.mode, self.max_depth, self.epsilon, self.inhibition_mu
155        )
156    }
157}
158
159impl ExecutionPlan for GraphActivationExec {
160    fn name(&self) -> &str {
161        "GraphActivationExec"
162    }
163
164    fn as_any(&self) -> &dyn Any {
165        self
166    }
167
168    fn schema(&self) -> SchemaRef {
169        self.schema.clone()
170    }
171
172    fn properties(&self) -> &PlanProperties {
173        &self.properties
174    }
175
176    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
177        vec![&self.input]
178    }
179
180    fn with_new_children(
181        self: Arc<Self>,
182        children: Vec<Arc<dyn ExecutionPlan>>,
183    ) -> Result<Arc<dyn ExecutionPlan>> {
184        let [child]: [Arc<dyn ExecutionPlan>; 1] = children.try_into().map_err(|v: Vec<_>| {
185            datafusion_common::DataFusionError::Plan(format!(
186                "GraphActivationExec requires exactly 1 child, got {}",
187                v.len()
188            ))
189        })?;
190        Ok(Arc::new(Self::new(
191            child,
192            self.seed_limit,
193            self.mode,
194            self.max_depth,
195            self.epsilon,
196            self.inhibition_mu,
197        )?))
198    }
199
200    fn execute(
201        &self,
202        partition: usize,
203        context: Arc<TaskContext>,
204    ) -> Result<SendableRecordBatchStream> {
205        let input = self.input.execute(partition, context.clone())?;
206        let schema = self.schema.clone();
207        let stream_schema = schema.clone();
208        let max_depth = self.max_depth;
209        let epsilon = self.epsilon;
210        let inhibition_mu = self.inhibition_mu;
211        let mode = self.mode;
212        let preserve_recall_rows = self.preserve_recall_rows;
213        let seed_limit = self.seed_limit;
214
215        let session_ext = context
216            .session_config()
217            .options()
218            .extensions
219            .get::<HirnSessionExt>()
220            .cloned();
221        let graph_read_runtime = session_ext
222            .as_ref()
223            .and_then(|ext| ext.graph_read_runtime());
224        let storage = session_ext.as_ref().and_then(|ext| ext.storage_arc());
225        let delegation_threshold = session_ext
226            .as_ref()
227            .map(|ext| ext.config.graph_depth_delegation_threshold)
228            .unwrap_or(usize::MAX);
229        let allowed_namespaces = session_ext.as_ref().and_then(|ext| {
230            ext.allowed_namespaces().map(|namespaces| {
231                namespaces
232                    .iter()
233                    .filter_map(|namespace| Namespace::new(namespace).ok())
234                    .collect::<Vec<_>>()
235            })
236        });
237
238        let fut = async move {
239            use futures::StreamExt;
240
241            let mut seed_strings = Vec::new();
242            let mut passthrough_rows = if preserve_recall_rows {
243                Some(RecallPassthroughRows::default())
244            } else {
245                None
246            };
247            let mut stream = input;
248            while let Some(batch) = stream.next().await {
249                let batch = batch?;
250                if let Some(rows) = passthrough_rows.as_mut() {
251                    accumulate_recall_rows(rows, &batch).map_err(|error| {
252                        datafusion_common::DataFusionError::Execution(error.to_string())
253                    })?;
254                }
255
256                if seed_strings.len() < seed_limit {
257                    let col = batch
258                        .column_by_name("node_id")
259                        .or_else(|| batch.column_by_name("id"));
260                    if let Some(col) = col {
261                        if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
262                            for i in 0..arr.len() {
263                                if seed_strings.len() >= seed_limit {
264                                    break;
265                                }
266                                if !arr.is_null(i) {
267                                    seed_strings.push(arr.value(i).to_string());
268                                }
269                            }
270                        }
271                    }
272                }
273
274                if !preserve_recall_rows && seed_strings.len() >= seed_limit {
275                    break;
276                }
277            }
278
279            if seed_strings.is_empty() {
280                let empty = RecordBatch::new_empty(schema);
281                return Ok(empty);
282            }
283
284            // 2. Parse MemoryIds, logging failures.
285            let mut seeds = Vec::with_capacity(seed_strings.len());
286            let mut parse_failures = 0_usize;
287            let mut first_errors: Vec<String> = Vec::new();
288            for s in &seed_strings {
289                match MemoryId::parse(s) {
290                    Ok(id) => seeds.push(id),
291                    Err(e) => {
292                        parse_failures += 1;
293                        if first_errors.len() < 3 {
294                            first_errors.push(format!("{s}: {e}"));
295                        }
296                        tracing::warn!(
297                            seed = %s,
298                            "GraphActivationExec: failed to parse seed MemoryId, skipping"
299                        );
300                    }
301                }
302            }
303
304            if seeds.is_empty() {
305                // All seeds failed to parse — this is an error, not a quiet empty result.
306                return Err(datafusion_common::DataFusionError::Execution(format!(
307                    "GraphActivationExec: all {} seed IDs failed to parse (first errors: {})",
308                    parse_failures,
309                    first_errors.join("; ")
310                )));
311            }
312
313            // 3. Run activation on the authoritative graph runtime.
314            let Some(runtime) = graph_read_runtime else {
315                return Err(datafusion_common::DataFusionError::Execution(
316                    "GraphActivationExec requires HirnSessionExt graph runtime".to_string(),
317                ));
318            };
319            let (ids, scores, depths) = {
320                let output = runtime
321                    .activate_graph(
322                        &seeds,
323                        mode,
324                        None,
325                        max_depth,
326                        epsilon,
327                        inhibition_mu,
328                        delegation_threshold,
329                        allowed_namespaces.as_deref(),
330                    )
331                    .await
332                    .map_err(|error| {
333                        datafusion_common::DataFusionError::Execution(error.to_string())
334                    })?;
335                (output.ids, output.scores, output.depths)
336            };
337
338            if ids.is_empty() {
339                return Ok(RecordBatch::new_empty(schema));
340            }
341
342            if preserve_recall_rows {
343                return build_recall_activation_output_batch(
344                    schema,
345                    passthrough_rows.unwrap_or_default(),
346                    storage.as_deref(),
347                    &ids,
348                    &scores,
349                    &depths,
350                )
351                .await
352                .map_err(|error| datafusion_common::DataFusionError::Execution(error.to_string()));
353            }
354
355            let id_refs: Vec<&str> = ids.iter().map(String::as_str).collect();
356            RecordBatch::try_new(
357                schema,
358                vec![
359                    Arc::new(StringArray::from(id_refs)),
360                    Arc::new(Float32Array::from(scores)),
361                    Arc::new(UInt32Array::from(depths)),
362                ],
363            )
364            .map_err(Into::into)
365        };
366
367        let stream = futures::stream::once(fut);
368        Ok(Box::pin(RecordBatchStreamAdapter::new(
369            stream_schema,
370            stream,
371        )))
372    }
373}
374
375/// Run spreading activation or PPR on the property graph and return flattened results.
376#[cfg(test)]
377fn run_activation(
378    graph: &PropertyGraph,
379    seeds: &[MemoryId],
380    mode: ActivationMode,
381    max_depth: u32,
382    epsilon: f32,
383    inhibition_mu: f32,
384    allowed_namespaces: Option<&[Namespace]>,
385) -> (Vec<String>, Vec<f32>, Vec<u32>) {
386    let base_config = ActivationConfig {
387        max_depth: max_depth as usize,
388        epsilon: f64::from(epsilon),
389        inhibition_strength: f64::from(inhibition_mu),
390        ..Default::default()
391    };
392    // F-103: scale frontier cap to observed graph density to avoid building
393    // 100 K-entry heaps on hub nodes in large graphs.
394    let config = base_config.tuned_for_graph(graph.node_count(), graph.edge_count());
395
396    let mut ids = Vec::new();
397    let mut scores = Vec::new();
398    let mut depths = Vec::new();
399
400    match mode {
401        ActivationMode::Static => {
402            let mut entries: Vec<_> =
403                hirn_graph::static_activation(graph, seeds, allowed_namespaces)
404                    .into_iter()
405                    .collect();
406            entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
407
408            for (node_id, score) in entries {
409                ids.push(node_id.to_string());
410                scores.push(score as f32);
411                depths.push(u32::from(!seeds.contains(&node_id)));
412            }
413        }
414        ActivationMode::Spreading => {
415            let result =
416                hirn_graph::spread_activation(graph, seeds, &config, None, allowed_namespaces)
417                    .expect("test activation config should be valid");
418            let mut entries: Vec<_> = result.activations.into_iter().collect();
419            entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
420
421            for (node_id, score) in entries {
422                let depth = result
423                    .traces
424                    .get(&node_id)
425                    .map(|t| t.path.len().saturating_sub(1) as u32)
426                    .unwrap_or(0);
427                ids.push(node_id.to_string());
428                scores.push(score as f32);
429                depths.push(depth);
430            }
431        }
432        ActivationMode::Ppr => {
433            let ppr_config = hirn_graph::PprConfig::default();
434            let activations =
435                hirn_graph::personalized_pagerank(graph, seeds, &ppr_config, allowed_namespaces)
436                    .expect("default PPR config should be valid");
437            let mut entries: Vec<_> = activations.into_iter().collect();
438            entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
439
440            for (node_id, score) in entries {
441                ids.push(node_id.to_string());
442                scores.push(score as f32);
443                depths.push(0); // PPR doesn't track depth.
444            }
445        }
446    }
447
448    (ids, scores, depths)
449}
450
451fn supports_recall_row_passthrough(schema: &Schema) -> bool {
452    [
453        "id",
454        "content",
455        "layer",
456        "namespace",
457        "score",
458        "temporal_ms",
459        "created_at_ms",
460        "importance",
461        "access_count",
462        "surprise",
463        "evidence_count",
464        "invocation_count",
465    ]
466    .iter()
467    .all(|field| schema.field_with_name(field).is_ok())
468}
469
470/// Canonical output schema for recall-row passthrough mode.
471///
472/// `build_recall_activation_output_batch` always reconstructs the batch from
473/// `RecallRow` structs in a fixed column order — it does NOT pass through
474/// arbitrary input columns.  Therefore the schema is fixed here rather than
475/// derived from `input_schema`, which would silently omit `full_content` when
476/// the upstream batch didn't include it.
477fn recall_activation_schema(_input_schema: SchemaRef) -> SchemaRef {
478    Arc::new(Schema::new(vec![
479        Field::new("id", DataType::Utf8, false),
480        Field::new("content", DataType::Utf8, false),
481        Field::new("full_content", DataType::Utf8, false),
482        Field::new("layer", DataType::Utf8, false),
483        Field::new("namespace", DataType::Utf8, false),
484        Field::new("score", DataType::Float32, false),
485        Field::new("temporal_ms", DataType::Int64, false),
486        Field::new("created_at_ms", DataType::Int64, false),
487        Field::new("importance", DataType::Float32, false),
488        Field::new("access_count", DataType::UInt32, false),
489        Field::new("surprise", DataType::Float32, true),
490        Field::new("evidence_count", DataType::UInt32, true),
491        Field::new("invocation_count", DataType::UInt64, true),
492        Field::new("activation_score", DataType::Float32, false),
493        Field::new("depth", DataType::UInt32, false),
494    ]))
495}
496
497async fn build_recall_activation_output_batch(
498    schema: SchemaRef,
499    mut passthrough_rows: RecallPassthroughRows,
500    storage: Option<&dyn hirn_storage::PhysicalStore>,
501    activated_ids: &[String],
502    activation_scores: &[f32],
503    depths: &[u32],
504) -> Result<RecordBatch, hirn_storage::HirnDbError> {
505    let mut ordered_ids = std::mem::take(&mut passthrough_rows.ordered_ids);
506    let mut base_rows = std::mem::take(&mut passthrough_rows.base_rows);
507
508    let missing_ids = activated_ids
509        .iter()
510        .filter(|id| !base_rows.contains_key(*id))
511        .filter_map(|id| MemoryId::parse(id).ok())
512        .collect::<Vec<_>>();
513
514    if !missing_ids.is_empty() {
515        let Some(storage) = storage else {
516            return Err(hirn_storage::HirnDbError::InvalidArgument(
517                "graph activation recall expansion requires storage access".to_string(),
518            ));
519        };
520        for row in fetch_recall_rows_by_ids(storage, &missing_ids).await? {
521            base_rows.entry(row.id.clone()).or_insert(row);
522        }
523    }
524
525    let activation_by_id = activated_ids
526        .iter()
527        .zip(activation_scores.iter())
528        .zip(depths.iter())
529        .map(|((activated_id, activation_score), depth)| {
530            (activated_id.as_str(), (*activation_score, *depth))
531        })
532        .collect::<HashMap<_, _>>();
533
534    for activated_id in activated_ids {
535        if !ordered_ids.iter().any(|id| id == activated_id) {
536            ordered_ids.push(activated_id.clone());
537        }
538    }
539
540    let mut rows = Vec::with_capacity(ordered_ids.len());
541    let mut activation_values = Vec::with_capacity(ordered_ids.len());
542    let mut depth_values = Vec::with_capacity(ordered_ids.len());
543    for ordered_id in ordered_ids {
544        if let Some(row) = base_rows.get(&ordered_id).cloned() {
545            let (activation_score, depth) = activation_by_id
546                .get(ordered_id.as_str())
547                .copied()
548                .unwrap_or((0.0, 0));
549            rows.push(row);
550            activation_values.push(activation_score);
551            depth_values.push(depth);
552        }
553    }
554
555    if rows.is_empty() {
556        return Ok(RecordBatch::new_empty(schema));
557    }
558
559    let ids = rows.iter().map(|row| row.id.as_str()).collect::<Vec<_>>();
560    let contents = rows
561        .iter()
562        .map(|row| row.content.as_str())
563        .collect::<Vec<_>>();
564    let full_contents = rows
565        .iter()
566        .map(|row| row.full_content.as_str())
567        .collect::<Vec<_>>();
568    let layers = rows.iter().map(|row| row.layer).collect::<Vec<_>>();
569    let namespaces = rows
570        .iter()
571        .map(|row| row.namespace.as_str())
572        .collect::<Vec<_>>();
573    let scores = rows.iter().map(|row| row.score).collect::<Vec<_>>();
574    let temporal = rows.iter().map(|row| row.temporal_ms).collect::<Vec<_>>();
575    let created_at = rows.iter().map(|row| row.created_at_ms).collect::<Vec<_>>();
576    let importances = rows.iter().map(|row| row.importance).collect::<Vec<_>>();
577    let access_counts = rows.iter().map(|row| row.access_count).collect::<Vec<_>>();
578    let surprises = rows.iter().map(|row| row.surprise).collect::<Vec<_>>();
579    let evidence_counts = rows
580        .iter()
581        .map(|row| row.evidence_count)
582        .collect::<Vec<_>>();
583    let invocation_counts = rows
584        .iter()
585        .map(|row| row.invocation_count)
586        .collect::<Vec<_>>();
587
588    RecordBatch::try_new(
589        schema,
590        vec![
591            Arc::new(StringArray::from(ids)) as ArrayRef,
592            Arc::new(StringArray::from(contents)) as ArrayRef,
593            Arc::new(StringArray::from(full_contents)) as ArrayRef,
594            Arc::new(StringArray::from(layers)) as ArrayRef,
595            Arc::new(StringArray::from(namespaces)) as ArrayRef,
596            Arc::new(Float32Array::from(scores)) as ArrayRef,
597            Arc::new(Int64Array::from(temporal)) as ArrayRef,
598            Arc::new(Int64Array::from(created_at)) as ArrayRef,
599            Arc::new(Float32Array::from(importances)) as ArrayRef,
600            Arc::new(UInt32Array::from(access_counts)) as ArrayRef,
601            Arc::new(Float32Array::from(surprises)) as ArrayRef,
602            Arc::new(UInt32Array::from(evidence_counts)) as ArrayRef,
603            Arc::new(UInt64Array::from(invocation_counts)) as ArrayRef,
604            Arc::new(Float32Array::from(activation_values)) as ArrayRef,
605            Arc::new(UInt32Array::from(depth_values)) as ArrayRef,
606        ],
607    )
608    .map_err(hirn_storage::HirnDbError::ArrowError)
609}
610
611#[derive(Debug, Default)]
612struct RecallPassthroughRows {
613    ordered_ids: Vec<String>,
614    base_rows: HashMap<String, RecallRow>,
615}
616
617fn accumulate_recall_rows(
618    rows: &mut RecallPassthroughRows,
619    batch: &RecordBatch,
620) -> Result<(), hirn_storage::HirnDbError> {
621    for row in recall_rows_from_batch(batch)? {
622        let row_id = row.id.clone();
623        if !rows.base_rows.contains_key(&row_id) {
624            rows.ordered_ids.push(row_id.clone());
625        }
626        rows.base_rows.entry(row_id).or_insert(row);
627    }
628
629    Ok(())
630}
631
632fn recall_rows_from_batch(
633    batch: &RecordBatch,
634) -> Result<Vec<RecallRow>, hirn_storage::HirnDbError> {
635    let ids = batch
636        .column_by_name("id")
637        .and_then(|column| column.as_any().downcast_ref::<StringArray>())
638        .ok_or_else(|| {
639            hirn_storage::HirnDbError::InvalidArgument(
640                "graph activation recall passthrough batch is missing `id`".to_string(),
641            )
642        })?;
643    let contents = batch
644        .column_by_name("content")
645        .and_then(|column| column.as_any().downcast_ref::<StringArray>())
646        .ok_or_else(|| {
647            hirn_storage::HirnDbError::InvalidArgument(
648                "graph activation recall passthrough batch is missing `content`".to_string(),
649            )
650        })?;
651    let full_contents = batch
652        .column_by_name("full_content")
653        .and_then(|column| column.as_any().downcast_ref::<StringArray>());
654    let layers = batch
655        .column_by_name("layer")
656        .and_then(|column| column.as_any().downcast_ref::<StringArray>())
657        .ok_or_else(|| {
658            hirn_storage::HirnDbError::InvalidArgument(
659                "graph activation recall passthrough batch is missing `layer`".to_string(),
660            )
661        })?;
662    let namespaces = batch
663        .column_by_name("namespace")
664        .and_then(|column| column.as_any().downcast_ref::<StringArray>())
665        .ok_or_else(|| {
666            hirn_storage::HirnDbError::InvalidArgument(
667                "graph activation recall passthrough batch is missing `namespace`".to_string(),
668            )
669        })?;
670    let scores = batch
671        .column_by_name("score")
672        .and_then(|column| column.as_any().downcast_ref::<Float32Array>())
673        .ok_or_else(|| {
674            hirn_storage::HirnDbError::InvalidArgument(
675                "graph activation recall passthrough batch is missing `score`".to_string(),
676            )
677        })?;
678    let created_at = batch
679        .column_by_name("created_at_ms")
680        .and_then(|column| column.as_any().downcast_ref::<Int64Array>())
681        .ok_or_else(|| {
682            hirn_storage::HirnDbError::InvalidArgument(
683                "graph activation recall passthrough batch is missing `created_at_ms`".to_string(),
684            )
685        })?;
686    let temporal = batch
687        .column_by_name("temporal_ms")
688        .and_then(|column| column.as_any().downcast_ref::<Int64Array>())
689        .unwrap_or(created_at);
690    let importances = batch
691        .column_by_name("importance")
692        .and_then(|column| column.as_any().downcast_ref::<Float32Array>())
693        .ok_or_else(|| {
694            hirn_storage::HirnDbError::InvalidArgument(
695                "graph activation recall passthrough batch is missing `importance`".to_string(),
696            )
697        })?;
698    let access_counts = batch
699        .column_by_name("access_count")
700        .and_then(|column| column.as_any().downcast_ref::<UInt32Array>())
701        .ok_or_else(|| {
702            hirn_storage::HirnDbError::InvalidArgument(
703                "graph activation recall passthrough batch is missing `access_count`".to_string(),
704            )
705        })?;
706    let surprises = batch
707        .column_by_name("surprise")
708        .and_then(|column| column.as_any().downcast_ref::<Float32Array>())
709        .ok_or_else(|| {
710            hirn_storage::HirnDbError::InvalidArgument(
711                "graph activation recall passthrough batch is missing `surprise`".to_string(),
712            )
713        })?;
714    let evidence_counts = batch
715        .column_by_name("evidence_count")
716        .and_then(|column| column.as_any().downcast_ref::<UInt32Array>())
717        .ok_or_else(|| {
718            hirn_storage::HirnDbError::InvalidArgument(
719                "graph activation recall passthrough batch is missing `evidence_count`".to_string(),
720            )
721        })?;
722    let invocation_counts = batch
723        .column_by_name("invocation_count")
724        .and_then(|column| column.as_any().downcast_ref::<UInt64Array>())
725        .ok_or_else(|| {
726            hirn_storage::HirnDbError::InvalidArgument(
727                "graph activation recall passthrough batch is missing `invocation_count`"
728                    .to_string(),
729            )
730        })?;
731
732    let mut rows = Vec::with_capacity(batch.num_rows());
733    for row in 0..batch.num_rows() {
734        rows.push(RecallRow {
735            id: ids.value(row).to_string(),
736            content: contents.value(row).to_string(),
737            full_content: full_contents
738                .map(|fc| fc.value(row).to_string())
739                .unwrap_or_else(|| contents.value(row).to_string()),
740            layer: match layers.value(row) {
741                "episodic" => "episodic",
742                "semantic" => "semantic",
743                "procedural" => "procedural",
744                other => {
745                    return Err(hirn_storage::HirnDbError::InvalidArgument(format!(
746                        "unsupported recall layer `{other}` in graph activation"
747                    )));
748                }
749            },
750            namespace: namespaces.value(row).to_string(),
751            score: if scores.is_null(row) {
752                0.0
753            } else {
754                scores.value(row)
755            },
756            temporal_ms: temporal.value(row),
757            created_at_ms: created_at.value(row),
758            importance: if importances.is_null(row) {
759                0.0
760            } else {
761                importances.value(row)
762            },
763            access_count: if access_counts.is_null(row) {
764                0
765            } else {
766                access_counts.value(row)
767            },
768            surprise: if surprises.is_null(row) {
769                None
770            } else {
771                Some(surprises.value(row))
772            },
773            evidence_count: if evidence_counts.is_null(row) {
774                None
775            } else {
776                Some(evidence_counts.value(row))
777            },
778            invocation_count: if invocation_counts.is_null(row) {
779                None
780            } else {
781                Some(invocation_counts.value(row))
782            },
783        });
784    }
785
786    Ok(rows)
787}
788
789#[cfg(test)]
790mod tests {
791    use std::sync::Mutex;
792
793    use super::*;
794    use arrow_array::{Array, RecordBatch};
795    use async_trait::async_trait;
796    use datafusion::prelude::SessionContext;
797    use datafusion_datasource::memory::MemorySourceConfig;
798    use futures::StreamExt;
799    use hirn_core::HirnResult;
800    use hirn_core::metadata::Metadata;
801    use hirn_core::types::Layer;
802    use hirn_graph::PropertyGraph;
803
804    use crate::{GraphActivationOutput, GraphCausalChainRow, GraphReadRuntime};
805
806    fn seed_batch(ids: &[&str]) -> RecordBatch {
807        RecordBatch::try_new(
808            Arc::new(Schema::new(vec![Field::new(
809                "node_id",
810                DataType::Utf8,
811                false,
812            )])),
813            vec![Arc::new(StringArray::from(ids.to_vec()))],
814        )
815        .unwrap()
816    }
817
818    /// Build a small graph: n1 -> n2 -> n3 (RelatedTo edges).
819    fn build_test_graph() -> (Arc<RwLock<PropertyGraph>>, Vec<MemoryId>) {
820        let mut g = PropertyGraph::new();
821        let ids: Vec<MemoryId> = (0..3).map(|_| MemoryId::new()).collect();
822        let now = hirn_core::timestamp::Timestamp::now();
823        for &id in &ids {
824            g.add_node(id, Layer::Episodic, 0.5, now);
825        }
826        use hirn_core::types::EdgeRelation;
827        g.add_edge(
828            ids[0],
829            ids[1],
830            EdgeRelation::RelatedTo,
831            0.8,
832            Metadata::new(),
833        )
834        .unwrap();
835        g.add_edge(
836            ids[1],
837            ids[2],
838            EdgeRelation::RelatedTo,
839            0.7,
840            Metadata::new(),
841        )
842        .unwrap();
843        (Arc::new(RwLock::new(g)), ids)
844    }
845
846    #[tokio::test]
847    async fn activation_spreads_to_neighbors() {
848        let (graph, ids) = build_test_graph();
849        let id_strs: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
850
851        // Seed only the first node.
852        let batch = seed_batch(&[&id_strs[0]]);
853        let schema = batch.schema();
854        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
855
856        let exec =
857            GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.1).unwrap();
858
859        let ctx = SessionContext::new();
860        register_graph_runtime(graph, &ctx);
861
862        let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
863        let mut all_ids = Vec::new();
864        while let Some(result) = stream.next().await {
865            let batch = result.unwrap();
866            assert_eq!(batch.schema(), GraphActivationExec::output_schema());
867            let node_col = batch
868                .column(0)
869                .as_any()
870                .downcast_ref::<StringArray>()
871                .unwrap();
872            for i in 0..node_col.len() {
873                all_ids.push(node_col.value(i).to_string());
874            }
875        }
876        // Activation should spread from n1 to n2 (and possibly n3).
877        assert!(
878            all_ids.len() >= 2,
879            "should activate seed + at least 1 neighbor, got {} ids: {:?}",
880            all_ids.len(),
881            all_ids
882        );
883        // Seed should be in results.
884        assert!(
885            all_ids.contains(&id_strs[0]),
886            "seed node should be in activation results"
887        );
888    }
889
890    #[tokio::test]
891    async fn missing_graph_runtime_returns_error() {
892        let id = MemoryId::new();
893        let id_str = id.to_string();
894        let batch = seed_batch(&[&id_str]);
895        let schema = batch.schema();
896        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
897
898        let exec =
899            GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.1).unwrap();
900        let ctx = SessionContext::new();
901        let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
902
903        let err = stream.next().await.unwrap().unwrap_err().to_string();
904        assert!(
905            err.contains("requires HirnSessionExt graph runtime"),
906            "expected missing graph runtime error, got: {err}"
907        );
908    }
909
910    #[tokio::test]
911    async fn all_invalid_seeds_returns_error() {
912        let batch = seed_batch(&["not-a-valid-ulid"]);
913        let schema = batch.schema();
914        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
915
916        let exec =
917            GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.1).unwrap();
918        let ctx = SessionContext::new();
919        let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
920
921        let result = stream.next().await.unwrap();
922        assert!(result.is_err(), "all invalid seeds should produce an error");
923        let err = result.unwrap_err().to_string();
924        assert!(
925            err.contains("failed to parse"),
926            "error should mention parse failure: {err}"
927        );
928    }
929
930    #[test]
931    fn output_schema_correct() {
932        let schema = GraphActivationExec::output_schema();
933        assert_eq!(schema.fields().len(), 3);
934        assert_eq!(schema.field(0).name(), "node_id");
935        assert_eq!(schema.field(1).name(), "activation_score");
936        assert_eq!(schema.field(2).name(), "depth");
937    }
938
939    struct LocalGraphReadRuntime {
940        graph: Arc<RwLock<PropertyGraph>>,
941    }
942
943    #[async_trait]
944    impl GraphReadRuntime for LocalGraphReadRuntime {
945        async fn activate_graph(
946            &self,
947            seeds: &[MemoryId],
948            mode: ActivationMode,
949            ppr_config: Option<&hirn_graph::PprConfig>,
950            max_depth: u32,
951            epsilon: f32,
952            inhibition_mu: f32,
953            _delegation_threshold: usize,
954            allowed_namespaces: Option<&[Namespace]>,
955        ) -> HirnResult<GraphActivationOutput> {
956            let graph = self.graph.read();
957            let (ids, scores, depths) = match mode {
958                ActivationMode::Ppr => {
959                    let default_ppr = hirn_graph::PprConfig::default();
960                    let ppr_config = ppr_config.unwrap_or(&default_ppr);
961                    let activations = hirn_graph::personalized_pagerank(
962                        &graph,
963                        seeds,
964                        ppr_config,
965                        allowed_namespaces,
966                    )
967                    .expect("test PPR config should be valid");
968                    let mut entries: Vec<_> = activations.into_iter().collect();
969                    entries
970                        .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
971                    (
972                        entries
973                            .iter()
974                            .map(|(node_id, _)| node_id.to_string())
975                            .collect(),
976                        entries.iter().map(|(_, score)| *score as f32).collect(),
977                        vec![0; entries.len()],
978                    )
979                }
980                _ => run_activation(
981                    &graph,
982                    seeds,
983                    mode,
984                    max_depth,
985                    epsilon,
986                    inhibition_mu,
987                    allowed_namespaces,
988                ),
989            };
990            Ok(GraphActivationOutput {
991                ids,
992                scores,
993                depths,
994            })
995        }
996
997        async fn causal_chain(
998            &self,
999            _start_ids: &[MemoryId],
1000            _max_depth: u32,
1001            _confidence_threshold: f32,
1002            _delegation_threshold: usize,
1003            _relation: hirn_core::types::EdgeRelation,
1004            _allowed_namespaces: Option<&[Namespace]>,
1005        ) -> HirnResult<Vec<GraphCausalChainRow>> {
1006            Ok(Vec::new())
1007        }
1008
1009        async fn traverse_graph(
1010            &self,
1011            _start_ids: &[MemoryId],
1012            _max_depth: u32,
1013            _delegation_threshold: usize,
1014            _relation_filter: Option<&[hirn_core::types::EdgeRelation]>,
1015            _allowed_namespaces: Option<&[Namespace]>,
1016        ) -> HirnResult<Vec<crate::GraphTraverseRow>> {
1017            Ok(Vec::new())
1018        }
1019    }
1020
1021    fn register_graph_runtime(graph: Arc<RwLock<PropertyGraph>>, ctx: &SessionContext) {
1022        let config = hirn_core::HirnConfig::builder()
1023            .db_path(std::path::Path::new("/tmp/test"))
1024            .build()
1025            .unwrap();
1026        HirnSessionExt::new(
1027            graph.clone() as Arc<dyn Any + Send + Sync>,
1028            Arc::new(config),
1029            None,
1030        )
1031        .with_graph_read_runtime(Arc::new(LocalGraphReadRuntime { graph }))
1032        .register(ctx)
1033        .expect("register should succeed");
1034    }
1035
1036    #[derive(Debug)]
1037    struct MockGraphReadRuntime {
1038        output: GraphActivationOutput,
1039    }
1040
1041    #[async_trait]
1042    impl crate::GraphReadRuntime for MockGraphReadRuntime {
1043        async fn activate_graph(
1044            &self,
1045            _seeds: &[MemoryId],
1046            _mode: ActivationMode,
1047            _ppr_config: Option<&hirn_graph::PprConfig>,
1048            _max_depth: u32,
1049            _epsilon: f32,
1050            _inhibition_mu: f32,
1051            _delegation_threshold: usize,
1052            _allowed_namespaces: Option<&[Namespace]>,
1053        ) -> HirnResult<GraphActivationOutput> {
1054            Ok(self.output.clone())
1055        }
1056
1057        async fn causal_chain(
1058            &self,
1059            _start_ids: &[MemoryId],
1060            _max_depth: u32,
1061            _confidence_threshold: f32,
1062            _delegation_threshold: usize,
1063            _relation: hirn_core::types::EdgeRelation,
1064            _allowed_namespaces: Option<&[Namespace]>,
1065        ) -> HirnResult<Vec<GraphCausalChainRow>> {
1066            Ok(Vec::new())
1067        }
1068
1069        async fn traverse_graph(
1070            &self,
1071            _start_ids: &[MemoryId],
1072            _max_depth: u32,
1073            _delegation_threshold: usize,
1074            _relation_filter: Option<&[hirn_core::types::EdgeRelation]>,
1075            _allowed_namespaces: Option<&[Namespace]>,
1076        ) -> HirnResult<Vec<crate::GraphTraverseRow>> {
1077            Ok(Vec::new())
1078        }
1079    }
1080
1081    #[derive(Debug)]
1082    struct RecordingGraphReadRuntime {
1083        seen_seeds: Arc<Mutex<Vec<MemoryId>>>,
1084    }
1085
1086    #[async_trait]
1087    impl crate::GraphReadRuntime for RecordingGraphReadRuntime {
1088        async fn activate_graph(
1089            &self,
1090            seeds: &[MemoryId],
1091            _mode: ActivationMode,
1092            _ppr_config: Option<&hirn_graph::PprConfig>,
1093            _max_depth: u32,
1094            _epsilon: f32,
1095            _inhibition_mu: f32,
1096            _delegation_threshold: usize,
1097            _allowed_namespaces: Option<&[Namespace]>,
1098        ) -> HirnResult<GraphActivationOutput> {
1099            *self.seen_seeds.lock().expect("lock should succeed") = seeds.to_vec();
1100            Ok(GraphActivationOutput {
1101                ids: seeds.iter().map(ToString::to_string).collect(),
1102                scores: vec![1.0; seeds.len()],
1103                depths: vec![0; seeds.len()],
1104            })
1105        }
1106
1107        async fn causal_chain(
1108            &self,
1109            _start_ids: &[MemoryId],
1110            _max_depth: u32,
1111            _confidence_threshold: f32,
1112            _delegation_threshold: usize,
1113            _relation: hirn_core::types::EdgeRelation,
1114            _allowed_namespaces: Option<&[Namespace]>,
1115        ) -> HirnResult<Vec<GraphCausalChainRow>> {
1116            Ok(Vec::new())
1117        }
1118
1119        async fn traverse_graph(
1120            &self,
1121            _start_ids: &[MemoryId],
1122            _max_depth: u32,
1123            _delegation_threshold: usize,
1124            _relation_filter: Option<&[hirn_core::types::EdgeRelation]>,
1125            _allowed_namespaces: Option<&[Namespace]>,
1126        ) -> HirnResult<Vec<crate::GraphTraverseRow>> {
1127            Ok(Vec::new())
1128        }
1129    }
1130
1131    #[tokio::test]
1132    async fn prefers_registered_graph_read_runtime() {
1133        let id = MemoryId::new();
1134        let id_str = id.to_string();
1135        let batch = seed_batch(&[&id_str]);
1136        let schema = batch.schema();
1137        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1138
1139        let exec =
1140            GraphActivationExec::new(input, 10, ActivationMode::Spreading, 6, 0.001, 0.1).unwrap();
1141        let ctx = SessionContext::new();
1142        let config = hirn_core::HirnConfig::builder()
1143            .db_path(std::path::Path::new("/tmp/test"))
1144            .build()
1145            .unwrap();
1146
1147        HirnSessionExt::new(
1148            Arc::new(()) as Arc<dyn Any + Send + Sync>,
1149            Arc::new(config),
1150            None,
1151        )
1152        .with_graph_read_runtime(Arc::new(MockGraphReadRuntime {
1153            output: GraphActivationOutput {
1154                ids: vec![id_str.clone()],
1155                scores: vec![0.42],
1156                depths: vec![6],
1157            },
1158        }))
1159        .register(&ctx)
1160        .expect("register should succeed");
1161
1162        let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
1163        let result = stream.next().await.unwrap().unwrap();
1164        let scores = result
1165            .column(1)
1166            .as_any()
1167            .downcast_ref::<Float32Array>()
1168            .unwrap();
1169        let depths = result
1170            .column(2)
1171            .as_any()
1172            .downcast_ref::<UInt32Array>()
1173            .unwrap();
1174
1175        assert!((scores.value(0) - 0.42).abs() < f32::EPSILON);
1176        assert_eq!(depths.value(0), 6);
1177    }
1178
1179    #[tokio::test]
1180    async fn ppr_mode_returns_different_ranking_than_spreading() {
1181        let (graph, ids) = build_test_graph();
1182        let id_strs: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
1183
1184        // Run spreading mode.
1185        let batch = seed_batch(&[&id_strs[0]]);
1186        let schema = batch.schema();
1187        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1188        let exec_spread =
1189            GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.0).unwrap();
1190        let ctx_s = SessionContext::new();
1191        register_graph_runtime(graph.clone(), &ctx_s);
1192
1193        let mut stream = exec_spread.execute(0, ctx_s.task_ctx()).unwrap();
1194        let batch_s = stream.next().await.unwrap().unwrap();
1195        let scores_s = batch_s
1196            .column(1)
1197            .as_any()
1198            .downcast_ref::<Float32Array>()
1199            .unwrap();
1200        let spread_scores: Vec<f32> = (0..scores_s.len()).map(|i| scores_s.value(i)).collect();
1201
1202        // Run PPR mode.
1203        let batch = seed_batch(&[&id_strs[0]]);
1204        let schema = batch.schema();
1205        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1206        let exec_ppr =
1207            GraphActivationExec::new(input, 10, ActivationMode::Ppr, 3, 0.001, 0.0).unwrap();
1208        let ctx_p = SessionContext::new();
1209        register_graph_runtime(graph, &ctx_p);
1210
1211        let mut stream = exec_ppr.execute(0, ctx_p.task_ctx()).unwrap();
1212        let batch_p = stream.next().await.unwrap().unwrap();
1213        let scores_p = batch_p
1214            .column(1)
1215            .as_any()
1216            .downcast_ref::<Float32Array>()
1217            .unwrap();
1218        let ppr_scores: Vec<f32> = (0..scores_p.len()).map(|i| scores_p.value(i)).collect();
1219
1220        // Both should return results, but scores should differ.
1221        assert!(
1222            !spread_scores.is_empty() && !ppr_scores.is_empty(),
1223            "both modes should return results"
1224        );
1225        // PPR produces different score distributions than spreading activation.
1226        // At minimum they shouldn't be identical (different algorithms).
1227        assert_ne!(
1228            spread_scores, ppr_scores,
1229            "PPR and spreading should produce different score vectors"
1230        );
1231    }
1232
1233    #[tokio::test]
1234    async fn lateral_inhibition_suppresses_competing_cluster() {
1235        // Build two clusters connected to a central node:
1236        // n1 → center ← n2, center → n3, center → n4
1237        // With inhibition, activating n1 should suppress n2's contribution.
1238        let mut g = PropertyGraph::new();
1239        let ids: Vec<MemoryId> = (0..5).map(|_| MemoryId::new()).collect();
1240        let now = hirn_core::timestamp::Timestamp::now();
1241        for &id in &ids {
1242            g.add_node(id, Layer::Episodic, 0.5, now);
1243        }
1244        use hirn_core::types::EdgeRelation;
1245        // Cluster A: ids[0] → ids[2] (center)
1246        g.add_edge(
1247            ids[0],
1248            ids[2],
1249            EdgeRelation::RelatedTo,
1250            0.9,
1251            Metadata::new(),
1252        )
1253        .unwrap();
1254        // Cluster B: ids[1] → ids[2] (center)
1255        g.add_edge(
1256            ids[1],
1257            ids[2],
1258            EdgeRelation::RelatedTo,
1259            0.9,
1260            Metadata::new(),
1261        )
1262        .unwrap();
1263        // Center outgoing: ids[2] → ids[3], ids[2] → ids[4]
1264        g.add_edge(
1265            ids[2],
1266            ids[3],
1267            EdgeRelation::RelatedTo,
1268            0.8,
1269            Metadata::new(),
1270        )
1271        .unwrap();
1272        g.add_edge(
1273            ids[2],
1274            ids[4],
1275            EdgeRelation::RelatedTo,
1276            0.8,
1277            Metadata::new(),
1278        )
1279        .unwrap();
1280
1281        let graph = Arc::new(RwLock::new(g));
1282        let id_strs: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
1283
1284        // Run WITHOUT inhibition (mu=0.0).
1285        let batch = seed_batch(&[&id_strs[0]]);
1286        let schema = batch.schema();
1287        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1288        let exec =
1289            GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.0).unwrap();
1290        let ctx_no_inh = SessionContext::new();
1291        register_graph_runtime(graph.clone(), &ctx_no_inh);
1292
1293        let mut stream = exec.execute(0, ctx_no_inh.task_ctx()).unwrap();
1294        let batch_no = stream.next().await.unwrap().unwrap();
1295        let scores_no = batch_no
1296            .column(1)
1297            .as_any()
1298            .downcast_ref::<Float32Array>()
1299            .unwrap();
1300        let total_no: f32 = (0..scores_no.len()).map(|i| scores_no.value(i)).sum();
1301
1302        // Run WITH strong inhibition (mu=0.5).
1303        let batch = seed_batch(&[&id_strs[0]]);
1304        let schema = batch.schema();
1305        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1306        let exec =
1307            GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.5).unwrap();
1308        let ctx_inh = SessionContext::new();
1309        register_graph_runtime(graph, &ctx_inh);
1310
1311        let mut stream = exec.execute(0, ctx_inh.task_ctx()).unwrap();
1312        let batch_inh = stream.next().await.unwrap().unwrap();
1313        let scores_inh = batch_inh
1314            .column(1)
1315            .as_any()
1316            .downcast_ref::<Float32Array>()
1317            .unwrap();
1318        let total_inh: f32 = (0..scores_inh.len()).map(|i| scores_inh.value(i)).sum();
1319
1320        // With inhibition, total activation should be lower (inhibition suppresses).
1321        assert!(
1322            total_inh <= total_no,
1323            "inhibition should reduce total activation: {total_inh} should be <= {total_no}"
1324        );
1325    }
1326
1327    #[tokio::test]
1328    async fn mixed_valid_and_invalid_seeds_processes_valid_ones() {
1329        let (graph, ids) = build_test_graph();
1330        let valid_str = ids[0].to_string();
1331        // Mix one valid ULID with one garbage string.
1332        let batch = seed_batch(&[&valid_str, "not-a-valid-ulid"]);
1333        let schema = batch.schema();
1334        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1335
1336        let exec =
1337            GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.1).unwrap();
1338        let ctx = SessionContext::new();
1339        register_graph_runtime(graph, &ctx);
1340
1341        let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
1342        let result = stream.next().await.unwrap().unwrap();
1343        // The valid seed should be processed; the invalid seed is skipped with a warning.
1344        assert!(
1345            result.num_rows() >= 1,
1346            "valid seed should produce activation results"
1347        );
1348        let node_col = result
1349            .column(0)
1350            .as_any()
1351            .downcast_ref::<StringArray>()
1352            .unwrap();
1353        let result_ids: Vec<&str> = (0..node_col.len()).map(|i| node_col.value(i)).collect();
1354        assert!(
1355            result_ids.contains(&valid_str.as_str()),
1356            "valid seed should appear in results"
1357        );
1358    }
1359
1360    #[tokio::test]
1361    async fn respects_seed_limit_before_graph_activation() {
1362        let ids: Vec<MemoryId> = (0..3).map(|_| MemoryId::new()).collect();
1363        let id_strs: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
1364        let batch = seed_batch(&[&id_strs[0], &id_strs[1], &id_strs[2]]);
1365        let schema = batch.schema();
1366        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1367
1368        let exec =
1369            GraphActivationExec::new(input, 2, ActivationMode::Spreading, 3, 0.001, 0.1).unwrap();
1370        let seen_seeds = Arc::new(Mutex::new(Vec::new()));
1371        let ctx = SessionContext::new();
1372        let config = hirn_core::HirnConfig::builder()
1373            .db_path(std::path::Path::new("/tmp/test"))
1374            .build()
1375            .unwrap();
1376
1377        HirnSessionExt::new(
1378            Arc::new(()) as Arc<dyn Any + Send + Sync>,
1379            Arc::new(config),
1380            None,
1381        )
1382        .with_graph_read_runtime(Arc::new(RecordingGraphReadRuntime {
1383            seen_seeds: seen_seeds.clone(),
1384        }))
1385        .register(&ctx)
1386        .expect("register should succeed");
1387
1388        let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
1389        let _ = stream.next().await.unwrap().unwrap();
1390
1391        let recorded = seen_seeds.lock().expect("lock should succeed").clone();
1392        assert_eq!(recorded, ids[..2].to_vec());
1393    }
1394
1395    #[tokio::test]
1396    async fn preserve_recall_rows_keeps_nonseed_candidates() {
1397        let ids: Vec<MemoryId> = (0..2).map(|_| MemoryId::new()).collect();
1398        let id_strs: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
1399
1400        let batch = RecordBatch::try_new(
1401            Arc::new(Schema::new(vec![
1402                Field::new("id", DataType::Utf8, false),
1403                Field::new("content", DataType::Utf8, false),
1404                Field::new("layer", DataType::Utf8, false),
1405                Field::new("namespace", DataType::Utf8, false),
1406                Field::new("score", DataType::Float32, false),
1407                Field::new("temporal_ms", DataType::Int64, false),
1408                Field::new("created_at_ms", DataType::Int64, false),
1409                Field::new("importance", DataType::Float32, false),
1410                Field::new("access_count", DataType::UInt32, false),
1411                Field::new("surprise", DataType::Float32, true),
1412                Field::new("evidence_count", DataType::UInt32, true),
1413                Field::new("invocation_count", DataType::UInt64, true),
1414            ])),
1415            vec![
1416                Arc::new(StringArray::from(vec![
1417                    id_strs[0].as_str(),
1418                    id_strs[1].as_str(),
1419                ])),
1420                Arc::new(StringArray::from(vec!["seed", "nonseed candidate"])),
1421                Arc::new(StringArray::from(vec!["episodic", "episodic"])),
1422                Arc::new(StringArray::from(vec!["default", "default"])),
1423                Arc::new(Float32Array::from(vec![0.9, 0.8])),
1424                Arc::new(Int64Array::from(vec![1_i64, 2_i64])),
1425                Arc::new(Int64Array::from(vec![1_i64, 2_i64])),
1426                Arc::new(Float32Array::from(vec![0.7, 0.6])),
1427                Arc::new(UInt32Array::from(vec![1_u32, 1_u32])),
1428                Arc::new(Float32Array::from(vec![Some(0.0_f32), Some(0.0_f32)])),
1429                Arc::new(UInt32Array::from(vec![Some(0_u32), Some(0_u32)])),
1430                Arc::new(UInt64Array::from(vec![Some(0_u64), Some(0_u64)])),
1431            ],
1432        )
1433        .unwrap();
1434
1435        let schema = batch.schema();
1436        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1437        let exec =
1438            GraphActivationExec::new(input, 1, ActivationMode::Spreading, 3, 0.001, 0.1).unwrap();
1439
1440        let seen_seeds = Arc::new(Mutex::new(Vec::new()));
1441        let ctx = SessionContext::new();
1442        let config = hirn_core::HirnConfig::builder()
1443            .db_path(std::path::Path::new("/tmp/test"))
1444            .build()
1445            .unwrap();
1446
1447        HirnSessionExt::new(
1448            Arc::new(()) as Arc<dyn Any + Send + Sync>,
1449            Arc::new(config),
1450            None,
1451        )
1452        .with_graph_read_runtime(Arc::new(RecordingGraphReadRuntime {
1453            seen_seeds: seen_seeds.clone(),
1454        }))
1455        .register(&ctx)
1456        .expect("register should succeed");
1457
1458        let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
1459        let result = stream.next().await.unwrap().unwrap();
1460        let ids = result
1461            .column_by_name("id")
1462            .and_then(|column| column.as_any().downcast_ref::<StringArray>())
1463            .unwrap();
1464        let activation_scores = result
1465            .column_by_name("activation_score")
1466            .and_then(|column| column.as_any().downcast_ref::<Float32Array>())
1467            .unwrap();
1468
1469        let output_ids = (0..ids.len())
1470            .map(|index| ids.value(index).to_string())
1471            .collect::<Vec<_>>();
1472        assert_eq!(output_ids, id_strs);
1473        assert!((activation_scores.value(0) - 1.0).abs() < f32::EPSILON);
1474        assert_eq!(activation_scores.value(1), 0.0);
1475    }
1476
1477    #[test]
1478    fn invalid_config_rejected_at_construction() {
1479        let batch = seed_batch(&["not-used"]);
1480        let schema = batch.schema();
1481        let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1482
1483        let err = GraphActivationExec::new(input, 10, ActivationMode::Spreading, 0, 0.001, 0.1)
1484            .unwrap_err();
1485        assert!(err.to_string().contains("invalid graph activation config"));
1486    }
1487}