1use crate::activation::{ActivationEngine, DimensionResult, HybridEngine};
4use crate::error::M1ndResult;
5use crate::graph::Graph;
6use crate::types::PropagationConfig;
7use crate::types::*;
8
9pub const DEFAULT_SEED_TRIALS: u8 = 8;
15pub const DEFAULT_KEYSTONE_TOP_N: usize = 20;
17
18pub struct RemovalMask {
28 pub removed_nodes: Vec<bool>,
30 pub removed_edges: Vec<bool>,
32}
33
34impl RemovalMask {
35 pub fn new(num_nodes: u32, num_edges: usize) -> Self {
37 Self {
38 removed_nodes: vec![false; num_nodes as usize],
39 removed_edges: vec![false; num_edges],
40 }
41 }
42
43 pub fn remove_node(&mut self, graph: &Graph, node: NodeId) {
45 let idx = node.as_usize();
46 if idx >= self.removed_nodes.len() {
47 return;
48 }
49 self.removed_nodes[idx] = true;
50
51 let out_range = graph.csr.out_range(node);
53 for j in out_range {
54 if j < self.removed_edges.len() {
55 self.removed_edges[j] = true;
56 }
57 }
58
59 let in_range = graph.csr.in_range(node);
61 for j in in_range {
62 let fwd_idx = graph.csr.rev_edge_idx[j].as_usize();
63 if fwd_idx < self.removed_edges.len() {
64 self.removed_edges[fwd_idx] = true;
65 }
66 }
67 }
68
69 pub fn remove_edge(&mut self, edge: EdgeIdx) {
71 self.removed_edges[edge.as_usize()] = true;
72 }
73
74 #[inline]
76 pub fn is_node_removed(&self, node: NodeId) -> bool {
77 self.removed_nodes[node.as_usize()]
78 }
79
80 #[inline]
82 pub fn is_edge_removed(&self, edge: EdgeIdx) -> bool {
83 self.removed_edges[edge.as_usize()]
84 }
85
86 pub fn reset(&mut self) {
88 self.removed_nodes.fill(false);
89 self.removed_edges.fill(false);
90 }
91}
92
93#[derive(Clone, Debug)]
100pub struct CounterfactualResult {
101 pub removed_nodes: Vec<NodeId>,
102 pub total_impact: FiniteF32,
104 pub pct_activation_lost: FiniteF32,
106 pub orphaned_nodes: Vec<NodeId>,
108 pub weakened_nodes: Vec<(NodeId, FiniteF32)>, pub communities_split: u32,
112 pub reachability_before: u32,
114 pub reachability_after: u32,
116}
117
118#[derive(Clone, Debug)]
125pub struct KeystoneEntry {
126 pub node: NodeId,
127 pub avg_impact: FiniteF32,
130 pub impact_std: FiniteF32,
132}
133
134#[derive(Clone, Debug)]
136pub struct KeystoneResult {
137 pub keystones: Vec<KeystoneEntry>,
139 pub num_trials: u8,
141}
142
143#[derive(Clone, Debug)]
150pub struct CascadeResult {
151 pub removed_node: NodeId,
152 pub cascade_depth: u8,
154 pub affected_by_depth: Vec<Vec<NodeId>>,
156 pub total_affected: u32,
158}
159
160#[derive(Clone, Debug)]
167pub struct SynergyResult {
168 pub individual_impacts: Vec<(NodeId, FiniteF32)>,
170 pub combined_impact: FiniteF32,
172 pub synergy_factor: FiniteF32,
174}
175
176#[derive(Clone, Debug)]
184pub struct RedundancyResult {
185 pub node: NodeId,
186 pub redundancy_score: FiniteF32,
188 pub confidence: RedundancyConfidence,
190 pub alternative_paths: u32,
192 pub is_architectural: bool,
194}
195
196#[derive(Clone, Copy, Debug, PartialEq, Eq)]
197pub enum RedundancyConfidence {
198 High,
199 Medium,
200 Low,
201}
202
203#[derive(Clone, Debug)]
210pub struct AntifragilityResult {
211 pub score: FiniteF32,
213 pub top_keystones: Vec<KeystoneEntry>,
215 pub most_redundant: Vec<RedundancyResult>,
217 pub least_redundant: Vec<RedundancyResult>,
219}
220
221fn run_baseline_activation(
226 graph: &Graph,
227 engine: &HybridEngine,
228 config: &PropagationConfig,
229 seeds: &[(NodeId, FiniteF32)],
230) -> M1ndResult<Vec<(NodeId, FiniteF32)>> {
231 let result = engine.propagate(graph, seeds, config)?;
232 Ok(result.scores)
233}
234
235fn propagate_with_mask(
238 graph: &Graph,
239 seeds: &[(NodeId, FiniteF32)],
240 config: &PropagationConfig,
241 mask: &RemovalMask,
242) -> M1ndResult<Vec<(NodeId, FiniteF32)>> {
243 let n = graph.num_nodes() as usize;
244 if n == 0 || seeds.is_empty() {
245 return Ok(Vec::new());
246 }
247
248 let threshold = config.threshold.get();
249 let decay = config.decay.get();
250 let max_depth = config.max_depth.min(20) as usize;
251
252 let mut activation = vec![0.0f32; n];
253 let mut visited = vec![false; n];
254 let mut frontier: Vec<NodeId> = Vec::new();
255
256 for &(node, score) in seeds {
257 let idx = node.as_usize();
258 if idx < n && !mask.is_node_removed(node) {
259 let s = score.get().min(config.saturation_cap.get());
260 if s > activation[idx] {
261 activation[idx] = s;
262 }
263 if !visited[idx] {
264 frontier.push(node);
265 visited[idx] = true;
266 }
267 }
268 }
269
270 for _depth in 0..max_depth {
271 if frontier.is_empty() {
272 break;
273 }
274 let mut next_frontier: Vec<NodeId> = Vec::new();
275
276 for &src in &frontier {
277 let src_act = activation[src.as_usize()];
278 if src_act < threshold {
279 continue;
280 }
281
282 let range = graph.csr.out_range(src);
283 for j in range {
284 if mask.is_edge_removed(EdgeIdx::new(j as u32)) {
286 continue;
287 }
288
289 let tgt = graph.csr.targets[j];
290 let tgt_idx = tgt.as_usize();
291
292 if tgt_idx >= n || mask.is_node_removed(tgt) {
294 continue;
295 }
296
297 let w = graph.csr.read_weight(EdgeIdx::new(j as u32)).get();
298 let is_inhib = graph.csr.inhibitory[j];
299
300 let mut signal = src_act * w * decay;
301 if is_inhib {
302 signal = -signal * config.inhibitory_factor.get();
303 }
304
305 if !is_inhib && signal > threshold {
306 if signal > activation[tgt_idx] {
307 activation[tgt_idx] = signal;
308 }
309 if !visited[tgt_idx] {
310 visited[tgt_idx] = true;
311 next_frontier.push(tgt);
312 }
313 } else if is_inhib {
314 activation[tgt_idx] = (activation[tgt_idx] + signal).max(0.0);
315 }
316 }
317 }
318
319 frontier = next_frontier;
320 }
321
322 let mut scores: Vec<(NodeId, FiniteF32)> = activation
323 .iter()
324 .enumerate()
325 .filter(|(i, &v)| v > 0.0 && !mask.is_node_removed(NodeId::new(*i as u32)))
326 .map(|(i, &v)| (NodeId::new(i as u32), FiniteF32::new(v)))
327 .collect();
328 scores.sort_by(|a, b| b.1.cmp(&a.1));
329
330 Ok(scores)
331}
332
333fn total_activation(scores: &[(NodeId, FiniteF32)]) -> f32 {
335 scores.iter().map(|(_, s)| s.get()).sum()
336}
337
338fn generate_diverse_seeds(graph: &Graph, num_trials: u8) -> Vec<Vec<(NodeId, FiniteF32)>> {
341 let n = graph.num_nodes() as usize;
342 if n == 0 {
343 return Vec::new();
344 }
345
346 let mut candidates: Vec<(usize, f32)> = (0..n)
348 .filter(|&i| {
349 let r = graph.csr.out_range(NodeId::new(i as u32));
350 r.end > r.start })
352 .map(|i| (i, graph.nodes.pagerank[i].get()))
353 .collect();
354 candidates.sort_by(|a, b| b.1.total_cmp(&a.1));
355
356 if candidates.is_empty() {
357 return Vec::new();
358 }
359
360 let mut trials = Vec::new();
362 let stride = candidates.len().max(1) / (num_trials as usize).max(1);
363 for t in 0..num_trials as usize {
364 let idx = (t * stride.max(1)) % candidates.len();
365 let (node_idx, _) = candidates[idx];
366 trials.push(vec![(NodeId::new(node_idx as u32), FiniteF32::ONE)]);
367 }
368 trials
369}
370
371pub struct CounterfactualEngine {
379 num_trials: u8,
380 keystone_top_n: usize,
381}
382
383impl CounterfactualEngine {
384 pub fn new(num_trials: u8, keystone_top_n: usize) -> Self {
385 Self {
386 num_trials,
387 keystone_top_n,
388 }
389 }
390
391 pub fn with_defaults() -> Self {
392 Self {
393 num_trials: DEFAULT_SEED_TRIALS,
394 keystone_top_n: DEFAULT_KEYSTONE_TOP_N,
395 }
396 }
397
398 pub fn simulate_removal(
404 &self,
405 graph: &Graph,
406 engine: &HybridEngine,
407 config: &PropagationConfig,
408 remove_nodes: &[NodeId],
409 ) -> M1ndResult<CounterfactualResult> {
410 let n = graph.num_nodes() as usize;
411
412 let seed_trials = generate_diverse_seeds(graph, self.num_trials);
414
415 let mut total_baseline = 0.0f32;
416 let mut total_removed = 0.0f32;
417
418 let mut mask = RemovalMask::new(graph.num_nodes(), graph.num_edges());
420 let mut removed_set = vec![false; n];
421 for &node in remove_nodes {
422 if node.as_usize() < n {
423 removed_set[node.as_usize()] = true;
424 mask.remove_node(graph, node);
425 }
426 }
427
428 let mut per_node_loss = vec![0.0f32; n];
429
430 for seeds in &seed_trials {
431 let adjusted_seeds: Vec<(NodeId, FiniteF32)> = seeds
433 .iter()
434 .map(|&(node, score)| {
435 if removed_set[node.as_usize()] {
436 let range = graph.csr.out_range(node);
438 for j in range {
439 let tgt = graph.csr.targets[j];
440 if !removed_set[tgt.as_usize()] {
441 return (tgt, score);
442 }
443 }
444 let rev_range = graph.csr.in_range(node);
446 for j in rev_range {
447 let src = graph.csr.rev_sources[j];
448 if !removed_set[src.as_usize()] {
449 return (src, score);
450 }
451 }
452 for i in 0..n {
454 if !removed_set[i] {
455 return (NodeId::new(i as u32), score);
456 }
457 }
458 (node, FiniteF32::ZERO) } else {
460 (node, score)
461 }
462 })
463 .filter(|(_, s)| s.get() > 0.0)
464 .collect();
465
466 let baseline = run_baseline_activation(graph, engine, config, seeds)?;
468 let baseline_total = total_activation(&baseline);
469 total_baseline += baseline_total;
470
471 let removed_scores = propagate_with_mask(graph, &adjusted_seeds, config, &mask)?;
473 let removed_total = total_activation(&removed_scores);
474 total_removed += removed_total;
475
476 let mut baseline_map = std::collections::HashMap::new();
478 for &(node, score) in &baseline {
479 baseline_map.insert(node.0, score.get());
480 }
481 let mut removed_map = std::collections::HashMap::new();
482 for &(node, score) in &removed_scores {
483 removed_map.insert(node.0, score.get());
484 }
485
486 for i in 0..n {
487 let base = baseline_map.get(&(i as u32)).copied().unwrap_or(0.0);
488 let rem = removed_map.get(&(i as u32)).copied().unwrap_or(0.0);
489 if base > 0.0 {
490 per_node_loss[i] += (base - rem) / base;
491 }
492 }
493 }
494
495 let num_trials = seed_trials.len().max(1) as f32;
496
497 let pct_lost = if total_baseline > 0.0 {
499 ((total_baseline - total_removed) / total_baseline)
500 .max(0.0)
501 .min(1.0)
502 } else {
503 0.0
504 };
505
506 let orphaned: Vec<NodeId> = (0..n)
508 .filter(|&i| per_node_loss[i] / num_trials > 0.99 && !removed_set[i])
509 .map(|i| NodeId::new(i as u32))
510 .collect();
511
512 let weakened: Vec<(NodeId, FiniteF32)> = (0..n)
514 .filter(|&i| {
515 let avg = per_node_loss[i] / num_trials;
516 avg > 0.5 && avg <= 0.99 && !removed_set[i]
517 })
518 .map(|i| {
519 let avg = per_node_loss[i] / num_trials;
520 (NodeId::new(i as u32), FiniteF32::new(avg))
521 })
522 .collect();
523
524 let reachability_before = Self::compute_reachability(graph, n, &vec![false; n]);
526 let reachability_after = Self::compute_reachability(graph, n, &removed_set);
527
528 Ok(CounterfactualResult {
529 removed_nodes: remove_nodes.to_vec(),
530 total_impact: FiniteF32::new(pct_lost),
531 pct_activation_lost: FiniteF32::new(pct_lost),
532 orphaned_nodes: orphaned,
533 weakened_nodes: weakened,
534 communities_split: 0, reachability_before,
536 reachability_after,
537 })
538 }
539
540 fn compute_reachability(graph: &Graph, n: usize, removed: &[bool]) -> u32 {
543 if n == 0 {
544 return 0;
545 }
546 let start = (0..n).filter(|&i| !removed[i]).max_by_key(|&i| {
548 let nid = NodeId::new(i as u32);
549 let out = graph.csr.out_range(nid);
550 let inv = graph.csr.in_range(nid);
551 (out.end - out.start) + (inv.end - inv.start)
552 });
553 let start = match start {
554 Some(s) => s,
555 None => return 0,
556 };
557
558 let mut visited = vec![false; n];
559 let mut queue = std::collections::VecDeque::new();
560 queue.push_back(start);
561 visited[start] = true;
562 let mut count = 1u32;
563
564 while let Some(node) = queue.pop_front() {
565 let nid = NodeId::new(node as u32);
566 let range = graph.csr.out_range(nid);
568 for j in range {
569 let tgt = graph.csr.targets[j].as_usize();
570 if tgt < n && !visited[tgt] && !removed[tgt] {
571 visited[tgt] = true;
572 queue.push_back(tgt);
573 count += 1;
574 }
575 }
576 let rev_range = graph.csr.in_range(nid);
578 for j in rev_range {
579 let src = graph.csr.rev_sources[j].as_usize();
580 if src < n && !visited[src] && !removed[src] {
581 visited[src] = true;
582 queue.push_back(src);
583 count += 1;
584 }
585 }
586 }
587
588 count
589 }
590
591 pub fn find_keystones(
595 &self,
596 graph: &Graph,
597 engine: &HybridEngine,
598 config: &PropagationConfig,
599 ) -> M1ndResult<KeystoneResult> {
600 let n = graph.num_nodes() as usize;
601 let mut impacts: Vec<(NodeId, f32)> = Vec::new();
602
603 let mut candidates: Vec<(usize, usize)> = (0..n)
605 .map(|i| {
606 let range = graph.csr.out_range(NodeId::new(i as u32));
607 (i, range.end - range.start)
608 })
609 .collect();
610 candidates.sort_by(|a, b| b.1.cmp(&a.1));
611 candidates.truncate(self.keystone_top_n * 2);
612
613 for (node_idx, _) in &candidates {
614 let result =
615 self.simulate_removal(graph, engine, config, &[NodeId::new(*node_idx as u32)])?;
616 impacts.push((NodeId::new(*node_idx as u32), result.total_impact.get()));
617 }
618
619 impacts.sort_by(|a, b| b.1.total_cmp(&a.1));
620 let keystones: Vec<KeystoneEntry> = impacts
621 .iter()
622 .take(self.keystone_top_n)
623 .map(|&(node, impact)| KeystoneEntry {
624 node,
625 avg_impact: FiniteF32::new(impact),
626 impact_std: FiniteF32::ZERO, })
628 .collect();
629
630 Ok(KeystoneResult {
631 keystones,
632 num_trials: self.num_trials,
633 })
634 }
635
636 pub fn cascade_analysis(
639 &self,
640 graph: &Graph,
641 _engine: &HybridEngine,
642 _config: &PropagationConfig,
643 remove_node: NodeId,
644 ) -> M1ndResult<CascadeResult> {
645 let n = graph.num_nodes() as usize;
646 if remove_node.as_usize() >= n {
647 return Ok(CascadeResult {
648 removed_node: remove_node,
649 cascade_depth: 0,
650 affected_by_depth: Vec::new(),
651 total_affected: 0,
652 });
653 }
654
655 let mut affected_by_depth: Vec<Vec<NodeId>> = Vec::new();
657 let mut visited = vec![false; n];
658 visited[remove_node.as_usize()] = true;
659
660 let mut frontier = vec![remove_node];
661 let max_depth = 5u8;
662
663 for _depth in 0..max_depth {
664 if frontier.is_empty() {
665 break;
666 }
667 let mut next = Vec::new();
668 let mut depth_affected = Vec::new();
669
670 for &node in &frontier {
671 let range = graph.csr.out_range(node);
672 for j in range {
673 let tgt = graph.csr.targets[j];
674 let tgt_idx = tgt.as_usize();
675 if tgt_idx < n && !visited[tgt_idx] {
676 visited[tgt_idx] = true;
677 next.push(tgt);
678 depth_affected.push(tgt);
679 }
680 }
681 }
682
683 if !depth_affected.is_empty() {
684 affected_by_depth.push(depth_affected);
685 }
686 frontier = next;
687 }
688
689 let total_affected: u32 = affected_by_depth.iter().map(|d| d.len() as u32).sum();
690
691 Ok(CascadeResult {
692 removed_node: remove_node,
693 cascade_depth: affected_by_depth.len() as u8,
694 affected_by_depth,
695 total_affected,
696 })
697 }
698
699 pub fn synergy_analysis(
702 &self,
703 graph: &Graph,
704 engine: &HybridEngine,
705 config: &PropagationConfig,
706 remove_nodes: &[NodeId],
707 ) -> M1ndResult<SynergyResult> {
708 let mut individual_impacts = Vec::new();
710 for &node in remove_nodes {
711 let result = self.simulate_removal(graph, engine, config, &[node])?;
712 individual_impacts.push((node, result.total_impact));
713 }
714
715 let combined = self.simulate_removal(graph, engine, config, remove_nodes)?;
717
718 let sum_individual: f32 = individual_impacts.iter().map(|(_, s)| s.get()).sum();
719 let synergy_factor = if sum_individual > 0.0 {
720 combined.total_impact.get() / sum_individual
721 } else {
722 1.0
723 };
724
725 Ok(SynergyResult {
726 individual_impacts,
727 combined_impact: combined.total_impact,
728 synergy_factor: FiniteF32::new(synergy_factor.min(10.0)),
729 })
730 }
731
732 pub fn check_redundancy(
736 &self,
737 graph: &Graph,
738 engine: &HybridEngine,
739 config: &PropagationConfig,
740 node: NodeId,
741 ) -> M1ndResult<RedundancyResult> {
742 let n = graph.num_nodes() as usize;
743 let idx = node.as_usize();
744 if idx >= n {
745 return Ok(RedundancyResult {
746 node,
747 redundancy_score: FiniteF32::ZERO,
748 confidence: RedundancyConfidence::Low,
749 alternative_paths: 0,
750 is_architectural: false,
751 });
752 }
753
754 let result = self.simulate_removal(graph, engine, config, &[node])?;
756 let impact = result.total_impact.get();
757
758 let redundancy = (1.0 - impact).max(0.0).min(1.0);
760
761 let out_range = graph.csr.out_range(node);
763 let out_degree = out_range.end - out_range.start;
764
765 let in_range = graph.csr.in_range(node);
766 let in_degree = in_range.end - in_range.start;
767
768 let alternative_paths = (out_degree.min(in_degree)) as u32;
769
770 let is_architectural = out_degree >= 5 && impact > 0.3;
772
773 let confidence = if self.num_trials >= 8 {
775 RedundancyConfidence::High
776 } else if self.num_trials >= 4 {
777 RedundancyConfidence::Medium
778 } else {
779 RedundancyConfidence::Low
780 };
781
782 Ok(RedundancyResult {
783 node,
784 redundancy_score: FiniteF32::new(redundancy),
785 confidence,
786 alternative_paths,
787 is_architectural,
788 })
789 }
790
791 pub fn antifragility_score(
794 &self,
795 graph: &Graph,
796 engine: &HybridEngine,
797 config: &PropagationConfig,
798 ) -> M1ndResult<AntifragilityResult> {
799 let keystones = self.find_keystones(graph, engine, config)?;
800
801 let max_impact = keystones
803 .keystones
804 .first()
805 .map(|k| k.avg_impact.get())
806 .unwrap_or(0.0);
807 let score = (1.0 - max_impact).max(0.0).min(1.0);
808
809 Ok(AntifragilityResult {
810 score: FiniteF32::new(score),
811 top_keystones: keystones.keystones,
812 most_redundant: Vec::new(), least_redundant: Vec::new(),
814 })
815 }
816}
817
818static_assertions::assert_impl_all!(CounterfactualEngine: Send, Sync);