1use std::any::Any;
10use std::collections::HashMap;
11use std::fmt;
12use std::sync::Arc;
13
14use arrow_array::{
15 Array, ArrayRef, Float32Array, Int64Array, RecordBatch, StringArray, UInt32Array, UInt64Array,
16};
17use arrow_schema::{DataType, Field, Schema, SchemaRef};
18use datafusion_common::Result;
19use datafusion_execution::{SendableRecordBatchStream, TaskContext};
20use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
21use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
22use datafusion_physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
23
24use hirn_core::id::MemoryId;
25use hirn_core::types::Namespace;
26use hirn_graph::ActivationConfig;
27#[cfg(test)]
28use hirn_graph::PropertyGraph;
29#[cfg(test)]
30use parking_lot::RwLock;
31
32use crate::extensions::HirnSessionExt;
33use crate::operators::lance_hybrid_search::{RecallRow, fetch_recall_rows_by_ids};
34
35#[derive(Debug, Clone, Copy)]
37pub enum ActivationMode {
38 Static,
40 Spreading,
42 Ppr,
44}
45
46#[derive(Debug)]
54pub struct GraphActivationExec {
55 input: Arc<dyn ExecutionPlan>,
56 schema: SchemaRef,
57 properties: PlanProperties,
58 seed_limit: usize,
59 mode: ActivationMode,
60 max_depth: u32,
61 epsilon: f32,
62 inhibition_mu: f32,
63 preserve_recall_rows: bool,
64}
65
66impl GraphActivationExec {
67 pub fn new(
68 input: Arc<dyn ExecutionPlan>,
69 seed_limit: usize,
70 mode: ActivationMode,
71 max_depth: u32,
72 epsilon: f32,
73 inhibition_mu: f32,
74 ) -> Result<Self> {
75 let seed_limit = seed_limit.max(1);
76 let config = ActivationConfig {
77 max_depth: max_depth as usize,
78 epsilon: f64::from(epsilon),
79 inhibition_strength: f64::from(inhibition_mu),
80 ..Default::default()
81 };
82 config.validate().map_err(|error| {
83 datafusion_common::DataFusionError::Execution(format!(
84 "invalid graph activation config: {error}"
85 ))
86 })?;
87
88 let preserve_recall_rows = supports_recall_row_passthrough(input.schema().as_ref());
89 let schema = if preserve_recall_rows {
90 recall_activation_schema(input.schema())
91 } else {
92 Self::output_schema()
93 };
94 let properties = PlanProperties::new(
95 datafusion_physical_expr::EquivalenceProperties::new(schema.clone()),
96 datafusion_physical_plan::Partitioning::UnknownPartitioning(1),
97 EmissionType::Final,
100 Boundedness::Bounded,
101 );
102 Ok(Self {
103 input,
104 schema,
105 properties,
106 seed_limit,
107 mode,
108 max_depth,
109 epsilon,
110 inhibition_mu,
111 preserve_recall_rows,
112 })
113 }
114
115 pub fn output_schema() -> SchemaRef {
117 Arc::new(Schema::new(vec![
118 Field::new("node_id", DataType::Utf8, false),
119 Field::new("activation_score", DataType::Float32, false),
120 Field::new("depth", DataType::UInt32, false),
121 ]))
122 }
123
124 pub fn mode(&self) -> ActivationMode {
125 self.mode
126 }
127
128 pub fn seed_limit(&self) -> usize {
129 self.seed_limit
130 }
131
132 pub fn max_depth(&self) -> u32 {
133 self.max_depth
134 }
135
136 pub fn epsilon(&self) -> f32 {
137 self.epsilon
138 }
139
140 pub fn inhibition_mu(&self) -> f32 {
141 self.inhibition_mu
142 }
143
144 pub fn preserves_recall_rows(&self) -> bool {
145 self.preserve_recall_rows
146 }
147}
148
149impl DisplayAs for GraphActivationExec {
150 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151 write!(
152 f,
153 "GraphActivationExec: seed_limit={}, mode={:?}, depth={}, ε={}, µ={}",
154 self.seed_limit, self.mode, self.max_depth, self.epsilon, self.inhibition_mu
155 )
156 }
157}
158
159impl ExecutionPlan for GraphActivationExec {
160 fn name(&self) -> &str {
161 "GraphActivationExec"
162 }
163
164 fn as_any(&self) -> &dyn Any {
165 self
166 }
167
168 fn schema(&self) -> SchemaRef {
169 self.schema.clone()
170 }
171
172 fn properties(&self) -> &PlanProperties {
173 &self.properties
174 }
175
176 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
177 vec![&self.input]
178 }
179
180 fn with_new_children(
181 self: Arc<Self>,
182 children: Vec<Arc<dyn ExecutionPlan>>,
183 ) -> Result<Arc<dyn ExecutionPlan>> {
184 let [child]: [Arc<dyn ExecutionPlan>; 1] = children.try_into().map_err(|v: Vec<_>| {
185 datafusion_common::DataFusionError::Plan(format!(
186 "GraphActivationExec requires exactly 1 child, got {}",
187 v.len()
188 ))
189 })?;
190 Ok(Arc::new(Self::new(
191 child,
192 self.seed_limit,
193 self.mode,
194 self.max_depth,
195 self.epsilon,
196 self.inhibition_mu,
197 )?))
198 }
199
200 fn execute(
201 &self,
202 partition: usize,
203 context: Arc<TaskContext>,
204 ) -> Result<SendableRecordBatchStream> {
205 let input = self.input.execute(partition, context.clone())?;
206 let schema = self.schema.clone();
207 let stream_schema = schema.clone();
208 let max_depth = self.max_depth;
209 let epsilon = self.epsilon;
210 let inhibition_mu = self.inhibition_mu;
211 let mode = self.mode;
212 let preserve_recall_rows = self.preserve_recall_rows;
213 let seed_limit = self.seed_limit;
214
215 let session_ext = context
216 .session_config()
217 .options()
218 .extensions
219 .get::<HirnSessionExt>()
220 .cloned();
221 let graph_read_runtime = session_ext
222 .as_ref()
223 .and_then(|ext| ext.graph_read_runtime());
224 let storage = session_ext.as_ref().and_then(|ext| ext.storage_arc());
225 let delegation_threshold = session_ext
226 .as_ref()
227 .map(|ext| ext.config.graph_depth_delegation_threshold)
228 .unwrap_or(usize::MAX);
229 let allowed_namespaces = session_ext.as_ref().and_then(|ext| {
230 ext.allowed_namespaces().map(|namespaces| {
231 namespaces
232 .iter()
233 .filter_map(|namespace| Namespace::new(namespace).ok())
234 .collect::<Vec<_>>()
235 })
236 });
237
238 let fut = async move {
239 use futures::StreamExt;
240
241 let mut seed_strings = Vec::new();
242 let mut passthrough_rows = if preserve_recall_rows {
243 Some(RecallPassthroughRows::default())
244 } else {
245 None
246 };
247 let mut stream = input;
248 while let Some(batch) = stream.next().await {
249 let batch = batch?;
250 if let Some(rows) = passthrough_rows.as_mut() {
251 accumulate_recall_rows(rows, &batch).map_err(|error| {
252 datafusion_common::DataFusionError::Execution(error.to_string())
253 })?;
254 }
255
256 if seed_strings.len() < seed_limit {
257 let col = batch
258 .column_by_name("node_id")
259 .or_else(|| batch.column_by_name("id"));
260 if let Some(col) = col {
261 if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
262 for i in 0..arr.len() {
263 if seed_strings.len() >= seed_limit {
264 break;
265 }
266 if !arr.is_null(i) {
267 seed_strings.push(arr.value(i).to_string());
268 }
269 }
270 }
271 }
272 }
273
274 if !preserve_recall_rows && seed_strings.len() >= seed_limit {
275 break;
276 }
277 }
278
279 if seed_strings.is_empty() {
280 let empty = RecordBatch::new_empty(schema);
281 return Ok(empty);
282 }
283
284 let mut seeds = Vec::with_capacity(seed_strings.len());
286 let mut parse_failures = 0_usize;
287 let mut first_errors: Vec<String> = Vec::new();
288 for s in &seed_strings {
289 match MemoryId::parse(s) {
290 Ok(id) => seeds.push(id),
291 Err(e) => {
292 parse_failures += 1;
293 if first_errors.len() < 3 {
294 first_errors.push(format!("{s}: {e}"));
295 }
296 tracing::warn!(
297 seed = %s,
298 "GraphActivationExec: failed to parse seed MemoryId, skipping"
299 );
300 }
301 }
302 }
303
304 if seeds.is_empty() {
305 return Err(datafusion_common::DataFusionError::Execution(format!(
307 "GraphActivationExec: all {} seed IDs failed to parse (first errors: {})",
308 parse_failures,
309 first_errors.join("; ")
310 )));
311 }
312
313 let Some(runtime) = graph_read_runtime else {
315 return Err(datafusion_common::DataFusionError::Execution(
316 "GraphActivationExec requires HirnSessionExt graph runtime".to_string(),
317 ));
318 };
319 let (ids, scores, depths) = {
320 let output = runtime
321 .activate_graph(
322 &seeds,
323 mode,
324 None,
325 max_depth,
326 epsilon,
327 inhibition_mu,
328 delegation_threshold,
329 allowed_namespaces.as_deref(),
330 )
331 .await
332 .map_err(|error| {
333 datafusion_common::DataFusionError::Execution(error.to_string())
334 })?;
335 (output.ids, output.scores, output.depths)
336 };
337
338 if ids.is_empty() {
339 return Ok(RecordBatch::new_empty(schema));
340 }
341
342 if preserve_recall_rows {
343 return build_recall_activation_output_batch(
344 schema,
345 passthrough_rows.unwrap_or_default(),
346 storage.as_deref(),
347 &ids,
348 &scores,
349 &depths,
350 )
351 .await
352 .map_err(|error| datafusion_common::DataFusionError::Execution(error.to_string()));
353 }
354
355 let id_refs: Vec<&str> = ids.iter().map(String::as_str).collect();
356 RecordBatch::try_new(
357 schema,
358 vec![
359 Arc::new(StringArray::from(id_refs)),
360 Arc::new(Float32Array::from(scores)),
361 Arc::new(UInt32Array::from(depths)),
362 ],
363 )
364 .map_err(Into::into)
365 };
366
367 let stream = futures::stream::once(fut);
368 Ok(Box::pin(RecordBatchStreamAdapter::new(
369 stream_schema,
370 stream,
371 )))
372 }
373}
374
375#[cfg(test)]
377fn run_activation(
378 graph: &PropertyGraph,
379 seeds: &[MemoryId],
380 mode: ActivationMode,
381 max_depth: u32,
382 epsilon: f32,
383 inhibition_mu: f32,
384 allowed_namespaces: Option<&[Namespace]>,
385) -> (Vec<String>, Vec<f32>, Vec<u32>) {
386 let base_config = ActivationConfig {
387 max_depth: max_depth as usize,
388 epsilon: f64::from(epsilon),
389 inhibition_strength: f64::from(inhibition_mu),
390 ..Default::default()
391 };
392 let config = base_config.tuned_for_graph(graph.node_count(), graph.edge_count());
395
396 let mut ids = Vec::new();
397 let mut scores = Vec::new();
398 let mut depths = Vec::new();
399
400 match mode {
401 ActivationMode::Static => {
402 let mut entries: Vec<_> =
403 hirn_graph::static_activation(graph, seeds, allowed_namespaces)
404 .into_iter()
405 .collect();
406 entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
407
408 for (node_id, score) in entries {
409 ids.push(node_id.to_string());
410 scores.push(score as f32);
411 depths.push(u32::from(!seeds.contains(&node_id)));
412 }
413 }
414 ActivationMode::Spreading => {
415 let result =
416 hirn_graph::spread_activation(graph, seeds, &config, None, allowed_namespaces)
417 .expect("test activation config should be valid");
418 let mut entries: Vec<_> = result.activations.into_iter().collect();
419 entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
420
421 for (node_id, score) in entries {
422 let depth = result
423 .traces
424 .get(&node_id)
425 .map(|t| t.path.len().saturating_sub(1) as u32)
426 .unwrap_or(0);
427 ids.push(node_id.to_string());
428 scores.push(score as f32);
429 depths.push(depth);
430 }
431 }
432 ActivationMode::Ppr => {
433 let ppr_config = hirn_graph::PprConfig::default();
434 let activations =
435 hirn_graph::personalized_pagerank(graph, seeds, &ppr_config, allowed_namespaces)
436 .expect("default PPR config should be valid");
437 let mut entries: Vec<_> = activations.into_iter().collect();
438 entries.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
439
440 for (node_id, score) in entries {
441 ids.push(node_id.to_string());
442 scores.push(score as f32);
443 depths.push(0); }
445 }
446 }
447
448 (ids, scores, depths)
449}
450
451fn supports_recall_row_passthrough(schema: &Schema) -> bool {
452 [
453 "id",
454 "content",
455 "layer",
456 "namespace",
457 "score",
458 "temporal_ms",
459 "created_at_ms",
460 "importance",
461 "access_count",
462 "surprise",
463 "evidence_count",
464 "invocation_count",
465 ]
466 .iter()
467 .all(|field| schema.field_with_name(field).is_ok())
468}
469
470fn recall_activation_schema(_input_schema: SchemaRef) -> SchemaRef {
478 Arc::new(Schema::new(vec![
479 Field::new("id", DataType::Utf8, false),
480 Field::new("content", DataType::Utf8, false),
481 Field::new("full_content", DataType::Utf8, false),
482 Field::new("layer", DataType::Utf8, false),
483 Field::new("namespace", DataType::Utf8, false),
484 Field::new("score", DataType::Float32, false),
485 Field::new("temporal_ms", DataType::Int64, false),
486 Field::new("created_at_ms", DataType::Int64, false),
487 Field::new("importance", DataType::Float32, false),
488 Field::new("access_count", DataType::UInt32, false),
489 Field::new("surprise", DataType::Float32, true),
490 Field::new("evidence_count", DataType::UInt32, true),
491 Field::new("invocation_count", DataType::UInt64, true),
492 Field::new("activation_score", DataType::Float32, false),
493 Field::new("depth", DataType::UInt32, false),
494 ]))
495}
496
497async fn build_recall_activation_output_batch(
498 schema: SchemaRef,
499 mut passthrough_rows: RecallPassthroughRows,
500 storage: Option<&dyn hirn_storage::PhysicalStore>,
501 activated_ids: &[String],
502 activation_scores: &[f32],
503 depths: &[u32],
504) -> Result<RecordBatch, hirn_storage::HirnDbError> {
505 let mut ordered_ids = std::mem::take(&mut passthrough_rows.ordered_ids);
506 let mut base_rows = std::mem::take(&mut passthrough_rows.base_rows);
507
508 let missing_ids = activated_ids
509 .iter()
510 .filter(|id| !base_rows.contains_key(*id))
511 .filter_map(|id| MemoryId::parse(id).ok())
512 .collect::<Vec<_>>();
513
514 if !missing_ids.is_empty() {
515 let Some(storage) = storage else {
516 return Err(hirn_storage::HirnDbError::InvalidArgument(
517 "graph activation recall expansion requires storage access".to_string(),
518 ));
519 };
520 for row in fetch_recall_rows_by_ids(storage, &missing_ids).await? {
521 base_rows.entry(row.id.clone()).or_insert(row);
522 }
523 }
524
525 let activation_by_id = activated_ids
526 .iter()
527 .zip(activation_scores.iter())
528 .zip(depths.iter())
529 .map(|((activated_id, activation_score), depth)| {
530 (activated_id.as_str(), (*activation_score, *depth))
531 })
532 .collect::<HashMap<_, _>>();
533
534 for activated_id in activated_ids {
535 if !ordered_ids.iter().any(|id| id == activated_id) {
536 ordered_ids.push(activated_id.clone());
537 }
538 }
539
540 let mut rows = Vec::with_capacity(ordered_ids.len());
541 let mut activation_values = Vec::with_capacity(ordered_ids.len());
542 let mut depth_values = Vec::with_capacity(ordered_ids.len());
543 for ordered_id in ordered_ids {
544 if let Some(row) = base_rows.get(&ordered_id).cloned() {
545 let (activation_score, depth) = activation_by_id
546 .get(ordered_id.as_str())
547 .copied()
548 .unwrap_or((0.0, 0));
549 rows.push(row);
550 activation_values.push(activation_score);
551 depth_values.push(depth);
552 }
553 }
554
555 if rows.is_empty() {
556 return Ok(RecordBatch::new_empty(schema));
557 }
558
559 let ids = rows.iter().map(|row| row.id.as_str()).collect::<Vec<_>>();
560 let contents = rows
561 .iter()
562 .map(|row| row.content.as_str())
563 .collect::<Vec<_>>();
564 let full_contents = rows
565 .iter()
566 .map(|row| row.full_content.as_str())
567 .collect::<Vec<_>>();
568 let layers = rows.iter().map(|row| row.layer).collect::<Vec<_>>();
569 let namespaces = rows
570 .iter()
571 .map(|row| row.namespace.as_str())
572 .collect::<Vec<_>>();
573 let scores = rows.iter().map(|row| row.score).collect::<Vec<_>>();
574 let temporal = rows.iter().map(|row| row.temporal_ms).collect::<Vec<_>>();
575 let created_at = rows.iter().map(|row| row.created_at_ms).collect::<Vec<_>>();
576 let importances = rows.iter().map(|row| row.importance).collect::<Vec<_>>();
577 let access_counts = rows.iter().map(|row| row.access_count).collect::<Vec<_>>();
578 let surprises = rows.iter().map(|row| row.surprise).collect::<Vec<_>>();
579 let evidence_counts = rows
580 .iter()
581 .map(|row| row.evidence_count)
582 .collect::<Vec<_>>();
583 let invocation_counts = rows
584 .iter()
585 .map(|row| row.invocation_count)
586 .collect::<Vec<_>>();
587
588 RecordBatch::try_new(
589 schema,
590 vec![
591 Arc::new(StringArray::from(ids)) as ArrayRef,
592 Arc::new(StringArray::from(contents)) as ArrayRef,
593 Arc::new(StringArray::from(full_contents)) as ArrayRef,
594 Arc::new(StringArray::from(layers)) as ArrayRef,
595 Arc::new(StringArray::from(namespaces)) as ArrayRef,
596 Arc::new(Float32Array::from(scores)) as ArrayRef,
597 Arc::new(Int64Array::from(temporal)) as ArrayRef,
598 Arc::new(Int64Array::from(created_at)) as ArrayRef,
599 Arc::new(Float32Array::from(importances)) as ArrayRef,
600 Arc::new(UInt32Array::from(access_counts)) as ArrayRef,
601 Arc::new(Float32Array::from(surprises)) as ArrayRef,
602 Arc::new(UInt32Array::from(evidence_counts)) as ArrayRef,
603 Arc::new(UInt64Array::from(invocation_counts)) as ArrayRef,
604 Arc::new(Float32Array::from(activation_values)) as ArrayRef,
605 Arc::new(UInt32Array::from(depth_values)) as ArrayRef,
606 ],
607 )
608 .map_err(hirn_storage::HirnDbError::ArrowError)
609}
610
611#[derive(Debug, Default)]
612struct RecallPassthroughRows {
613 ordered_ids: Vec<String>,
614 base_rows: HashMap<String, RecallRow>,
615}
616
617fn accumulate_recall_rows(
618 rows: &mut RecallPassthroughRows,
619 batch: &RecordBatch,
620) -> Result<(), hirn_storage::HirnDbError> {
621 for row in recall_rows_from_batch(batch)? {
622 let row_id = row.id.clone();
623 if !rows.base_rows.contains_key(&row_id) {
624 rows.ordered_ids.push(row_id.clone());
625 }
626 rows.base_rows.entry(row_id).or_insert(row);
627 }
628
629 Ok(())
630}
631
632fn recall_rows_from_batch(
633 batch: &RecordBatch,
634) -> Result<Vec<RecallRow>, hirn_storage::HirnDbError> {
635 let ids = batch
636 .column_by_name("id")
637 .and_then(|column| column.as_any().downcast_ref::<StringArray>())
638 .ok_or_else(|| {
639 hirn_storage::HirnDbError::InvalidArgument(
640 "graph activation recall passthrough batch is missing `id`".to_string(),
641 )
642 })?;
643 let contents = batch
644 .column_by_name("content")
645 .and_then(|column| column.as_any().downcast_ref::<StringArray>())
646 .ok_or_else(|| {
647 hirn_storage::HirnDbError::InvalidArgument(
648 "graph activation recall passthrough batch is missing `content`".to_string(),
649 )
650 })?;
651 let full_contents = batch
652 .column_by_name("full_content")
653 .and_then(|column| column.as_any().downcast_ref::<StringArray>());
654 let layers = batch
655 .column_by_name("layer")
656 .and_then(|column| column.as_any().downcast_ref::<StringArray>())
657 .ok_or_else(|| {
658 hirn_storage::HirnDbError::InvalidArgument(
659 "graph activation recall passthrough batch is missing `layer`".to_string(),
660 )
661 })?;
662 let namespaces = batch
663 .column_by_name("namespace")
664 .and_then(|column| column.as_any().downcast_ref::<StringArray>())
665 .ok_or_else(|| {
666 hirn_storage::HirnDbError::InvalidArgument(
667 "graph activation recall passthrough batch is missing `namespace`".to_string(),
668 )
669 })?;
670 let scores = batch
671 .column_by_name("score")
672 .and_then(|column| column.as_any().downcast_ref::<Float32Array>())
673 .ok_or_else(|| {
674 hirn_storage::HirnDbError::InvalidArgument(
675 "graph activation recall passthrough batch is missing `score`".to_string(),
676 )
677 })?;
678 let created_at = batch
679 .column_by_name("created_at_ms")
680 .and_then(|column| column.as_any().downcast_ref::<Int64Array>())
681 .ok_or_else(|| {
682 hirn_storage::HirnDbError::InvalidArgument(
683 "graph activation recall passthrough batch is missing `created_at_ms`".to_string(),
684 )
685 })?;
686 let temporal = batch
687 .column_by_name("temporal_ms")
688 .and_then(|column| column.as_any().downcast_ref::<Int64Array>())
689 .unwrap_or(created_at);
690 let importances = batch
691 .column_by_name("importance")
692 .and_then(|column| column.as_any().downcast_ref::<Float32Array>())
693 .ok_or_else(|| {
694 hirn_storage::HirnDbError::InvalidArgument(
695 "graph activation recall passthrough batch is missing `importance`".to_string(),
696 )
697 })?;
698 let access_counts = batch
699 .column_by_name("access_count")
700 .and_then(|column| column.as_any().downcast_ref::<UInt32Array>())
701 .ok_or_else(|| {
702 hirn_storage::HirnDbError::InvalidArgument(
703 "graph activation recall passthrough batch is missing `access_count`".to_string(),
704 )
705 })?;
706 let surprises = batch
707 .column_by_name("surprise")
708 .and_then(|column| column.as_any().downcast_ref::<Float32Array>())
709 .ok_or_else(|| {
710 hirn_storage::HirnDbError::InvalidArgument(
711 "graph activation recall passthrough batch is missing `surprise`".to_string(),
712 )
713 })?;
714 let evidence_counts = batch
715 .column_by_name("evidence_count")
716 .and_then(|column| column.as_any().downcast_ref::<UInt32Array>())
717 .ok_or_else(|| {
718 hirn_storage::HirnDbError::InvalidArgument(
719 "graph activation recall passthrough batch is missing `evidence_count`".to_string(),
720 )
721 })?;
722 let invocation_counts = batch
723 .column_by_name("invocation_count")
724 .and_then(|column| column.as_any().downcast_ref::<UInt64Array>())
725 .ok_or_else(|| {
726 hirn_storage::HirnDbError::InvalidArgument(
727 "graph activation recall passthrough batch is missing `invocation_count`"
728 .to_string(),
729 )
730 })?;
731
732 let mut rows = Vec::with_capacity(batch.num_rows());
733 for row in 0..batch.num_rows() {
734 rows.push(RecallRow {
735 id: ids.value(row).to_string(),
736 content: contents.value(row).to_string(),
737 full_content: full_contents
738 .map(|fc| fc.value(row).to_string())
739 .unwrap_or_else(|| contents.value(row).to_string()),
740 layer: match layers.value(row) {
741 "episodic" => "episodic",
742 "semantic" => "semantic",
743 "procedural" => "procedural",
744 other => {
745 return Err(hirn_storage::HirnDbError::InvalidArgument(format!(
746 "unsupported recall layer `{other}` in graph activation"
747 )));
748 }
749 },
750 namespace: namespaces.value(row).to_string(),
751 score: if scores.is_null(row) {
752 0.0
753 } else {
754 scores.value(row)
755 },
756 temporal_ms: temporal.value(row),
757 created_at_ms: created_at.value(row),
758 importance: if importances.is_null(row) {
759 0.0
760 } else {
761 importances.value(row)
762 },
763 access_count: if access_counts.is_null(row) {
764 0
765 } else {
766 access_counts.value(row)
767 },
768 surprise: if surprises.is_null(row) {
769 None
770 } else {
771 Some(surprises.value(row))
772 },
773 evidence_count: if evidence_counts.is_null(row) {
774 None
775 } else {
776 Some(evidence_counts.value(row))
777 },
778 invocation_count: if invocation_counts.is_null(row) {
779 None
780 } else {
781 Some(invocation_counts.value(row))
782 },
783 });
784 }
785
786 Ok(rows)
787}
788
789#[cfg(test)]
790mod tests {
791 use std::sync::Mutex;
792
793 use super::*;
794 use arrow_array::{Array, RecordBatch};
795 use async_trait::async_trait;
796 use datafusion::prelude::SessionContext;
797 use datafusion_datasource::memory::MemorySourceConfig;
798 use futures::StreamExt;
799 use hirn_core::HirnResult;
800 use hirn_core::metadata::Metadata;
801 use hirn_core::types::Layer;
802 use hirn_graph::PropertyGraph;
803
804 use crate::{GraphActivationOutput, GraphCausalChainRow, GraphReadRuntime};
805
806 fn seed_batch(ids: &[&str]) -> RecordBatch {
807 RecordBatch::try_new(
808 Arc::new(Schema::new(vec![Field::new(
809 "node_id",
810 DataType::Utf8,
811 false,
812 )])),
813 vec![Arc::new(StringArray::from(ids.to_vec()))],
814 )
815 .unwrap()
816 }
817
818 fn build_test_graph() -> (Arc<RwLock<PropertyGraph>>, Vec<MemoryId>) {
820 let mut g = PropertyGraph::new();
821 let ids: Vec<MemoryId> = (0..3).map(|_| MemoryId::new()).collect();
822 let now = hirn_core::timestamp::Timestamp::now();
823 for &id in &ids {
824 g.add_node(id, Layer::Episodic, 0.5, now);
825 }
826 use hirn_core::types::EdgeRelation;
827 g.add_edge(
828 ids[0],
829 ids[1],
830 EdgeRelation::RelatedTo,
831 0.8,
832 Metadata::new(),
833 )
834 .unwrap();
835 g.add_edge(
836 ids[1],
837 ids[2],
838 EdgeRelation::RelatedTo,
839 0.7,
840 Metadata::new(),
841 )
842 .unwrap();
843 (Arc::new(RwLock::new(g)), ids)
844 }
845
846 #[tokio::test]
847 async fn activation_spreads_to_neighbors() {
848 let (graph, ids) = build_test_graph();
849 let id_strs: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
850
851 let batch = seed_batch(&[&id_strs[0]]);
853 let schema = batch.schema();
854 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
855
856 let exec =
857 GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.1).unwrap();
858
859 let ctx = SessionContext::new();
860 register_graph_runtime(graph, &ctx);
861
862 let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
863 let mut all_ids = Vec::new();
864 while let Some(result) = stream.next().await {
865 let batch = result.unwrap();
866 assert_eq!(batch.schema(), GraphActivationExec::output_schema());
867 let node_col = batch
868 .column(0)
869 .as_any()
870 .downcast_ref::<StringArray>()
871 .unwrap();
872 for i in 0..node_col.len() {
873 all_ids.push(node_col.value(i).to_string());
874 }
875 }
876 assert!(
878 all_ids.len() >= 2,
879 "should activate seed + at least 1 neighbor, got {} ids: {:?}",
880 all_ids.len(),
881 all_ids
882 );
883 assert!(
885 all_ids.contains(&id_strs[0]),
886 "seed node should be in activation results"
887 );
888 }
889
890 #[tokio::test]
891 async fn missing_graph_runtime_returns_error() {
892 let id = MemoryId::new();
893 let id_str = id.to_string();
894 let batch = seed_batch(&[&id_str]);
895 let schema = batch.schema();
896 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
897
898 let exec =
899 GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.1).unwrap();
900 let ctx = SessionContext::new();
901 let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
902
903 let err = stream.next().await.unwrap().unwrap_err().to_string();
904 assert!(
905 err.contains("requires HirnSessionExt graph runtime"),
906 "expected missing graph runtime error, got: {err}"
907 );
908 }
909
910 #[tokio::test]
911 async fn all_invalid_seeds_returns_error() {
912 let batch = seed_batch(&["not-a-valid-ulid"]);
913 let schema = batch.schema();
914 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
915
916 let exec =
917 GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.1).unwrap();
918 let ctx = SessionContext::new();
919 let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
920
921 let result = stream.next().await.unwrap();
922 assert!(result.is_err(), "all invalid seeds should produce an error");
923 let err = result.unwrap_err().to_string();
924 assert!(
925 err.contains("failed to parse"),
926 "error should mention parse failure: {err}"
927 );
928 }
929
930 #[test]
931 fn output_schema_correct() {
932 let schema = GraphActivationExec::output_schema();
933 assert_eq!(schema.fields().len(), 3);
934 assert_eq!(schema.field(0).name(), "node_id");
935 assert_eq!(schema.field(1).name(), "activation_score");
936 assert_eq!(schema.field(2).name(), "depth");
937 }
938
939 struct LocalGraphReadRuntime {
940 graph: Arc<RwLock<PropertyGraph>>,
941 }
942
943 #[async_trait]
944 impl GraphReadRuntime for LocalGraphReadRuntime {
945 async fn activate_graph(
946 &self,
947 seeds: &[MemoryId],
948 mode: ActivationMode,
949 ppr_config: Option<&hirn_graph::PprConfig>,
950 max_depth: u32,
951 epsilon: f32,
952 inhibition_mu: f32,
953 _delegation_threshold: usize,
954 allowed_namespaces: Option<&[Namespace]>,
955 ) -> HirnResult<GraphActivationOutput> {
956 let graph = self.graph.read();
957 let (ids, scores, depths) = match mode {
958 ActivationMode::Ppr => {
959 let default_ppr = hirn_graph::PprConfig::default();
960 let ppr_config = ppr_config.unwrap_or(&default_ppr);
961 let activations = hirn_graph::personalized_pagerank(
962 &graph,
963 seeds,
964 ppr_config,
965 allowed_namespaces,
966 )
967 .expect("test PPR config should be valid");
968 let mut entries: Vec<_> = activations.into_iter().collect();
969 entries
970 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
971 (
972 entries
973 .iter()
974 .map(|(node_id, _)| node_id.to_string())
975 .collect(),
976 entries.iter().map(|(_, score)| *score as f32).collect(),
977 vec![0; entries.len()],
978 )
979 }
980 _ => run_activation(
981 &graph,
982 seeds,
983 mode,
984 max_depth,
985 epsilon,
986 inhibition_mu,
987 allowed_namespaces,
988 ),
989 };
990 Ok(GraphActivationOutput {
991 ids,
992 scores,
993 depths,
994 })
995 }
996
997 async fn causal_chain(
998 &self,
999 _start_ids: &[MemoryId],
1000 _max_depth: u32,
1001 _confidence_threshold: f32,
1002 _delegation_threshold: usize,
1003 _relation: hirn_core::types::EdgeRelation,
1004 _allowed_namespaces: Option<&[Namespace]>,
1005 ) -> HirnResult<Vec<GraphCausalChainRow>> {
1006 Ok(Vec::new())
1007 }
1008
1009 async fn traverse_graph(
1010 &self,
1011 _start_ids: &[MemoryId],
1012 _max_depth: u32,
1013 _delegation_threshold: usize,
1014 _relation_filter: Option<&[hirn_core::types::EdgeRelation]>,
1015 _allowed_namespaces: Option<&[Namespace]>,
1016 ) -> HirnResult<Vec<crate::GraphTraverseRow>> {
1017 Ok(Vec::new())
1018 }
1019 }
1020
1021 fn register_graph_runtime(graph: Arc<RwLock<PropertyGraph>>, ctx: &SessionContext) {
1022 let config = hirn_core::HirnConfig::builder()
1023 .db_path(std::path::Path::new("/tmp/test"))
1024 .build()
1025 .unwrap();
1026 HirnSessionExt::new(
1027 graph.clone() as Arc<dyn Any + Send + Sync>,
1028 Arc::new(config),
1029 None,
1030 )
1031 .with_graph_read_runtime(Arc::new(LocalGraphReadRuntime { graph }))
1032 .register(ctx)
1033 .expect("register should succeed");
1034 }
1035
1036 #[derive(Debug)]
1037 struct MockGraphReadRuntime {
1038 output: GraphActivationOutput,
1039 }
1040
1041 #[async_trait]
1042 impl crate::GraphReadRuntime for MockGraphReadRuntime {
1043 async fn activate_graph(
1044 &self,
1045 _seeds: &[MemoryId],
1046 _mode: ActivationMode,
1047 _ppr_config: Option<&hirn_graph::PprConfig>,
1048 _max_depth: u32,
1049 _epsilon: f32,
1050 _inhibition_mu: f32,
1051 _delegation_threshold: usize,
1052 _allowed_namespaces: Option<&[Namespace]>,
1053 ) -> HirnResult<GraphActivationOutput> {
1054 Ok(self.output.clone())
1055 }
1056
1057 async fn causal_chain(
1058 &self,
1059 _start_ids: &[MemoryId],
1060 _max_depth: u32,
1061 _confidence_threshold: f32,
1062 _delegation_threshold: usize,
1063 _relation: hirn_core::types::EdgeRelation,
1064 _allowed_namespaces: Option<&[Namespace]>,
1065 ) -> HirnResult<Vec<GraphCausalChainRow>> {
1066 Ok(Vec::new())
1067 }
1068
1069 async fn traverse_graph(
1070 &self,
1071 _start_ids: &[MemoryId],
1072 _max_depth: u32,
1073 _delegation_threshold: usize,
1074 _relation_filter: Option<&[hirn_core::types::EdgeRelation]>,
1075 _allowed_namespaces: Option<&[Namespace]>,
1076 ) -> HirnResult<Vec<crate::GraphTraverseRow>> {
1077 Ok(Vec::new())
1078 }
1079 }
1080
1081 #[derive(Debug)]
1082 struct RecordingGraphReadRuntime {
1083 seen_seeds: Arc<Mutex<Vec<MemoryId>>>,
1084 }
1085
1086 #[async_trait]
1087 impl crate::GraphReadRuntime for RecordingGraphReadRuntime {
1088 async fn activate_graph(
1089 &self,
1090 seeds: &[MemoryId],
1091 _mode: ActivationMode,
1092 _ppr_config: Option<&hirn_graph::PprConfig>,
1093 _max_depth: u32,
1094 _epsilon: f32,
1095 _inhibition_mu: f32,
1096 _delegation_threshold: usize,
1097 _allowed_namespaces: Option<&[Namespace]>,
1098 ) -> HirnResult<GraphActivationOutput> {
1099 *self.seen_seeds.lock().expect("lock should succeed") = seeds.to_vec();
1100 Ok(GraphActivationOutput {
1101 ids: seeds.iter().map(ToString::to_string).collect(),
1102 scores: vec![1.0; seeds.len()],
1103 depths: vec![0; seeds.len()],
1104 })
1105 }
1106
1107 async fn causal_chain(
1108 &self,
1109 _start_ids: &[MemoryId],
1110 _max_depth: u32,
1111 _confidence_threshold: f32,
1112 _delegation_threshold: usize,
1113 _relation: hirn_core::types::EdgeRelation,
1114 _allowed_namespaces: Option<&[Namespace]>,
1115 ) -> HirnResult<Vec<GraphCausalChainRow>> {
1116 Ok(Vec::new())
1117 }
1118
1119 async fn traverse_graph(
1120 &self,
1121 _start_ids: &[MemoryId],
1122 _max_depth: u32,
1123 _delegation_threshold: usize,
1124 _relation_filter: Option<&[hirn_core::types::EdgeRelation]>,
1125 _allowed_namespaces: Option<&[Namespace]>,
1126 ) -> HirnResult<Vec<crate::GraphTraverseRow>> {
1127 Ok(Vec::new())
1128 }
1129 }
1130
1131 #[tokio::test]
1132 async fn prefers_registered_graph_read_runtime() {
1133 let id = MemoryId::new();
1134 let id_str = id.to_string();
1135 let batch = seed_batch(&[&id_str]);
1136 let schema = batch.schema();
1137 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1138
1139 let exec =
1140 GraphActivationExec::new(input, 10, ActivationMode::Spreading, 6, 0.001, 0.1).unwrap();
1141 let ctx = SessionContext::new();
1142 let config = hirn_core::HirnConfig::builder()
1143 .db_path(std::path::Path::new("/tmp/test"))
1144 .build()
1145 .unwrap();
1146
1147 HirnSessionExt::new(
1148 Arc::new(()) as Arc<dyn Any + Send + Sync>,
1149 Arc::new(config),
1150 None,
1151 )
1152 .with_graph_read_runtime(Arc::new(MockGraphReadRuntime {
1153 output: GraphActivationOutput {
1154 ids: vec![id_str.clone()],
1155 scores: vec![0.42],
1156 depths: vec![6],
1157 },
1158 }))
1159 .register(&ctx)
1160 .expect("register should succeed");
1161
1162 let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
1163 let result = stream.next().await.unwrap().unwrap();
1164 let scores = result
1165 .column(1)
1166 .as_any()
1167 .downcast_ref::<Float32Array>()
1168 .unwrap();
1169 let depths = result
1170 .column(2)
1171 .as_any()
1172 .downcast_ref::<UInt32Array>()
1173 .unwrap();
1174
1175 assert!((scores.value(0) - 0.42).abs() < f32::EPSILON);
1176 assert_eq!(depths.value(0), 6);
1177 }
1178
1179 #[tokio::test]
1180 async fn ppr_mode_returns_different_ranking_than_spreading() {
1181 let (graph, ids) = build_test_graph();
1182 let id_strs: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
1183
1184 let batch = seed_batch(&[&id_strs[0]]);
1186 let schema = batch.schema();
1187 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1188 let exec_spread =
1189 GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.0).unwrap();
1190 let ctx_s = SessionContext::new();
1191 register_graph_runtime(graph.clone(), &ctx_s);
1192
1193 let mut stream = exec_spread.execute(0, ctx_s.task_ctx()).unwrap();
1194 let batch_s = stream.next().await.unwrap().unwrap();
1195 let scores_s = batch_s
1196 .column(1)
1197 .as_any()
1198 .downcast_ref::<Float32Array>()
1199 .unwrap();
1200 let spread_scores: Vec<f32> = (0..scores_s.len()).map(|i| scores_s.value(i)).collect();
1201
1202 let batch = seed_batch(&[&id_strs[0]]);
1204 let schema = batch.schema();
1205 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1206 let exec_ppr =
1207 GraphActivationExec::new(input, 10, ActivationMode::Ppr, 3, 0.001, 0.0).unwrap();
1208 let ctx_p = SessionContext::new();
1209 register_graph_runtime(graph, &ctx_p);
1210
1211 let mut stream = exec_ppr.execute(0, ctx_p.task_ctx()).unwrap();
1212 let batch_p = stream.next().await.unwrap().unwrap();
1213 let scores_p = batch_p
1214 .column(1)
1215 .as_any()
1216 .downcast_ref::<Float32Array>()
1217 .unwrap();
1218 let ppr_scores: Vec<f32> = (0..scores_p.len()).map(|i| scores_p.value(i)).collect();
1219
1220 assert!(
1222 !spread_scores.is_empty() && !ppr_scores.is_empty(),
1223 "both modes should return results"
1224 );
1225 assert_ne!(
1228 spread_scores, ppr_scores,
1229 "PPR and spreading should produce different score vectors"
1230 );
1231 }
1232
1233 #[tokio::test]
1234 async fn lateral_inhibition_suppresses_competing_cluster() {
1235 let mut g = PropertyGraph::new();
1239 let ids: Vec<MemoryId> = (0..5).map(|_| MemoryId::new()).collect();
1240 let now = hirn_core::timestamp::Timestamp::now();
1241 for &id in &ids {
1242 g.add_node(id, Layer::Episodic, 0.5, now);
1243 }
1244 use hirn_core::types::EdgeRelation;
1245 g.add_edge(
1247 ids[0],
1248 ids[2],
1249 EdgeRelation::RelatedTo,
1250 0.9,
1251 Metadata::new(),
1252 )
1253 .unwrap();
1254 g.add_edge(
1256 ids[1],
1257 ids[2],
1258 EdgeRelation::RelatedTo,
1259 0.9,
1260 Metadata::new(),
1261 )
1262 .unwrap();
1263 g.add_edge(
1265 ids[2],
1266 ids[3],
1267 EdgeRelation::RelatedTo,
1268 0.8,
1269 Metadata::new(),
1270 )
1271 .unwrap();
1272 g.add_edge(
1273 ids[2],
1274 ids[4],
1275 EdgeRelation::RelatedTo,
1276 0.8,
1277 Metadata::new(),
1278 )
1279 .unwrap();
1280
1281 let graph = Arc::new(RwLock::new(g));
1282 let id_strs: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
1283
1284 let batch = seed_batch(&[&id_strs[0]]);
1286 let schema = batch.schema();
1287 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1288 let exec =
1289 GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.0).unwrap();
1290 let ctx_no_inh = SessionContext::new();
1291 register_graph_runtime(graph.clone(), &ctx_no_inh);
1292
1293 let mut stream = exec.execute(0, ctx_no_inh.task_ctx()).unwrap();
1294 let batch_no = stream.next().await.unwrap().unwrap();
1295 let scores_no = batch_no
1296 .column(1)
1297 .as_any()
1298 .downcast_ref::<Float32Array>()
1299 .unwrap();
1300 let total_no: f32 = (0..scores_no.len()).map(|i| scores_no.value(i)).sum();
1301
1302 let batch = seed_batch(&[&id_strs[0]]);
1304 let schema = batch.schema();
1305 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1306 let exec =
1307 GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.5).unwrap();
1308 let ctx_inh = SessionContext::new();
1309 register_graph_runtime(graph, &ctx_inh);
1310
1311 let mut stream = exec.execute(0, ctx_inh.task_ctx()).unwrap();
1312 let batch_inh = stream.next().await.unwrap().unwrap();
1313 let scores_inh = batch_inh
1314 .column(1)
1315 .as_any()
1316 .downcast_ref::<Float32Array>()
1317 .unwrap();
1318 let total_inh: f32 = (0..scores_inh.len()).map(|i| scores_inh.value(i)).sum();
1319
1320 assert!(
1322 total_inh <= total_no,
1323 "inhibition should reduce total activation: {total_inh} should be <= {total_no}"
1324 );
1325 }
1326
1327 #[tokio::test]
1328 async fn mixed_valid_and_invalid_seeds_processes_valid_ones() {
1329 let (graph, ids) = build_test_graph();
1330 let valid_str = ids[0].to_string();
1331 let batch = seed_batch(&[&valid_str, "not-a-valid-ulid"]);
1333 let schema = batch.schema();
1334 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1335
1336 let exec =
1337 GraphActivationExec::new(input, 10, ActivationMode::Spreading, 3, 0.001, 0.1).unwrap();
1338 let ctx = SessionContext::new();
1339 register_graph_runtime(graph, &ctx);
1340
1341 let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
1342 let result = stream.next().await.unwrap().unwrap();
1343 assert!(
1345 result.num_rows() >= 1,
1346 "valid seed should produce activation results"
1347 );
1348 let node_col = result
1349 .column(0)
1350 .as_any()
1351 .downcast_ref::<StringArray>()
1352 .unwrap();
1353 let result_ids: Vec<&str> = (0..node_col.len()).map(|i| node_col.value(i)).collect();
1354 assert!(
1355 result_ids.contains(&valid_str.as_str()),
1356 "valid seed should appear in results"
1357 );
1358 }
1359
1360 #[tokio::test]
1361 async fn respects_seed_limit_before_graph_activation() {
1362 let ids: Vec<MemoryId> = (0..3).map(|_| MemoryId::new()).collect();
1363 let id_strs: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
1364 let batch = seed_batch(&[&id_strs[0], &id_strs[1], &id_strs[2]]);
1365 let schema = batch.schema();
1366 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1367
1368 let exec =
1369 GraphActivationExec::new(input, 2, ActivationMode::Spreading, 3, 0.001, 0.1).unwrap();
1370 let seen_seeds = Arc::new(Mutex::new(Vec::new()));
1371 let ctx = SessionContext::new();
1372 let config = hirn_core::HirnConfig::builder()
1373 .db_path(std::path::Path::new("/tmp/test"))
1374 .build()
1375 .unwrap();
1376
1377 HirnSessionExt::new(
1378 Arc::new(()) as Arc<dyn Any + Send + Sync>,
1379 Arc::new(config),
1380 None,
1381 )
1382 .with_graph_read_runtime(Arc::new(RecordingGraphReadRuntime {
1383 seen_seeds: seen_seeds.clone(),
1384 }))
1385 .register(&ctx)
1386 .expect("register should succeed");
1387
1388 let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
1389 let _ = stream.next().await.unwrap().unwrap();
1390
1391 let recorded = seen_seeds.lock().expect("lock should succeed").clone();
1392 assert_eq!(recorded, ids[..2].to_vec());
1393 }
1394
1395 #[tokio::test]
1396 async fn preserve_recall_rows_keeps_nonseed_candidates() {
1397 let ids: Vec<MemoryId> = (0..2).map(|_| MemoryId::new()).collect();
1398 let id_strs: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
1399
1400 let batch = RecordBatch::try_new(
1401 Arc::new(Schema::new(vec![
1402 Field::new("id", DataType::Utf8, false),
1403 Field::new("content", DataType::Utf8, false),
1404 Field::new("layer", DataType::Utf8, false),
1405 Field::new("namespace", DataType::Utf8, false),
1406 Field::new("score", DataType::Float32, false),
1407 Field::new("temporal_ms", DataType::Int64, false),
1408 Field::new("created_at_ms", DataType::Int64, false),
1409 Field::new("importance", DataType::Float32, false),
1410 Field::new("access_count", DataType::UInt32, false),
1411 Field::new("surprise", DataType::Float32, true),
1412 Field::new("evidence_count", DataType::UInt32, true),
1413 Field::new("invocation_count", DataType::UInt64, true),
1414 ])),
1415 vec![
1416 Arc::new(StringArray::from(vec![
1417 id_strs[0].as_str(),
1418 id_strs[1].as_str(),
1419 ])),
1420 Arc::new(StringArray::from(vec!["seed", "nonseed candidate"])),
1421 Arc::new(StringArray::from(vec!["episodic", "episodic"])),
1422 Arc::new(StringArray::from(vec!["default", "default"])),
1423 Arc::new(Float32Array::from(vec![0.9, 0.8])),
1424 Arc::new(Int64Array::from(vec![1_i64, 2_i64])),
1425 Arc::new(Int64Array::from(vec![1_i64, 2_i64])),
1426 Arc::new(Float32Array::from(vec![0.7, 0.6])),
1427 Arc::new(UInt32Array::from(vec![1_u32, 1_u32])),
1428 Arc::new(Float32Array::from(vec![Some(0.0_f32), Some(0.0_f32)])),
1429 Arc::new(UInt32Array::from(vec![Some(0_u32), Some(0_u32)])),
1430 Arc::new(UInt64Array::from(vec![Some(0_u64), Some(0_u64)])),
1431 ],
1432 )
1433 .unwrap();
1434
1435 let schema = batch.schema();
1436 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1437 let exec =
1438 GraphActivationExec::new(input, 1, ActivationMode::Spreading, 3, 0.001, 0.1).unwrap();
1439
1440 let seen_seeds = Arc::new(Mutex::new(Vec::new()));
1441 let ctx = SessionContext::new();
1442 let config = hirn_core::HirnConfig::builder()
1443 .db_path(std::path::Path::new("/tmp/test"))
1444 .build()
1445 .unwrap();
1446
1447 HirnSessionExt::new(
1448 Arc::new(()) as Arc<dyn Any + Send + Sync>,
1449 Arc::new(config),
1450 None,
1451 )
1452 .with_graph_read_runtime(Arc::new(RecordingGraphReadRuntime {
1453 seen_seeds: seen_seeds.clone(),
1454 }))
1455 .register(&ctx)
1456 .expect("register should succeed");
1457
1458 let mut stream = exec.execute(0, ctx.task_ctx()).unwrap();
1459 let result = stream.next().await.unwrap().unwrap();
1460 let ids = result
1461 .column_by_name("id")
1462 .and_then(|column| column.as_any().downcast_ref::<StringArray>())
1463 .unwrap();
1464 let activation_scores = result
1465 .column_by_name("activation_score")
1466 .and_then(|column| column.as_any().downcast_ref::<Float32Array>())
1467 .unwrap();
1468
1469 let output_ids = (0..ids.len())
1470 .map(|index| ids.value(index).to_string())
1471 .collect::<Vec<_>>();
1472 assert_eq!(output_ids, id_strs);
1473 assert!((activation_scores.value(0) - 1.0).abs() < f32::EPSILON);
1474 assert_eq!(activation_scores.value(1), 0.0);
1475 }
1476
1477 #[test]
1478 fn invalid_config_rejected_at_construction() {
1479 let batch = seed_batch(&["not-used"]);
1480 let schema = batch.schema();
1481 let input = MemorySourceConfig::try_new_exec(&[vec![batch]], schema, None).unwrap();
1482
1483 let err = GraphActivationExec::new(input, 10, ActivationMode::Spreading, 0, 0.001, 0.1)
1484 .unwrap_err();
1485 assert!(err.to_string().contains("invalid graph activation config"));
1486 }
1487}