use std::collections::{HashMap, HashSet, VecDeque};
use serde::{Deserialize, Serialize};
use crate::bundle::Mechanism;
use crate::causal_graph::CausalGraph;
use crate::project::Project;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CounterfactualQuery {
pub intervene_on: String,
pub set_to: f64,
pub target: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum CounterfactualVerdict {
Resolved {
factual: f64,
counterfactual: f64,
delta: f64,
paths_used: Vec<Vec<String>>,
},
MechanismUnspecified {
unspecified_edges: Vec<(String, String)>,
},
NoCausalPath { factual: f64 },
UnknownNode { which: String },
InvalidIntervention { reason: String },
}
#[must_use]
pub fn answer_counterfactual(
project: &Project,
query: &CounterfactualQuery,
) -> CounterfactualVerdict {
if !(0.0..=1.0).contains(&query.set_to) {
return CounterfactualVerdict::InvalidIntervention {
reason: format!(
"intervention must be on the confidence axis [0,1], got {}",
query.set_to
),
};
}
let confidence_index = build_confidence_index(project);
let factual_target = match confidence_index.get(&query.target) {
Some(&c) => c,
None => {
return CounterfactualVerdict::UnknownNode {
which: query.target.clone(),
};
}
};
let factual_source = match confidence_index.get(&query.intervene_on) {
Some(&c) => c,
None => {
return CounterfactualVerdict::UnknownNode {
which: query.intervene_on.clone(),
};
}
};
let graph = CausalGraph::from_project(project);
if !graph.contains(&query.intervene_on) {
return CounterfactualVerdict::UnknownNode {
which: query.intervene_on.clone(),
};
}
if !graph.contains(&query.target) {
return CounterfactualVerdict::UnknownNode {
which: query.target.clone(),
};
}
let paths = directed_paths_from_to(&graph, &query.intervene_on, &query.target, 8);
if paths.is_empty() {
return CounterfactualVerdict::NoCausalPath {
factual: factual_target,
};
}
let mech_index = build_mechanism_index(project);
let mut unspecified_edges: HashSet<(String, String)> = HashSet::new();
let mut path_deltas: Vec<f64> = Vec::new();
let mut paths_used: Vec<Vec<String>> = Vec::new();
let delta_x = query.set_to - factual_source;
for path in &paths {
let mut delta = delta_x;
let mut path_ok = true;
for window in path.windows(2) {
let parent = &window[0];
let child = &window[1];
match mech_index.get(&(parent.clone(), child.clone())) {
Some(m) => match m.apply(delta) {
Some(next_delta) => delta = next_delta,
None => {
unspecified_edges.insert((parent.clone(), child.clone()));
path_ok = false;
break;
}
},
None => {
unspecified_edges.insert((parent.clone(), child.clone()));
path_ok = false;
break;
}
}
}
if path_ok {
path_deltas.push(delta);
paths_used.push(path.clone());
}
}
if path_deltas.is_empty() {
let mut edges: Vec<(String, String)> = unspecified_edges.into_iter().collect();
edges.sort();
return CounterfactualVerdict::MechanismUnspecified {
unspecified_edges: edges,
};
}
let aggregate_delta = path_deltas
.iter()
.copied()
.fold(0.0_f64, |acc, d| if d.abs() > acc.abs() { d } else { acc });
let counterfactual = (factual_target + aggregate_delta).clamp(0.0, 1.0);
CounterfactualVerdict::Resolved {
factual: factual_target,
counterfactual,
delta: counterfactual - factual_target,
paths_used,
}
}
fn directed_paths_from_to(
graph: &CausalGraph,
cause: &str,
effect: &str,
max_depth: usize,
) -> Vec<Vec<String>> {
const MAX_PATHS: usize = 32;
let mut out: Vec<Vec<String>> = Vec::new();
let mut queue: VecDeque<Vec<String>> = VecDeque::new();
queue.push_back(vec![cause.to_string()]);
while let Some(path) = queue.pop_front() {
if out.len() >= MAX_PATHS {
break;
}
if path.len() > max_depth {
continue;
}
let last = path.last().expect("path non-empty");
if last == effect && path.len() > 1 {
out.push(path);
continue;
}
for child in graph.children_of(last) {
let child_owned = child.to_string();
if path.contains(&child_owned) {
continue; }
let mut next = path.clone();
next.push(child_owned);
queue.push_back(next);
}
}
out
}
fn build_confidence_index(project: &Project) -> HashMap<String, f64> {
let mut idx = HashMap::new();
for finding in &project.findings {
idx.insert(finding.id.clone(), finding.confidence.score);
}
idx
}
fn build_mechanism_index(project: &Project) -> HashMap<(String, String), Mechanism> {
let mut idx = HashMap::new();
for finding in &project.findings {
for link in &finding.links {
if !matches!(link.link_type.as_str(), "depends" | "supports") {
continue;
}
let target = match link.target.split_once(':') {
Some((_, rest)) => rest.to_string(),
None => link.target.clone(),
};
if let Some(m) = link.mechanism {
idx.insert((target, finding.id.clone()), m);
}
}
}
idx
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bundle::{
Assertion, Conditions, Confidence, Evidence, Extraction, FindingBundle, Flags, Link,
Mechanism, MechanismSign, Provenance,
};
use crate::project;
fn conditions() -> Conditions {
Conditions {
text: String::new(),
species_verified: vec![],
species_unverified: vec![],
in_vitro: false,
in_vivo: false,
human_data: false,
clinical_trial: false,
concentration_range: None,
duration: None,
age_group: None,
cell_type: None,
}
}
fn provenance() -> Provenance {
Provenance {
source_type: "published_paper".into(),
doi: None,
pmid: None,
pmc: None,
openalex_id: None,
url: None,
title: "Test".into(),
authors: vec![],
year: Some(2025),
journal: None,
license: None,
publisher: None,
funders: vec![],
extraction: Extraction::default(),
review: None,
citation_count: None,
}
}
fn finding(id: &str, conf: f64, links: Vec<Link>) -> FindingBundle {
let mut b = FindingBundle::new(
Assertion {
text: format!("claim {id}"),
assertion_type: "mechanism".into(),
entities: vec![],
relation: None,
direction: None,
causal_claim: None,
causal_evidence_grade: None,
},
Evidence {
evidence_type: "experimental".into(),
model_system: String::new(),
species: None,
method: String::new(),
sample_size: None,
effect_size: None,
p_value: None,
replicated: false,
replication_count: None,
evidence_spans: vec![],
},
conditions(),
Confidence::raw(conf, "test", 0.85),
provenance(),
Flags::default(),
);
b.id = id.to_string();
b.links = links;
b
}
fn link_with_mechanism(target: &str, mech: Option<Mechanism>) -> Link {
Link {
target: target.into(),
link_type: "depends".into(),
note: String::new(),
inferred_by: "test".into(),
created_at: String::new(),
mechanism: mech,
}
}
fn fixture_chain(ab: Option<Mechanism>, bc: Option<Mechanism>) -> Project {
let a = finding("vf_aaa", 0.9, vec![]);
let b = finding("vf_bbb", 0.8, vec![link_with_mechanism("vf_aaa", ab)]);
let c = finding("vf_ccc", 0.7, vec![link_with_mechanism("vf_bbb", bc)]);
project::assemble("test", vec![a, b, c], 1, 0, "test")
}
#[test]
fn linear_chain_resolves() {
let project = fixture_chain(
Some(Mechanism::Linear {
sign: MechanismSign::Positive,
slope: 0.5,
}),
Some(Mechanism::Linear {
sign: MechanismSign::Positive,
slope: 0.4,
}),
);
let q = CounterfactualQuery {
intervene_on: "vf_aaa".into(),
set_to: 0.5,
target: "vf_ccc".into(),
};
let v = answer_counterfactual(&project, &q);
match v {
CounterfactualVerdict::Resolved {
factual,
counterfactual,
delta,
..
} => {
assert!((factual - 0.7).abs() < 1e-9);
assert!((delta - (-0.08)).abs() < 1e-6, "delta = {delta}");
assert!(counterfactual > 0.0 && counterfactual < 1.0);
}
_ => panic!("expected Resolved, got {v:?}"),
}
}
#[test]
fn missing_mechanism_blocks_propagation() {
let project = fixture_chain(
Some(Mechanism::Linear {
sign: MechanismSign::Positive,
slope: 0.5,
}),
None,
);
let q = CounterfactualQuery {
intervene_on: "vf_aaa".into(),
set_to: 0.5,
target: "vf_ccc".into(),
};
let v = answer_counterfactual(&project, &q);
assert!(matches!(
v,
CounterfactualVerdict::MechanismUnspecified { .. }
));
}
#[test]
fn unknown_mechanism_blocks_propagation() {
let project = fixture_chain(
Some(Mechanism::Linear {
sign: MechanismSign::Positive,
slope: 0.5,
}),
Some(Mechanism::Unknown),
);
let q = CounterfactualQuery {
intervene_on: "vf_aaa".into(),
set_to: 0.5,
target: "vf_ccc".into(),
};
let v = answer_counterfactual(&project, &q);
assert!(matches!(
v,
CounterfactualVerdict::MechanismUnspecified { .. }
));
}
#[test]
fn out_of_range_intervention_rejected() {
let project = fixture_chain(None, None);
let q = CounterfactualQuery {
intervene_on: "vf_aaa".into(),
set_to: 1.5,
target: "vf_ccc".into(),
};
assert!(matches!(
answer_counterfactual(&project, &q),
CounterfactualVerdict::InvalidIntervention { .. }
));
}
#[test]
fn no_path_yields_factual() {
let project = fixture_chain(None, None);
let q = CounterfactualQuery {
intervene_on: "vf_ccc".into(), set_to: 0.5,
target: "vf_aaa".into(),
};
match answer_counterfactual(&project, &q) {
CounterfactualVerdict::NoCausalPath { factual } => {
assert!((factual - 0.9).abs() < 1e-9);
}
v => panic!("expected NoCausalPath, got {v:?}"),
}
}
#[test]
fn negative_sign_flips_direction() {
let project = fixture_chain(
Some(Mechanism::Linear {
sign: MechanismSign::Negative,
slope: 0.5,
}),
Some(Mechanism::Linear {
sign: MechanismSign::Positive,
slope: 1.0,
}),
);
let q = CounterfactualQuery {
intervene_on: "vf_aaa".into(),
set_to: 1.0,
target: "vf_ccc".into(),
};
match answer_counterfactual(&project, &q) {
CounterfactualVerdict::Resolved {
counterfactual,
delta,
..
} => {
assert!((delta - (-0.05)).abs() < 1e-6, "delta = {delta}");
assert!((counterfactual - 0.65).abs() < 1e-6);
}
v => panic!("expected Resolved, got {v:?}"),
}
}
}