1use std::collections::{HashMap, HashSet};
13use std::sync::Arc;
14
15use arrow_array::Array;
16use futures::TryStreamExt;
17use hirn_core::HirnResult;
18use hirn_core::id::MemoryId;
19use hirn_core::metadata::Metadata;
20use hirn_core::timestamp::Timestamp;
21use hirn_core::types::{EdgeRelation, Layer, Namespace};
22use hirn_graph::graph::{
23 EdgeId, GraphEdge, GraphNodeData, MAX_EDGES_PER_NODE, validate_edge_metadata,
24};
25
26use hirn_storage::PhysicalStore;
27use hirn_storage::datasets::graph::{self, DATASET_EDGES_NAME, DATASET_NODES_NAME};
28use hirn_storage::store::{ExactMatchFilter, IndexConfig, IndexType, ScanOptions};
29
30#[derive(Debug, Clone)]
35pub struct BfsResult {
36 pub depths: Vec<Vec<GraphEdge>>,
39 pub visited: Vec<MemoryId>,
41}
42
43impl BfsResult {
44 pub fn all_targets(&self) -> Vec<MemoryId> {
46 use std::collections::HashSet;
47 let mut seen = HashSet::new();
48 let mut targets = Vec::new();
49 for depth_edges in &self.depths {
50 for edge in depth_edges {
51 if seen.insert(edge.target) {
52 targets.push(edge.target);
53 }
54 }
55 }
56 targets
57 }
58
59 pub fn total_edges(&self) -> usize {
61 self.depths.iter().map(Vec::len).sum()
62 }
63}
64
65#[derive(Debug, Clone)]
70pub struct CausalBfsRow {
71 pub chain_id: String,
72 pub source_id: MemoryId,
73 pub target_id: MemoryId,
74 pub strength: f32,
75 pub confidence: f32,
76 pub evidence_count: u32,
77 pub mechanism: Option<String>,
78 pub depth: u32,
79 pub chain_score: f32,
80}
81
82#[derive(Debug, Clone)]
84struct CausalBfsEdge {
85 source: MemoryId,
86 target: MemoryId,
87 strength: f32,
88 confidence: f32,
89 evidence_count: u32,
90 mechanism: Option<String>,
91}
92
93fn emit_causal_rows(
95 chain_edges: &[CausalBfsEdge],
96 rows: &mut Vec<CausalBfsRow>,
97 chain_counter: &mut u32,
98) {
99 let chain_id = format!("chain_{}", *chain_counter);
100 *chain_counter += 1;
101
102 let score_sum: f32 = chain_edges
104 .iter()
105 .map(|e| e.strength * e.confidence * (1.0_f32 + e.evidence_count as f32).ln())
106 .sum();
107 let chain_score = score_sum / chain_edges.len().max(1) as f32;
108
109 for (depth, edge) in chain_edges.iter().enumerate() {
110 rows.push(CausalBfsRow {
111 chain_id: chain_id.clone(),
112 source_id: edge.source,
113 target_id: edge.target,
114 strength: edge.strength,
115 confidence: edge.confidence,
116 evidence_count: edge.evidence_count,
117 mechanism: edge.mechanism.clone(),
118 depth: depth as u32,
119 chain_score,
120 });
121 }
122}
123
124pub struct PersistentGraph {
129 storage: Arc<dyn PhysicalStore>,
130}
131
132impl PersistentGraph {
133 fn layer_exact_filter(layer: Layer) -> ExactMatchFilter {
134 ExactMatchFilter::utf8_value("layer", layer_to_str(layer))
135 }
136
137 fn namespace_exact_filter(namespace: &Namespace) -> ExactMatchFilter {
138 ExactMatchFilter::utf8_value("namespace", namespace.as_str())
139 }
140
141 fn source_exact_filter(source: MemoryId) -> ExactMatchFilter {
142 ExactMatchFilter::utf8_value("source", source.to_string())
143 }
144
145 fn target_exact_filter(target: MemoryId) -> ExactMatchFilter {
146 ExactMatchFilter::utf8_value("target", target.to_string())
147 }
148
149 fn quoted_in_values<T>(ids: &[T]) -> Vec<String>
150 where
151 T: ToString,
152 {
153 ids.iter()
154 .map(|id| {
155 let value = id.to_string();
156 let escaped = value.replace('\'', "''");
157 format!("'{escaped}'")
158 })
159 .collect()
160 }
161
162 fn quoted_namespace_values(namespaces: &[Namespace]) -> Vec<String> {
163 namespaces
164 .iter()
165 .map(|namespace| {
166 let escaped = namespace.as_str().replace('\'', "''");
167 format!("'{escaped}'")
168 })
169 .collect()
170 }
171
172 #[must_use]
178 pub fn new(storage: Arc<dyn PhysicalStore>) -> Self {
179 Self { storage }
180 }
181
182 pub async fn open(storage: Arc<dyn PhysicalStore>) -> HirnResult<Self> {
187 let pg = Self { storage };
188 pg.ensure_indices().await?;
189 Ok(pg)
190 }
191
192 async fn ensure_indices(&self) -> HirnResult<()> {
194 if self
196 .storage
197 .exists(DATASET_NODES_NAME)
198 .await
199 .unwrap_or(false)
200 {
201 let count = self
202 .storage
203 .count(DATASET_NODES_NAME, None)
204 .await
205 .unwrap_or(0);
206 if count > 0 {
207 let _ = self
208 .storage
209 .create_index(
210 DATASET_NODES_NAME,
211 IndexConfig {
212 columns: vec!["id".into()],
213 index_type: IndexType::BTree,
214 replace: false,
215 params: Default::default(),
216 },
217 )
218 .await;
219 let _ = self
220 .storage
221 .create_index(
222 DATASET_NODES_NAME,
223 IndexConfig {
224 columns: vec!["layer".into()],
225 index_type: IndexType::Bitmap,
226 replace: false,
227 params: Default::default(),
228 },
229 )
230 .await;
231 }
232 }
233 if self
234 .storage
235 .exists(DATASET_EDGES_NAME)
236 .await
237 .unwrap_or(false)
238 {
239 let count = self
240 .storage
241 .count(DATASET_EDGES_NAME, None)
242 .await
243 .unwrap_or(0);
244 if count > 0 {
245 let _ = self
246 .storage
247 .create_index(
248 DATASET_EDGES_NAME,
249 IndexConfig {
250 columns: vec!["source".into()],
251 index_type: IndexType::Bitmap,
252 replace: false,
253 params: Default::default(),
254 },
255 )
256 .await;
257 let _ = self
258 .storage
259 .create_index(
260 DATASET_EDGES_NAME,
261 IndexConfig {
262 columns: vec!["target".into()],
263 index_type: IndexType::BTree,
264 replace: false,
265 params: Default::default(),
266 },
267 )
268 .await;
269 let _ = self
270 .storage
271 .create_index(
272 DATASET_EDGES_NAME,
273 IndexConfig {
274 columns: vec!["relation".into()],
275 index_type: IndexType::Bitmap,
276 replace: false,
277 params: Default::default(),
278 },
279 )
280 .await;
281 }
282 }
283 Ok(())
284 }
285
286 async fn scan_nodes(&self, options: ScanOptions) -> HirnResult<Vec<GraphNodeData>> {
287 let mut stream = self
288 .storage
289 .scan_stream(DATASET_NODES_NAME, options)
290 .await?;
291 let mut nodes = Vec::new();
292
293 while let Some(batch) = stream.try_next().await? {
294 nodes.extend(graph::nodes_from_batch(&batch)?);
295 }
296
297 Ok(nodes)
298 }
299
300 async fn scan_edges(&self, options: ScanOptions) -> HirnResult<Vec<GraphEdge>> {
301 let mut stream = self
302 .storage
303 .scan_stream(DATASET_EDGES_NAME, options)
304 .await?;
305 let mut edges = Vec::new();
306
307 while let Some(batch) = stream.try_next().await? {
308 edges.extend(
313 graph::edges_from_batch(&batch)?
314 .into_iter()
315 .filter(|e| e.is_currently_active()),
316 );
317 }
318
319 Ok(edges)
320 }
321
322 pub async fn add_node(
326 &self,
327 id: MemoryId,
328 layer: Layer,
329 importance: f32,
330 created_at: Timestamp,
331 namespace: Namespace,
332 ) -> HirnResult<bool> {
333 let node = GraphNodeData {
334 id,
335 layer,
336 importance,
337 created_at,
338 namespace,
339 access_count: 0,
340 };
341 let batch = graph::nodes_to_batch(&[node])?;
342 self.storage
343 .merge_insert(DATASET_NODES_NAME, &["id"], batch)
344 .await?;
345 Ok(true)
346 }
347
348 pub async fn add_nodes(&self, nodes: &[GraphNodeData]) -> HirnResult<()> {
350 if nodes.is_empty() {
351 return Ok(());
352 }
353
354 let batch = graph::nodes_to_batch(nodes)?;
355 self.storage
356 .merge_insert(DATASET_NODES_NAME, &["id"], batch)
357 .await?;
358 Ok(())
359 }
360
361 pub async fn get_node(&self, id: MemoryId) -> HirnResult<Option<GraphNodeData>> {
363 let id_str = id.to_string();
364 let nodes = self
365 .scan_nodes(ScanOptions {
366 columns: None,
367 filter: None,
368 exact_filter: Some(ExactMatchFilter::utf8_value("id", id_str)),
369 order_by: None,
370 limit: Some(1),
371 offset: None,
372 })
373 .await?;
374
375 Ok(nodes.into_iter().next())
376 }
377
378 pub async fn update_node(&self, node: GraphNodeData) -> HirnResult<()> {
380 let batch = graph::nodes_to_batch(&[node])?;
381 self.storage
382 .merge_insert(DATASET_NODES_NAME, &["id"], batch)
383 .await?;
384 Ok(())
385 }
386
387 pub async fn flush_access_counts(&self, dirty: &[(MemoryId, u64)]) -> HirnResult<()> {
395 if dirty.is_empty() {
396 return Ok(());
397 }
398
399 for chunk in dirty.chunks(500) {
401 let id_list: Vec<String> = chunk
403 .iter()
404 .map(|(id, _)| format!("'{}'", id.to_string().replace('\'', "''")))
405 .collect();
406 let filter = format!("id IN ({})", id_list.join(", "));
407
408 let mut case_expr = String::from("CASE id");
410 for (id, count) in chunk {
411 case_expr.push_str(&format!(
412 " WHEN '{}' THEN {}",
413 id.to_string().replace('\'', "''"),
414 count
415 ));
416 }
417 case_expr.push_str(" ELSE access_count END");
419
420 let case_expr_ref: &str = &case_expr;
421 let updates: &[(&str, &str)] = &[("access_count", case_expr_ref)];
422
423 if let Err(e) = self
424 .storage
425 .update_where(DATASET_NODES_NAME, &filter, updates)
426 .await
427 {
428 tracing::warn!(error = %e, "flush_access_counts: update_where failed; skipping chunk");
429 }
430 }
431
432 Ok(())
433 }
434
435 pub async fn remove_node(&self, id: MemoryId) -> HirnResult<bool> {
437 let id_str = id.to_string();
438
439 if self.get_node(id).await?.is_none() {
441 return Ok(false);
442 }
443
444 self.expire_node_edges(id, Timestamp::now()).await?;
447
448 let exact_filter = ExactMatchFilter::utf8_value("id", id_str);
450 self.storage
451 .delete_exact(DATASET_NODES_NAME, &exact_filter)
452 .await?;
453
454 Ok(true)
455 }
456
457 pub async fn expire_node_edges(
464 &self,
465 node_id: MemoryId,
466 expiry: Timestamp,
467 ) -> HirnResult<()> {
468 let id_str = node_id.to_string();
469 let expiry_ms = expiry.timestamp_ms() as i64;
470 let expiry_expr = expiry_ms.to_string();
471
472 let filter_source = format!(
474 "source = '{}' AND (valid_until_ms IS NULL OR valid_until_ms = 0)",
475 id_str.replace('\'', "''")
476 );
477 let filter_target = format!(
478 "target = '{}' AND (valid_until_ms IS NULL OR valid_until_ms = 0)",
479 id_str.replace('\'', "''")
480 );
481
482 let updates: &[(&str, &str)] = &[("valid_until_ms", &expiry_expr)];
483 if let Err(e) = self
485 .storage
486 .update_where(DATASET_EDGES_NAME, &filter_source, updates)
487 .await
488 {
489 tracing::warn!(node_id = %node_id, error = %e, "expire_node_edges: failed to expire source edges");
490 }
491 if let Err(e) = self
492 .storage
493 .update_where(DATASET_EDGES_NAME, &filter_target, updates)
494 .await
495 {
496 tracing::warn!(node_id = %node_id, error = %e, "expire_node_edges: failed to expire target edges");
497 }
498 Ok(())
499 }
500
501
502 pub async fn has_node(&self, id: MemoryId) -> HirnResult<bool> {
504 Ok(self.get_node(id).await?.is_some())
505 }
506
507 pub async fn node_count(&self) -> HirnResult<u64> {
509 if !self.storage.exists(DATASET_NODES_NAME).await? {
510 return Ok(0);
511 }
512 self.storage
513 .count(DATASET_NODES_NAME, None)
514 .await
515 .map_err(Into::into)
516 }
517
518 pub async fn node_ids(&self) -> HirnResult<Vec<MemoryId>> {
520 if !self.storage.exists(DATASET_NODES_NAME).await? {
521 return Ok(vec![]);
522 }
523 let mut stream = self
524 .storage
525 .scan_stream(
526 DATASET_NODES_NAME,
527 ScanOptions {
528 columns: Some(vec!["id".into()]),
529 filter: None,
530 exact_filter: None,
531 order_by: None,
532 limit: None,
533 offset: None,
534 },
535 )
536 .await?;
537
538 let mut ids = Vec::new();
539 while let Some(batch) = stream.try_next().await? {
540 let col = batch
541 .column_by_name("id")
542 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>());
543 if let Some(arr) = col {
544 for i in 0..arr.len() {
545 if let Ok(id) = MemoryId::parse(arr.value(i)) {
546 ids.push(id);
547 }
548 }
549 }
550 }
551 Ok(ids)
552 }
553
554 pub async fn nodes_by_layer(&self, layer: Layer) -> HirnResult<Vec<GraphNodeData>> {
556 if !self.storage.exists(DATASET_NODES_NAME).await? {
557 return Ok(vec![]);
558 }
559 self.scan_nodes(ScanOptions {
560 columns: None,
561 filter: None,
562 exact_filter: Some(Self::layer_exact_filter(layer)),
563 order_by: None,
564 limit: None,
565 offset: None,
566 })
567 .await
568 }
569
570 pub async fn nodes_by_namespace(&self, ns: &Namespace) -> HirnResult<Vec<GraphNodeData>> {
572 if !self.storage.exists(DATASET_NODES_NAME).await? {
573 return Ok(vec![]);
574 }
575 self.scan_nodes(ScanOptions {
576 columns: None,
577 filter: None,
578 exact_filter: Some(Self::namespace_exact_filter(ns)),
579 order_by: None,
580 limit: None,
581 offset: None,
582 })
583 .await
584 }
585
586 pub async fn node_importance(&self, id: MemoryId) -> HirnResult<Option<f32>> {
588 Ok(self.get_node(id).await?.map(|n| n.importance))
589 }
590
591 pub async fn set_node_importance(&self, id: MemoryId, importance: f32) -> HirnResult<()> {
593 if let Some(mut node) = self.get_node(id).await? {
594 node.importance = importance;
595 self.update_node(node).await?;
596 }
597 Ok(())
598 }
599
600 pub async fn add_edge(
611 &self,
612 source: MemoryId,
613 target: MemoryId,
614 relation: EdgeRelation,
615 weight: f32,
616 metadata: Metadata,
617 ) -> HirnResult<EdgeId> {
618 let id = self
619 .add_edge_one_dir(source, target, relation, weight, metadata.clone(), None)
620 .await?;
621
622 if relation.is_bidirectional() && source != target {
624 match self
625 .add_edge_one_dir(target, source, relation, weight, metadata, None)
626 .await
627 {
628 Ok(_) => {}
629 Err(hirn_core::HirnError::AlreadyExists(_)) => {}
630 Err(e) => return Err(e),
631 }
632 }
633
634 Ok(id)
635 }
636
637 pub async fn add_causal_edge(
643 &self,
644 source: MemoryId,
645 target: MemoryId,
646 relation: EdgeRelation,
647 weight: f32,
648 metadata: Metadata,
649 causal: hirn_graph::CausalEdgeData,
650 ) -> HirnResult<EdgeId> {
651 let id = self
652 .add_edge_one_dir(
653 source,
654 target,
655 relation,
656 weight,
657 metadata.clone(),
658 Some(Box::new(causal.clone())),
659 )
660 .await?;
661
662 if relation.is_bidirectional() && source != target {
663 match self
664 .add_edge_one_dir(
665 target,
666 source,
667 relation,
668 weight,
669 metadata,
670 Some(Box::new(causal)),
671 )
672 .await
673 {
674 Ok(_) => {}
675 Err(hirn_core::HirnError::AlreadyExists(_)) => {}
676 Err(e) => return Err(e),
677 }
678 }
679
680 Ok(id)
681 }
682
683 async fn add_edge_one_dir(
685 &self,
686 source: MemoryId,
687 target: MemoryId,
688 relation: EdgeRelation,
689 weight: f32,
690 metadata: Metadata,
691 causal: Option<Box<hirn_graph::CausalEdgeData>>,
692 ) -> HirnResult<EdgeId> {
693 validate_edge_metadata(&metadata)?;
694
695 let existing = self.get_edges_from(source).await?;
697 if existing.len() >= MAX_EDGES_PER_NODE {
698 return Err(hirn_core::HirnError::InvalidInput(format!(
699 "node {} has reached the maximum of {} edges",
700 source, MAX_EDGES_PER_NODE
701 )));
702 }
703
704 for e in &existing {
706 if e.target == target && e.relation == relation {
707 return Err(hirn_core::HirnError::AlreadyExists(format!(
708 "edge {source} -[{relation:?}]-> {target} already exists"
709 )));
710 }
711 }
712
713 let now = Timestamp::now();
714 let id = MemoryId::new();
715
716 let ns = match self.get_node(source).await? {
718 Some(n) => n.namespace,
719 None => Namespace::default(),
720 };
721
722 let edge = GraphEdge {
723 id,
724 source,
725 target,
726 relation,
727 weight: weight.clamp(0.01, 1.0),
728 co_retrieval_count: 0,
729 created_at: now,
730 updated_at: now,
731 valid_from: None,
732 valid_until: None,
733 metadata,
734 resolved: false,
735 namespace: ns,
736 causal,
737 };
738
739 let batch = graph::edges_to_batch(&[edge])?;
740 self.storage
741 .merge_insert(DATASET_EDGES_NAME, &["id"], batch)
742 .await?;
743
744 Ok(id)
745 }
746
747 pub async fn get_edges_from(&self, source: MemoryId) -> HirnResult<Vec<GraphEdge>> {
749 if !self.storage.exists(DATASET_EDGES_NAME).await? {
750 return Ok(vec![]);
751 }
752 self.scan_edges(ScanOptions {
753 columns: None,
754 filter: None,
755 exact_filter: Some(Self::source_exact_filter(source)),
756 order_by: None,
757 limit: None,
758 offset: None,
759 })
760 .await
761 }
762
763 pub async fn get_edges_to(&self, target: MemoryId) -> HirnResult<Vec<GraphEdge>> {
765 if !self.storage.exists(DATASET_EDGES_NAME).await? {
766 return Ok(vec![]);
767 }
768 self.scan_edges(ScanOptions {
769 columns: None,
770 filter: None,
771 exact_filter: Some(Self::target_exact_filter(target)),
772 order_by: None,
773 limit: None,
774 offset: None,
775 })
776 .await
777 }
778
779 pub async fn get_edges(&self, node_id: MemoryId) -> HirnResult<Vec<GraphEdge>> {
781 if !self.storage.exists(DATASET_EDGES_NAME).await? {
782 return Ok(vec![]);
783 }
784 let id_str = node_id.to_string();
785 self.scan_edges(ScanOptions {
786 columns: None,
787 filter: None,
788 exact_filter: Some(ExactMatchFilter::utf8_multi_column_or(
789 vec!["source".to_string(), "target".to_string()],
790 &id_str,
791 )),
792 order_by: None,
793 limit: None,
794 offset: None,
795 })
796 .await
797 }
798
799 pub async fn get_edges_between(&self, a: MemoryId, b: MemoryId) -> HirnResult<Vec<GraphEdge>> {
801 if !self.storage.exists(DATASET_EDGES_NAME).await? {
802 return Ok(vec![]);
803 }
804 let a_str = a.to_string();
805 let b_str = b.to_string();
806 self.scan_edges(ScanOptions {
807 columns: None,
808 filter: Some(format!(
809 "(source = '{a_str}' AND target = '{b_str}') OR (source = '{b_str}' AND target = '{a_str}')"
810 )),
811 exact_filter: None,
812 order_by: None,
813 limit: None,
814 offset: None,
815 })
816 .await
817 }
818
819 pub async fn get_edges_of_type(
821 &self,
822 node_id: MemoryId,
823 relation: EdgeRelation,
824 ) -> HirnResult<Vec<GraphEdge>> {
825 if !self.storage.exists(DATASET_EDGES_NAME).await? {
826 return Ok(vec![]);
827 }
828 let id_str = node_id.to_string();
829 let rel_str = edge_relation_to_str(relation);
830 self.scan_edges(ScanOptions {
831 columns: None,
832 filter: Some(format!(
833 "(source = '{id_str}' OR target = '{id_str}') AND relation = '{rel_str}'"
834 )),
835 exact_filter: None,
836 order_by: None,
837 limit: None,
838 offset: None,
839 })
840 .await
841 }
842
843 pub async fn update_edge_weight(
845 &self,
846 edge_id: EdgeId,
847 new_weight: f32,
848 co_retrieval_count: Option<u64>,
849 ) -> HirnResult<()> {
850 if let Some(mut edge) = self.get_edges_by_ids(&[edge_id]).await?.into_iter().next() {
851 edge.weight = new_weight.clamp(0.01, 1.0);
852 edge.updated_at = Timestamp::now();
853 if let Some(count) = co_retrieval_count {
854 edge.co_retrieval_count = count;
855 }
856 self.upsert_edges(&[edge]).await?;
857 }
858 Ok(())
859 }
860
861 pub async fn get_edges_by_ids(&self, edge_ids: &[EdgeId]) -> HirnResult<Vec<GraphEdge>> {
863 if edge_ids.is_empty() {
864 return Ok(vec![]);
865 }
866 if !self.storage.exists(DATASET_EDGES_NAME).await? {
867 return Ok(vec![]);
868 }
869
870 let unique_ids: Vec<EdgeId> = edge_ids
871 .iter()
872 .copied()
873 .collect::<HashSet<_>>()
874 .into_iter()
875 .collect();
876 let predicate = format!("id IN ({})", Self::quoted_in_values(&unique_ids).join(", "));
877 self.scan_edges(ScanOptions {
878 columns: None,
879 filter: Some(predicate),
880 exact_filter: None,
881 order_by: None,
882 limit: None,
883 offset: None,
884 })
885 .await
886 }
887
888 pub async fn get_edges_for_nodes(&self, node_ids: &[MemoryId]) -> HirnResult<Vec<GraphEdge>> {
890 if node_ids.is_empty() {
891 return Ok(vec![]);
892 }
893 if !self.storage.exists(DATASET_EDGES_NAME).await? {
894 return Ok(vec![]);
895 }
896
897 let unique_ids: Vec<MemoryId> = node_ids
898 .iter()
899 .copied()
900 .collect::<HashSet<_>>()
901 .into_iter()
902 .collect();
903 let in_values = Self::quoted_in_values(&unique_ids).join(", ");
904 self.scan_edges(ScanOptions {
905 columns: None,
906 filter: Some(format!(
907 "source IN ({in_values}) OR target IN ({in_values})"
908 )),
909 exact_filter: None,
910 order_by: None,
911 limit: None,
912 offset: None,
913 })
914 .await
915 }
916
917 pub async fn get_edge(&self, edge_id: EdgeId) -> HirnResult<Option<GraphEdge>> {
919 if !self.storage.exists(DATASET_EDGES_NAME).await? {
920 return Ok(None);
921 }
922 let id_str = edge_id.to_string();
923 let edges = self
924 .scan_edges(ScanOptions {
925 columns: None,
926 filter: None,
927 exact_filter: Some(ExactMatchFilter::utf8_value("id", id_str)),
928 order_by: None,
929 limit: Some(1),
930 offset: None,
931 })
932 .await?;
933
934 Ok(edges.into_iter().next())
935 }
936
937 pub async fn remove_edge(&self, edge_id: EdgeId) -> HirnResult<()> {
939 let id_str = edge_id.to_string();
940 let exact_filter = ExactMatchFilter::utf8_value("id", id_str);
941 self.storage
942 .delete_exact(DATASET_EDGES_NAME, &exact_filter)
943 .await?;
944 Ok(())
945 }
946
947 pub async fn edge_count(&self) -> HirnResult<u64> {
953 if !self.storage.exists(DATASET_EDGES_NAME).await? {
954 return Ok(0);
955 }
956 let now_ms = hirn_core::timestamp::Timestamp::now().timestamp_ms();
958 let active_filter = format!(
959 "valid_until_ms IS NULL OR valid_until_ms = 0 OR valid_until_ms > {now_ms}"
960 );
961 self.storage
962 .count(DATASET_EDGES_NAME, Some(&active_filter))
963 .await
964 .map_err(Into::into)
965 }
966
967 pub async fn add_edges(&self, edges: &[GraphEdge]) -> HirnResult<()> {
969 self.upsert_edges(edges).await
970 }
971
972 pub async fn upsert_edges(&self, edges: &[GraphEdge]) -> HirnResult<()> {
974 if edges.is_empty() {
975 return Ok(());
976 }
977 let batch = graph::edges_to_batch(edges)?;
978 self.storage
979 .merge_insert(DATASET_EDGES_NAME, &["id"], batch)
980 .await?;
981 Ok(())
982 }
983
984 pub async fn outgoing_weighted(
989 &self,
990 node_id: MemoryId,
991 ) -> HirnResult<Vec<(MemoryId, f32, EdgeRelation)>> {
992 let edges = self.get_edges_from(node_id).await?;
993 Ok(edges
994 .into_iter()
995 .map(|e| (e.target, e.weight, e.relation))
996 .collect())
997 }
998
999 pub async fn batch_adjacency_read(&self, frontier: &[MemoryId]) -> HirnResult<Vec<GraphEdge>> {
1002 self.batch_adjacency_read_scoped(frontier, None).await
1003 }
1004
1005 pub async fn batch_adjacency_read_scoped(
1011 &self,
1012 frontier: &[MemoryId],
1013 allowed_namespaces: Option<&[Namespace]>,
1014 ) -> HirnResult<Vec<GraphEdge>> {
1015 if frontier.is_empty() {
1016 return Ok(vec![]);
1017 }
1018 if allowed_namespaces.is_some_and(<[Namespace]>::is_empty) {
1019 return Ok(vec![]);
1020 }
1021 if !self.storage.exists(DATASET_EDGES_NAME).await? {
1022 return Ok(vec![]);
1023 }
1024
1025 let mut predicate = format!(
1026 "source IN ({})",
1027 Self::quoted_in_values(frontier).join(", ")
1028 );
1029 if let Some(allowed_namespaces) = allowed_namespaces {
1030 predicate.push_str(" AND namespace IN (");
1031 predicate.push_str(&Self::quoted_namespace_values(allowed_namespaces).join(", "));
1032 predicate.push(')');
1033 }
1034
1035 let edges = self
1036 .scan_edges(ScanOptions {
1037 columns: None,
1038 filter: Some(predicate),
1039 exact_filter: None,
1040 order_by: None,
1041 limit: None,
1042 offset: None,
1043 })
1044 .await?;
1045
1046 self.filter_edges_by_target_namespace(edges, allowed_namespaces)
1047 .await
1048 }
1049
1050 pub async fn batch_adjacency_read_filtered(
1056 &self,
1057 frontier: &[MemoryId],
1058 relation: EdgeRelation,
1059 ) -> HirnResult<Vec<GraphEdge>> {
1060 self.batch_adjacency_read_filtered_scoped(frontier, relation, None)
1061 .await
1062 }
1063
1064 pub async fn batch_adjacency_read_filtered_scoped(
1066 &self,
1067 frontier: &[MemoryId],
1068 relation: EdgeRelation,
1069 allowed_namespaces: Option<&[Namespace]>,
1070 ) -> HirnResult<Vec<GraphEdge>> {
1071 if frontier.is_empty() {
1072 return Ok(vec![]);
1073 }
1074 if allowed_namespaces.is_some_and(<[Namespace]>::is_empty) {
1075 return Ok(vec![]);
1076 }
1077 if !self.storage.exists(DATASET_EDGES_NAME).await? {
1078 return Ok(vec![]);
1079 }
1080
1081 let rel_str = edge_relation_to_str(relation);
1082 let mut predicate = format!(
1083 "source IN ({}) AND relation = '{rel_str}'",
1084 Self::quoted_in_values(frontier).join(", ")
1085 );
1086 if let Some(allowed_namespaces) = allowed_namespaces {
1087 predicate.push_str(" AND namespace IN (");
1088 predicate.push_str(&Self::quoted_namespace_values(allowed_namespaces).join(", "));
1089 predicate.push(')');
1090 }
1091
1092 let edges = self
1093 .scan_edges(ScanOptions {
1094 columns: None,
1095 filter: Some(predicate),
1096 exact_filter: None,
1097 order_by: None,
1098 limit: None,
1099 offset: None,
1100 })
1101 .await?;
1102
1103 self.filter_edges_by_target_namespace(edges, allowed_namespaces)
1104 .await
1105 }
1106
1107 async fn filter_edges_by_target_namespace(
1108 &self,
1109 edges: Vec<GraphEdge>,
1110 allowed_namespaces: Option<&[Namespace]>,
1111 ) -> HirnResult<Vec<GraphEdge>> {
1112 let Some(allowed_namespaces) = allowed_namespaces else {
1113 return Ok(edges);
1114 };
1115 if allowed_namespaces.is_empty() || edges.is_empty() {
1116 return Ok(vec![]);
1117 }
1118
1119 let mut namespace_cache = HashMap::new();
1120 let mut visible_edges = Vec::with_capacity(edges.len());
1121 for edge in edges {
1122 if let std::collections::hash_map::Entry::Vacant(entry) =
1123 namespace_cache.entry(edge.target)
1124 {
1125 let is_visible = self
1126 .node_namespace(edge.target)
1127 .await?
1128 .is_some_and(|namespace| allowed_namespaces.contains(&namespace));
1129 entry.insert(is_visible);
1130 }
1131 if namespace_cache.get(&edge.target).copied().unwrap_or(false) {
1132 visible_edges.push(edge);
1133 }
1134 }
1135
1136 Ok(visible_edges)
1137 }
1138
1139 async fn filter_node_ids_by_namespace(
1140 &self,
1141 ids: &[MemoryId],
1142 allowed_namespaces: Option<&[Namespace]>,
1143 ) -> HirnResult<Vec<MemoryId>> {
1144 let Some(allowed_namespaces) = allowed_namespaces else {
1145 return Ok(ids.to_vec());
1146 };
1147 if allowed_namespaces.is_empty() || ids.is_empty() {
1148 return Ok(vec![]);
1149 }
1150
1151 let mut visible = Vec::with_capacity(ids.len());
1152 for &id in ids {
1153 if self
1154 .node_namespace(id)
1155 .await?
1156 .is_some_and(|namespace| allowed_namespaces.contains(&namespace))
1157 {
1158 visible.push(id);
1159 }
1160 }
1161
1162 Ok(visible)
1163 }
1164
1165 pub async fn batch_bfs(
1174 &self,
1175 start_ids: &[MemoryId],
1176 max_depth: usize,
1177 ) -> HirnResult<BfsResult> {
1178 self.batch_bfs_filtered(start_ids, max_depth, None).await
1179 }
1180
1181 pub async fn batch_bfs_filtered(
1183 &self,
1184 start_ids: &[MemoryId],
1185 max_depth: usize,
1186 relation: Option<EdgeRelation>,
1187 ) -> HirnResult<BfsResult> {
1188 self.batch_bfs_filtered_scoped(start_ids, max_depth, relation, None)
1189 .await
1190 }
1191
1192 pub async fn batch_bfs_filtered_scoped(
1194 &self,
1195 start_ids: &[MemoryId],
1196 max_depth: usize,
1197 relation: Option<EdgeRelation>,
1198 allowed_namespaces: Option<&[Namespace]>,
1199 ) -> HirnResult<BfsResult> {
1200 use std::collections::HashSet;
1201
1202 let start_ids = self
1203 .filter_node_ids_by_namespace(start_ids, allowed_namespaces)
1204 .await?;
1205 let mut visited: HashSet<MemoryId> = start_ids.iter().copied().collect();
1206 let mut depths: Vec<Vec<GraphEdge>> = Vec::with_capacity(max_depth);
1207 let mut frontier: Vec<MemoryId> = start_ids;
1208
1209 for _ in 0..max_depth {
1210 if frontier.is_empty() {
1211 break;
1212 }
1213
1214 let edges = match relation {
1215 Some(rel) => {
1216 self.batch_adjacency_read_filtered_scoped(&frontier, rel, allowed_namespaces)
1217 .await?
1218 }
1219 None => {
1220 self.batch_adjacency_read_scoped(&frontier, allowed_namespaces)
1221 .await?
1222 }
1223 };
1224
1225 let mut next_frontier = Vec::new();
1226 let mut depth_edges = Vec::new();
1227
1228 for edge in edges {
1229 depth_edges.push(edge.clone());
1230 if visited.insert(edge.target) {
1231 next_frontier.push(edge.target);
1232 }
1233 }
1234
1235 depths.push(depth_edges);
1236 frontier = next_frontier;
1237 }
1238
1239 Ok(BfsResult {
1240 depths,
1241 visited: visited.into_iter().collect(),
1242 })
1243 }
1244
1245 pub async fn deep_causal_bfs(
1258 &self,
1259 start_ids: &[MemoryId],
1260 max_depth: usize,
1261 confidence_threshold: f32,
1262 relation: EdgeRelation,
1263 allowed_namespaces: Option<&[Namespace]>,
1264 ) -> HirnResult<Vec<CausalBfsRow>> {
1265 use std::collections::{HashMap, HashSet};
1266
1267 let bfs = self
1268 .batch_bfs_filtered_scoped(start_ids, max_depth, Some(relation), allowed_namespaces)
1269 .await?;
1270
1271 let mut adjacency: HashMap<MemoryId, Vec<&GraphEdge>> = HashMap::new();
1273 for depth_edges in &bfs.depths {
1274 for edge in depth_edges {
1275 adjacency.entry(edge.source).or_default().push(edge);
1276 }
1277 }
1278
1279 let mut rows = Vec::new();
1281 let mut chain_counter = 0_u32;
1282
1283 for &seed in start_ids {
1284 let mut stack: Vec<(MemoryId, usize, Vec<CausalBfsEdge>, HashSet<MemoryId>)> = vec![{
1286 let mut visited = HashSet::new();
1287 visited.insert(seed);
1288 (seed, 0, Vec::new(), visited)
1289 }];
1290
1291 while let Some((node, depth, chain_edges, visited)) = stack.pop() {
1292 if depth >= max_depth {
1293 if !chain_edges.is_empty() {
1294 emit_causal_rows(&chain_edges, &mut rows, &mut chain_counter);
1295 }
1296 continue;
1297 }
1298
1299 let neighbors = adjacency.get(&node);
1300 let causal: Vec<&GraphEdge> = neighbors
1301 .map(|edges| {
1302 edges
1303 .iter()
1304 .filter(|e| {
1305 let conf = e.confidence().unwrap_or(0.5);
1306 conf >= confidence_threshold && !visited.contains(&e.target)
1307 })
1308 .copied()
1309 .collect()
1310 })
1311 .unwrap_or_default();
1312
1313 if causal.is_empty() {
1314 if !chain_edges.is_empty() {
1315 emit_causal_rows(&chain_edges, &mut rows, &mut chain_counter);
1316 }
1317 continue;
1318 }
1319
1320 for edge in causal {
1321 let mut new_chain = chain_edges.clone();
1322 new_chain.push(CausalBfsEdge {
1323 source: edge.source,
1324 target: edge.target,
1325 strength: edge.strength().unwrap_or(edge.weight),
1326 confidence: edge.confidence().unwrap_or(0.5),
1327 evidence_count: edge.evidence_count().unwrap_or(1) as u32,
1328 mechanism: edge.mechanism().map(str::to_owned),
1329 });
1330 let mut new_visited = visited.clone();
1331 new_visited.insert(edge.target);
1332 stack.push((edge.target, depth + 1, new_chain, new_visited));
1333 }
1334 }
1335 }
1336
1337 Ok(rows)
1338 }
1339
1340 pub async fn get_neighbors(
1342 &self,
1343 start: MemoryId,
1344 max_depth: usize,
1345 min_weight: f32,
1346 ) -> HirnResult<Vec<MemoryId>> {
1347 self.get_neighbors_filtered(start, max_depth, min_weight, None)
1348 .await
1349 }
1350
1351 pub async fn get_neighbors_filtered(
1356 &self,
1357 start: MemoryId,
1358 max_depth: usize,
1359 min_weight: f32,
1360 namespace: Option<&Namespace>,
1361 ) -> HirnResult<Vec<MemoryId>> {
1362 use std::collections::HashSet;
1363
1364 let mut visited = HashSet::new();
1365 visited.insert(start);
1366
1367 let mut frontier = vec![start];
1368 let mut result = Vec::new();
1369
1370 for _ in 0..max_depth {
1371 if frontier.is_empty() {
1372 break;
1373 }
1374
1375 let edges = self.batch_adjacency_read(&frontier).await?;
1376 let mut next_frontier = Vec::new();
1377
1378 for edge in edges {
1379 if edge.weight < min_weight {
1380 continue;
1381 }
1382 if visited.contains(&edge.target) {
1383 continue;
1384 }
1385
1386 if let Some(ns) = namespace {
1388 if let Some(node) = self.get_node(edge.target).await? {
1389 let shared = Namespace::shared();
1390 if node.namespace != *ns && node.namespace != shared && *ns != shared {
1391 continue;
1392 }
1393 }
1394 }
1395
1396 visited.insert(edge.target);
1397 result.push(edge.target);
1398 next_frontier.push(edge.target);
1399 }
1400
1401 frontier = next_frontier;
1402 }
1403
1404 Ok(result)
1405 }
1406
1407 pub async fn shortest_path(
1411 &self,
1412 source: MemoryId,
1413 target: MemoryId,
1414 ) -> HirnResult<Option<Vec<MemoryId>>> {
1415 use std::collections::{HashMap as StdMap, HashSet};
1416
1417 if source == target {
1418 return Ok(Some(vec![source]));
1419 }
1420
1421 let mut visited = HashSet::new();
1422 visited.insert(source);
1423 let mut parent: StdMap<MemoryId, MemoryId> = StdMap::new();
1424 let mut frontier = vec![source];
1425
1426 while !frontier.is_empty() {
1427 let edges = self.batch_adjacency_read(&frontier).await?;
1428 let mut next_frontier = Vec::new();
1429
1430 for edge in edges {
1431 if visited.contains(&edge.target) {
1432 continue;
1433 }
1434 parent.insert(edge.target, edge.source);
1435 if edge.target == target {
1436 let mut path = vec![target];
1438 let mut node = target;
1439 while let Some(&prev) = parent.get(&node) {
1440 path.push(prev);
1441 node = prev;
1442 }
1443 path.reverse();
1444 return Ok(Some(path));
1445 }
1446 visited.insert(edge.target);
1447 next_frontier.push(edge.target);
1448 }
1449
1450 frontier = next_frontier;
1451 }
1452 Ok(None)
1453 }
1454
1455 pub async fn subgraph(&self, node_ids: &[MemoryId]) -> HirnResult<Vec<GraphEdge>> {
1460 if node_ids.is_empty() {
1461 return Ok(vec![]);
1462 }
1463
1464 let id_set: std::collections::HashSet<MemoryId> = node_ids.iter().copied().collect();
1465 let all_edges = self.batch_adjacency_read(node_ids).await?;
1466
1467 Ok(all_edges
1468 .into_iter()
1469 .filter(|e| id_set.contains(&e.target))
1470 .collect())
1471 }
1472
1473 pub async fn degree_centrality(&self) -> HirnResult<HashMap<MemoryId, usize>> {
1475 if !self.storage.exists(DATASET_EDGES_NAME).await? {
1476 return Ok(HashMap::new());
1477 }
1478 let mut stream = self
1479 .storage
1480 .scan_stream(
1481 DATASET_EDGES_NAME,
1482 ScanOptions {
1483 columns: Some(vec!["source".into(), "target".into()]),
1484 filter: None,
1485 exact_filter: None,
1486 order_by: None,
1487 limit: None,
1488 offset: None,
1489 },
1490 )
1491 .await?;
1492
1493 let mut degrees: HashMap<MemoryId, usize> = HashMap::new();
1494 while let Some(batch) = stream.try_next().await? {
1495 let src = batch
1496 .column_by_name("source")
1497 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>());
1498 let tgt = batch
1499 .column_by_name("target")
1500 .and_then(|c| c.as_any().downcast_ref::<arrow_array::StringArray>());
1501 if let (Some(s), Some(t)) = (src, tgt) {
1502 for i in 0..batch.num_rows() {
1503 if let Ok(id) = MemoryId::parse(s.value(i)) {
1504 *degrees.entry(id).or_default() += 1;
1505 }
1506 if let Ok(id) = MemoryId::parse(t.value(i)) {
1507 *degrees.entry(id).or_default() += 1;
1508 }
1509 }
1510 }
1511 }
1512 Ok(degrees)
1513 }
1514
1515 pub async fn path_exists_via(
1519 &self,
1520 source: MemoryId,
1521 target: MemoryId,
1522 allowed_relations: &[EdgeRelation],
1523 ) -> HirnResult<bool> {
1524 use std::collections::HashSet;
1525
1526 if source == target {
1527 return Ok(true);
1528 }
1529
1530 let mut visited = HashSet::new();
1531 visited.insert(source);
1532 let mut frontier = vec![source];
1533
1534 while !frontier.is_empty() {
1535 let edges = self.batch_adjacency_read(&frontier).await?;
1536 let mut next_frontier = Vec::new();
1537
1538 for edge in edges {
1539 if !allowed_relations.contains(&edge.relation) {
1540 continue;
1541 }
1542 if visited.contains(&edge.target) {
1543 continue;
1544 }
1545 if edge.target == target {
1546 return Ok(true);
1547 }
1548 visited.insert(edge.target);
1549 next_frontier.push(edge.target);
1550 }
1551
1552 frontier = next_frontier;
1553 }
1554 Ok(false)
1555 }
1556
1557 pub async fn node_layer(&self, id: MemoryId) -> HirnResult<Option<Layer>> {
1559 Ok(self.get_node(id).await?.map(|n| n.layer))
1560 }
1561
1562 pub async fn node_namespace(&self, id: MemoryId) -> HirnResult<Option<Namespace>> {
1564 Ok(self.get_node(id).await?.map(|n| n.namespace))
1565 }
1566
1567 pub async fn all_edges(&self) -> HirnResult<Vec<GraphEdge>> {
1569 if !self.storage.exists(DATASET_EDGES_NAME).await? {
1570 return Ok(vec![]);
1571 }
1572 let mut batches = self
1573 .storage
1574 .scan_stream(
1575 DATASET_EDGES_NAME,
1576 ScanOptions {
1577 columns: None,
1578 filter: None,
1579 exact_filter: None,
1580 order_by: None,
1581 limit: None,
1582 offset: None,
1583 },
1584 )
1585 .await?;
1586
1587 let mut result = Vec::new();
1588 while let Some(batch) = batches.try_next().await? {
1589 result.extend(graph::edges_from_batch(&batch)?);
1590 }
1591 Ok(result)
1592 }
1593
1594 pub async fn namespaces_compatible(&self, a: MemoryId, b: MemoryId) -> HirnResult<bool> {
1597 let ns_a = self.node_namespace(a).await?;
1598 let ns_b = self.node_namespace(b).await?;
1599 match (ns_a, ns_b) {
1600 (Some(a), Some(b)) => {
1601 let shared = Namespace::shared();
1602 Ok(a == b || a == shared || b == shared)
1603 }
1604 _ => Ok(false),
1605 }
1606 }
1607}
1608
1609fn layer_to_str(l: Layer) -> &'static str {
1610 match l {
1611 Layer::Working => "Working",
1612 Layer::Episodic => "Episodic",
1613 Layer::Semantic => "Semantic",
1614 Layer::Procedural => "Procedural",
1615 }
1616}
1617
1618fn edge_relation_to_str(r: EdgeRelation) -> &'static str {
1619 match r {
1620 EdgeRelation::RelatedTo => "RelatedTo",
1621 EdgeRelation::Causes => "Causes",
1622 EdgeRelation::CausedBy => "CausedBy",
1623 EdgeRelation::DerivedFrom => "DerivedFrom",
1624 EdgeRelation::Contradicts => "Contradicts",
1625 EdgeRelation::Supports => "Supports",
1626 EdgeRelation::TemporalNext => "TemporalNext",
1627 EdgeRelation::PartOf => "PartOf",
1628 EdgeRelation::InstanceOf => "InstanceOf",
1629 EdgeRelation::SimilarTo => "SimilarTo",
1630 EdgeRelation::Inhibits => "Inhibits",
1631 EdgeRelation::ParticipatesIn => "ParticipatesIn",
1632 }
1633}
1634
1635use crate::graph_store::GraphStore;
1638use async_trait::async_trait;
1639
1640#[async_trait]
1641impl GraphStore for PersistentGraph {
1642 async fn add_node(
1643 &self,
1644 id: MemoryId,
1645 layer: Layer,
1646 importance: f32,
1647 created_at: Timestamp,
1648 namespace: Namespace,
1649 ) -> HirnResult<bool> {
1650 self.add_node(id, layer, importance, created_at, namespace)
1651 .await
1652 }
1653
1654 async fn remove_node(&self, id: MemoryId) -> HirnResult<bool> {
1655 self.remove_node(id).await
1656 }
1657
1658 async fn has_node(&self, id: MemoryId) -> HirnResult<bool> {
1659 self.has_node(id).await
1660 }
1661
1662 async fn get_node(&self, id: MemoryId) -> HirnResult<Option<GraphNodeData>> {
1663 self.get_node(id).await
1664 }
1665
1666 async fn node_ids(&self) -> HirnResult<Vec<MemoryId>> {
1667 self.node_ids().await
1668 }
1669
1670 async fn node_importance(&self, id: MemoryId) -> HirnResult<Option<f32>> {
1671 self.node_importance(id).await
1672 }
1673
1674 async fn set_node_importance(&self, id: MemoryId, importance: f32) -> HirnResult<()> {
1675 self.set_node_importance(id, importance).await
1676 }
1677
1678 async fn node_layer(&self, id: MemoryId) -> HirnResult<Option<Layer>> {
1679 self.node_layer(id).await
1680 }
1681
1682 async fn node_namespace(&self, id: MemoryId) -> HirnResult<Option<Namespace>> {
1683 self.node_namespace(id).await
1684 }
1685
1686 async fn namespaces_compatible(&self, a: MemoryId, b: MemoryId) -> HirnResult<bool> {
1687 self.namespaces_compatible(a, b).await
1688 }
1689
1690 async fn add_edge(
1691 &self,
1692 source: MemoryId,
1693 target: MemoryId,
1694 relation: EdgeRelation,
1695 weight: f32,
1696 metadata: Metadata,
1697 ) -> HirnResult<EdgeId> {
1698 self.add_edge(source, target, relation, weight, metadata)
1699 .await
1700 }
1701
1702 async fn add_causal_edge(
1703 &self,
1704 source: MemoryId,
1705 target: MemoryId,
1706 relation: EdgeRelation,
1707 weight: f32,
1708 metadata: Metadata,
1709 causal: hirn_graph::CausalEdgeData,
1710 ) -> HirnResult<EdgeId> {
1711 self.add_causal_edge(source, target, relation, weight, metadata, causal)
1712 .await
1713 }
1714
1715 async fn remove_edge(&self, edge_id: EdgeId) -> HirnResult<()> {
1716 self.remove_edge(edge_id).await
1717 }
1718
1719 async fn get_edge(&self, edge_id: EdgeId) -> HirnResult<Option<GraphEdge>> {
1720 self.get_edge(edge_id).await
1721 }
1722
1723 async fn get_edges(&self, node_id: MemoryId) -> HirnResult<Vec<GraphEdge>> {
1724 self.get_edges(node_id).await
1725 }
1726
1727 async fn get_edges_between(&self, a: MemoryId, b: MemoryId) -> HirnResult<Vec<GraphEdge>> {
1728 self.get_edges_between(a, b).await
1729 }
1730
1731 async fn get_edges_of_type(
1732 &self,
1733 node_id: MemoryId,
1734 relation: EdgeRelation,
1735 ) -> HirnResult<Vec<GraphEdge>> {
1736 self.get_edges_of_type(node_id, relation).await
1737 }
1738
1739 async fn all_edges(&self) -> HirnResult<Vec<GraphEdge>> {
1740 self.all_edges().await
1741 }
1742
1743 async fn update_edge_weight(
1744 &self,
1745 edge_id: EdgeId,
1746 new_weight: f32,
1747 co_retrieval_count: Option<u64>,
1748 ) -> HirnResult<()> {
1749 self.update_edge_weight(edge_id, new_weight, co_retrieval_count)
1750 .await
1751 }
1752
1753 async fn get_neighbors(
1754 &self,
1755 start: MemoryId,
1756 depth: usize,
1757 min_weight: f32,
1758 ) -> HirnResult<Vec<MemoryId>> {
1759 self.get_neighbors(start, depth, min_weight).await
1760 }
1761
1762 async fn get_neighbors_filtered(
1763 &self,
1764 start: MemoryId,
1765 depth: usize,
1766 min_weight: f32,
1767 namespace: Option<&Namespace>,
1768 ) -> HirnResult<Vec<MemoryId>> {
1769 self.get_neighbors_filtered(start, depth, min_weight, namespace)
1770 .await
1771 }
1772
1773 async fn outgoing_weighted(
1774 &self,
1775 node_id: MemoryId,
1776 ) -> HirnResult<Vec<(MemoryId, f32, EdgeRelation)>> {
1777 self.outgoing_weighted(node_id).await
1778 }
1779
1780 async fn shortest_path(
1781 &self,
1782 source: MemoryId,
1783 target: MemoryId,
1784 ) -> HirnResult<Option<Vec<MemoryId>>> {
1785 self.shortest_path(source, target).await
1786 }
1787
1788 async fn node_count(&self) -> HirnResult<usize> {
1789 self.node_count().await.map(|c| c as usize)
1790 }
1791
1792 async fn edge_count(&self) -> HirnResult<usize> {
1793 self.edge_count().await.map(|c| c as usize)
1794 }
1795}
1796
1797#[cfg(test)]
1798mod tests {
1799 use super::*;
1800 use hirn_core::metadata::MetadataValue;
1801 use hirn_graph::MAX_EDGE_METADATA_BYTES;
1802
1803 fn dummy_storage() -> Arc<dyn PhysicalStore> {
1804 Arc::new(hirn_storage::memory_store::MemoryStore::new())
1805 }
1806
1807 #[tokio::test]
1808 async fn open_on_empty_storage() {
1809 let pg = PersistentGraph::open(dummy_storage()).await.unwrap();
1810 assert_eq!(pg.node_count().await.unwrap(), 0);
1811 assert_eq!(pg.edge_count().await.unwrap(), 0);
1812 }
1813
1814 #[tokio::test]
1815 async fn add_edge_rejects_oversized_metadata() {
1816 let pg = PersistentGraph::new(dummy_storage());
1817 let ns = Namespace::default_ns();
1818 let now = Timestamp::now();
1819 let a = MemoryId::new();
1820 let b = MemoryId::new();
1821 pg.add_node(a, Layer::Episodic, 0.5, now, ns.clone())
1822 .await
1823 .unwrap();
1824 pg.add_node(b, Layer::Episodic, 0.5, now, ns).await.unwrap();
1825
1826 let mut metadata = Metadata::new();
1827 metadata.insert(
1828 "payload".into(),
1829 MetadataValue::String("x".repeat(MAX_EDGE_METADATA_BYTES + 64)),
1830 );
1831
1832 let err = pg
1833 .add_edge(a, b, EdgeRelation::Causes, 0.8, metadata)
1834 .await
1835 .unwrap_err();
1836 assert!(err.to_string().contains("edge metadata exceeds"));
1837 }
1838
1839 async fn populated_graph(node_count: usize) -> (PersistentGraph, Vec<MemoryId>) {
1844 let pg = PersistentGraph::new(dummy_storage());
1845 let ns = Namespace::default_ns();
1846 let now = Timestamp::now();
1847 let mut ids = Vec::with_capacity(node_count);
1848
1849 for _ in 0..node_count {
1850 let id = MemoryId::new();
1851 ids.push(id);
1852 pg.add_node(id, Layer::Episodic, 0.5, now, ns.clone())
1853 .await
1854 .unwrap();
1855 }
1856
1857 for i in 0..node_count.saturating_sub(1) {
1859 pg.add_edge(
1860 ids[i],
1861 ids[i + 1],
1862 EdgeRelation::TemporalNext,
1863 0.8,
1864 Metadata::default(),
1865 )
1866 .await
1867 .unwrap();
1868 }
1869
1870 (pg, ids)
1871 }
1872
1873 #[tokio::test]
1876 async fn batch_adjacency_read_empty_frontier() {
1877 let pg = PersistentGraph::new(dummy_storage());
1878 let result = pg.batch_adjacency_read(&[]).await.unwrap();
1879 assert!(result.is_empty());
1880 }
1881
1882 #[tokio::test]
1883 async fn batch_adjacency_read_no_edges() {
1884 let pg = PersistentGraph::new(dummy_storage());
1885 let ns = Namespace::default_ns();
1886 let id = MemoryId::new();
1887 pg.add_node(id, Layer::Episodic, 0.5, Timestamp::now(), ns)
1888 .await
1889 .unwrap();
1890 let result = pg.batch_adjacency_read(&[id]).await.unwrap();
1891 assert!(result.is_empty());
1892 }
1893
1894 #[tokio::test]
1895 async fn batch_adjacency_read_single_node() {
1896 let (pg, ids) = populated_graph(5).await;
1897 let result = pg.batch_adjacency_read(&[ids[0]]).await.unwrap();
1899 assert_eq!(result.len(), 1);
1900 assert_eq!(result[0].source, ids[0]);
1901 assert_eq!(result[0].target, ids[1]);
1902 }
1903
1904 #[tokio::test]
1905 async fn batch_adjacency_read_multiple_nodes() {
1906 let (pg, ids) = populated_graph(5).await;
1907 let frontier = vec![ids[0], ids[1], ids[2]];
1909 let result = pg.batch_adjacency_read(&frontier).await.unwrap();
1910 assert_eq!(result.len(), 3);
1911
1912 let targets: std::collections::HashSet<MemoryId> =
1913 result.iter().map(|e| e.target).collect();
1914 assert!(targets.contains(&ids[1]));
1915 assert!(targets.contains(&ids[2]));
1916 assert!(targets.contains(&ids[3]));
1917 }
1918
1919 #[tokio::test]
1920 async fn batch_adjacency_read_filtered_by_relation() {
1921 let (pg, ids) = populated_graph(5).await;
1922 pg.add_edge(
1924 ids[0],
1925 ids[3],
1926 EdgeRelation::Causes,
1927 0.9,
1928 Metadata::default(),
1929 )
1930 .await
1931 .unwrap();
1932
1933 let result = pg
1935 .batch_adjacency_read_filtered(&[ids[0]], EdgeRelation::Causes)
1936 .await
1937 .unwrap();
1938 assert_eq!(result.len(), 1);
1939 assert_eq!(result[0].target, ids[3]);
1940
1941 let result = pg
1943 .batch_adjacency_read_filtered(&[ids[0]], EdgeRelation::TemporalNext)
1944 .await
1945 .unwrap();
1946 assert_eq!(result.len(), 1);
1947 assert_eq!(result[0].target, ids[1]);
1948 }
1949
1950 #[tokio::test]
1953 async fn batch_bfs_depth_zero() {
1954 let (pg, ids) = populated_graph(5).await;
1955 let result = pg.batch_bfs(&[ids[0]], 0).await.unwrap();
1956 assert!(result.depths.is_empty());
1957 assert_eq!(result.visited.len(), 1);
1958 assert!(result.visited.contains(&ids[0]));
1959 }
1960
1961 #[tokio::test]
1962 async fn batch_bfs_depth_one() {
1963 let (pg, ids) = populated_graph(5).await;
1964 let result = pg.batch_bfs(&[ids[0]], 1).await.unwrap();
1965 assert_eq!(result.depths.len(), 1);
1966 assert_eq!(result.depths[0].len(), 1); assert_eq!(result.depths[0][0].target, ids[1]);
1968 assert_eq!(result.visited.len(), 2); }
1970
1971 #[tokio::test]
1972 async fn batch_bfs_depth_two() {
1973 let (pg, ids) = populated_graph(5).await;
1974 let result = pg.batch_bfs(&[ids[0]], 2).await.unwrap();
1975 assert_eq!(result.depths.len(), 2);
1976 assert_eq!(result.depths[0].len(), 1);
1978 assert_eq!(result.depths[1].len(), 1);
1980 assert_eq!(result.visited.len(), 3); }
1982
1983 #[tokio::test]
1984 async fn batch_bfs_multiple_start_nodes() {
1985 let (pg, ids) = populated_graph(10).await;
1986 let result = pg.batch_bfs(&[ids[0], ids[5]], 1).await.unwrap();
1987 assert_eq!(result.depths.len(), 1);
1988 assert_eq!(result.depths[0].len(), 2);
1990 assert_eq!(result.visited.len(), 4); }
1992
1993 #[tokio::test]
1994 async fn batch_bfs_cycle_terminates() {
1995 let pg = PersistentGraph::new(dummy_storage());
1996 let ns = Namespace::default_ns();
1997 let now = Timestamp::now();
1998 let a = MemoryId::new();
1999 let b = MemoryId::new();
2000 let c = MemoryId::new();
2001 pg.add_node(a, Layer::Episodic, 0.5, now, ns.clone())
2002 .await
2003 .unwrap();
2004 pg.add_node(b, Layer::Episodic, 0.5, now, ns.clone())
2005 .await
2006 .unwrap();
2007 pg.add_node(c, Layer::Episodic, 0.5, now, ns).await.unwrap();
2008
2009 pg.add_edge(a, b, EdgeRelation::Causes, 0.8, Metadata::default())
2011 .await
2012 .unwrap();
2013 pg.add_edge(b, c, EdgeRelation::Causes, 0.8, Metadata::default())
2014 .await
2015 .unwrap();
2016 pg.add_edge(c, a, EdgeRelation::Causes, 0.8, Metadata::default())
2017 .await
2018 .unwrap();
2019
2020 let result = pg.batch_bfs(&[a], 10).await.unwrap();
2021 assert_eq!(result.visited.len(), 3);
2023 assert!(result.depths.len() <= 3);
2024 }
2025
2026 #[tokio::test]
2027 async fn batch_bfs_disconnected_graph() {
2028 let pg = PersistentGraph::new(dummy_storage());
2029 let ns = Namespace::default_ns();
2030 let now = Timestamp::now();
2031 let a = MemoryId::new();
2032 let b = MemoryId::new();
2033 let c = MemoryId::new(); pg.add_node(a, Layer::Episodic, 0.5, now, ns.clone())
2035 .await
2036 .unwrap();
2037 pg.add_node(b, Layer::Episodic, 0.5, now, ns.clone())
2038 .await
2039 .unwrap();
2040 pg.add_node(c, Layer::Episodic, 0.5, now, ns).await.unwrap();
2041 pg.add_edge(a, b, EdgeRelation::Causes, 0.8, Metadata::default())
2042 .await
2043 .unwrap();
2044
2045 let result = pg.batch_bfs(&[a], 5).await.unwrap();
2046 assert!(result.visited.contains(&a));
2047 assert!(result.visited.contains(&b));
2048 assert!(!result.visited.contains(&c)); }
2050
2051 #[tokio::test]
2052 async fn batch_bfs_filtered_causal_only() {
2053 let pg = PersistentGraph::new(dummy_storage());
2054 let ns = Namespace::default_ns();
2055 let now = Timestamp::now();
2056 let a = MemoryId::new();
2057 let b = MemoryId::new();
2058 let c = MemoryId::new();
2059 pg.add_node(a, Layer::Episodic, 0.5, now, ns.clone())
2060 .await
2061 .unwrap();
2062 pg.add_node(b, Layer::Episodic, 0.5, now, ns.clone())
2063 .await
2064 .unwrap();
2065 pg.add_node(c, Layer::Episodic, 0.5, now, ns).await.unwrap();
2066 pg.add_edge(a, b, EdgeRelation::Causes, 0.8, Metadata::default())
2067 .await
2068 .unwrap();
2069 pg.add_edge(a, c, EdgeRelation::TemporalNext, 0.8, Metadata::default())
2070 .await
2071 .unwrap();
2072
2073 let result = pg
2075 .batch_bfs_filtered(&[a], 2, Some(EdgeRelation::Causes))
2076 .await
2077 .unwrap();
2078 assert!(result.visited.contains(&b));
2079 assert!(!result.visited.contains(&c)); }
2081
2082 #[tokio::test]
2083 async fn batch_bfs_filtered_scoped_blocks_hidden_targets() {
2084 let pg = PersistentGraph::new(dummy_storage());
2085 let visible_ns = Namespace::new("visible").unwrap();
2086 let hidden_ns = Namespace::new("hidden").unwrap();
2087 let now = Timestamp::now();
2088 let a = MemoryId::new();
2089 let b = MemoryId::new();
2090 let c = MemoryId::new();
2091 pg.add_node(a, Layer::Episodic, 0.5, now, visible_ns)
2092 .await
2093 .unwrap();
2094 pg.add_node(b, Layer::Episodic, 0.5, now, hidden_ns)
2095 .await
2096 .unwrap();
2097 pg.add_node(c, Layer::Episodic, 0.5, now, visible_ns)
2098 .await
2099 .unwrap();
2100 pg.add_edge(a, b, EdgeRelation::Causes, 0.8, Metadata::default())
2101 .await
2102 .unwrap();
2103 pg.add_edge(b, c, EdgeRelation::Causes, 0.8, Metadata::default())
2104 .await
2105 .unwrap();
2106
2107 let result = pg
2108 .batch_bfs_filtered_scoped(&[a], 3, Some(EdgeRelation::Causes), Some(&[visible_ns]))
2109 .await
2110 .unwrap();
2111
2112 assert!(result.visited.contains(&a));
2113 assert!(!result.visited.contains(&b));
2114 assert!(!result.visited.contains(&c));
2115 assert_eq!(result.total_edges(), 0);
2116 }
2117
2118 #[tokio::test]
2119 async fn deep_causal_bfs_scoped_does_not_traverse_hidden_bridges() {
2120 let pg = PersistentGraph::new(dummy_storage());
2121 let visible_ns = Namespace::new("visible_causal").unwrap();
2122 let hidden_ns = Namespace::new("hidden_causal").unwrap();
2123 let now = Timestamp::now();
2124 let a = MemoryId::new();
2125 let b = MemoryId::new();
2126 let c = MemoryId::new();
2127 pg.add_node(a, Layer::Episodic, 0.5, now, visible_ns)
2128 .await
2129 .unwrap();
2130 pg.add_node(b, Layer::Episodic, 0.5, now, hidden_ns)
2131 .await
2132 .unwrap();
2133 pg.add_node(c, Layer::Episodic, 0.5, now, visible_ns)
2134 .await
2135 .unwrap();
2136 pg.add_edge(a, b, EdgeRelation::Causes, 0.9, Metadata::default())
2137 .await
2138 .unwrap();
2139 pg.add_edge(b, c, EdgeRelation::Causes, 0.9, Metadata::default())
2140 .await
2141 .unwrap();
2142
2143 let rows = pg
2144 .deep_causal_bfs(&[a], 3, 0.0, EdgeRelation::Causes, Some(&[visible_ns]))
2145 .await
2146 .unwrap();
2147
2148 assert!(rows.is_empty());
2149 }
2150
2151 #[tokio::test]
2152 async fn bfs_result_all_targets() {
2153 let (pg, ids) = populated_graph(5).await;
2154 let result = pg.batch_bfs(&[ids[0]], 3).await.unwrap();
2155 let targets = result.all_targets();
2156 assert!(targets.contains(&ids[1]));
2157 assert!(targets.contains(&ids[2]));
2158 assert!(targets.contains(&ids[3]));
2159 }
2160
2161 #[tokio::test]
2162 async fn bfs_result_total_edges() {
2163 let (pg, ids) = populated_graph(5).await;
2164 let result = pg.batch_bfs(&[ids[0]], 4).await.unwrap();
2165 assert_eq!(result.total_edges(), 4); }
2167}