Skip to main content

clawft_kernel/
causal.rs

1//! Causal graph DAG with typed/weighted directed edges.
2//!
3//! Provides a concurrent, lock-free causal reasoning graph where nodes
4//! represent events or observations and edges encode causal relationships
5//! with weights and provenance metadata. Built on `DashMap` for safe
6//! concurrent access from multiple agent threads.
7
8use dashmap::DashMap;
9use serde::{Deserialize, Serialize};
10use std::collections::VecDeque;
11use std::fmt;
12use std::sync::atomic::{AtomicU64, Ordering};
13
14/// Numeric identifier for causal graph nodes.
15///
16/// This is local to the causal module and distinct from
17/// [`crate::cluster::NodeId`] which is a `String`.
18pub type NodeId = u64;
19
20// ---------------------------------------------------------------------------
21// CausalEdgeType
22// ---------------------------------------------------------------------------
23
24/// The kind of causal relationship an edge represents.
25#[non_exhaustive]
26#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
27pub enum CausalEdgeType {
28    /// A directly causes B.
29    Causes,
30    /// A suppresses or prevents B.
31    Inhibits,
32    /// A and B are statistically correlated (non-directional semantics,
33    /// but stored in the directed graph for traversal purposes).
34    Correlates,
35    /// A is a precondition that enables B.
36    Enables,
37    /// A temporally follows B.
38    Follows,
39    /// A provides evidence against B.
40    Contradicts,
41    /// Edge was created by a ClawStage trigger.
42    TriggeredBy,
43    /// A provides supporting evidence for B.
44    EvidenceFor,
45}
46
47impl fmt::Display for CausalEdgeType {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        match self {
50            Self::Causes => write!(f, "Causes"),
51            Self::Inhibits => write!(f, "Inhibits"),
52            Self::Correlates => write!(f, "Correlates"),
53            Self::Enables => write!(f, "Enables"),
54            Self::Follows => write!(f, "Follows"),
55            Self::Contradicts => write!(f, "Contradicts"),
56            Self::TriggeredBy => write!(f, "TriggeredBy"),
57            Self::EvidenceFor => write!(f, "EvidenceFor"),
58        }
59    }
60}
61
62// ---------------------------------------------------------------------------
63// CausalEdge
64// ---------------------------------------------------------------------------
65
66/// A weighted, typed directed edge between two causal graph nodes.
67#[derive(Serialize, Deserialize, Debug, Clone)]
68pub struct CausalEdge {
69    /// Source node (tail of the arrow).
70    pub source: NodeId,
71    /// Target node (head of the arrow).
72    pub target: NodeId,
73    /// Semantic type of the relationship.
74    pub edge_type: CausalEdgeType,
75    /// Strength / confidence of the relationship (0.0 .. 1.0 typical).
76    pub weight: f32,
77    /// Hybrid logical clock timestamp at creation.
78    pub timestamp: u64,
79    /// ExoChain sequence number for provenance tracking.
80    pub chain_seq: u64,
81    /// Universal Node ID bytes for the source node.
82    pub source_universal_id: [u8; 32],
83    /// Universal Node ID bytes for the target node.
84    pub target_universal_id: [u8; 32],
85}
86
87// ---------------------------------------------------------------------------
88// CausalNode
89// ---------------------------------------------------------------------------
90
91/// A node in the causal graph representing an event, observation, or concept.
92#[derive(Serialize, Deserialize, Debug, Clone)]
93pub struct CausalNode {
94    /// Local numeric identifier.
95    pub id: NodeId,
96    /// Human-readable label.
97    pub label: String,
98    /// HLC timestamp at creation.
99    pub created_at: u64,
100    /// Arbitrary JSON metadata attached to this node.
101    pub metadata: serde_json::Value,
102}
103
104// ---------------------------------------------------------------------------
105// CausalGraph
106// ---------------------------------------------------------------------------
107
108/// A concurrent directed acyclic graph for causal reasoning.
109///
110/// Internally backed by [`DashMap`] for lock-free concurrent reads and
111/// fine-grained write locking. Edge lists are stored in both forward
112/// (outgoing) and reverse (incoming) adjacency maps for efficient
113/// bidirectional traversal.
114pub struct CausalGraph {
115    nodes: DashMap<NodeId, CausalNode>,
116    forward_edges: DashMap<NodeId, Vec<CausalEdge>>,
117    reverse_edges: DashMap<NodeId, Vec<CausalEdge>>,
118    next_node_id: AtomicU64,
119    node_count: AtomicU64,
120    edge_count: AtomicU64,
121}
122
123impl CausalGraph {
124    /// Create an empty causal graph.
125    pub fn new() -> Self {
126        Self {
127            nodes: DashMap::new(),
128            forward_edges: DashMap::new(),
129            reverse_edges: DashMap::new(),
130            next_node_id: AtomicU64::new(1),
131            node_count: AtomicU64::new(0),
132            edge_count: AtomicU64::new(0),
133        }
134    }
135
136    /// Add a node with an auto-assigned ID.
137    pub fn add_node(&self, label: String, metadata: serde_json::Value) -> NodeId {
138        let id = self.next_node_id.fetch_add(1, Ordering::SeqCst);
139        let node = CausalNode {
140            id,
141            label,
142            created_at: 0, // caller may set via metadata; HLC not available here
143            metadata,
144        };
145        self.nodes.insert(id, node);
146        self.forward_edges.insert(id, Vec::new());
147        self.reverse_edges.insert(id, Vec::new());
148        self.node_count.fetch_add(1, Ordering::SeqCst);
149        id
150    }
151
152    /// Retrieve a clone of the node with the given ID.
153    pub fn get_node(&self, id: NodeId) -> Option<CausalNode> {
154        self.nodes.get(&id).map(|r| r.value().clone())
155    }
156
157    /// Remove a node and all edges incident to it.
158    ///
159    /// Returns the removed node, or `None` if the ID was not found.
160    pub fn remove_node(&self, id: NodeId) -> Option<CausalNode> {
161        let (_, node) = self.nodes.remove(&id)?;
162
163        // Remove forward edges from this node and update reverse adjacency.
164        if let Some((_, fwd)) = self.forward_edges.remove(&id) {
165            let removed = fwd.len() as u64;
166            for edge in &fwd {
167                if let Some(mut rev) = self.reverse_edges.get_mut(&edge.target) {
168                    rev.retain(|e| e.source != id);
169                }
170            }
171            self.edge_count.fetch_sub(removed, Ordering::SeqCst);
172        }
173
174        // Remove reverse edges to this node and update forward adjacency.
175        if let Some((_, rev)) = self.reverse_edges.remove(&id) {
176            let removed = rev.len() as u64;
177            for edge in &rev {
178                if let Some(mut fwd) = self.forward_edges.get_mut(&edge.source) {
179                    fwd.retain(|e| e.target != id);
180                }
181            }
182            self.edge_count.fetch_sub(removed, Ordering::SeqCst);
183        }
184
185        self.node_count.fetch_sub(1, Ordering::SeqCst);
186        Some(node)
187    }
188
189    /// Create an edge from `source` to `target`.
190    ///
191    /// Returns `false` if either endpoint does not exist.
192    pub fn link(
193        &self,
194        source: NodeId,
195        target: NodeId,
196        edge_type: CausalEdgeType,
197        weight: f32,
198        timestamp: u64,
199        chain_seq: u64,
200    ) -> bool {
201        if !self.nodes.contains_key(&source) || !self.nodes.contains_key(&target) {
202            return false;
203        }
204
205        let edge = CausalEdge {
206            source,
207            target,
208            edge_type,
209            weight,
210            timestamp,
211            chain_seq,
212            source_universal_id: [0u8; 32],
213            target_universal_id: [0u8; 32],
214        };
215
216        if let Some(mut fwd) = self.forward_edges.get_mut(&source) {
217            fwd.push(edge.clone());
218        }
219        if let Some(mut rev) = self.reverse_edges.get_mut(&target) {
220            rev.push(edge);
221        }
222
223        self.edge_count.fetch_add(1, Ordering::SeqCst);
224        true
225    }
226
227    /// Remove all edges between `source` and `target` (in that direction).
228    ///
229    /// Returns the number of edges removed.
230    pub fn unlink(&self, source: NodeId, target: NodeId) -> usize {
231        let mut count = 0usize;
232
233        if let Some(mut fwd) = self.forward_edges.get_mut(&source) {
234            let before = fwd.len();
235            fwd.retain(|e| e.target != target);
236            count = before - fwd.len();
237        }
238
239        if let Some(mut rev) = self.reverse_edges.get_mut(&target) {
240            rev.retain(|e| e.source != source);
241        }
242
243        self.edge_count
244            .fetch_sub(count as u64, Ordering::SeqCst);
245        count
246    }
247
248    /// Return all edges originating from `id`.
249    pub fn get_forward_edges(&self, id: NodeId) -> Vec<CausalEdge> {
250        self.forward_edges
251            .get(&id)
252            .map(|r| r.value().clone())
253            .unwrap_or_default()
254    }
255
256    /// Return all edges targeting `id`.
257    pub fn get_reverse_edges(&self, id: NodeId) -> Vec<CausalEdge> {
258        self.reverse_edges
259            .get(&id)
260            .map(|r| r.value().clone())
261            .unwrap_or_default()
262    }
263
264    /// Return forward edges from `id` that match `edge_type`.
265    pub fn get_edges_by_type(&self, id: NodeId, edge_type: &CausalEdgeType) -> Vec<CausalEdge> {
266        self.forward_edges
267            .get(&id)
268            .map(|r| {
269                r.value()
270                    .iter()
271                    .filter(|e| &e.edge_type == edge_type)
272                    .cloned()
273                    .collect()
274            })
275            .unwrap_or_default()
276    }
277
278    /// Number of nodes currently in the graph.
279    pub fn node_count(&self) -> u64 {
280        self.node_count.load(Ordering::SeqCst)
281    }
282
283    /// Number of edges currently in the graph.
284    pub fn edge_count(&self) -> u64 {
285        self.edge_count.load(Ordering::SeqCst)
286    }
287
288    /// Remove all nodes and edges (used during calibration cleanup).
289    pub fn clear(&self) {
290        self.nodes.clear();
291        self.forward_edges.clear();
292        self.reverse_edges.clear();
293        self.node_count.store(0, Ordering::SeqCst);
294        self.edge_count.store(0, Ordering::SeqCst);
295        // Note: next_node_id is intentionally NOT reset so IDs remain unique
296        // across clear cycles.
297    }
298
299    /// BFS traversal forward from `start` up to `depth` hops.
300    ///
301    /// Returns all discovered node IDs (excluding `start`).
302    pub fn traverse_forward(&self, start: NodeId, depth: usize) -> Vec<NodeId> {
303        self.bfs(start, depth, true)
304    }
305
306    /// BFS traversal backward (following reverse edges) from `start`
307    /// up to `depth` hops.
308    ///
309    /// Returns all discovered node IDs (excluding `start`).
310    pub fn traverse_reverse(&self, start: NodeId, depth: usize) -> Vec<NodeId> {
311        self.bfs(start, depth, false)
312    }
313
314    /// Find the shortest path from `from` to `to` using BFS, limited
315    /// to `max_depth` hops.
316    ///
317    /// Returns the node sequence including both endpoints, or `None`
318    /// if no path exists within the depth limit.
319    pub fn find_path(&self, from: NodeId, to: NodeId, max_depth: usize) -> Option<Vec<NodeId>> {
320        if from == to {
321            return Some(vec![from]);
322        }
323        if !self.nodes.contains_key(&from) || !self.nodes.contains_key(&to) {
324            return None;
325        }
326
327        // BFS with parent tracking.
328        let mut visited: std::collections::HashSet<NodeId> = std::collections::HashSet::new();
329        let mut parent: std::collections::HashMap<NodeId, NodeId> =
330            std::collections::HashMap::new();
331        let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
332
333        visited.insert(from);
334        queue.push_back((from, 0));
335
336        while let Some((current, d)) = queue.pop_front() {
337            if d >= max_depth {
338                continue;
339            }
340            let edges = self.get_forward_edges(current);
341            for edge in edges {
342                if visited.contains(&edge.target) {
343                    continue;
344                }
345                visited.insert(edge.target);
346                parent.insert(edge.target, current);
347
348                if edge.target == to {
349                    // Reconstruct path.
350                    let mut path = Vec::new();
351                    let mut cur = to;
352                    while cur != from {
353                        path.push(cur);
354                        cur = parent[&cur];
355                    }
356                    path.push(from);
357                    path.reverse();
358                    return Some(path);
359                }
360
361                queue.push_back((edge.target, d + 1));
362            }
363        }
364
365        None
366    }
367
368    // -- private helpers --
369
370    fn bfs(&self, start: NodeId, depth: usize, forward: bool) -> Vec<NodeId> {
371        let mut visited: std::collections::HashSet<NodeId> = std::collections::HashSet::new();
372        let mut result = Vec::new();
373        let mut queue: VecDeque<(NodeId, usize)> = VecDeque::new();
374
375        visited.insert(start);
376        queue.push_back((start, 0));
377
378        while let Some((current, d)) = queue.pop_front() {
379            if d >= depth {
380                continue;
381            }
382            let edges = if forward {
383                self.get_forward_edges(current)
384            } else {
385                self.get_reverse_edges(current)
386            };
387            for edge in edges {
388                let neighbor = if forward { edge.target } else { edge.source };
389                if visited.contains(&neighbor) {
390                    continue;
391                }
392                visited.insert(neighbor);
393                result.push(neighbor);
394                queue.push_back((neighbor, d + 1));
395            }
396        }
397
398        result
399    }
400
401    /// List all node IDs currently in the graph.
402    pub fn node_ids(&self) -> Vec<NodeId> {
403        self.nodes.iter().map(|r| *r.key()).collect()
404    }
405
406    /// Degree of a node (in + out edges, treating graph as undirected).
407    pub fn degree(&self, id: NodeId) -> usize {
408        let fwd = self.forward_edges.get(&id).map_or(0, |e| e.len());
409        let rev = self.reverse_edges.get(&id).map_or(0, |e| e.len());
410        fwd + rev
411    }
412
413    /// In-degree (number of incoming edges).
414    pub fn in_degree(&self, id: NodeId) -> usize {
415        self.reverse_edges.get(&id).map_or(0, |e| e.len())
416    }
417
418    /// Out-degree (number of outgoing edges).
419    pub fn out_degree(&self, id: NodeId) -> usize {
420        self.forward_edges.get(&id).map_or(0, |e| e.len())
421    }
422
423    // -----------------------------------------------------------------------
424    // Connected Components (undirected)
425    // -----------------------------------------------------------------------
426
427    /// Find connected components treating the graph as undirected.
428    ///
429    /// Returns a vec of components, each component being a vec of node IDs.
430    /// Components are sorted largest-first.
431    pub fn connected_components(&self) -> Vec<Vec<NodeId>> {
432        let ids = self.node_ids();
433        let mut visited: std::collections::HashSet<NodeId> = std::collections::HashSet::new();
434        let mut components = Vec::new();
435
436        for &id in &ids {
437            if visited.contains(&id) {
438                continue;
439            }
440            // BFS over both directions (undirected).
441            let mut component = Vec::new();
442            let mut queue: VecDeque<NodeId> = VecDeque::new();
443            visited.insert(id);
444            queue.push_back(id);
445
446            while let Some(current) = queue.pop_front() {
447                component.push(current);
448                // Forward neighbors.
449                for edge in self.get_forward_edges(current) {
450                    if visited.insert(edge.target) {
451                        queue.push_back(edge.target);
452                    }
453                }
454                // Reverse neighbors.
455                for edge in self.get_reverse_edges(current) {
456                    if visited.insert(edge.source) {
457                        queue.push_back(edge.source);
458                    }
459                }
460            }
461            component.sort();
462            components.push(component);
463        }
464
465        components.sort_by_key(|b| std::cmp::Reverse(b.len()));
466        components
467    }
468
469    // -----------------------------------------------------------------------
470    // Community Detection (Label Propagation)
471    // -----------------------------------------------------------------------
472
473    /// Detect communities using label propagation on the undirected graph.
474    ///
475    /// Each node starts with its own label. In each iteration, every node
476    /// adopts the most frequent label among its neighbors (weighted by edge
477    /// weight). Converges when no labels change, or after `max_iterations`.
478    ///
479    /// Returns a map from community label (a NodeId) to the set of node IDs
480    /// in that community.
481    pub fn detect_communities(&self, max_iterations: usize) -> Vec<Vec<NodeId>> {
482        let ids = self.node_ids();
483        if ids.is_empty() {
484            return Vec::new();
485        }
486
487        // Initialize: each node gets its own ID as label.
488        let mut labels: std::collections::HashMap<NodeId, NodeId> = std::collections::HashMap::new();
489        for &id in &ids {
490            labels.insert(id, id);
491        }
492
493        for _iter in 0..max_iterations {
494            let mut changed = false;
495
496            // Process nodes in a deterministic order.
497            let mut process_order = ids.clone();
498            process_order.sort();
499
500            for &id in &process_order {
501                // Gather neighbor labels weighted by edge weight.
502                let mut label_weights: std::collections::HashMap<NodeId, f32> =
503                    std::collections::HashMap::new();
504
505                for edge in self.get_forward_edges(id) {
506                    if let Some(&lbl) = labels.get(&edge.target) {
507                        *label_weights.entry(lbl).or_insert(0.0) += edge.weight;
508                    }
509                }
510                for edge in self.get_reverse_edges(id) {
511                    if let Some(&lbl) = labels.get(&edge.source) {
512                        *label_weights.entry(lbl).or_insert(0.0) += edge.weight;
513                    }
514                }
515
516                if label_weights.is_empty() {
517                    continue; // isolated node keeps its label
518                }
519
520                // Pick the label with the highest total weight.
521                // On ties, pick the smallest label for determinism.
522                let best_label = label_weights
523                    .iter()
524                    .max_by(|a, b| {
525                        a.1.partial_cmp(b.1)
526                            .unwrap_or(std::cmp::Ordering::Equal)
527                            .then_with(|| b.0.cmp(a.0)) // smaller ID wins ties
528                    })
529                    .map(|(&lbl, _)| lbl)
530                    .unwrap();
531
532                if labels[&id] != best_label {
533                    labels.insert(id, best_label);
534                    changed = true;
535                }
536            }
537
538            if !changed {
539                break;
540            }
541        }
542
543        // Group nodes by label.
544        let mut communities: std::collections::HashMap<NodeId, Vec<NodeId>> =
545            std::collections::HashMap::new();
546        for (&node, &label) in &labels {
547            communities.entry(label).or_default().push(node);
548        }
549
550        let mut result: Vec<Vec<NodeId>> = communities.into_values().collect();
551        for community in &mut result {
552            community.sort();
553        }
554        result.sort_by_key(|b| std::cmp::Reverse(b.len()));
555        result
556    }
557
558    // -----------------------------------------------------------------------
559    // Spectral Analysis
560    // -----------------------------------------------------------------------
561
562    /// Compute the algebraic connectivity (lambda_2) of the graph.
563    ///
564    /// Lambda_2 is the second-smallest eigenvalue of the graph Laplacian.
565    /// - lambda_2 = 0 means the graph is disconnected.
566    /// - Higher values indicate stronger connectivity.
567    ///
568    /// Uses sparse Lanczos iteration at O(k*m) where m = number of edges
569    /// and k = `max_iterations`. For typical ECC graphs with average degree
570    /// ~10, this is ~200x faster than the dense O(k*n^2) approach.
571    ///
572    /// Returns `(lambda_2, fiedler_vector)` where the Fiedler vector can be
573    /// used for spectral partitioning (sign of each component indicates which
574    /// partition the node belongs to).
575    pub fn spectral_analysis(&self, max_iterations: usize) -> SpectralResult {
576        let ids = self.node_ids();
577        let n = ids.len();
578
579        if n < 2 {
580            return SpectralResult {
581                lambda_2: 0.0,
582                fiedler_vector: Vec::new(),
583                node_ids: ids,
584            };
585        }
586
587        // Build index map: NodeId -> matrix index.
588        let mut id_to_idx: std::collections::HashMap<NodeId, usize> =
589            std::collections::HashMap::new();
590        let mut sorted_ids = ids.clone();
591        sorted_ids.sort();
592        for (i, &id) in sorted_ids.iter().enumerate() {
593            id_to_idx.insert(id, i);
594        }
595
596        // Build sparse adjacency: adj[i] = Vec<(j, weight)> (symmetric).
597        // Also accumulate degrees.  O(m) space instead of O(n^2).
598        let mut adj: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
599        let mut degree: Vec<f64> = vec![0.0; n];
600
601        for &id in &sorted_ids {
602            let i = id_to_idx[&id];
603
604            // Forward edges.
605            for edge in self.get_forward_edges(id) {
606                if let Some(&j) = id_to_idx.get(&edge.target) {
607                    if i != j {
608                        let w = edge.weight as f64;
609                        adj[i].push((j, w));
610                        adj[j].push((i, w));
611                        degree[i] += w;
612                        degree[j] += w;
613                    }
614                }
615            }
616            // Reverse edges — only upper triangle to avoid double-counting.
617            for edge in self.get_reverse_edges(id) {
618                if let Some(&j) = id_to_idx.get(&edge.source) {
619                    if i != j && j > i {
620                        let w = edge.weight as f64;
621                        adj[i].push((j, w));
622                        adj[j].push((i, w));
623                        degree[i] += w;
624                        degree[j] += w;
625                    }
626                }
627            }
628        }
629
630        // Fix up degree: recompute from adjacency for correctness (handles
631        // any double-adds from symmetric storage).
632        for i in 0..n {
633            let mut d = 0.0f64;
634            for &(_, w) in &adj[i] {
635                d += w;
636            }
637            degree[i] = d;
638        }
639
640        // Sparse Laplacian mat-vec: result = L * x = D*x - A*x
641        let laplacian_mul = |x: &[f64], out: &mut [f64]| {
642            for i in 0..n {
643                let mut sum = degree[i] * x[i]; // D*x
644                for &(j, w) in &adj[i] {
645                    sum -= w * x[j]; // -A*x
646                }
647                out[i] = sum;
648            }
649        };
650
651        // ── Lanczos iteration ──────────────────────────────────────────
652        // Builds a k x k tridiagonal matrix T whose eigenvalues approximate
653        // those of L restricted to the subspace orthogonal to the constant
654        // (null-space) vector.  We then extract lambda_2 from T.
655        //
656        // The Fiedler vector is recovered by mapping the corresponding
657        // eigenvector of T back through the Lanczos basis.
658
659        let inv_sqrt_n = 1.0 / (n as f64).sqrt();
660
661        // Initial vector: deterministic, orthogonal to the constant vector.
662        let mut q: Vec<f64> = (0..n)
663            .map(|i| (i as f64) - (n as f64 - 1.0) / 2.0)
664            .collect();
665
666        // Project out the constant (null-space) direction.
667        let dot_ones: f64 = q.iter().sum::<f64>() * inv_sqrt_n;
668        for qi in q.iter_mut() {
669            *qi -= dot_ones * inv_sqrt_n;
670        }
671        normalize_vec(&mut q);
672
673        let k = max_iterations.min(n - 1); // can't exceed n-1 Lanczos steps
674        let mut alpha: Vec<f64> = Vec::with_capacity(k); // diagonal of T
675        let mut beta: Vec<f64> = Vec::with_capacity(k);  // sub-diagonal of T
676        let mut basis: Vec<Vec<f64>> = Vec::with_capacity(k); // Lanczos vectors
677
678        let mut q_prev: Vec<f64> = vec![0.0; n];
679        let mut w_buf: Vec<f64> = vec![0.0; n];
680
681        for j in 0..k {
682            basis.push(q.clone());
683
684            // w = L * q_j
685            laplacian_mul(&q, &mut w_buf);
686
687            // alpha_j = q_j^T * w
688            let aj: f64 = q.iter().zip(w_buf.iter()).map(|(a, b)| a * b).sum();
689            alpha.push(aj);
690
691            // w = w - alpha_j * q_j - beta_{j-1} * q_{j-1}
692            let bj_prev = if j > 0 { beta[j - 1] } else { 0.0 };
693            for i in 0..n {
694                w_buf[i] -= aj * q[i] + bj_prev * q_prev[i];
695            }
696
697            // Re-orthogonalize against all previous basis vectors AND the
698            // constant vector (full reorth for numerical stability).
699            let dot_c: f64 = w_buf.iter().sum::<f64>() * inv_sqrt_n;
700            for wi in w_buf.iter_mut() {
701                *wi -= dot_c * inv_sqrt_n;
702            }
703            for prev in &basis {
704                let dot: f64 = w_buf.iter().zip(prev.iter()).map(|(a, b)| a * b).sum();
705                for i in 0..n {
706                    w_buf[i] -= dot * prev[i];
707                }
708            }
709
710            let bj: f64 = w_buf.iter().map(|x| x * x).sum::<f64>().sqrt();
711            beta.push(bj);
712
713            if bj < 1e-12 {
714                // Invariant subspace found; stop early.
715                break;
716            }
717
718            // q_{j+1} = w / beta_j
719            q_prev = q.clone();
720            q = w_buf.iter().map(|&x| x / bj).collect();
721        }
722
723        // ── Extract eigenvalues from the tridiagonal matrix T ──────────
724        // Use the implicit-shift QR algorithm on the symmetric tridiagonal
725        // matrix (alpha, beta).  We only need the smallest eigenvalue of T
726        // (which approximates lambda_2 since we projected out the null space).
727        let m = alpha.len();
728        let (evals, evecs) = tridiag_eigen(&alpha, &beta[..m.saturating_sub(1).max(0).min(beta.len())], m);
729
730        // Find the smallest eigenvalue (approximation to lambda_2).
731        let mut min_idx = 0;
732        let mut min_val = f64::MAX;
733        for (i, &ev) in evals.iter().enumerate() {
734            if ev < min_val {
735                min_val = ev;
736                min_idx = i;
737            }
738        }
739        let lambda_2 = min_val.max(0.0);
740
741        // Recover the Fiedler vector: v = Q * s, where Q is the n x m Lanczos
742        // basis and s is the eigenvector of T corresponding to lambda_2.
743        let s = &evecs[min_idx];
744        let mut fiedler = vec![0.0f64; n];
745        for (j, bvec) in basis.iter().enumerate() {
746            if j < s.len() {
747                let sj = s[j];
748                for i in 0..n {
749                    fiedler[i] += sj * bvec[i];
750                }
751            }
752        }
753        normalize_vec(&mut fiedler);
754
755        SpectralResult {
756            lambda_2,
757            fiedler_vector: fiedler,
758            node_ids: sorted_ids,
759        }
760    }
761
762    /// Partition the graph into two halves using the Fiedler vector.
763    ///
764    /// Nodes with positive Fiedler vector components go to partition A,
765    /// negative to partition B. This is spectral bisection — the
766    /// minimum-cut balanced partition.
767    pub fn spectral_partition(&self) -> (Vec<NodeId>, Vec<NodeId>) {
768        let result = self.spectral_analysis(50);
769        let mut a = Vec::new();
770        let mut b = Vec::new();
771
772        for (i, &id) in result.node_ids.iter().enumerate() {
773            if i < result.fiedler_vector.len() && result.fiedler_vector[i] >= 0.0 {
774                a.push(id);
775            } else {
776                b.push(id);
777            }
778        }
779        (a, b)
780    }
781
782    // -----------------------------------------------------------------------
783    // Predictive Analysis
784    // -----------------------------------------------------------------------
785
786    /// Compute co-modification coupling between nodes based on temporal
787    /// co-occurrence patterns.
788    ///
789    /// Given a list of "change events" (each event is a set of node IDs that
790    /// changed together, plus a timestamp), computes a coupling score for
791    /// every pair of nodes that have been modified together.
792    ///
793    /// The coupling score for nodes (A, B) is:
794    ///   coupling = co_changes(A,B) / max(changes(A), changes(B))
795    ///
796    /// Returns pairs sorted by coupling score descending.
797    pub fn compute_coupling(
798        &self,
799        change_events: &[ChangeEvent],
800    ) -> Vec<CouplingPair> {
801        let mut change_count: std::collections::HashMap<NodeId, usize> =
802            std::collections::HashMap::new();
803        let mut co_change_count: std::collections::HashMap<(NodeId, NodeId), usize> =
804            std::collections::HashMap::new();
805
806        for event in change_events {
807            let mut nodes: Vec<NodeId> = event.node_ids.clone();
808            nodes.sort();
809            nodes.dedup();
810
811            for &id in &nodes {
812                *change_count.entry(id).or_insert(0) += 1;
813            }
814            // Count co-occurrences.
815            for i in 0..nodes.len() {
816                for j in (i + 1)..nodes.len() {
817                    let key = (nodes[i], nodes[j]);
818                    *co_change_count.entry(key).or_insert(0) += 1;
819                }
820            }
821        }
822
823        let mut pairs: Vec<CouplingPair> = co_change_count
824            .iter()
825            .map(|(&(a, b), &co)| {
826                let max_changes = change_count
827                    .get(&a)
828                    .copied()
829                    .unwrap_or(1)
830                    .max(change_count.get(&b).copied().unwrap_or(1));
831                CouplingPair {
832                    node_a: a,
833                    node_b: b,
834                    co_changes: co,
835                    coupling_score: co as f64 / max_changes as f64,
836                }
837            })
838            .collect();
839
840        pairs.sort_by(|a, b| {
841            b.coupling_score
842                .partial_cmp(&a.coupling_score)
843                .unwrap_or(std::cmp::Ordering::Equal)
844        });
845        pairs
846    }
847
848    /// Detect burst patterns in change events and predict which nodes are
849    /// likely to change next.
850    ///
851    /// A "burst" is a period where a node has significantly more changes
852    /// than its baseline rate. Nodes currently in a burst, or recently
853    /// co-modified with nodes in a burst, are predicted to change next.
854    ///
855    /// `window_size` is the number of recent events to consider for the
856    /// burst window. `baseline_factor` is the multiplier above which a
857    /// node's activity is considered a burst (e.g., 2.0 = 2x baseline).
858    ///
859    /// Returns nodes sorted by prediction confidence (descending).
860    pub fn predict_changes(
861        &self,
862        change_events: &[ChangeEvent],
863        window_size: usize,
864        baseline_factor: f64,
865    ) -> Vec<ChangePrediction> {
866        if change_events.is_empty() {
867            return Vec::new();
868        }
869
870        // Sort events by timestamp.
871        let mut sorted_events = change_events.to_vec();
872        sorted_events.sort_by_key(|e| e.timestamp);
873
874        let total = sorted_events.len();
875        let window_start = total.saturating_sub(window_size);
876
877        // Compute baseline rate (changes per event across all history).
878        let mut total_counts: std::collections::HashMap<NodeId, usize> =
879            std::collections::HashMap::new();
880        for event in &sorted_events {
881            for &id in &event.node_ids {
882                *total_counts.entry(id).or_insert(0) += 1;
883            }
884        }
885
886        // Compute window rate.
887        let mut window_counts: std::collections::HashMap<NodeId, usize> =
888            std::collections::HashMap::new();
889        for event in &sorted_events[window_start..] {
890            for &id in &event.node_ids {
891                *window_counts.entry(id).or_insert(0) += 1;
892            }
893        }
894
895        let window_len = total - window_start;
896
897        // Identify nodes in burst.
898        let mut burst_nodes: Vec<(NodeId, f64)> = Vec::new();
899        for (&id, &window_count) in &window_counts {
900            let total_count = total_counts.get(&id).copied().unwrap_or(0);
901            let baseline_rate = total_count as f64 / total as f64;
902            let window_rate = window_count as f64 / window_len as f64;
903
904            if baseline_rate > 0.0 && window_rate / baseline_rate >= baseline_factor {
905                burst_nodes.push((id, window_rate / baseline_rate));
906            }
907        }
908
909        // Compute coupling to identify co-modification partners.
910        let coupling = self.compute_coupling(change_events);
911        let coupling_map: std::collections::HashMap<(NodeId, NodeId), f64> = coupling
912            .iter()
913            .map(|p| ((p.node_a, p.node_b), p.coupling_score))
914            .collect();
915
916        // Score all nodes.
917        let mut predictions: std::collections::HashMap<NodeId, f64> =
918            std::collections::HashMap::new();
919
920        // Burst nodes get high base confidence.
921        for &(id, burst_ratio) in &burst_nodes {
922            *predictions.entry(id).or_insert(0.0) += burst_ratio * 0.6;
923        }
924
925        // Coupled partners of burst nodes get transitive confidence.
926        for &(burst_id, burst_ratio) in &burst_nodes {
927            for (&(a, b), &coupling_score) in &coupling_map {
928                let partner = if a == burst_id {
929                    Some(b)
930                } else if b == burst_id {
931                    Some(a)
932                } else {
933                    None
934                };
935                if let Some(partner_id) = partner {
936                    *predictions.entry(partner_id).or_insert(0.0) +=
937                        burst_ratio * coupling_score * 0.4;
938                }
939            }
940        }
941
942        // Recent activity boost: nodes that appeared in the last few events.
943        let recency_window = (window_size / 3).max(1);
944        let recency_start = total.saturating_sub(recency_window);
945        for event in &sorted_events[recency_start..] {
946            for &id in &event.node_ids {
947                *predictions.entry(id).or_insert(0.0) += 0.1;
948            }
949        }
950
951        let mut result: Vec<ChangePrediction> = predictions
952            .into_iter()
953            .map(|(id, confidence)| {
954                let label = self
955                    .get_node(id)
956                    .map(|n| n.label.clone())
957                    .unwrap_or_else(|| format!("node:{id}"));
958                let in_burst = burst_nodes.iter().any(|&(bid, _)| bid == id);
959                ChangePrediction {
960                    node_id: id,
961                    label,
962                    confidence: confidence.min(1.0),
963                    in_burst,
964                    recent_changes: window_counts.get(&id).copied().unwrap_or(0),
965                }
966            })
967            .collect();
968
969        result.sort_by(|a, b| {
970            b.confidence
971                .partial_cmp(&a.confidence)
972                .unwrap_or(std::cmp::Ordering::Equal)
973        });
974        result
975    }
976}
977
978// ---------------------------------------------------------------------------
979// Supporting types
980// ---------------------------------------------------------------------------
981
982/// Result of spectral analysis on the causal graph.
983#[derive(Debug, Clone)]
984pub struct SpectralResult {
985    /// Algebraic connectivity (second-smallest Laplacian eigenvalue).
986    /// 0.0 means disconnected. Higher = more connected.
987    pub lambda_2: f64,
988    /// Fiedler vector — sign indicates spectral partition membership.
989    pub fiedler_vector: Vec<f64>,
990    /// Node IDs in the same order as the Fiedler vector.
991    pub node_ids: Vec<NodeId>,
992}
993
994/// A temporal change event: a set of nodes that changed together.
995#[derive(Debug, Clone)]
996pub struct ChangeEvent {
997    /// Nodes that changed in this event (e.g., modules modified in a commit).
998    pub node_ids: Vec<NodeId>,
999    /// Timestamp of the event.
1000    pub timestamp: u64,
1001}
1002
1003/// Coupling between two nodes based on co-modification frequency.
1004#[derive(Debug, Clone)]
1005pub struct CouplingPair {
1006    /// First node.
1007    pub node_a: NodeId,
1008    /// Second node.
1009    pub node_b: NodeId,
1010    /// Number of times both changed in the same event.
1011    pub co_changes: usize,
1012    /// Coupling score: co_changes / max(changes_a, changes_b).
1013    pub coupling_score: f64,
1014}
1015
1016/// A prediction that a node will change soon.
1017#[derive(Debug, Clone)]
1018pub struct ChangePrediction {
1019    /// Node ID.
1020    pub node_id: NodeId,
1021    /// Human-readable label.
1022    pub label: String,
1023    /// Prediction confidence (0.0 .. 1.0).
1024    pub confidence: f64,
1025    /// Whether this node is currently in a burst pattern.
1026    pub in_burst: bool,
1027    /// Number of changes in the recent window.
1028    pub recent_changes: usize,
1029}
1030
1031/// L2-normalize a vector in place.
1032fn normalize_vec(v: &mut [f64]) {
1033    let norm: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
1034    if norm > 1e-12 {
1035        v.iter_mut().for_each(|x| *x /= norm);
1036    }
1037}
1038
1039/// Compute all eigenvalues and eigenvectors of a symmetric tridiagonal matrix.
1040///
1041/// * `diag` — main diagonal (length m).
1042/// * `off`  — sub-diagonal (length m-1).
1043///
1044/// Returns `(eigenvalues, eigenvectors)` where `eigenvectors[i]` is the
1045/// eigenvector for `eigenvalues[i]`, each of length `m`.
1046///
1047/// Uses the Jacobi eigenvalue algorithm on the full m x m symmetric matrix
1048/// built from the tridiagonal.  Since m is small (Lanczos iteration count,
1049/// typically 20-50), the O(m^3) cost is negligible compared to the O(k*m)
1050/// sparse mat-vecs.
1051fn tridiag_eigen(diag: &[f64], off: &[f64], m: usize) -> (Vec<f64>, Vec<Vec<f64>>) {
1052    if m == 0 {
1053        return (Vec::new(), Vec::new());
1054    }
1055    if m == 1 {
1056        return (vec![diag[0]], vec![vec![1.0]]);
1057    }
1058
1059    // Build full symmetric matrix from the tridiagonal.
1060    let mut a = vec![vec![0.0f64; m]; m];
1061    for i in 0..m {
1062        a[i][i] = diag[i];
1063    }
1064    let off_len = off.len().min(m - 1);
1065    for i in 0..off_len {
1066        a[i][i + 1] = off[i];
1067        a[i + 1][i] = off[i];
1068    }
1069
1070    // Eigenvector matrix V (columns are eigenvectors), starts as identity.
1071    let mut v = vec![vec![0.0f64; m]; m];
1072    for i in 0..m {
1073        v[i][i] = 1.0;
1074    }
1075
1076    // Jacobi cyclic sweeps.
1077    for _ in 0..100 * m {
1078        // Find the largest off-diagonal element.
1079        let mut max_off = 0.0f64;
1080        let mut p = 0usize;
1081        let mut q = 1usize;
1082        for i in 0..m {
1083            for j in (i + 1)..m {
1084                if a[i][j].abs() > max_off {
1085                    max_off = a[i][j].abs();
1086                    p = i;
1087                    q = j;
1088                }
1089            }
1090        }
1091
1092        if max_off < 1e-15 {
1093            break;
1094        }
1095
1096        // Compute Jacobi rotation angle to zero out a[p][q].
1097        let theta = (a[q][q] - a[p][p]) / (2.0 * a[p][q]);
1098        let t = theta.signum() / (theta.abs() + (1.0 + theta * theta).sqrt());
1099        let c = 1.0 / (1.0 + t * t).sqrt();
1100        let s = t * c;
1101
1102        // Apply similarity rotation to A.
1103        let app = a[p][p];
1104        let aqq = a[q][q];
1105        let apq = a[p][q];
1106        a[p][p] = c * c * app - 2.0 * s * c * apq + s * s * aqq;
1107        a[q][q] = s * s * app + 2.0 * s * c * apq + c * c * aqq;
1108        a[p][q] = 0.0;
1109        a[q][p] = 0.0;
1110
1111        for r in 0..m {
1112            if r != p && r != q {
1113                let arp = a[r][p];
1114                let arq = a[r][q];
1115                a[r][p] = c * arp - s * arq;
1116                a[p][r] = a[r][p];
1117                a[r][q] = s * arp + c * arq;
1118                a[q][r] = a[r][q];
1119            }
1120        }
1121
1122        // Accumulate rotation into eigenvector matrix.
1123        for r in 0..m {
1124            let vp = v[r][p];
1125            let vq = v[r][q];
1126            v[r][p] = c * vp - s * vq;
1127            v[r][q] = s * vp + c * vq;
1128        }
1129    }
1130
1131    // Eigenvalues are the diagonal of A.
1132    let eigenvalues: Vec<f64> = (0..m).map(|i| a[i][i]).collect();
1133
1134    // Eigenvectors: column j of V is the eigenvector for eigenvalue j.
1135    // Return as eigenvectors[j] = column j.
1136    let eigenvectors: Vec<Vec<f64>> = (0..m)
1137        .map(|j| (0..m).map(|i| v[i][j]).collect())
1138        .collect();
1139
1140    (eigenvalues, eigenvectors)
1141}
1142
1143// ---------------------------------------------------------------------------
1144// Persistence
1145// ---------------------------------------------------------------------------
1146
1147/// Serializable snapshot of a [`CausalGraph`] for JSON persistence.
1148#[derive(Serialize, Deserialize)]
1149struct CausalGraphSnapshot {
1150    next_node_id: u64,
1151    nodes: Vec<CausalNode>,
1152    forward_edges: std::collections::HashMap<NodeId, Vec<CausalEdge>>,
1153}
1154
1155impl CausalGraph {
1156    /// Serialize the entire graph to a JSON writer.
1157    pub fn save_to_writer<W: std::io::Write>(&self, writer: W) -> Result<(), std::io::Error> {
1158        let snapshot = self.to_snapshot();
1159        serde_json::to_writer_pretty(writer, &snapshot)
1160            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))
1161    }
1162
1163    /// Deserialize a graph from a JSON reader.
1164    pub fn load_from_reader<R: std::io::Read>(reader: R) -> Result<Self, std::io::Error> {
1165        let snapshot: CausalGraphSnapshot = serde_json::from_reader(reader)
1166            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
1167        Ok(Self::from_snapshot(snapshot))
1168    }
1169
1170    /// Save the graph to a file path.
1171    pub fn save_to_file(&self, path: &std::path::Path) -> Result<(), std::io::Error> {
1172        if let Some(parent) = path.parent() {
1173            std::fs::create_dir_all(parent)?;
1174        }
1175        let file = std::fs::File::create(path)?;
1176        let writer = std::io::BufWriter::new(file);
1177        self.save_to_writer(writer)
1178    }
1179
1180    /// Load a graph from a file path.
1181    pub fn load_from_file(path: &std::path::Path) -> Result<Self, std::io::Error> {
1182        let file = std::fs::File::open(path)?;
1183        let reader = std::io::BufReader::new(file);
1184        Self::load_from_reader(reader)
1185    }
1186
1187    fn to_snapshot(&self) -> CausalGraphSnapshot {
1188        let nodes: Vec<CausalNode> = self.nodes.iter().map(|r| r.value().clone()).collect();
1189        let mut forward_edges = std::collections::HashMap::new();
1190        for entry in self.forward_edges.iter() {
1191            if !entry.value().is_empty() {
1192                forward_edges.insert(*entry.key(), entry.value().clone());
1193            }
1194        }
1195        CausalGraphSnapshot {
1196            next_node_id: self.next_node_id.load(Ordering::SeqCst),
1197            nodes,
1198            forward_edges,
1199        }
1200    }
1201
1202    fn from_snapshot(snapshot: CausalGraphSnapshot) -> Self {
1203        let graph = Self {
1204            nodes: DashMap::new(),
1205            forward_edges: DashMap::new(),
1206            reverse_edges: DashMap::new(),
1207            next_node_id: AtomicU64::new(snapshot.next_node_id),
1208            node_count: AtomicU64::new(0),
1209            edge_count: AtomicU64::new(0),
1210        };
1211
1212        // Restore nodes.
1213        for node in &snapshot.nodes {
1214            graph.nodes.insert(node.id, node.clone());
1215            graph.forward_edges.insert(node.id, Vec::new());
1216            graph.reverse_edges.insert(node.id, Vec::new());
1217        }
1218        graph.node_count.store(snapshot.nodes.len() as u64, Ordering::SeqCst);
1219
1220        // Restore edges from forward_edges map.
1221        let mut total_edges: u64 = 0;
1222        for (source_id, edges) in &snapshot.forward_edges {
1223            for edge in edges {
1224                if let Some(mut fwd) = graph.forward_edges.get_mut(source_id) {
1225                    fwd.push(edge.clone());
1226                }
1227                if let Some(mut rev) = graph.reverse_edges.get_mut(&edge.target) {
1228                    rev.push(edge.clone());
1229                }
1230                total_edges += 1;
1231            }
1232        }
1233        graph.edge_count.store(total_edges, Ordering::SeqCst);
1234
1235        graph
1236    }
1237}
1238
1239impl Default for CausalGraph {
1240    fn default() -> Self {
1241        Self::new()
1242    }
1243}
1244
1245// ===========================================================================
1246// Tests
1247// ===========================================================================
1248
1249#[cfg(test)]
1250mod tests {
1251    use super::*;
1252
1253    fn make_graph() -> CausalGraph {
1254        CausalGraph::new()
1255    }
1256
1257    // 1
1258    #[test]
1259    fn new_graph_empty() {
1260        let g = make_graph();
1261        assert_eq!(g.node_count(), 0);
1262        assert_eq!(g.edge_count(), 0);
1263    }
1264
1265    // 2
1266    #[test]
1267    fn add_node_returns_id() {
1268        let g = make_graph();
1269        let id1 = g.add_node("A".into(), serde_json::json!({}));
1270        let id2 = g.add_node("B".into(), serde_json::json!({}));
1271        assert_ne!(id1, id2);
1272        assert_eq!(g.node_count(), 2);
1273    }
1274
1275    // 3
1276    #[test]
1277    fn get_node() {
1278        let g = make_graph();
1279        let id = g.add_node("hello".into(), serde_json::json!({"key": "val"}));
1280        let node = g.get_node(id).unwrap();
1281        assert_eq!(node.label, "hello");
1282        assert_eq!(node.metadata["key"], "val");
1283    }
1284
1285    // 4
1286    #[test]
1287    fn remove_node() {
1288        let g = make_graph();
1289        let id = g.add_node("X".into(), serde_json::json!({}));
1290        assert!(g.get_node(id).is_some());
1291        let removed = g.remove_node(id).unwrap();
1292        assert_eq!(removed.label, "X");
1293        assert!(g.get_node(id).is_none());
1294        assert_eq!(g.node_count(), 0);
1295    }
1296
1297    // 5
1298    #[test]
1299    fn link_creates_edge() {
1300        let g = make_graph();
1301        let a = g.add_node("A".into(), serde_json::json!({}));
1302        let b = g.add_node("B".into(), serde_json::json!({}));
1303        assert!(g.link(a, b, CausalEdgeType::Causes, 0.9, 100, 1));
1304        assert_eq!(g.edge_count(), 1);
1305    }
1306
1307    // 6
1308    #[test]
1309    fn link_invalid_source_returns_false() {
1310        let g = make_graph();
1311        let b = g.add_node("B".into(), serde_json::json!({}));
1312        assert!(!g.link(9999, b, CausalEdgeType::Causes, 0.5, 0, 0));
1313        assert_eq!(g.edge_count(), 0);
1314    }
1315
1316    // 7
1317    #[test]
1318    fn link_invalid_target_returns_false() {
1319        let g = make_graph();
1320        let a = g.add_node("A".into(), serde_json::json!({}));
1321        assert!(!g.link(a, 9999, CausalEdgeType::Causes, 0.5, 0, 0));
1322        assert_eq!(g.edge_count(), 0);
1323    }
1324
1325    // 8
1326    #[test]
1327    fn get_forward_edges() {
1328        let g = make_graph();
1329        let a = g.add_node("A".into(), serde_json::json!({}));
1330        let b = g.add_node("B".into(), serde_json::json!({}));
1331        let c = g.add_node("C".into(), serde_json::json!({}));
1332        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1333        g.link(a, c, CausalEdgeType::Enables, 0.5, 0, 0);
1334        let fwd = g.get_forward_edges(a);
1335        assert_eq!(fwd.len(), 2);
1336        assert!(fwd.iter().any(|e| e.target == b));
1337        assert!(fwd.iter().any(|e| e.target == c));
1338    }
1339
1340    // 9
1341    #[test]
1342    fn get_reverse_edges() {
1343        let g = make_graph();
1344        let a = g.add_node("A".into(), serde_json::json!({}));
1345        let b = g.add_node("B".into(), serde_json::json!({}));
1346        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1347        let rev = g.get_reverse_edges(b);
1348        assert_eq!(rev.len(), 1);
1349        assert_eq!(rev[0].source, a);
1350    }
1351
1352    // 10
1353    #[test]
1354    fn get_edges_by_type() {
1355        let g = make_graph();
1356        let a = g.add_node("A".into(), serde_json::json!({}));
1357        let b = g.add_node("B".into(), serde_json::json!({}));
1358        let c = g.add_node("C".into(), serde_json::json!({}));
1359        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1360        g.link(a, c, CausalEdgeType::Inhibits, 0.3, 0, 0);
1361        let causes = g.get_edges_by_type(a, &CausalEdgeType::Causes);
1362        assert_eq!(causes.len(), 1);
1363        assert_eq!(causes[0].target, b);
1364    }
1365
1366    // 11
1367    #[test]
1368    fn unlink_removes_edges() {
1369        let g = make_graph();
1370        let a = g.add_node("A".into(), serde_json::json!({}));
1371        let b = g.add_node("B".into(), serde_json::json!({}));
1372        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1373        g.link(a, b, CausalEdgeType::Enables, 0.5, 0, 0);
1374        assert_eq!(g.edge_count(), 2);
1375        let removed = g.unlink(a, b);
1376        assert_eq!(removed, 2);
1377        assert_eq!(g.edge_count(), 0);
1378        assert!(g.get_forward_edges(a).is_empty());
1379        assert!(g.get_reverse_edges(b).is_empty());
1380    }
1381
1382    // 12
1383    #[test]
1384    fn node_count() {
1385        let g = make_graph();
1386        assert_eq!(g.node_count(), 0);
1387        g.add_node("A".into(), serde_json::json!({}));
1388        assert_eq!(g.node_count(), 1);
1389        let id = g.add_node("B".into(), serde_json::json!({}));
1390        assert_eq!(g.node_count(), 2);
1391        g.remove_node(id);
1392        assert_eq!(g.node_count(), 1);
1393    }
1394
1395    // 13
1396    #[test]
1397    fn edge_count() {
1398        let g = make_graph();
1399        let a = g.add_node("A".into(), serde_json::json!({}));
1400        let b = g.add_node("B".into(), serde_json::json!({}));
1401        let c = g.add_node("C".into(), serde_json::json!({}));
1402        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1403        g.link(b, c, CausalEdgeType::Follows, 0.8, 0, 0);
1404        assert_eq!(g.edge_count(), 2);
1405    }
1406
1407    // 14
1408    #[test]
1409    fn clear_empties_graph() {
1410        let g = make_graph();
1411        let a = g.add_node("A".into(), serde_json::json!({}));
1412        let b = g.add_node("B".into(), serde_json::json!({}));
1413        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1414        g.clear();
1415        assert_eq!(g.node_count(), 0);
1416        assert_eq!(g.edge_count(), 0);
1417    }
1418
1419    // 15 — A -> B, traverse 1 hop forward from A
1420    #[test]
1421    fn traverse_forward_single_hop() {
1422        let g = make_graph();
1423        let a = g.add_node("A".into(), serde_json::json!({}));
1424        let b = g.add_node("B".into(), serde_json::json!({}));
1425        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1426        let reachable = g.traverse_forward(a, 1);
1427        assert_eq!(reachable, vec![b]);
1428    }
1429
1430    // 16 — A -> B -> C, traverse 2 hops from A
1431    #[test]
1432    fn traverse_forward_multi_hop() {
1433        let g = make_graph();
1434        let a = g.add_node("A".into(), serde_json::json!({}));
1435        let b = g.add_node("B".into(), serde_json::json!({}));
1436        let c = g.add_node("C".into(), serde_json::json!({}));
1437        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1438        g.link(b, c, CausalEdgeType::Causes, 1.0, 0, 0);
1439        let reachable = g.traverse_forward(a, 2);
1440        assert!(reachable.contains(&b));
1441        assert!(reachable.contains(&c));
1442        assert_eq!(reachable.len(), 2);
1443    }
1444
1445    // 17 — A -> B -> C, traverse only 1 hop from A (should NOT reach C)
1446    #[test]
1447    fn traverse_forward_depth_limit() {
1448        let g = make_graph();
1449        let a = g.add_node("A".into(), serde_json::json!({}));
1450        let b = g.add_node("B".into(), serde_json::json!({}));
1451        let c = g.add_node("C".into(), serde_json::json!({}));
1452        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1453        g.link(b, c, CausalEdgeType::Causes, 1.0, 0, 0);
1454        let reachable = g.traverse_forward(a, 1);
1455        assert_eq!(reachable, vec![b]);
1456        assert!(!reachable.contains(&c));
1457    }
1458
1459    // 18 — A -> B -> C, traverse reverse from C
1460    #[test]
1461    fn traverse_reverse() {
1462        let g = make_graph();
1463        let a = g.add_node("A".into(), serde_json::json!({}));
1464        let b = g.add_node("B".into(), serde_json::json!({}));
1465        let c = g.add_node("C".into(), serde_json::json!({}));
1466        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1467        g.link(b, c, CausalEdgeType::Causes, 1.0, 0, 0);
1468        let reachable = g.traverse_reverse(c, 2);
1469        assert!(reachable.contains(&b));
1470        assert!(reachable.contains(&a));
1471    }
1472
1473    // 19
1474    #[test]
1475    fn find_path_exists() {
1476        let g = make_graph();
1477        let a = g.add_node("A".into(), serde_json::json!({}));
1478        let b = g.add_node("B".into(), serde_json::json!({}));
1479        let c = g.add_node("C".into(), serde_json::json!({}));
1480        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1481        g.link(b, c, CausalEdgeType::Causes, 1.0, 0, 0);
1482        let path = g.find_path(a, c, 5).unwrap();
1483        assert_eq!(path, vec![a, b, c]);
1484    }
1485
1486    // 20
1487    #[test]
1488    fn find_path_no_path() {
1489        let g = make_graph();
1490        let a = g.add_node("A".into(), serde_json::json!({}));
1491        let b = g.add_node("B".into(), serde_json::json!({}));
1492        // No edge between them.
1493        assert!(g.find_path(a, b, 5).is_none());
1494    }
1495
1496    // 21
1497    #[test]
1498    fn concurrent_add_nodes() {
1499        use std::sync::Arc;
1500        use std::thread;
1501
1502        let g = Arc::new(CausalGraph::new());
1503        let mut handles = Vec::new();
1504
1505        for t in 0..4 {
1506            let g = Arc::clone(&g);
1507            handles.push(thread::spawn(move || {
1508                for i in 0..25 {
1509                    g.add_node(format!("t{t}-n{i}"), serde_json::json!({}));
1510                }
1511            }));
1512        }
1513
1514        for h in handles {
1515            h.join().unwrap();
1516        }
1517
1518        assert_eq!(g.node_count(), 100);
1519    }
1520
1521    // 22
1522    #[test]
1523    fn causal_edge_type_display() {
1524        assert_eq!(CausalEdgeType::Causes.to_string(), "Causes");
1525        assert_eq!(CausalEdgeType::Inhibits.to_string(), "Inhibits");
1526        assert_eq!(CausalEdgeType::Correlates.to_string(), "Correlates");
1527        assert_eq!(CausalEdgeType::Enables.to_string(), "Enables");
1528        assert_eq!(CausalEdgeType::Follows.to_string(), "Follows");
1529        assert_eq!(CausalEdgeType::Contradicts.to_string(), "Contradicts");
1530        assert_eq!(CausalEdgeType::TriggeredBy.to_string(), "TriggeredBy");
1531        assert_eq!(CausalEdgeType::EvidenceFor.to_string(), "EvidenceFor");
1532    }
1533
1534    // =====================================================================
1535    // Degree tests
1536    // =====================================================================
1537
1538    // 23
1539    #[test]
1540    fn degree_computation() {
1541        let g = make_graph();
1542        let a = g.add_node("A".into(), serde_json::json!({}));
1543        let b = g.add_node("B".into(), serde_json::json!({}));
1544        let c = g.add_node("C".into(), serde_json::json!({}));
1545        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1546        g.link(a, c, CausalEdgeType::Enables, 0.5, 0, 0);
1547        g.link(b, c, CausalEdgeType::Follows, 0.8, 0, 0);
1548        assert_eq!(g.out_degree(a), 2);
1549        assert_eq!(g.in_degree(a), 0);
1550        assert_eq!(g.degree(a), 2);
1551        assert_eq!(g.in_degree(c), 2);
1552        assert_eq!(g.out_degree(c), 0);
1553        assert_eq!(g.degree(c), 2);
1554        assert_eq!(g.degree(b), 2); // 1 in + 1 out
1555    }
1556
1557    // =====================================================================
1558    // Connected Components tests
1559    // =====================================================================
1560
1561    // 24
1562    #[test]
1563    fn connected_components_single() {
1564        let g = make_graph();
1565        let a = g.add_node("A".into(), serde_json::json!({}));
1566        let b = g.add_node("B".into(), serde_json::json!({}));
1567        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1568        let cc = g.connected_components();
1569        assert_eq!(cc.len(), 1);
1570        assert_eq!(cc[0].len(), 2);
1571    }
1572
1573    // 25
1574    #[test]
1575    fn connected_components_two_islands() {
1576        let g = make_graph();
1577        let a = g.add_node("A".into(), serde_json::json!({}));
1578        let b = g.add_node("B".into(), serde_json::json!({}));
1579        let c = g.add_node("C".into(), serde_json::json!({}));
1580        let d = g.add_node("D".into(), serde_json::json!({}));
1581        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1582        g.link(c, d, CausalEdgeType::Causes, 1.0, 0, 0);
1583        let cc = g.connected_components();
1584        assert_eq!(cc.len(), 2);
1585        assert_eq!(cc[0].len(), 2);
1586        assert_eq!(cc[1].len(), 2);
1587    }
1588
1589    // 26
1590    #[test]
1591    fn connected_components_isolated_node() {
1592        let g = make_graph();
1593        let a = g.add_node("A".into(), serde_json::json!({}));
1594        let b = g.add_node("B".into(), serde_json::json!({}));
1595        g.add_node("isolated".into(), serde_json::json!({}));
1596        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1597        let cc = g.connected_components();
1598        assert_eq!(cc.len(), 2);
1599        assert_eq!(cc[0].len(), 2); // largest first
1600        assert_eq!(cc[1].len(), 1); // isolated
1601    }
1602
1603    // =====================================================================
1604    // Community Detection tests
1605    // =====================================================================
1606
1607    // 27
1608    #[test]
1609    fn community_detection_two_clusters() {
1610        let g = make_graph();
1611        // Cluster 1: A-B-C strongly connected.
1612        let a = g.add_node("A".into(), serde_json::json!({}));
1613        let b = g.add_node("B".into(), serde_json::json!({}));
1614        let c = g.add_node("C".into(), serde_json::json!({}));
1615        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1616        g.link(b, c, CausalEdgeType::Causes, 1.0, 0, 0);
1617        g.link(c, a, CausalEdgeType::Causes, 1.0, 0, 0);
1618
1619        // Cluster 2: D-E-F strongly connected.
1620        let d = g.add_node("D".into(), serde_json::json!({}));
1621        let e = g.add_node("E".into(), serde_json::json!({}));
1622        let f = g.add_node("F".into(), serde_json::json!({}));
1623        g.link(d, e, CausalEdgeType::Causes, 1.0, 0, 0);
1624        g.link(e, f, CausalEdgeType::Causes, 1.0, 0, 0);
1625        g.link(f, d, CausalEdgeType::Causes, 1.0, 0, 0);
1626
1627        // Weak bridge between clusters.
1628        g.link(c, d, CausalEdgeType::Correlates, 0.1, 0, 0);
1629
1630        let communities = g.detect_communities(20);
1631        // Should find 2 communities (clusters) even with the weak bridge.
1632        // Label propagation may merge them due to the bridge, but the strong
1633        // internal edges should dominate.
1634        assert!(!communities.is_empty());
1635        // At minimum, isolated nodes shouldn't be their own community.
1636        assert!(communities.len() <= 3);
1637    }
1638
1639    // 28
1640    #[test]
1641    fn community_detection_isolated_nodes() {
1642        let g = make_graph();
1643        g.add_node("A".into(), serde_json::json!({}));
1644        g.add_node("B".into(), serde_json::json!({}));
1645        let communities = g.detect_communities(10);
1646        // Each isolated node stays in its own community.
1647        assert_eq!(communities.len(), 2);
1648    }
1649
1650    // 29
1651    #[test]
1652    fn community_detection_empty_graph() {
1653        let g = make_graph();
1654        let communities = g.detect_communities(10);
1655        assert!(communities.is_empty());
1656    }
1657
1658    // =====================================================================
1659    // Spectral Analysis tests
1660    // =====================================================================
1661
1662    // 30
1663    #[test]
1664    fn spectral_connected_graph_positive_lambda2() {
1665        let g = make_graph();
1666        let a = g.add_node("A".into(), serde_json::json!({}));
1667        let b = g.add_node("B".into(), serde_json::json!({}));
1668        let c = g.add_node("C".into(), serde_json::json!({}));
1669        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1670        g.link(b, c, CausalEdgeType::Causes, 1.0, 0, 0);
1671        g.link(a, c, CausalEdgeType::Causes, 1.0, 0, 0);
1672
1673        let result = g.spectral_analysis(50);
1674        assert!(
1675            result.lambda_2 > 0.0,
1676            "connected graph should have lambda_2 > 0, got {}",
1677            result.lambda_2
1678        );
1679        assert_eq!(result.fiedler_vector.len(), 3);
1680        assert_eq!(result.node_ids.len(), 3);
1681    }
1682
1683    // 31
1684    #[test]
1685    fn spectral_disconnected_graph_zero_lambda2() {
1686        let g = make_graph();
1687        let a = g.add_node("A".into(), serde_json::json!({}));
1688        let b = g.add_node("B".into(), serde_json::json!({}));
1689        // No edges — disconnected.
1690        let result = g.spectral_analysis(50);
1691        assert!(
1692            result.lambda_2 < 0.01,
1693            "disconnected graph should have lambda_2 ~ 0, got {}",
1694            result.lambda_2
1695        );
1696        assert_eq!(result.node_ids.len(), 2);
1697    }
1698
1699    // 32
1700    #[test]
1701    fn spectral_partition_splits_graph() {
1702        let g = make_graph();
1703        let a = g.add_node("A".into(), serde_json::json!({}));
1704        let b = g.add_node("B".into(), serde_json::json!({}));
1705        let c = g.add_node("C".into(), serde_json::json!({}));
1706        let d = g.add_node("D".into(), serde_json::json!({}));
1707        // Two clusters with weak bridge.
1708        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1709        g.link(c, d, CausalEdgeType::Causes, 1.0, 0, 0);
1710        g.link(b, c, CausalEdgeType::Correlates, 0.1, 0, 0);
1711
1712        let (part_a, part_b) = g.spectral_partition();
1713        assert!(!part_a.is_empty());
1714        assert!(!part_b.is_empty());
1715        assert_eq!(part_a.len() + part_b.len(), 4);
1716    }
1717
1718    // 33
1719    #[test]
1720    fn spectral_single_node() {
1721        let g = make_graph();
1722        g.add_node("A".into(), serde_json::json!({}));
1723        let result = g.spectral_analysis(50);
1724        assert_eq!(result.lambda_2, 0.0);
1725    }
1726
1727    // =====================================================================
1728    // Coupling / Predictive Analysis tests
1729    // =====================================================================
1730
1731    // 34
1732    #[test]
1733    fn coupling_basic() {
1734        let g = make_graph();
1735        let a = g.add_node("A".into(), serde_json::json!({}));
1736        let b = g.add_node("B".into(), serde_json::json!({}));
1737        let c = g.add_node("C".into(), serde_json::json!({}));
1738
1739        let events = vec![
1740            ChangeEvent { node_ids: vec![a, b], timestamp: 1 },
1741            ChangeEvent { node_ids: vec![a, b], timestamp: 2 },
1742            ChangeEvent { node_ids: vec![a, c], timestamp: 3 },
1743            ChangeEvent { node_ids: vec![b, c], timestamp: 4 },
1744        ];
1745
1746        let coupling = g.compute_coupling(&events);
1747        assert!(!coupling.is_empty());
1748
1749        // A-B co-changed 2 times out of max(3,3)=3 → 0.67
1750        let ab = coupling.iter().find(|p| {
1751            (p.node_a == a && p.node_b == b) || (p.node_a == b && p.node_b == a)
1752        });
1753        assert!(ab.is_some());
1754        let ab = ab.unwrap();
1755        assert_eq!(ab.co_changes, 2);
1756        assert!((ab.coupling_score - 2.0 / 3.0).abs() < 0.01);
1757    }
1758
1759    // 35
1760    #[test]
1761    fn coupling_empty_events() {
1762        let g = make_graph();
1763        let coupling = g.compute_coupling(&[]);
1764        assert!(coupling.is_empty());
1765    }
1766
1767    // 36
1768    #[test]
1769    fn predict_changes_burst_detection() {
1770        let g = make_graph();
1771        let a = g.add_node("module_a".into(), serde_json::json!({}));
1772        let b = g.add_node("module_b".into(), serde_json::json!({}));
1773        let c = g.add_node("module_c".into(), serde_json::json!({}));
1774
1775        // History: 50 events. Module A changes rarely (every 10th event).
1776        // Module C fills the rest so there are plenty of events.
1777        let mut events = Vec::new();
1778        for i in 0..50 {
1779            events.push(ChangeEvent { node_ids: vec![c], timestamp: i });
1780            if i % 10 == 0 {
1781                events.push(ChangeEvent { node_ids: vec![a], timestamp: i });
1782            }
1783        }
1784        // Burst window: module A + B change together in every recent event.
1785        for i in 50..60 {
1786            events.push(ChangeEvent { node_ids: vec![a, b], timestamp: i });
1787        }
1788        // Module A baseline: ~5 changes in ~55 events before window (rate ~0.09).
1789        // Module A window: 10 changes in 10 events (rate 1.0).
1790        // Burst ratio: ~11x — well above 1.5 threshold.
1791
1792        let predictions = g.predict_changes(&events, 10, 1.5);
1793        assert!(!predictions.is_empty());
1794
1795        // Module A should be predicted (in burst).
1796        let pred_a = predictions.iter().find(|p| p.node_id == a);
1797        assert!(pred_a.is_some(), "module_a should be in predictions");
1798        assert!(pred_a.unwrap().in_burst, "module_a should be in burst");
1799
1800        // Module B should be predicted (co-modified with A during burst).
1801        let pred_b = predictions.iter().find(|p| p.node_id == b);
1802        assert!(pred_b.is_some(), "module_b should be predicted via coupling");
1803    }
1804
1805    // 37
1806    #[test]
1807    fn predict_changes_empty_events() {
1808        let g = make_graph();
1809        let predictions = g.predict_changes(&[], 5, 2.0);
1810        assert!(predictions.is_empty());
1811    }
1812
1813    // 38
1814    #[test]
1815    fn node_ids_returns_all() {
1816        let g = make_graph();
1817        let a = g.add_node("A".into(), serde_json::json!({}));
1818        let b = g.add_node("B".into(), serde_json::json!({}));
1819        let ids = g.node_ids();
1820        assert_eq!(ids.len(), 2);
1821        assert!(ids.contains(&a));
1822        assert!(ids.contains(&b));
1823    }
1824
1825    // 39
1826    #[test]
1827    fn spectral_strongly_connected_high_lambda2() {
1828        // A complete graph of 4 nodes should have high lambda_2.
1829        let g = make_graph();
1830        let nodes: Vec<NodeId> = (0..4)
1831            .map(|i| g.add_node(format!("N{i}"), serde_json::json!({})))
1832            .collect();
1833        for i in 0..4 {
1834            for j in 0..4 {
1835                if i != j {
1836                    g.link(nodes[i], nodes[j], CausalEdgeType::Causes, 1.0, 0, 0);
1837                }
1838            }
1839        }
1840        let result = g.spectral_analysis(50);
1841        // For K4, lambda_2 should be 4.0 (all eigenvalues of Laplacian of K4 are 0,4,4,4).
1842        assert!(
1843            result.lambda_2 > 3.0,
1844            "complete graph K4 lambda_2 should be ~4.0, got {}",
1845            result.lambda_2
1846        );
1847    }
1848
1849    // ── Persistence tests ────────────────────────────────────────────
1850
1851    fn tmp_path(name: &str) -> std::path::PathBuf {
1852        std::env::temp_dir().join(format!(
1853            "causal_test_{name}_{}",
1854            std::time::SystemTime::now()
1855                .duration_since(std::time::UNIX_EPOCH)
1856                .unwrap()
1857                .as_nanos()
1858        ))
1859    }
1860
1861    // 40
1862    #[test]
1863    fn persist_empty_graph_roundtrip() {
1864        let g = make_graph();
1865        let path = tmp_path("empty");
1866        g.save_to_file(&path).unwrap();
1867        let loaded = CausalGraph::load_from_file(&path).unwrap();
1868        assert_eq!(loaded.node_count(), 0);
1869        assert_eq!(loaded.edge_count(), 0);
1870        let _ = std::fs::remove_file(&path);
1871    }
1872
1873    // 41
1874    #[test]
1875    fn persist_nodes_and_edges_roundtrip() {
1876        let g = make_graph();
1877        let a = g.add_node("Alpha".into(), serde_json::json!({"role": "source"}));
1878        let b = g.add_node("Beta".into(), serde_json::json!({"role": "target"}));
1879        let c = g.add_node("Gamma".into(), serde_json::json!({}));
1880        g.link(a, b, CausalEdgeType::Causes, 0.9, 100, 1);
1881        g.link(b, c, CausalEdgeType::Enables, 0.5, 200, 2);
1882        g.link(a, c, CausalEdgeType::Inhibits, 0.3, 300, 3);
1883
1884        let path = tmp_path("nodes_edges");
1885        g.save_to_file(&path).unwrap();
1886        let loaded = CausalGraph::load_from_file(&path).unwrap();
1887
1888        assert_eq!(loaded.node_count(), 3);
1889        assert_eq!(loaded.edge_count(), 3);
1890
1891        // Verify node data.
1892        let na = loaded.get_node(a).unwrap();
1893        assert_eq!(na.label, "Alpha");
1894        assert_eq!(na.metadata["role"], "source");
1895
1896        let nb = loaded.get_node(b).unwrap();
1897        assert_eq!(nb.label, "Beta");
1898
1899        // Verify edges.
1900        let fwd_a = loaded.get_forward_edges(a);
1901        assert_eq!(fwd_a.len(), 2);
1902        assert!(fwd_a.iter().any(|e| e.target == b && e.edge_type == CausalEdgeType::Causes));
1903        assert!(fwd_a.iter().any(|e| e.target == c && e.edge_type == CausalEdgeType::Inhibits));
1904
1905        let _ = std::fs::remove_file(&path);
1906    }
1907
1908    // 42
1909    #[test]
1910    fn persist_node_metadata_survives() {
1911        let g = make_graph();
1912        let id = g.add_node("rich".into(), serde_json::json!({
1913            "tags": ["a", "b"],
1914            "count": 42,
1915            "nested": {"x": true}
1916        }));
1917
1918        let path = tmp_path("metadata");
1919        g.save_to_file(&path).unwrap();
1920        let loaded = CausalGraph::load_from_file(&path).unwrap();
1921        let node = loaded.get_node(id).unwrap();
1922        assert_eq!(node.metadata["tags"][0], "a");
1923        assert_eq!(node.metadata["count"], 42);
1924        assert_eq!(node.metadata["nested"]["x"], true);
1925        let _ = std::fs::remove_file(&path);
1926    }
1927
1928    // 43
1929    #[test]
1930    fn persist_edge_types_and_weights() {
1931        let g = make_graph();
1932        let a = g.add_node("A".into(), serde_json::json!({}));
1933        let b = g.add_node("B".into(), serde_json::json!({}));
1934        g.link(a, b, CausalEdgeType::Contradicts, 0.77, 555, 10);
1935
1936        let path = tmp_path("edge_types");
1937        g.save_to_file(&path).unwrap();
1938        let loaded = CausalGraph::load_from_file(&path).unwrap();
1939        let edges = loaded.get_forward_edges(a);
1940        assert_eq!(edges.len(), 1);
1941        assert_eq!(edges[0].edge_type, CausalEdgeType::Contradicts);
1942        assert!((edges[0].weight - 0.77).abs() < 0.001);
1943        assert_eq!(edges[0].timestamp, 555);
1944        assert_eq!(edges[0].chain_seq, 10);
1945        let _ = std::fs::remove_file(&path);
1946    }
1947
1948    // 44
1949    #[test]
1950    fn persist_next_node_id_preserved() {
1951        let g = make_graph();
1952        let _a = g.add_node("A".into(), serde_json::json!({}));
1953        let _b = g.add_node("B".into(), serde_json::json!({}));
1954        let _c = g.add_node("C".into(), serde_json::json!({}));
1955        // next_node_id should be 4 now (started at 1, added 3 nodes).
1956
1957        let path = tmp_path("next_id");
1958        g.save_to_file(&path).unwrap();
1959        let loaded = CausalGraph::load_from_file(&path).unwrap();
1960
1961        // Adding a new node should get id >= 4, not collide with existing.
1962        let new_id = loaded.add_node("D".into(), serde_json::json!({}));
1963        assert!(new_id >= 4, "new node should get id >= 4, got {new_id}");
1964        assert!(loaded.get_node(new_id).is_some());
1965        assert_eq!(loaded.node_count(), 4);
1966        let _ = std::fs::remove_file(&path);
1967    }
1968
1969    // 45
1970    #[test]
1971    fn persist_writer_reader_roundtrip() {
1972        let g = make_graph();
1973        let a = g.add_node("X".into(), serde_json::json!({}));
1974        let b = g.add_node("Y".into(), serde_json::json!({}));
1975        g.link(a, b, CausalEdgeType::Follows, 1.0, 0, 0);
1976
1977        let mut buf = Vec::new();
1978        g.save_to_writer(&mut buf).unwrap();
1979
1980        let loaded = CausalGraph::load_from_reader(buf.as_slice()).unwrap();
1981        assert_eq!(loaded.node_count(), 2);
1982        assert_eq!(loaded.edge_count(), 1);
1983        let edges = loaded.get_forward_edges(a);
1984        assert_eq!(edges[0].edge_type, CausalEdgeType::Follows);
1985    }
1986
1987    // 46
1988    #[test]
1989    fn persist_reverse_edges_rebuilt() {
1990        let g = make_graph();
1991        let a = g.add_node("A".into(), serde_json::json!({}));
1992        let b = g.add_node("B".into(), serde_json::json!({}));
1993        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
1994
1995        let path = tmp_path("reverse");
1996        g.save_to_file(&path).unwrap();
1997        let loaded = CausalGraph::load_from_file(&path).unwrap();
1998
1999        // Reverse edges should be rebuilt from forward edges.
2000        let rev = loaded.get_reverse_edges(b);
2001        assert_eq!(rev.len(), 1);
2002        assert_eq!(rev[0].source, a);
2003        let _ = std::fs::remove_file(&path);
2004    }
2005
2006    // =====================================================================
2007    // Sparse Lanczos vs dense reference comparison
2008    // =====================================================================
2009
2010    /// Dense reference: compute lambda_2 via Jacobi eigendecomposition of the
2011    /// full Laplacian.  O(n^3) — only used in tests for correctness checking.
2012    fn dense_spectral_lambda2(g: &CausalGraph, _max_iterations: usize) -> f64 {
2013        let ids = g.node_ids();
2014        let n = ids.len();
2015        if n < 2 {
2016            return 0.0;
2017        }
2018
2019        let mut id_to_idx: std::collections::HashMap<NodeId, usize> =
2020            std::collections::HashMap::new();
2021        let mut sorted_ids = ids.clone();
2022        sorted_ids.sort();
2023        for (i, &id) in sorted_ids.iter().enumerate() {
2024            id_to_idx.insert(id, i);
2025        }
2026
2027        let mut laplacian = vec![vec![0.0f64; n]; n];
2028        for &id in &sorted_ids {
2029            let i = id_to_idx[&id];
2030            for edge in g.get_forward_edges(id) {
2031                if let Some(&j) = id_to_idx.get(&edge.target) {
2032                    if i != j {
2033                        let w = edge.weight as f64;
2034                        laplacian[i][j] -= w;
2035                        laplacian[j][i] -= w;
2036                    }
2037                }
2038            }
2039            for edge in g.get_reverse_edges(id) {
2040                if let Some(&j) = id_to_idx.get(&edge.source) {
2041                    if i != j && j > i {
2042                        let w = edge.weight as f64;
2043                        laplacian[i][j] -= w;
2044                        laplacian[j][i] -= w;
2045                    }
2046                }
2047            }
2048        }
2049        for i in 0..n {
2050            let off_sum: f64 = (0..n).filter(|&j| j != i).map(|j| -laplacian[i][j]).sum();
2051            laplacian[i][i] = off_sum;
2052        }
2053
2054        // Jacobi eigendecomposition of the full Laplacian.
2055        let mut a = laplacian;
2056        for _ in 0..100 * n {
2057            let mut max_off = 0.0f64;
2058            let mut p = 0usize;
2059            let mut q = 1usize;
2060            for i in 0..n {
2061                for j in (i + 1)..n {
2062                    if a[i][j].abs() > max_off {
2063                        max_off = a[i][j].abs();
2064                        p = i;
2065                        q = j;
2066                    }
2067                }
2068            }
2069            if max_off < 1e-15 { break; }
2070            let theta = (a[q][q] - a[p][p]) / (2.0 * a[p][q]);
2071            let t = theta.signum() / (theta.abs() + (1.0 + theta * theta).sqrt());
2072            let c = 1.0 / (1.0 + t * t).sqrt();
2073            let s = t * c;
2074            let app = a[p][p]; let aqq = a[q][q]; let apq = a[p][q];
2075            a[p][p] = c * c * app - 2.0 * s * c * apq + s * s * aqq;
2076            a[q][q] = s * s * app + 2.0 * s * c * apq + c * c * aqq;
2077            a[p][q] = 0.0;
2078            a[q][p] = 0.0;
2079            for r in 0..n {
2080                if r != p && r != q {
2081                    let arp = a[r][p]; let arq = a[r][q];
2082                    a[r][p] = c * arp - s * arq; a[p][r] = a[r][p];
2083                    a[r][q] = s * arp + c * arq; a[q][r] = a[r][q];
2084                }
2085            }
2086        }
2087
2088        // Collect eigenvalues, sort, return second-smallest.
2089        let mut evals: Vec<f64> = (0..n).map(|i| a[i][i]).collect();
2090        evals.sort_by(|a, b| a.partial_cmp(b).unwrap());
2091        if n >= 2 { evals[1].max(0.0) } else { 0.0 }
2092    }
2093
2094    // 47
2095    #[test]
2096    fn spectral_lanczos_matches_dense_triangle() {
2097        // Triangle graph (K3): known lambda_2 = 3.0 for unit weights.
2098        let g = make_graph();
2099        let a = g.add_node("A".into(), serde_json::json!({}));
2100        let b = g.add_node("B".into(), serde_json::json!({}));
2101        let c = g.add_node("C".into(), serde_json::json!({}));
2102        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
2103        g.link(b, c, CausalEdgeType::Causes, 1.0, 0, 0);
2104        g.link(a, c, CausalEdgeType::Causes, 1.0, 0, 0);
2105
2106        let sparse_result = g.spectral_analysis(50);
2107        let dense_lambda2 = dense_spectral_lambda2(&g, 200);
2108
2109        assert!(
2110            (sparse_result.lambda_2 - dense_lambda2).abs() < 0.5,
2111            "Lanczos lambda_2={} vs dense lambda_2={} differ too much",
2112            sparse_result.lambda_2,
2113            dense_lambda2,
2114        );
2115        // Both should be close to 3.0 for K3 with symmetric unit edges.
2116        assert!(sparse_result.lambda_2 > 1.0, "lambda_2 should be > 1 for K3");
2117    }
2118
2119    // 48
2120    #[test]
2121    fn spectral_lanczos_matches_dense_path() {
2122        // Path graph: A - B - C - D (lambda_2 ~ 0.586 for unit weights).
2123        let g = make_graph();
2124        let a = g.add_node("A".into(), serde_json::json!({}));
2125        let b = g.add_node("B".into(), serde_json::json!({}));
2126        let c = g.add_node("C".into(), serde_json::json!({}));
2127        let d = g.add_node("D".into(), serde_json::json!({}));
2128        g.link(a, b, CausalEdgeType::Causes, 1.0, 0, 0);
2129        g.link(b, c, CausalEdgeType::Causes, 1.0, 0, 0);
2130        g.link(c, d, CausalEdgeType::Causes, 1.0, 0, 0);
2131
2132        let sparse_result = g.spectral_analysis(50);
2133        let dense_lambda2 = dense_spectral_lambda2(&g, 200);
2134
2135        assert!(
2136            (sparse_result.lambda_2 - dense_lambda2).abs() < 0.5,
2137            "Lanczos lambda_2={} vs dense lambda_2={} differ too much",
2138            sparse_result.lambda_2,
2139            dense_lambda2,
2140        );
2141        assert!(sparse_result.lambda_2 > 0.0, "path graph should be connected");
2142    }
2143
2144    // 49
2145    #[test]
2146    fn spectral_lanczos_matches_dense_k4() {
2147        // K4: lambda_2 = 4.0 for unit-weight complete graph on 4 nodes.
2148        let g = make_graph();
2149        let nodes: Vec<NodeId> = (0..4)
2150            .map(|i| g.add_node(format!("N{i}"), serde_json::json!({})))
2151            .collect();
2152        for i in 0..4 {
2153            for j in (i + 1)..4 {
2154                g.link(nodes[i], nodes[j], CausalEdgeType::Causes, 1.0, 0, 0);
2155            }
2156        }
2157
2158        let sparse_result = g.spectral_analysis(50);
2159        let dense_lambda2 = dense_spectral_lambda2(&g, 200);
2160
2161        assert!(
2162            (sparse_result.lambda_2 - dense_lambda2).abs() < 0.5,
2163            "K4: Lanczos lambda_2={} vs dense lambda_2={}",
2164            sparse_result.lambda_2,
2165            dense_lambda2,
2166        );
2167    }
2168
2169    // 50
2170    #[test]
2171    fn spectral_lanczos_disconnected() {
2172        // Two isolated nodes — lambda_2 should be 0.
2173        let g = make_graph();
2174        g.add_node("A".into(), serde_json::json!({}));
2175        g.add_node("B".into(), serde_json::json!({}));
2176
2177        let result = g.spectral_analysis(50);
2178        assert!(
2179            result.lambda_2 < 0.01,
2180            "disconnected graph should have lambda_2 ~ 0, got {}",
2181            result.lambda_2
2182        );
2183    }
2184}