Skip to main content

hirn_engine/graph/
cached_graph_store.rs

1//! Two-tier graph store: in-memory hot cache backed by persistent cold tier.
2//!
3//! All read operations (`get_edges`, `neighbors`, `outgoing_weighted`,
4//! spreading activation, PPR, Hebbian) execute on the hot in-memory
5//! [`PropertyGraph`] — zero I/O. Write operations update the hot tier
6//! first, then flush to the cold [`PersistentGraph`] (Lance datasets).
7//!
8//! ## Lock Ordering
9//!
10//! | Order | Lock | Purpose |
11//! |-------|------|---------|
12//! | 1 | `graph` (`RwLock`) | In-memory `PropertyGraph` |
13//! | 2 | `ns_index` (`RwLock`) | Namespace→node index |
14//!
15//! **Never** acquire `ns_index` before `graph`.
16
17use std::cmp::Ordering;
18use std::collections::{HashMap, HashSet};
19use std::sync::Arc;
20
21use parking_lot::RwLock;
22
23use async_trait::async_trait;
24
25use hirn_core::HirnResult;
26use hirn_core::id::MemoryId;
27use hirn_core::metadata::Metadata;
28use hirn_core::timestamp::Timestamp;
29use hirn_core::types::{EdgeRelation, Layer, Namespace};
30
31use crate::graph::{EdgeId, GraphEdge, GraphNodeData, PropertyGraph};
32use crate::graph_store::GraphStore;
33use crate::persistent_graph::PersistentGraph;
34use hirn_exec::{
35    ActivationMode as ExecActivationMode, GraphActivationOutput, GraphCausalChainRow,
36    GraphReadRuntime, GraphTraverseRow,
37};
38
39#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
40pub(crate) struct EdgeInsert {
41    pub(crate) source: MemoryId,
42    pub(crate) target: MemoryId,
43    pub(crate) relation: EdgeRelation,
44    pub(crate) weight: f32,
45    pub(crate) metadata: Metadata,
46}
47
48/// Two-tier graph: in-memory hot cache + persistent cold tier.
49///
50/// All read operations use the hot tier exclusively (sub-ms latency).
51/// Writes update hot tier synchronously, then flush to the cold tier
52/// asynchronously.
53#[derive(Clone)]
54pub struct CachedGraphStore {
55    /// Hot tier: in-memory property graph.
56    hot: Arc<RwLock<PropertyGraph>>,
57    /// Cold tier: LanceDB-backed persistent graph.
58    cold: Arc<PersistentGraph>,
59}
60
61impl CachedGraphStore {
62    /// Create a new cached graph store backed by the given persistent graph.
63    ///
64    /// The hot tier starts empty. Call [`load_from_cold`](Self::load_from_cold)
65    /// to populate it from storage.
66    pub fn new(cold: Arc<PersistentGraph>) -> Self {
67        Self {
68            hot: Arc::new(RwLock::new(PropertyGraph::new())),
69            cold,
70        }
71    }
72
73    /// Create with a custom max-node capacity for the hot tier.
74    pub fn with_max_nodes(cold: Arc<PersistentGraph>, max_node_count: usize) -> Self {
75        Self {
76            hot: Arc::new(RwLock::new(PropertyGraph::with_max_nodes(max_node_count))),
77            cold,
78        }
79    }
80
81    /// Load the hot tier from the cold tier (startup initialization).
82    ///
83    /// Fetches all nodes and edges from the persistent graph and inserts
84    /// them into the in-memory property graph.
85    pub async fn load_from_cold(&self) -> HirnResult<()> {
86        let all_edges = self.cold.all_edges().await?;
87        let all_node_ids = self.cold.node_ids().await?;
88
89        // Fetch all node data from cold tier *before* acquiring the write lock,
90        // so we don't hold a parking_lot guard across an await.
91        let mut node_data = Vec::with_capacity(all_node_ids.len());
92        for id in &all_node_ids {
93            if let Ok(Some(nd)) = self.cold.get_node(*id).await {
94                node_data.push(nd);
95            }
96        }
97
98        // Now apply everything synchronously under the write lock.
99        let mut graph = self.hot.write();
100
101        for nd in node_data {
102            graph.add_node_ns(
103                nd.id,
104                nd.layer,
105                nd.importance,
106                nd.created_at,
107                nd.namespace.clone(),
108            );
109        }
110
111        for edge in all_edges {
112            // Ensure both endpoints exist in hot tier.
113            if !graph.has_node(edge.source) {
114                graph.add_node(edge.source, Layer::Episodic, 0.5, edge.created_at);
115            }
116            if !graph.has_node(edge.target) {
117                graph.add_node(edge.target, Layer::Episodic, 0.5, edge.created_at);
118            }
119            // add_edge_one_dir to avoid double-reverse (edges already stored in both dirs).
120            let _ = graph.add_edge(
121                edge.source,
122                edge.target,
123                edge.relation,
124                edge.weight,
125                edge.metadata.clone(),
126            );
127        }
128
129        tracing::info!(
130            nodes = graph.node_count(),
131            edges = graph.edge_count(),
132            "CachedGraphStore: hot tier loaded from cold"
133        );
134
135        Ok(())
136    }
137
138    /// Get a read reference to the hot tier for synchronous algorithms
139    /// (spreading activation, PPR, Hebbian).
140    pub fn hot_graph(&self) -> parking_lot::RwLockReadGuard<'_, PropertyGraph> {
141        self.hot.read()
142    }
143
144    /// Get the `Arc<RwLock<PropertyGraph>>` handle for the hot tier.
145    ///
146    /// Used to pass the graph into `HirnSessionExt` so that DataFusion
147    /// operators in `hirn-exec` can downcast and access `PropertyGraph`
148    /// without depending on `hirn-engine`.
149    pub fn hot_arc(&self) -> Arc<RwLock<PropertyGraph>> {
150        self.hot.clone()
151    }
152
153    /// Get a write reference to the hot tier (e.g. for Hebbian flush).
154    pub fn hot_graph_mut(&self) -> parking_lot::RwLockWriteGuard<'_, PropertyGraph> {
155        self.hot.write()
156    }
157
158    /// Reference to the cold tier for direct operations.
159    pub fn cold(&self) -> &PersistentGraph {
160        &self.cold
161    }
162
163    /// Flush hot-tier `access_count` updates to the cold-tier Lance dataset.
164    ///
165    /// Drains the dirty set accumulated by `record_access()` calls and bulk-updates
166    /// the `access_count` column in the `graph_nodes` Lance dataset using a CASE
167    /// expression — one SQL round-trip per 500 nodes instead of one per node.
168    ///
169    /// This is a no-op when no accesses have occurred since the last flush.
170    pub async fn flush_hot_access_counts(&self) -> HirnResult<()> {
171        let dirty = {
172            let mut graph = self.hot.write();
173            graph.drain_dirty_access_counts()
174        };
175        if dirty.is_empty() {
176            return Ok(());
177        }
178        tracing::debug!(dirty_count = dirty.len(), "flushing access counts to cold tier");
179        self.cold.flush_access_counts(&dirty).await
180    }
181
182    /// Spawn a background tokio task that periodically flushes hot-tier access
183    /// counts to the cold tier.
184    ///
185    /// The returned `JoinHandle` can be aborted at shutdown, but the calling
186    /// code may also simply drop it — the task will run until the `Arc`s it
187    /// holds are the last remaining references (i.e. until the store is dropped).
188    pub fn spawn_access_count_flush_task(
189        &self,
190        interval: std::time::Duration,
191    ) -> tokio::task::JoinHandle<()> {
192        let store = self.clone();
193        tokio::spawn(async move {
194            let mut ticker = tokio::time::interval(interval);
195            ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
196            // First tick fires immediately; skip it so we don't flush on startup.
197            ticker.tick().await;
198            loop {
199                ticker.tick().await;
200                if let Err(e) = store.flush_hot_access_counts().await {
201                    tracing::warn!(error = %e, "access_count flush background task failed");
202                }
203            }
204        })
205    }
206
207    /// Add or update multiple nodes with one cold-tier write.
208    pub async fn add_nodes(&self, nodes: &[GraphNodeData]) -> HirnResult<()> {
209        if nodes.is_empty() {
210            return Ok(());
211        }
212
213        let mut inserted_ids = Vec::with_capacity(nodes.len());
214        {
215            let mut graph = self.hot.write();
216            for node in nodes {
217                if graph.add_node_ns(
218                    node.id,
219                    node.layer,
220                    node.importance,
221                    node.created_at,
222                    node.namespace,
223                ) {
224                    inserted_ids.push(node.id);
225                }
226            }
227        }
228
229        if let Err(error) = self.cold.add_nodes(nodes).await {
230            for node in nodes {
231                let _ = self.cold.remove_node(node.id).await;
232            }
233            if !inserted_ids.is_empty() {
234                let mut graph = self.hot.write();
235                for id in inserted_ids {
236                    graph.remove_node(id);
237                }
238            }
239            return Err(error);
240        }
241
242        Ok(())
243    }
244
245    fn created_edges_from_hot(
246        graph: &PropertyGraph,
247        edge_id: EdgeId,
248        source: MemoryId,
249        target: MemoryId,
250        relation: EdgeRelation,
251    ) -> HirnResult<Vec<GraphEdge>> {
252        let mut created_edges =
253            Vec::with_capacity(if relation.is_bidirectional() && source != target {
254                2
255            } else {
256                1
257            });
258
259        let primary = graph.edge_by_id(edge_id).cloned().ok_or_else(|| {
260            hirn_core::HirnError::DatabaseCorrupted(format!(
261                "cached graph missing newly created edge {edge_id}"
262            ))
263        })?;
264        created_edges.push(primary);
265
266        if relation.is_bidirectional() && source != target {
267            let reverse = graph
268                .get_edges_between(target, source)
269                .into_iter()
270                .find(|edge| {
271                    edge.source == target && edge.target == source && edge.relation == relation
272                })
273                .cloned()
274                .ok_or_else(|| {
275                    hirn_core::HirnError::DatabaseCorrupted(format!(
276                        "cached graph missing reverse edge for {source} -[{relation:?}]-> {target}"
277                    ))
278                })?;
279            created_edges.push(reverse);
280        }
281
282        Ok(created_edges)
283    }
284
285    fn rollback_hot_edges(&self, edge_ids: &[EdgeId]) {
286        let mut graph = self.hot.write();
287        for edge_id in edge_ids {
288            let _ = graph.remove_edge(*edge_id);
289        }
290    }
291
292    pub(crate) async fn add_edges_best_effort(
293        &self,
294        requests: &[EdgeInsert],
295    ) -> HirnResult<Vec<(EdgeInsert, EdgeId)>> {
296        if requests.is_empty() {
297            return Ok(Vec::new());
298        }
299
300        let (created, created_edges, rollback_edge_ids, fatal_error) = {
301            let mut graph = self.hot.write();
302            let mut created = Vec::with_capacity(requests.len());
303            let mut created_edges = Vec::with_capacity(requests.len() * 2);
304            let mut rollback_edge_ids = Vec::with_capacity(requests.len() * 2);
305            let mut fatal_error = None;
306
307            for request in requests {
308                match graph.add_edge(
309                    request.source,
310                    request.target,
311                    request.relation,
312                    request.weight,
313                    request.metadata.clone(),
314                ) {
315                    Ok(edge_id) => {
316                        created.push((request.clone(), edge_id));
317                        match Self::created_edges_from_hot(
318                            &graph,
319                            edge_id,
320                            request.source,
321                            request.target,
322                            request.relation,
323                        ) {
324                            Ok(new_edges) => {
325                                rollback_edge_ids.extend(new_edges.iter().map(|edge| edge.id));
326                                created_edges.extend(new_edges);
327                            }
328                            Err(error) => {
329                                fatal_error = Some(error);
330                                break;
331                            }
332                        }
333                    }
334                    Err(
335                        hirn_core::HirnError::AlreadyExists(_)
336                        | hirn_core::HirnError::InvalidInput(_)
337                        | hirn_core::HirnError::NotFound(_),
338                    ) => {}
339                    Err(error) => {
340                        fatal_error = Some(error);
341                        break;
342                    }
343                }
344            }
345
346            (created, created_edges, rollback_edge_ids, fatal_error)
347        };
348
349        if let Some(error) = fatal_error {
350            self.rollback_hot_edges(&rollback_edge_ids);
351            return Err(error);
352        }
353
354        if !created_edges.is_empty() {
355            if let Err(error) = self.cold.add_edges(&created_edges).await {
356                tracing::warn!(
357                    edge_count = created_edges.len(),
358                    error = %error,
359                    "CachedGraphStore: batched cold edge flush failed"
360                );
361            }
362        }
363
364        Ok(created)
365    }
366}
367
368#[async_trait]
369impl GraphReadRuntime for CachedGraphStore {
370    async fn activate_graph(
371        &self,
372        seeds: &[MemoryId],
373        mode: ExecActivationMode,
374        ppr_config: Option<&hirn_graph::PprConfig>,
375        max_depth: u32,
376        epsilon: f32,
377        inhibition_mu: f32,
378        delegation_threshold: usize,
379        allowed_namespaces: Option<&[Namespace]>,
380    ) -> HirnResult<GraphActivationOutput> {
381        if max_depth as usize > delegation_threshold {
382            tracing::debug!(
383                depth = max_depth,
384                delegation_threshold,
385                mode = ?mode,
386                "CachedGraphStore: delegating graph activation to persistent tier"
387            );
388            return self
389                .activate_via_persistent_graph(
390                    seeds,
391                    mode,
392                    ppr_config,
393                    max_depth,
394                    epsilon,
395                    inhibition_mu,
396                    allowed_namespaces,
397                )
398                .await;
399        }
400
401        tracing::trace!(
402            depth = max_depth,
403            delegation_threshold,
404            mode = ?mode,
405            "CachedGraphStore: running graph activation on hot tier"
406        );
407        self.activate_via_hot_graph(
408            seeds,
409            mode,
410            ppr_config,
411            max_depth,
412            epsilon,
413            inhibition_mu,
414            allowed_namespaces,
415        )
416    }
417
418    async fn causal_chain(
419        &self,
420        start_ids: &[MemoryId],
421        max_depth: u32,
422        confidence_threshold: f32,
423        delegation_threshold: usize,
424        relation: EdgeRelation,
425        allowed_namespaces: Option<&[Namespace]>,
426    ) -> HirnResult<Vec<GraphCausalChainRow>> {
427        if start_ids.is_empty() || max_depth == 0 {
428            return Ok(Vec::new());
429        }
430
431        if max_depth as usize > delegation_threshold {
432            tracing::debug!(
433                depth = max_depth,
434                delegation_threshold,
435                relation = ?relation,
436                "CachedGraphStore: delegating causal traversal to persistent tier"
437            );
438            return self
439                .causal_chain_via_persistent_graph(
440                    start_ids,
441                    max_depth,
442                    confidence_threshold,
443                    relation,
444                    allowed_namespaces,
445                )
446                .await;
447        }
448
449        tracing::trace!(
450            depth = max_depth,
451            delegation_threshold,
452            relation = ?relation,
453            "CachedGraphStore: running causal traversal on hot tier"
454        );
455        self.causal_chain_via_hot_graph(
456            start_ids,
457            max_depth,
458            confidence_threshold,
459            relation,
460            allowed_namespaces,
461        )
462        .await
463    }
464
465    async fn traverse_graph(
466        &self,
467        start_ids: &[MemoryId],
468        max_depth: u32,
469        delegation_threshold: usize,
470        relation_filter: Option<&[EdgeRelation]>,
471        allowed_namespaces: Option<&[Namespace]>,
472    ) -> HirnResult<Vec<GraphTraverseRow>> {
473        if start_ids.is_empty() || max_depth == 0 {
474            return Ok(Vec::new());
475        }
476        if matches!(relation_filter, Some([])) {
477            return Ok(Vec::new());
478        }
479
480        if max_depth as usize > delegation_threshold {
481            tracing::debug!(
482                depth = max_depth,
483                delegation_threshold,
484                relation_filter = ?relation_filter,
485                "CachedGraphStore: delegating graph traversal to persistent tier"
486            );
487            return self
488                .traverse_via_persistent_graph(
489                    start_ids,
490                    max_depth,
491                    relation_filter,
492                    allowed_namespaces,
493                )
494                .await;
495        }
496
497        tracing::trace!(
498            depth = max_depth,
499            delegation_threshold,
500            relation_filter = ?relation_filter,
501            "CachedGraphStore: running graph traversal on hot tier"
502        );
503        self.traverse_via_hot_graph(start_ids, max_depth, relation_filter, allowed_namespaces)
504    }
505}
506
507impl CachedGraphStore {
508    fn activate_via_hot_graph(
509        &self,
510        seeds: &[MemoryId],
511        mode: ExecActivationMode,
512        ppr_config: Option<&hirn_graph::PprConfig>,
513        max_depth: u32,
514        epsilon: f32,
515        inhibition_mu: f32,
516        allowed_namespaces: Option<&[Namespace]>,
517    ) -> HirnResult<GraphActivationOutput> {
518        let config = hirn_graph::ActivationConfig {
519            max_depth: max_depth as usize,
520            epsilon: f64::from(epsilon),
521            inhibition_strength: f64::from(inhibition_mu),
522            ..Default::default()
523        };
524        config.validate()?;
525
526        let graph = self.hot_graph();
527        match mode {
528            ExecActivationMode::Static => {
529                let mut entries: Vec<_> =
530                    hirn_graph::static_activation(&graph, seeds, allowed_namespaces)
531                        .into_iter()
532                        .collect();
533                entries
534                    .sort_by(|left, right| right.1.partial_cmp(&left.1).unwrap_or(Ordering::Equal));
535
536                Ok(GraphActivationOutput {
537                    ids: entries
538                        .iter()
539                        .map(|(node_id, _)| node_id.to_string())
540                        .collect(),
541                    scores: entries.iter().map(|(_, score)| *score as f32).collect(),
542                    depths: entries
543                        .iter()
544                        .map(|(node_id, _)| u32::from(!seeds.contains(node_id)))
545                        .collect(),
546                })
547            }
548            ExecActivationMode::Spreading => {
549                let result = hirn_graph::spread_activation(
550                    &graph,
551                    seeds,
552                    &config,
553                    None,
554                    allowed_namespaces,
555                )?;
556                let mut entries: Vec<_> = result.activations.into_iter().collect();
557                entries
558                    .sort_by(|left, right| right.1.partial_cmp(&left.1).unwrap_or(Ordering::Equal));
559
560                Ok(GraphActivationOutput {
561                    ids: entries
562                        .iter()
563                        .map(|(node_id, _)| node_id.to_string())
564                        .collect(),
565                    scores: entries.iter().map(|(_, score)| *score as f32).collect(),
566                    depths: entries
567                        .iter()
568                        .map(|(node_id, _)| {
569                            result
570                                .traces
571                                .get(node_id)
572                                .map(|trace| trace.path.len().saturating_sub(1) as u32)
573                                .unwrap_or(0)
574                        })
575                        .collect(),
576                })
577            }
578            ExecActivationMode::Ppr => {
579                let default_ppr = hirn_graph::PprConfig::default();
580                let ppr_config = ppr_config.unwrap_or(&default_ppr);
581                let mut entries: Vec<_> = hirn_graph::personalized_pagerank(
582                    &graph,
583                    seeds,
584                    ppr_config,
585                    allowed_namespaces,
586                )?
587                .into_iter()
588                .collect();
589                entries
590                    .sort_by(|left, right| right.1.partial_cmp(&left.1).unwrap_or(Ordering::Equal));
591
592                Ok(GraphActivationOutput {
593                    ids: entries
594                        .iter()
595                        .map(|(node_id, _)| node_id.to_string())
596                        .collect(),
597                    scores: entries.iter().map(|(_, score)| *score as f32).collect(),
598                    depths: vec![0; entries.len()],
599                })
600            }
601        }
602    }
603
604    async fn activate_via_persistent_graph(
605        &self,
606        seeds: &[MemoryId],
607        mode: ExecActivationMode,
608        ppr_config: Option<&hirn_graph::PprConfig>,
609        max_depth: u32,
610        epsilon: f32,
611        inhibition_mu: f32,
612        allowed_namespaces: Option<&[Namespace]>,
613    ) -> HirnResult<GraphActivationOutput> {
614        let config = hirn_graph::ActivationConfig {
615            max_depth: max_depth as usize,
616            epsilon: f64::from(epsilon),
617            inhibition_strength: f64::from(inhibition_mu),
618            ..Default::default()
619        };
620        config.validate()?;
621
622        match mode {
623            ExecActivationMode::Static => {
624                let mut entries: Vec<_> = crate::persistent_activation::static_activation(
625                    self.cold(),
626                    seeds,
627                    allowed_namespaces,
628                )
629                .await?
630                .into_iter()
631                .collect();
632                entries
633                    .sort_by(|left, right| right.1.partial_cmp(&left.1).unwrap_or(Ordering::Equal));
634
635                Ok(GraphActivationOutput {
636                    ids: entries
637                        .iter()
638                        .map(|(node_id, _)| node_id.to_string())
639                        .collect(),
640                    scores: entries.iter().map(|(_, score)| *score as f32).collect(),
641                    depths: entries
642                        .iter()
643                        .map(|(node_id, _)| u32::from(!seeds.contains(node_id)))
644                        .collect(),
645                })
646            }
647            ExecActivationMode::Spreading => {
648                let result = crate::persistent_activation::spread_activation(
649                    self.cold(),
650                    seeds,
651                    &config,
652                    None,
653                    allowed_namespaces,
654                )
655                .await?;
656                let mut entries: Vec<_> = result.activations.into_iter().collect();
657                entries
658                    .sort_by(|left, right| right.1.partial_cmp(&left.1).unwrap_or(Ordering::Equal));
659
660                Ok(GraphActivationOutput {
661                    ids: entries
662                        .iter()
663                        .map(|(node_id, _)| node_id.to_string())
664                        .collect(),
665                    scores: entries.iter().map(|(_, score)| *score as f32).collect(),
666                    depths: entries
667                        .iter()
668                        .map(|(node_id, _)| {
669                            result
670                                .traces
671                                .get(node_id)
672                                .map(|trace| trace.path.len().saturating_sub(1) as u32)
673                                .unwrap_or(0)
674                        })
675                        .collect(),
676                })
677            }
678            ExecActivationMode::Ppr => {
679                let default_ppr = hirn_graph::PprConfig::default();
680                let ppr_config = ppr_config.unwrap_or(&default_ppr);
681                let mut entries: Vec<_> = crate::persistent_activation::personalized_pagerank(
682                    self.cold(),
683                    seeds,
684                    ppr_config,
685                    allowed_namespaces,
686                )
687                .await?
688                .into_iter()
689                .collect();
690                entries
691                    .sort_by(|left, right| right.1.partial_cmp(&left.1).unwrap_or(Ordering::Equal));
692
693                Ok(GraphActivationOutput {
694                    ids: entries
695                        .iter()
696                        .map(|(node_id, _)| node_id.to_string())
697                        .collect(),
698                    scores: entries.iter().map(|(_, score)| *score as f32).collect(),
699                    depths: vec![0; entries.len()],
700                })
701            }
702        }
703    }
704
705    async fn causal_chain_via_hot_graph(
706        &self,
707        start_ids: &[MemoryId],
708        max_depth: u32,
709        confidence_threshold: f32,
710        relation: EdgeRelation,
711        allowed_namespaces: Option<&[Namespace]>,
712    ) -> HirnResult<Vec<GraphCausalChainRow>> {
713        let mut rows = Vec::new();
714        let mut chain_counter = 0_u32;
715
716        for &start_id in start_ids {
717            let chain_result = match relation {
718                EdgeRelation::Causes => {
719                    crate::causal::causal_chain_forward(
720                        self,
721                        start_id,
722                        max_depth as usize,
723                        confidence_threshold,
724                        allowed_namespaces,
725                    )
726                    .await?
727                }
728                EdgeRelation::CausedBy => {
729                    crate::causal::causal_chain_backward(
730                        self,
731                        start_id,
732                        max_depth as usize,
733                        confidence_threshold,
734                        allowed_namespaces,
735                    )
736                    .await?
737                }
738                other => {
739                    return Err(hirn_core::HirnError::InvalidInput(format!(
740                        "unsupported causal traversal relation `{other:?}`"
741                    )));
742                }
743            };
744
745            append_causal_rows(&chain_result.chains, &mut rows, &mut chain_counter);
746        }
747
748        Ok(rows)
749    }
750
751    async fn causal_chain_via_persistent_graph(
752        &self,
753        start_ids: &[MemoryId],
754        max_depth: u32,
755        confidence_threshold: f32,
756        relation: EdgeRelation,
757        allowed_namespaces: Option<&[Namespace]>,
758    ) -> HirnResult<Vec<GraphCausalChainRow>> {
759        let rows = self
760            .cold()
761            .deep_causal_bfs(
762                start_ids,
763                max_depth as usize,
764                confidence_threshold,
765                relation,
766                allowed_namespaces,
767            )
768            .await?
769            .into_iter()
770            .map(|row| GraphCausalChainRow {
771                chain_id: row.chain_id,
772                source_id: row.source_id.to_string(),
773                target_id: row.target_id.to_string(),
774                strength: row.strength,
775                confidence: row.confidence,
776                evidence_count: row.evidence_count,
777                mechanism: row.mechanism,
778                depth: row.depth,
779                chain_score: row.chain_score,
780            })
781            .collect::<Vec<_>>();
782
783        self.filter_causal_rows_by_namespace(rows, allowed_namespaces)
784            .await
785    }
786
787    async fn filter_causal_rows_by_namespace(
788        &self,
789        rows: Vec<GraphCausalChainRow>,
790        allowed_namespaces: Option<&[Namespace]>,
791    ) -> HirnResult<Vec<GraphCausalChainRow>> {
792        let Some(allowed_namespaces) = allowed_namespaces else {
793            return Ok(rows);
794        };
795        if rows.is_empty() {
796            return Ok(rows);
797        }
798
799        let mut visible_nodes = HashMap::new();
800        for row in &rows {
801            for node_id in [&row.source_id, &row.target_id] {
802                let Ok(node_id) = MemoryId::parse(node_id) else {
803                    continue;
804                };
805                if visible_nodes.contains_key(&node_id) {
806                    continue;
807                }
808                let is_visible = self
809                    .cold()
810                    .node_namespace(node_id)
811                    .await?
812                    .is_some_and(|namespace| allowed_namespaces.contains(&namespace));
813                visible_nodes.insert(node_id, is_visible);
814            }
815        }
816
817        let mut visible_chain_ids = HashSet::new();
818        let mut hidden_chain_ids = HashSet::new();
819        for row in &rows {
820            let source_visible = MemoryId::parse(&row.source_id)
821                .ok()
822                .and_then(|node_id| visible_nodes.get(&node_id).copied())
823                .unwrap_or(false);
824            let target_visible = MemoryId::parse(&row.target_id)
825                .ok()
826                .and_then(|node_id| visible_nodes.get(&node_id).copied())
827                .unwrap_or(false);
828
829            if source_visible && target_visible {
830                if !hidden_chain_ids.contains(&row.chain_id) {
831                    visible_chain_ids.insert(row.chain_id.clone());
832                }
833            } else {
834                hidden_chain_ids.insert(row.chain_id.clone());
835                visible_chain_ids.remove(&row.chain_id);
836            }
837        }
838
839        Ok(rows
840            .into_iter()
841            .filter(|row| visible_chain_ids.contains(&row.chain_id))
842            .collect())
843    }
844
845    fn traverse_via_hot_graph(
846        &self,
847        start_ids: &[MemoryId],
848        max_depth: u32,
849        relation_filter: Option<&[EdgeRelation]>,
850        allowed_namespaces: Option<&[Namespace]>,
851    ) -> HirnResult<Vec<GraphTraverseRow>> {
852        let graph = self.hot.read();
853        let mut visited = start_ids.iter().copied().collect::<HashSet<_>>();
854        let mut frontier = start_ids.to_vec();
855        let mut rows = Vec::new();
856
857        for depth in 0..max_depth {
858            if frontier.is_empty() {
859                break;
860            }
861
862            let mut next_frontier = Vec::new();
863            for node_id in frontier {
864                for (target, _weight, relation) in graph.outgoing_weighted(node_id) {
865                    if relation_filter.is_some_and(|relations| !relations.contains(&relation)) {
866                        continue;
867                    }
868                    if let Some(allowed_namespaces) = allowed_namespaces {
869                        let Some(namespace) = graph.node_namespace(target) else {
870                            continue;
871                        };
872                        if !allowed_namespaces.contains(namespace) {
873                            continue;
874                        }
875                    }
876                    if visited.insert(target) {
877                        next_frontier.push(target);
878                        rows.push(GraphTraverseRow {
879                            node_id: target.to_string(),
880                            depth: depth + 1,
881                        });
882                    }
883                }
884            }
885
886            frontier = next_frontier;
887        }
888
889        Ok(rows)
890    }
891
892    async fn traverse_via_persistent_graph(
893        &self,
894        start_ids: &[MemoryId],
895        max_depth: u32,
896        relation_filter: Option<&[EdgeRelation]>,
897        allowed_namespaces: Option<&[Namespace]>,
898    ) -> HirnResult<Vec<GraphTraverseRow>> {
899        let mut visited = start_ids.iter().copied().collect::<HashSet<_>>();
900        let mut frontier = start_ids.to_vec();
901        let mut rows = Vec::new();
902
903        for depth in 0..max_depth {
904            if frontier.is_empty() {
905                break;
906            }
907
908            let edges = match relation_filter {
909                Some([relation]) => {
910                    self.cold()
911                        .batch_adjacency_read_filtered(&frontier, *relation)
912                        .await?
913                }
914                _ => self.cold().batch_adjacency_read(&frontier).await?,
915            };
916
917            let mut next_frontier = Vec::new();
918            for edge in edges {
919                if relation_filter.is_some_and(|relations| !relations.contains(&edge.relation)) {
920                    continue;
921                }
922                if let Some(allowed_namespaces) = allowed_namespaces {
923                    let Some(namespace) = self.cold().node_namespace(edge.target).await? else {
924                        continue;
925                    };
926                    if !allowed_namespaces.contains(&namespace) {
927                        continue;
928                    }
929                }
930                if visited.insert(edge.target) {
931                    next_frontier.push(edge.target);
932                    rows.push(GraphTraverseRow {
933                        node_id: edge.target.to_string(),
934                        depth: depth + 1,
935                    });
936                }
937            }
938
939            frontier = next_frontier;
940        }
941
942        Ok(rows)
943    }
944}
945
946fn append_causal_rows(
947    chains: &[crate::causal::CausalChain],
948    rows: &mut Vec<GraphCausalChainRow>,
949    chain_counter: &mut u32,
950) {
951    for chain in chains {
952        if chain.links.is_empty() {
953            continue;
954        }
955
956        let chain_id = format!("chain_{}", *chain_counter);
957        *chain_counter += 1;
958        let chain_score = chain
959            .links
960            .iter()
961            .map(|link| {
962                let strength = link.strength.unwrap_or(link.weight);
963                let confidence = link.confidence.unwrap_or(0.5);
964                let evidence = link.evidence_count.unwrap_or(1).max(1) as f32;
965                strength * confidence * (1.0_f32 + evidence).ln()
966            })
967            .sum::<f32>()
968            / chain.links.len().max(1) as f32;
969
970        for (depth, link) in chain.links.iter().enumerate() {
971            rows.push(GraphCausalChainRow {
972                chain_id: chain_id.clone(),
973                source_id: link.source.to_string(),
974                target_id: link.target.to_string(),
975                strength: link.strength.unwrap_or(link.weight),
976                confidence: link.confidence.unwrap_or(0.5),
977                evidence_count: link.evidence_count.unwrap_or(1).max(1) as u32,
978                mechanism: link.mechanism.clone(),
979                depth: depth as u32,
980                chain_score,
981            });
982        }
983    }
984}
985
986#[async_trait]
987impl GraphStore for CachedGraphStore {
988    // ── Node operations ─────────────────────────────────────────────────
989
990    async fn add_node(
991        &self,
992        id: MemoryId,
993        layer: Layer,
994        importance: f32,
995        created_at: Timestamp,
996        namespace: Namespace,
997    ) -> HirnResult<bool> {
998        // Write-through: hot first, then cold.
999        let added = {
1000            let mut graph = self.hot.write();
1001            graph.add_node_ns(id, layer, importance, created_at, namespace.clone())
1002        };
1003        if let Err(error) = self
1004            .cold
1005            .add_node(id, layer, importance, created_at, namespace)
1006            .await
1007        {
1008            let _ = self.cold.remove_node(id).await;
1009            if added {
1010                let mut graph = self.hot.write();
1011                graph.remove_node(id);
1012            }
1013            return Err(error);
1014        }
1015        Ok(added)
1016    }
1017
1018    async fn remove_node(&self, id: MemoryId) -> HirnResult<bool> {
1019        let existed_cold = self.cold.remove_node(id).await?;
1020        let existed_hot = {
1021            let mut graph = self.hot.write();
1022            graph.remove_node(id)
1023        };
1024        Ok(existed_hot || existed_cold)
1025    }
1026
1027    async fn has_node(&self, id: MemoryId) -> HirnResult<bool> {
1028        let graph = self.hot.read();
1029        Ok(graph.has_node(id))
1030    }
1031
1032    async fn get_node(&self, id: MemoryId) -> HirnResult<Option<GraphNodeData>> {
1033        let graph = self.hot.read();
1034        let importance = graph.node_importance(id);
1035        let layer = graph.node_layer(id);
1036        match (importance, layer) {
1037            (Some(imp), Some(lay)) => Ok(Some(GraphNodeData {
1038                id,
1039                layer: lay,
1040                importance: imp,
1041                created_at: Timestamp::now(),
1042                namespace: graph.node_namespace(id).cloned().unwrap_or_default(),
1043                access_count: graph.access_count(id),
1044            })),
1045            _ => Ok(None),
1046        }
1047    }
1048
1049    async fn node_ids(&self) -> HirnResult<Vec<MemoryId>> {
1050        let graph = self.hot.read();
1051        Ok(graph.node_ids())
1052    }
1053
1054    async fn node_importance(&self, id: MemoryId) -> HirnResult<Option<f32>> {
1055        let graph = self.hot.read();
1056        Ok(graph.node_importance(id))
1057    }
1058
1059    async fn set_node_importance(&self, id: MemoryId, importance: f32) -> HirnResult<()> {
1060        self.cold.set_node_importance(id, importance).await?;
1061        {
1062            let mut graph = self.hot.write();
1063            graph.set_node_importance(id, importance);
1064        }
1065        Ok(())
1066    }
1067
1068    async fn node_layer(&self, id: MemoryId) -> HirnResult<Option<Layer>> {
1069        let graph = self.hot.read();
1070        Ok(graph.node_layer(id))
1071    }
1072
1073    async fn node_namespace(&self, id: MemoryId) -> HirnResult<Option<Namespace>> {
1074        let graph = self.hot.read();
1075        Ok(graph.node_namespace(id).cloned())
1076    }
1077
1078    async fn namespaces_compatible(&self, a: MemoryId, b: MemoryId) -> HirnResult<bool> {
1079        let graph = self.hot.read();
1080        let ns_a = graph.node_namespace(a).cloned();
1081        let ns_b = graph.node_namespace(b).cloned();
1082        match (ns_a, ns_b) {
1083            (Some(a_ns), Some(b_ns)) => {
1084                Ok(a_ns == b_ns || a_ns == Namespace::shared() || b_ns == Namespace::shared())
1085            }
1086            _ => Ok(false),
1087        }
1088    }
1089
1090    // ── Edge operations ─────────────────────────────────────────────────
1091
1092    async fn add_edge(
1093        &self,
1094        source: MemoryId,
1095        target: MemoryId,
1096        relation: EdgeRelation,
1097        weight: f32,
1098        metadata: Metadata,
1099    ) -> HirnResult<EdgeId> {
1100        // Write-through: hot first, then cold.
1101        let (edge_id, created_edges) = {
1102            let mut graph = self.hot.write();
1103            let edge_id = graph.add_edge(source, target, relation, weight, metadata)?;
1104            let created_edges =
1105                Self::created_edges_from_hot(&graph, edge_id, source, target, relation)?;
1106            (edge_id, created_edges)
1107        };
1108
1109        if let Err(error) = self.cold.add_edges(&created_edges).await {
1110            for edge in &created_edges {
1111                let _ = self.cold.remove_edge(edge.id).await;
1112            }
1113            let created_edge_ids = created_edges.iter().map(|edge| edge.id).collect::<Vec<_>>();
1114            self.rollback_hot_edges(&created_edge_ids);
1115            return Err(error);
1116        }
1117
1118        Ok(edge_id)
1119    }
1120
1121    async fn add_causal_edge(
1122        &self,
1123        source: MemoryId,
1124        target: MemoryId,
1125        relation: EdgeRelation,
1126        weight: f32,
1127        metadata: Metadata,
1128        causal: hirn_graph::CausalEdgeData,
1129    ) -> HirnResult<EdgeId> {
1130        // Write-through: hot first, then cold.
1131        let (edge_id, created_edges) = {
1132            let mut graph = self.hot.write();
1133            let edge_id =
1134                graph.add_causal_edge(source, target, relation, weight, metadata, causal)?;
1135            let created_edges =
1136                Self::created_edges_from_hot(&graph, edge_id, source, target, relation)?;
1137            (edge_id, created_edges)
1138        };
1139
1140        if let Err(error) = self.cold.add_edges(&created_edges).await {
1141            for edge in &created_edges {
1142                let _ = self.cold.remove_edge(edge.id).await;
1143            }
1144            let created_edge_ids = created_edges.iter().map(|edge| edge.id).collect::<Vec<_>>();
1145            self.rollback_hot_edges(&created_edge_ids);
1146            return Err(error);
1147        }
1148
1149        Ok(edge_id)
1150    }
1151
1152    async fn remove_edge(&self, edge_id: EdgeId) -> HirnResult<()> {
1153        self.cold.remove_edge(edge_id).await?;
1154        {
1155            let mut graph = self.hot.write();
1156            let _ = graph.remove_edge(edge_id);
1157        }
1158        Ok(())
1159    }
1160
1161    async fn get_edge(&self, edge_id: EdgeId) -> HirnResult<Option<GraphEdge>> {
1162        let graph = self.hot.read();
1163        Ok(graph.edge_by_id(edge_id).cloned())
1164    }
1165
1166    async fn get_edges(&self, node_id: MemoryId) -> HirnResult<Vec<GraphEdge>> {
1167        let graph = self.hot.read();
1168        Ok(graph.get_edges(node_id).into_iter().cloned().collect())
1169    }
1170
1171    async fn get_edges_between(&self, a: MemoryId, b: MemoryId) -> HirnResult<Vec<GraphEdge>> {
1172        let graph = self.hot.read();
1173        Ok(graph.get_edges_between(a, b).into_iter().cloned().collect())
1174    }
1175
1176    async fn get_edges_of_type(
1177        &self,
1178        node_id: MemoryId,
1179        relation: EdgeRelation,
1180    ) -> HirnResult<Vec<GraphEdge>> {
1181        let graph = self.hot.read();
1182        Ok(graph
1183            .get_edges_of_type(node_id, relation)
1184            .into_iter()
1185            .cloned()
1186            .collect())
1187    }
1188
1189    async fn get_edges_of_type_many(
1190        &self,
1191        node_ids: &[MemoryId],
1192        relation: EdgeRelation,
1193    ) -> HirnResult<HashMap<MemoryId, Vec<GraphEdge>>> {
1194        let graph = self.hot.read();
1195        Ok(graph
1196            .edges_for_nodes(node_ids)
1197            .into_iter()
1198            .filter_map(|(node_id, edges)| {
1199                let filtered = edges
1200                    .into_iter()
1201                    .filter(|edge| edge.relation == relation)
1202                    .cloned()
1203                    .collect::<Vec<_>>();
1204                if filtered.is_empty() {
1205                    None
1206                } else {
1207                    Some((node_id, filtered))
1208                }
1209            })
1210            .collect())
1211    }
1212
1213    async fn all_edges(&self) -> HirnResult<Vec<GraphEdge>> {
1214        let graph = self.hot.read();
1215        Ok(graph.all_edges().into_iter().cloned().collect())
1216    }
1217
1218    async fn update_edge_weight(
1219        &self,
1220        edge_id: EdgeId,
1221        new_weight: f32,
1222        co_retrieval_count: Option<u64>,
1223    ) -> HirnResult<()> {
1224        self.cold
1225            .update_edge_weight(edge_id, new_weight, co_retrieval_count)
1226            .await?;
1227        {
1228            let mut graph = self.hot.write();
1229            if let Some(edge) = graph.edge_mut(edge_id) {
1230                edge.weight = new_weight;
1231                if let Some(count) = co_retrieval_count {
1232                    edge.co_retrieval_count = count;
1233                }
1234            }
1235        }
1236        Ok(())
1237    }
1238
1239    // ── Traversal ───────────────────────────────────────────────────────
1240
1241    async fn get_neighbors(
1242        &self,
1243        start: MemoryId,
1244        depth: usize,
1245        min_weight: f32,
1246    ) -> HirnResult<Vec<MemoryId>> {
1247        let graph = self.hot.read();
1248        Ok(graph.get_neighbors(start, depth, min_weight))
1249    }
1250
1251    async fn get_neighbors_filtered(
1252        &self,
1253        start: MemoryId,
1254        depth: usize,
1255        min_weight: f32,
1256        namespace: Option<&Namespace>,
1257    ) -> HirnResult<Vec<MemoryId>> {
1258        let graph = self.hot.read();
1259        match namespace {
1260            Some(ns) => Ok(graph.get_neighbors_filtered(
1261                start,
1262                depth,
1263                min_weight,
1264                Some(std::slice::from_ref(ns)),
1265            )),
1266            None => Ok(graph.get_neighbors(start, depth, min_weight)),
1267        }
1268    }
1269
1270    async fn outgoing_weighted(
1271        &self,
1272        node_id: MemoryId,
1273    ) -> HirnResult<Vec<(MemoryId, f32, EdgeRelation)>> {
1274        let graph = self.hot.read();
1275        Ok(graph.outgoing_weighted(node_id))
1276    }
1277
1278    async fn shortest_path(
1279        &self,
1280        source: MemoryId,
1281        target: MemoryId,
1282    ) -> HirnResult<Option<Vec<MemoryId>>> {
1283        let graph = self.hot.read();
1284        Ok(graph.shortest_path(source, target))
1285    }
1286
1287    // ── Counts ──────────────────────────────────────────────────────────
1288
1289    async fn node_count(&self) -> HirnResult<usize> {
1290        let graph = self.hot.read();
1291        Ok(graph.node_count())
1292    }
1293
1294    async fn edge_count(&self) -> HirnResult<usize> {
1295        let graph = self.hot.read();
1296        Ok(graph.edge_count())
1297    }
1298}
1299
1300#[cfg(test)]
1301mod tests {
1302    use super::*;
1303    use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
1304
1305    use arrow_array::RecordBatch;
1306    use datafusion::catalog::TableProvider;
1307    use hirn_core::types::Namespace;
1308    use hirn_storage::HirnDbError;
1309    use hirn_storage::datasets::graph::{DATASET_EDGES_NAME, DATASET_NODES_NAME};
1310    use hirn_storage::memory_store::MemoryStore;
1311    use hirn_storage::store::{
1312        ColumnTransform, CompactOptions, CompactResult, DatasetInfo, FtsSearchOptions,
1313        HybridSearchOptions, IndexConfig, MultivectorSearchOptions, PhysicalStore, ScanOptions,
1314        VectorSearchOptions, VersionTag,
1315    };
1316
1317    struct FaultInjectingGraphStore {
1318        inner: MemoryStore,
1319        fail_node_merge_insert: AtomicBool,
1320        fail_edge_merge_insert: AtomicBool,
1321        fail_node_delete: AtomicBool,
1322        fail_edge_delete: AtomicBool,
1323    }
1324
1325    #[async_trait]
1326    impl PhysicalStore for FaultInjectingGraphStore {
1327        async fn append(&self, dataset: &str, batch: RecordBatch) -> Result<(), HirnDbError> {
1328            self.inner.append(dataset, batch).await
1329        }
1330
1331        async fn append_batches(
1332            &self,
1333            dataset: &str,
1334            batches: Vec<RecordBatch>,
1335        ) -> Result<(), HirnDbError> {
1336            self.inner.append_batches(dataset, batches).await
1337        }
1338
1339        async fn scan(
1340            &self,
1341            dataset: &str,
1342            opts: ScanOptions,
1343        ) -> Result<Vec<RecordBatch>, HirnDbError> {
1344            self.inner.scan(dataset, opts).await
1345        }
1346
1347        async fn scan_stream(
1348            &self,
1349            dataset: &str,
1350            opts: ScanOptions,
1351        ) -> Result<hirn_storage::store::RecordBatchStream, HirnDbError> {
1352            self.inner.scan_stream(dataset, opts).await
1353        }
1354
1355        async fn delete(&self, dataset: &str, predicate: &str) -> Result<u64, HirnDbError> {
1356            if dataset == DATASET_NODES_NAME && self.fail_node_delete.load(AtomicOrdering::Acquire)
1357            {
1358                return Err(HirnDbError::Unsupported(
1359                    "simulated graph node delete failure".to_string(),
1360                ));
1361            }
1362            if dataset == DATASET_EDGES_NAME && self.fail_edge_delete.load(AtomicOrdering::Acquire)
1363            {
1364                return Err(HirnDbError::Unsupported(
1365                    "simulated graph edge delete failure".to_string(),
1366                ));
1367            }
1368            self.inner.delete(dataset, predicate).await
1369        }
1370
1371        async fn update_where(
1372            &self,
1373            dataset: &str,
1374            filter: &str,
1375            updates: &[(&str, &str)],
1376        ) -> Result<u64, HirnDbError> {
1377            self.inner.update_where(dataset, filter, updates).await
1378        }
1379
1380        async fn merge_insert(
1381            &self,
1382            dataset: &str,
1383            on: &[&str],
1384            batch: RecordBatch,
1385        ) -> Result<(), HirnDbError> {
1386            if dataset == DATASET_NODES_NAME
1387                && self.fail_node_merge_insert.load(AtomicOrdering::Acquire)
1388            {
1389                return Err(HirnDbError::Unsupported(
1390                    "simulated graph node persist failure".to_string(),
1391                ));
1392            }
1393            if dataset == DATASET_EDGES_NAME
1394                && self.fail_edge_merge_insert.load(AtomicOrdering::Acquire)
1395            {
1396                return Err(HirnDbError::Unsupported(
1397                    "simulated graph edge persist failure".to_string(),
1398                ));
1399            }
1400            self.inner.merge_insert(dataset, on, batch).await
1401        }
1402
1403        async fn count(&self, dataset: &str, filter: Option<&str>) -> Result<u64, HirnDbError> {
1404            self.inner.count(dataset, filter).await
1405        }
1406
1407        async fn vector_search(
1408            &self,
1409            dataset: &str,
1410            opts: VectorSearchOptions,
1411        ) -> Result<Vec<RecordBatch>, HirnDbError> {
1412            self.inner.vector_search(dataset, opts).await
1413        }
1414
1415        async fn vector_search_many(
1416            &self,
1417            dataset: &str,
1418            queries: Vec<VectorSearchOptions>,
1419        ) -> Result<Vec<Vec<RecordBatch>>, HirnDbError> {
1420            self.inner.vector_search_many(dataset, queries).await
1421        }
1422
1423        async fn fts_search(
1424            &self,
1425            dataset: &str,
1426            opts: FtsSearchOptions,
1427        ) -> Result<Vec<RecordBatch>, HirnDbError> {
1428            self.inner.fts_search(dataset, opts).await
1429        }
1430
1431        async fn hybrid_search(
1432            &self,
1433            dataset: &str,
1434            opts: HybridSearchOptions,
1435        ) -> Result<Vec<RecordBatch>, HirnDbError> {
1436            self.inner.hybrid_search(dataset, opts).await
1437        }
1438
1439        async fn multivector_search(
1440            &self,
1441            dataset: &str,
1442            opts: MultivectorSearchOptions,
1443        ) -> Result<Vec<RecordBatch>, HirnDbError> {
1444            self.inner.multivector_search(dataset, opts).await
1445        }
1446
1447        async fn create_index(
1448            &self,
1449            dataset: &str,
1450            config: IndexConfig,
1451        ) -> Result<(), HirnDbError> {
1452            self.inner.create_index(dataset, config).await
1453        }
1454
1455        async fn optimize_indices(&self, dataset: &str) -> Result<(), HirnDbError> {
1456            self.inner.optimize_indices(dataset).await
1457        }
1458
1459        async fn compact(
1460            &self,
1461            dataset: &str,
1462            opts: CompactOptions,
1463        ) -> Result<CompactResult, HirnDbError> {
1464            self.inner.compact(dataset, opts).await
1465        }
1466
1467        async fn version(&self, dataset: &str) -> Result<u64, HirnDbError> {
1468            self.inner.version(dataset).await
1469        }
1470
1471        async fn tag(&self, dataset: &str, tag: &str) -> Result<(), HirnDbError> {
1472            self.inner.tag(dataset, tag).await
1473        }
1474
1475        async fn checkout(&self, dataset: &str, version: u64) -> Result<(), HirnDbError> {
1476            self.inner.checkout(dataset, version).await
1477        }
1478
1479        async fn list_tags(&self, dataset: &str) -> Result<Vec<VersionTag>, HirnDbError> {
1480            self.inner.list_tags(dataset).await
1481        }
1482
1483        async fn list_datasets(&self) -> Result<Vec<DatasetInfo>, HirnDbError> {
1484            self.inner.list_datasets().await
1485        }
1486
1487        async fn exists(&self, dataset: &str) -> Result<bool, HirnDbError> {
1488            self.inner.exists(dataset).await
1489        }
1490
1491        async fn list_namespaces(&self) -> Result<Vec<String>, HirnDbError> {
1492            self.inner.list_namespaces().await
1493        }
1494
1495        async fn create_namespace(&self, name: &str) -> Result<(), HirnDbError> {
1496            self.inner.create_namespace(name).await
1497        }
1498
1499        async fn drop_namespace(&self, name: &str) -> Result<(), HirnDbError> {
1500            self.inner.drop_namespace(name).await
1501        }
1502
1503        async fn add_columns(
1504            &self,
1505            dataset: &str,
1506            transforms: Vec<ColumnTransform>,
1507        ) -> Result<(), HirnDbError> {
1508            self.inner.add_columns(dataset, transforms).await
1509        }
1510
1511        async fn drop_columns(&self, dataset: &str, columns: &[&str]) -> Result<(), HirnDbError> {
1512            self.inner.drop_columns(dataset, columns).await
1513        }
1514
1515        async fn table_provider(&self, dataset: &str) -> Option<Arc<dyn TableProvider>> {
1516            self.inner.table_provider(dataset).await
1517        }
1518    }
1519
1520    /// Create a minimal PersistentGraph backed by an in-memory store.
1521    async fn test_cold() -> Arc<PersistentGraph> {
1522        let storage: Arc<dyn hirn_storage::PhysicalStore> =
1523            Arc::new(hirn_storage::memory_store::MemoryStore::new());
1524        Arc::new(PersistentGraph::new(storage))
1525    }
1526
1527    async fn fault_injecting_cold() -> (Arc<PersistentGraph>, Arc<FaultInjectingGraphStore>) {
1528        let storage = Arc::new(FaultInjectingGraphStore {
1529            inner: MemoryStore::new(),
1530            fail_node_merge_insert: AtomicBool::new(false),
1531            fail_edge_merge_insert: AtomicBool::new(false),
1532            fail_node_delete: AtomicBool::new(false),
1533            fail_edge_delete: AtomicBool::new(false),
1534        });
1535        let store: Arc<dyn hirn_storage::PhysicalStore> = storage.clone();
1536        (Arc::new(PersistentGraph::new(store)), storage)
1537    }
1538
1539    #[tokio::test]
1540    async fn hot_tier_reflects_writes_immediately() {
1541        let cold = test_cold().await;
1542        let cached = CachedGraphStore::new(cold);
1543
1544        let a = MemoryId::new();
1545        let b = MemoryId::new();
1546        let ns = Namespace::default();
1547
1548        cached
1549            .add_node(a, Layer::Episodic, 0.9, Timestamp::now(), ns.clone())
1550            .await
1551            .unwrap();
1552        cached
1553            .add_node(b, Layer::Semantic, 0.5, Timestamp::now(), ns)
1554            .await
1555            .unwrap();
1556
1557        assert!(cached.has_node(a).await.unwrap());
1558        assert!(cached.has_node(b).await.unwrap());
1559        assert_eq!(cached.node_count().await.unwrap(), 2);
1560
1561        let eid = cached
1562            .add_edge(a, b, EdgeRelation::Causes, 0.7, Metadata::new())
1563            .await
1564            .unwrap();
1565
1566        let edges = cached.get_edges(a).await.unwrap();
1567        assert!(!edges.is_empty());
1568        assert_eq!(edges[0].id, eid);
1569    }
1570
1571    #[tokio::test]
1572    async fn write_through_to_cold_tier() {
1573        let cold = test_cold().await;
1574        let cached = CachedGraphStore::new(cold.clone());
1575
1576        let a = MemoryId::new();
1577        let ns = Namespace::default();
1578        cached
1579            .add_node(a, Layer::Episodic, 0.8, Timestamp::now(), ns)
1580            .await
1581            .unwrap();
1582
1583        // Verify cold tier has the node too.
1584        assert!(cold.has_node(a).await.unwrap());
1585    }
1586
1587    #[tokio::test]
1588    async fn batch_add_nodes_rolls_back_hot_tier_when_cold_persist_fails() {
1589        let (cold, storage) = fault_injecting_cold().await;
1590        let cached = CachedGraphStore::new(cold);
1591
1592        let first = MemoryId::new();
1593        let second = MemoryId::new();
1594        let namespace = Namespace::default();
1595        let now = Timestamp::now();
1596
1597        storage
1598            .fail_node_merge_insert
1599            .store(true, AtomicOrdering::Release);
1600
1601        let result = cached
1602            .add_nodes(&[
1603                GraphNodeData {
1604                    id: first,
1605                    layer: Layer::Episodic,
1606                    importance: 0.8,
1607                    created_at: now,
1608                    namespace,
1609                    access_count: 0,
1610                },
1611                GraphNodeData {
1612                    id: second,
1613                    layer: Layer::Semantic,
1614                    importance: 0.6,
1615                    created_at: now,
1616                    namespace,
1617                    access_count: 0,
1618                },
1619            ])
1620            .await;
1621
1622        assert!(result.is_err());
1623        assert!(!cached.has_node(first).await.unwrap());
1624        assert!(!cached.has_node(second).await.unwrap());
1625    }
1626
1627    #[tokio::test]
1628    async fn write_through_edges_preserve_hot_edge_ids_in_cold_tier() {
1629        let cold = test_cold().await;
1630        let cached = CachedGraphStore::new(cold.clone());
1631
1632        let a = MemoryId::new();
1633        let b = MemoryId::new();
1634        let ns = Namespace::default();
1635
1636        cached
1637            .add_node(a, Layer::Episodic, 0.8, Timestamp::now(), ns.clone())
1638            .await
1639            .unwrap();
1640        cached
1641            .add_node(b, Layer::Semantic, 0.6, Timestamp::now(), ns)
1642            .await
1643            .unwrap();
1644
1645        let edge_id = cached
1646            .add_edge(a, b, EdgeRelation::Causes, 0.7, Metadata::new())
1647            .await
1648            .unwrap();
1649
1650        let cold_edge = cold.get_edge(edge_id).await.unwrap();
1651        assert!(
1652            cold_edge.is_some(),
1653            "cold tier should store the same edge id returned by the hot tier"
1654        );
1655    }
1656
1657    #[tokio::test]
1658    async fn add_node_rolls_back_hot_tier_when_cold_persist_fails() {
1659        let (cold, storage) = fault_injecting_cold().await;
1660        let cached = CachedGraphStore::new(cold);
1661
1662        let a = MemoryId::new();
1663        storage
1664            .fail_node_merge_insert
1665            .store(true, AtomicOrdering::Release);
1666
1667        let result = cached
1668            .add_node(
1669                a,
1670                Layer::Episodic,
1671                0.8,
1672                Timestamp::now(),
1673                Namespace::default(),
1674            )
1675            .await;
1676
1677        assert!(result.is_err());
1678        assert!(!cached.has_node(a).await.unwrap());
1679    }
1680
1681    #[tokio::test]
1682    async fn add_edge_rolls_back_hot_tier_when_cold_persist_fails() {
1683        let (cold, storage) = fault_injecting_cold().await;
1684        let cached = CachedGraphStore::new(cold.clone());
1685
1686        let a = MemoryId::new();
1687        let b = MemoryId::new();
1688        let ns = Namespace::default();
1689
1690        cached
1691            .add_node(a, Layer::Episodic, 0.8, Timestamp::now(), ns.clone())
1692            .await
1693            .unwrap();
1694        cached
1695            .add_node(b, Layer::Semantic, 0.6, Timestamp::now(), ns)
1696            .await
1697            .unwrap();
1698
1699        storage
1700            .fail_edge_merge_insert
1701            .store(true, AtomicOrdering::Release);
1702        let result = cached
1703            .add_edge(a, b, EdgeRelation::Causes, 0.7, Metadata::new())
1704            .await;
1705
1706        assert!(result.is_err());
1707        assert!(cached.get_edges(a).await.unwrap().is_empty());
1708        assert!(cold.get_edges(a).await.unwrap().is_empty());
1709    }
1710
1711    #[tokio::test]
1712    async fn remove_edge_preserves_hot_tier_when_cold_delete_fails() {
1713        let (cold, storage) = fault_injecting_cold().await;
1714        let cached = CachedGraphStore::new(cold);
1715
1716        let a = MemoryId::new();
1717        let b = MemoryId::new();
1718        let ns = Namespace::default();
1719
1720        cached
1721            .add_node(a, Layer::Episodic, 0.8, Timestamp::now(), ns.clone())
1722            .await
1723            .unwrap();
1724        cached
1725            .add_node(b, Layer::Semantic, 0.6, Timestamp::now(), ns)
1726            .await
1727            .unwrap();
1728        let edge_id = cached
1729            .add_edge(a, b, EdgeRelation::Causes, 0.7, Metadata::new())
1730            .await
1731            .unwrap();
1732
1733        storage
1734            .fail_edge_delete
1735            .store(true, AtomicOrdering::Release);
1736        let result = cached.remove_edge(edge_id).await;
1737
1738        assert!(result.is_err());
1739        let hot_edge = cached.get_edge(edge_id).await.unwrap();
1740        assert!(
1741            hot_edge.is_some(),
1742            "hot tier should keep the edge when cold deletion fails"
1743        );
1744    }
1745
1746    #[tokio::test]
1747    async fn reads_never_hit_cold_tier() {
1748        let cold = test_cold().await;
1749        let cached = CachedGraphStore::new(cold);
1750
1751        let a = MemoryId::new();
1752        let b = MemoryId::new();
1753        let ns = Namespace::default();
1754
1755        cached
1756            .add_node(a, Layer::Episodic, 0.5, Timestamp::now(), ns.clone())
1757            .await
1758            .unwrap();
1759        cached
1760            .add_node(b, Layer::Episodic, 0.5, Timestamp::now(), ns)
1761            .await
1762            .unwrap();
1763        cached
1764            .add_edge(a, b, EdgeRelation::SimilarTo, 0.6, Metadata::new())
1765            .await
1766            .unwrap();
1767
1768        // All read operations use hot tier (PropertyGraph).
1769        let neighbors = cached.get_neighbors(a, 1, 0.0).await.unwrap();
1770        assert!(!neighbors.is_empty());
1771
1772        let outgoing = cached.outgoing_weighted(a).await.unwrap();
1773        assert!(!outgoing.is_empty());
1774
1775        let path = cached.shortest_path(a, b).await.unwrap();
1776        assert!(path.is_some());
1777    }
1778
1779    #[tokio::test]
1780    async fn load_from_cold_populates_hot() {
1781        let cold = test_cold().await;
1782
1783        // Write directly to cold tier.
1784        let a = MemoryId::new();
1785        let b = MemoryId::new();
1786        let ns = Namespace::default();
1787        cold.add_node(a, Layer::Episodic, 0.5, Timestamp::now(), ns.clone())
1788            .await
1789            .unwrap();
1790        cold.add_node(b, Layer::Semantic, 0.7, Timestamp::now(), ns)
1791            .await
1792            .unwrap();
1793        cold.add_edge(a, b, EdgeRelation::Causes, 0.8, Metadata::new())
1794            .await
1795            .unwrap();
1796
1797        // Create cached store and load.
1798        let cached = CachedGraphStore::new(cold);
1799        cached.load_from_cold().await.unwrap();
1800
1801        // Hot tier should have everything.
1802        assert!(cached.has_node(a).await.unwrap());
1803        assert!(cached.has_node(b).await.unwrap());
1804        let edges = cached.get_edges(a).await.unwrap();
1805        assert!(!edges.is_empty());
1806    }
1807
1808    #[tokio::test]
1809    async fn concurrent_readers_dont_block() {
1810        let cold = test_cold().await;
1811        let cached = Arc::new(CachedGraphStore::new(cold));
1812
1813        let a = MemoryId::new();
1814        let ns = Namespace::default();
1815        cached
1816            .add_node(a, Layer::Episodic, 0.5, Timestamp::now(), ns)
1817            .await
1818            .unwrap();
1819
1820        // Spawn 4 concurrent readers.
1821        let mut handles = Vec::new();
1822        for _ in 0..4 {
1823            let cached = Arc::clone(&cached);
1824            handles.push(tokio::spawn(async move {
1825                for _ in 0..100 {
1826                    let _ = cached.has_node(a).await;
1827                    let _ = cached.node_count().await;
1828                }
1829            }));
1830        }
1831
1832        for h in handles {
1833            h.await.unwrap();
1834        }
1835
1836        // If we get here, no deadlocks occurred.
1837        assert!(cached.has_node(a).await.unwrap());
1838    }
1839
1840    #[tokio::test]
1841    async fn spreading_activation_on_hot_tier_is_fast() {
1842        use hirn_graph::activation::{ActivationConfig, spread_activation};
1843        use std::time::Instant;
1844
1845        // Build a 1000-node graph with realistic connectivity.
1846        let mut pg = PropertyGraph::new();
1847        let mut nodes = Vec::with_capacity(1000);
1848        for _ in 0..1000 {
1849            let id = MemoryId::new();
1850            pg.add_node(id, Layer::Episodic, 0.5, Timestamp::now());
1851            nodes.push(id);
1852        }
1853        // ~5 edges per node (5000 edges total).
1854        for i in 0..1000 {
1855            for j in 1..=5 {
1856                let target = (i + j * 7) % 1000;
1857                if i != target {
1858                    let _ = pg.add_edge(
1859                        nodes[i],
1860                        nodes[target],
1861                        EdgeRelation::Causes,
1862                        0.5,
1863                        Metadata::new(),
1864                    );
1865                }
1866            }
1867        }
1868
1869        let cfg = ActivationConfig::default();
1870        let seed = &[nodes[0]];
1871
1872        // Warm up.
1873        let _ = spread_activation(&pg, seed, &cfg, None, None).unwrap();
1874
1875        // Measure.
1876        let start = Instant::now();
1877        let result = spread_activation(&pg, seed, &cfg, None, None).unwrap();
1878        let elapsed = start.elapsed();
1879
1880        assert!(
1881            !result.activations.is_empty(),
1882            "activation should return results"
1883        );
1884        assert!(
1885            elapsed.as_millis() < 50,
1886            "spreading activation on 1000-node hot graph took {}ms (should be < 50ms)",
1887            elapsed.as_millis()
1888        );
1889    }
1890
1891    #[tokio::test]
1892    async fn deep_activation_runtime_delegates_to_cold_tier_when_hot_is_empty() {
1893        let cold = test_cold().await;
1894        let cached = CachedGraphStore::new(cold.clone());
1895
1896        let a = MemoryId::new();
1897        let b = MemoryId::new();
1898        let ns = Namespace::default();
1899        cold.add_node(a, Layer::Episodic, 0.5, Timestamp::now(), ns.clone())
1900            .await
1901            .unwrap();
1902        cold.add_node(b, Layer::Episodic, 0.5, Timestamp::now(), ns)
1903            .await
1904            .unwrap();
1905        cold.add_edge(a, b, EdgeRelation::RelatedTo, 0.9, Metadata::new())
1906            .await
1907            .unwrap();
1908
1909        let result = hirn_exec::GraphReadRuntime::activate_graph(
1910            &cached,
1911            &[a],
1912            hirn_exec::ActivationMode::Static,
1913            None,
1914            6,
1915            0.001,
1916            0.1,
1917            5,
1918            None,
1919        )
1920        .await
1921        .unwrap();
1922
1923        let seed = a.to_string();
1924        let neighbor = b.to_string();
1925        assert!(
1926            result.ids.iter().any(|id| id == &seed),
1927            "cold-tier activation should include the seed"
1928        );
1929        assert!(
1930            result.ids.iter().any(|id| id == &neighbor),
1931            "cold-tier activation should include the persisted neighbor even when the hot graph is empty"
1932        );
1933    }
1934
1935    #[tokio::test]
1936    async fn deep_causal_runtime_delegates_to_cold_tier_when_hot_is_empty() {
1937        let cold = test_cold().await;
1938        let cached = CachedGraphStore::new(cold.clone());
1939
1940        let a = MemoryId::new();
1941        let b = MemoryId::new();
1942        let ns = Namespace::default();
1943        cold.add_node(a, Layer::Episodic, 0.5, Timestamp::now(), ns.clone())
1944            .await
1945            .unwrap();
1946        cold.add_node(b, Layer::Episodic, 0.5, Timestamp::now(), ns)
1947            .await
1948            .unwrap();
1949        cold.add_edge(a, b, EdgeRelation::Causes, 0.9, Metadata::new())
1950            .await
1951            .unwrap();
1952
1953        let rows = hirn_exec::GraphReadRuntime::causal_chain(
1954            &cached,
1955            &[a],
1956            6,
1957            0.0,
1958            5,
1959            EdgeRelation::Causes,
1960            None,
1961        )
1962        .await
1963        .unwrap();
1964
1965        assert_eq!(
1966            rows.len(),
1967            1,
1968            "cold-tier causal traversal should emit one edge row"
1969        );
1970        assert_eq!(rows[0].source_id, a.to_string());
1971        assert_eq!(rows[0].target_id, b.to_string());
1972    }
1973
1974    #[tokio::test]
1975    async fn deep_traverse_runtime_delegates_to_cold_tier_when_hot_is_empty() {
1976        let cold = test_cold().await;
1977        let cached = CachedGraphStore::new(cold.clone());
1978
1979        let a = MemoryId::new();
1980        let b = MemoryId::new();
1981        let ns = Namespace::default();
1982        cold.add_node(a, Layer::Episodic, 0.5, Timestamp::now(), ns)
1983            .await
1984            .unwrap();
1985        cold.add_node(
1986            b,
1987            Layer::Episodic,
1988            0.5,
1989            Timestamp::now(),
1990            Namespace::default(),
1991        )
1992        .await
1993        .unwrap();
1994        cold.add_edge(a, b, EdgeRelation::RelatedTo, 0.9, Metadata::new())
1995            .await
1996            .unwrap();
1997
1998        let rows = hirn_exec::GraphReadRuntime::traverse_graph(
1999            &cached,
2000            &[a],
2001            6,
2002            5,
2003            Some(&[EdgeRelation::RelatedTo]),
2004            None,
2005        )
2006        .await
2007        .unwrap();
2008
2009        assert!(
2010            rows.iter().any(|row| row.node_id == b.to_string()),
2011            "cold-tier traversal should include the persisted neighbor even when the hot graph is empty"
2012        );
2013    }
2014}