use std::collections::{HashMap, HashSet, VecDeque};
use scirs2_core::ndarray::ArrayView2;
use super::conditional_independence::{ConditionalIndependenceTest, PartialCorrelationTest};
use super::pc_algorithm::subsets;
use super::{CausalGraph, EdgeMark};
use crate::error::{StatsError, StatsResult};
#[derive(Debug, Clone)]
pub struct FciAlgorithm {
pub alpha: f64,
pub max_cond_set_size: usize,
pub max_pdsep_size: usize,
}
impl Default for FciAlgorithm {
fn default() -> Self {
Self {
alpha: 0.05,
max_cond_set_size: 4,
max_pdsep_size: 4,
}
}
}
#[derive(Debug, Clone)]
pub struct FciResult {
pub graph: CausalGraph,
pub sep_sets: HashMap<(usize, usize), Vec<usize>>,
pub n_tests: usize,
pub has_latent_confounders: bool,
}
impl FciAlgorithm {
pub fn new(alpha: f64) -> Self {
Self {
alpha,
..Default::default()
}
}
pub fn with_params(alpha: f64, max_cond_set_size: usize, max_pdsep_size: usize) -> Self {
Self {
alpha,
max_cond_set_size,
max_pdsep_size,
}
}
pub fn fit(&self, data: ArrayView2<f64>, var_names: &[&str]) -> StatsResult<FciResult> {
let ci_test = PartialCorrelationTest::new(self.alpha);
self.fit_with_test(data, var_names, &ci_test)
}
pub fn fit_with_test<T: ConditionalIndependenceTest>(
&self,
data: ArrayView2<f64>,
var_names: &[&str],
ci_test: &T,
) -> StatsResult<FciResult> {
let p = data.ncols();
if var_names.len() != p {
return Err(StatsError::DimensionMismatch(
"var_names length must match number of columns".to_owned(),
));
}
if p == 0 {
return Ok(FciResult {
graph: CausalGraph::new(var_names),
sep_sets: HashMap::new(),
n_tests: 0,
has_latent_confounders: false,
});
}
let (mut adj, mut sep_sets, mut n_tests) =
skeleton_discovery(data, p, self.alpha, self.max_cond_set_size, ci_test)?;
let mut graph = CausalGraph::new(var_names);
for i in 0..p {
for j in (i + 1)..p {
if adj[i][j] {
graph.set_edge(i, j, EdgeMark::Circle, EdgeMark::Circle);
}
}
}
orient_unshielded_colliders(&mut graph, &adj, &sep_sets, p);
let pdsep_removals = possible_dsep_phase(
&graph,
data,
&adj,
p,
self.alpha,
self.max_pdsep_size,
ci_test,
&mut n_tests,
)?;
for (x, y, z_set) in pdsep_removals {
adj[x][y] = false;
adj[y][x] = false;
graph.remove_edge(x, y);
let key = (x.min(y), x.max(y));
sep_sets.insert(key, z_set);
}
for i in 0..p {
for j in (i + 1)..p {
if adj[i][j] {
graph.set_edge(i, j, EdgeMark::Circle, EdgeMark::Circle);
}
}
}
orient_unshielded_colliders(&mut graph, &adj, &sep_sets, p);
apply_fci_rules(&mut graph, &adj, &sep_sets, p);
let has_latent_confounders =
(0..p).any(|i| (0..p).any(|j| i != j && graph.is_bidirected(i, j)));
Ok(FciResult {
graph,
sep_sets,
n_tests,
has_latent_confounders,
})
}
}
fn skeleton_discovery<T: ConditionalIndependenceTest>(
data: ArrayView2<f64>,
p: usize,
alpha: f64,
max_cond_set_size: usize,
ci_test: &T,
) -> StatsResult<(Vec<Vec<bool>>, HashMap<(usize, usize), Vec<usize>>, usize)> {
let mut adj = vec![vec![true; p]; p];
for i in 0..p {
adj[i][i] = false;
}
let mut sep_sets: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
let mut n_tests = 0usize;
for ord in 0..=max_cond_set_size {
let adj_snapshot = adj.clone();
let edges: Vec<(usize, usize)> = (0..p)
.flat_map(|i| ((i + 1)..p).map(move |j| (i, j)))
.filter(|&(i, j)| adj_snapshot[i][j])
.collect();
let mut removals = Vec::new();
for (x, y) in edges {
let z_x: Vec<usize> = (0..p)
.filter(|&k| k != x && k != y && adj_snapshot[x][k])
.collect();
let z_y: Vec<usize> = (0..p)
.filter(|&k| k != x && k != y && adj_snapshot[y][k])
.collect();
let mut found = false;
if z_x.len() >= ord {
for z_set in subsets(&z_x, ord) {
n_tests += 1;
if ci_test.is_independent(x, y, &z_set, data, alpha)? {
removals.push((x, y, z_set));
found = true;
break;
}
}
}
if !found && z_y.len() >= ord {
for z_set in subsets(&z_y, ord) {
n_tests += 1;
if ci_test.is_independent(x, y, &z_set, data, alpha)? {
removals.push((x, y, z_set));
break;
}
}
}
}
for (x, y, z_set) in removals {
adj[x][y] = false;
adj[y][x] = false;
let key = (x.min(y), x.max(y));
sep_sets.insert(key, z_set);
}
}
Ok((adj, sep_sets, n_tests))
}
fn orient_unshielded_colliders(
graph: &mut CausalGraph,
adj: &[Vec<bool>],
sep_sets: &HashMap<(usize, usize), Vec<usize>>,
p: usize,
) {
for z in 0..p {
let neighbours: Vec<usize> = (0..p).filter(|&k| k != z && adj[z][k]).collect();
for i in 0..neighbours.len() {
for j in (i + 1)..neighbours.len() {
let x = neighbours[i];
let y = neighbours[j];
if adj[x][y] {
continue; }
let key = (x.min(y), x.max(y));
let sep = sep_sets.get(&key).cloned().unwrap_or_default();
if !sep.contains(&z) {
let mark_xz_from = graph.get_mark_from(x, z).unwrap_or(EdgeMark::Circle);
graph.set_edge(x, z, mark_xz_from, EdgeMark::Arrow);
let mark_yz_from = graph.get_mark_from(y, z).unwrap_or(EdgeMark::Circle);
graph.set_edge(y, z, mark_yz_from, EdgeMark::Arrow);
}
}
}
}
}
fn possible_dsep(graph: &CausalGraph, a: usize, b: usize, p: usize) -> HashSet<usize> {
let mut pdsep = HashSet::new();
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
for k in 0..p {
if k != a && k != b && graph.is_adjacent(a, k) {
queue.push_back((k, a)); }
}
while let Some((cur, prev)) = queue.pop_front() {
if !visited.insert((cur, prev)) {
continue;
}
pdsep.insert(cur);
for next in 0..p {
if next == prev || next == a || !graph.is_adjacent(cur, next) {
continue;
}
let mark_at_cur_from_prev = graph.get_mark_at(prev, cur);
let is_possible_collider = match mark_at_cur_from_prev {
Some(EdgeMark::Arrow) | Some(EdgeMark::Circle) => true,
_ => false,
};
if is_possible_collider {
queue.push_back((next, cur));
}
}
}
pdsep
}
fn possible_dsep_phase<T: ConditionalIndependenceTest>(
graph: &CausalGraph,
data: ArrayView2<f64>,
adj: &[Vec<bool>],
p: usize,
alpha: f64,
max_pdsep_size: usize,
ci_test: &T,
n_tests: &mut usize,
) -> StatsResult<Vec<(usize, usize, Vec<usize>)>> {
let mut removals = Vec::new();
for x in 0..p {
for y in (x + 1)..p {
if !adj[x][y] {
continue;
}
let pdsep_x = possible_dsep(graph, x, y, p);
let pdsep_y = possible_dsep(graph, y, x, p);
let combined: Vec<usize> = pdsep_x
.union(&pdsep_y)
.copied()
.filter(|&k| k != x && k != y)
.collect();
if combined.is_empty() {
continue;
}
let max_size = max_pdsep_size.min(combined.len());
let mut found = false;
for ord in 0..=max_size {
if found {
break;
}
for z_set in subsets(&combined, ord) {
*n_tests += 1;
if ci_test.is_independent(x, y, &z_set, data, alpha)? {
removals.push((x, y, z_set));
found = true;
break;
}
}
}
}
}
Ok(removals)
}
fn apply_fci_rules(
graph: &mut CausalGraph,
adj: &[Vec<bool>],
sep_sets: &HashMap<(usize, usize), Vec<usize>>,
p: usize,
) {
let max_iterations = p * p * 2 + 10;
let mut changed = true;
let mut iterations = 0;
while changed && iterations < max_iterations {
changed = false;
iterations += 1;
changed |= fci_r1(graph, p);
changed |= fci_r2(graph, p);
changed |= fci_r3(graph, adj, p);
changed |= fci_r4(graph, adj, sep_sets, p);
changed |= fci_r5(graph, adj, p);
changed |= fci_r6(graph, p);
changed |= fci_r7(graph, p);
changed |= fci_r8(graph, p);
changed |= fci_r9(graph, p);
changed |= fci_r10(graph, p);
}
}
fn fci_r1(graph: &mut CausalGraph, p: usize) -> bool {
let mut changed = false;
for b in 0..p {
for a in 0..p {
if a == b {
continue;
}
if graph.get_mark_at(a, b) != Some(EdgeMark::Arrow) {
continue;
}
for c in 0..p {
if c == a || c == b {
continue;
}
if !graph.is_adjacent(b, c) {
continue;
}
if graph.is_adjacent(a, c) {
continue;
}
if graph.get_mark_from(b, c) != Some(EdgeMark::Circle) {
continue;
}
let mark_at_c = graph.get_mark_at(b, c).unwrap_or(EdgeMark::Circle);
graph.set_edge(b, c, EdgeMark::Tail, mark_at_c);
changed = true;
}
}
}
changed
}
fn fci_r2(graph: &mut CausalGraph, p: usize) -> bool {
let mut changed = false;
for a in 0..p {
for c in 0..p {
if a == c || !graph.is_adjacent(a, c) {
continue;
}
if graph.get_mark_at(a, c) != Some(EdgeMark::Circle) {
continue;
}
for b in 0..p {
if b == a || b == c {
continue;
}
let case1 = graph.get_mark_from(a, b) == Some(EdgeMark::Tail)
&& graph.get_mark_at(a, b) == Some(EdgeMark::Arrow)
&& graph.get_mark_at(b, c) == Some(EdgeMark::Arrow);
let case2 = graph.get_mark_at(a, b) == Some(EdgeMark::Arrow)
&& graph.get_mark_from(b, c) == Some(EdgeMark::Tail)
&& graph.get_mark_at(b, c) == Some(EdgeMark::Arrow);
if case1 || case2 {
let mark_from_a = graph.get_mark_from(a, c).unwrap_or(EdgeMark::Circle);
graph.set_edge(a, c, mark_from_a, EdgeMark::Arrow);
changed = true;
break;
}
}
}
}
changed
}
fn fci_r3(graph: &mut CausalGraph, adj: &[Vec<bool>], p: usize) -> bool {
let mut changed = false;
for d in 0..p {
for b in 0..p {
if d == b || !graph.is_adjacent(d, b) {
continue;
}
if graph.get_mark_at(d, b) != Some(EdgeMark::Circle) {
continue;
}
let parents_b: Vec<usize> = (0..p)
.filter(|&k| {
k != b
&& k != d
&& graph.is_adjacent(k, b)
&& graph.get_mark_at(k, b) == Some(EdgeMark::Arrow)
})
.collect();
let mut orient = false;
for i in 0..parents_b.len() {
for j in (i + 1)..parents_b.len() {
let a = parents_b[i];
let c = parents_b[j];
if adj[a][c] {
continue;
}
if !graph.is_adjacent(a, d) {
continue;
}
if graph.get_mark_at(a, d) != Some(EdgeMark::Circle) {
continue;
}
if !graph.is_adjacent(c, d) {
continue;
}
if graph.get_mark_at(c, d) != Some(EdgeMark::Circle) {
continue;
}
orient = true;
break;
}
if orient {
break;
}
}
if orient {
let mark_from = graph.get_mark_from(d, b).unwrap_or(EdgeMark::Circle);
graph.set_edge(d, b, mark_from, EdgeMark::Arrow);
changed = true;
}
}
}
changed
}
fn fci_r4(
graph: &mut CausalGraph,
_adj: &[Vec<bool>],
sep_sets: &HashMap<(usize, usize), Vec<usize>>,
p: usize,
) -> bool {
let mut changed = false;
for c in 0..p {
for b in 0..p {
if b == c || !graph.is_adjacent(b, c) {
continue;
}
if graph.get_mark_at(b, c) != Some(EdgeMark::Arrow) {
continue;
}
if graph.get_mark_from(b, c) != Some(EdgeMark::Circle) {
continue;
}
for a in 0..p {
if a == b || a == c || !graph.is_adjacent(a, c) {
continue;
}
}
for a in 0..p {
if a == b || a == c {
continue;
}
if graph.is_adjacent(a, c) {
continue; }
if !graph.is_adjacent(a, b) {
continue;
}
if graph.get_mark_at(a, b) != Some(EdgeMark::Arrow) {
continue;
}
let key = (a.min(c), a.max(c));
let sep = sep_sets.get(&key).cloned().unwrap_or_default();
if sep.contains(&b) {
let mark_from_b = graph.get_mark_from(b, c).unwrap_or(EdgeMark::Circle);
let _mark_at_c = EdgeMark::Arrow;
graph.set_edge(b, c, EdgeMark::Tail, EdgeMark::Arrow);
let _ = mark_from_b;
} else {
graph.set_edge(b, c, EdgeMark::Arrow, EdgeMark::Arrow);
}
changed = true;
break;
}
}
}
changed
}
fn fci_r5(graph: &mut CausalGraph, _adj: &[Vec<bool>], p: usize) -> bool {
let mut changed = false;
for a in 0..p {
for b in (a + 1)..p {
if !graph.is_adjacent(a, b) {
continue;
}
if graph.get_mark_from(a, b) != Some(EdgeMark::Circle)
|| graph.get_mark_at(a, b) != Some(EdgeMark::Circle)
{
continue;
}
if has_uncovered_circle_path(graph, a, b, p) {
graph.set_edge(a, b, EdgeMark::Tail, EdgeMark::Tail);
changed = true;
}
}
}
changed
}
fn fci_r6(graph: &mut CausalGraph, p: usize) -> bool {
let mut changed = false;
for b in 0..p {
for a in 0..p {
if a == b || !graph.is_adjacent(a, b) {
continue;
}
if graph.get_mark_from(a, b) != Some(EdgeMark::Tail)
|| graph.get_mark_at(a, b) != Some(EdgeMark::Tail)
{
continue;
}
for c in 0..p {
if c == a || c == b || !graph.is_adjacent(b, c) {
continue;
}
if graph.get_mark_from(b, c) != Some(EdgeMark::Circle) {
continue;
}
let mark_at_c = graph.get_mark_at(b, c).unwrap_or(EdgeMark::Circle);
graph.set_edge(b, c, EdgeMark::Tail, mark_at_c);
changed = true;
}
}
}
changed
}
fn fci_r7(graph: &mut CausalGraph, p: usize) -> bool {
let mut changed = false;
for b in 0..p {
for a in 0..p {
if a == b || !graph.is_adjacent(a, b) {
continue;
}
if graph.get_mark_from(a, b) != Some(EdgeMark::Tail)
|| graph.get_mark_at(a, b) != Some(EdgeMark::Circle)
{
continue;
}
for c in 0..p {
if c == a || c == b || !graph.is_adjacent(b, c) {
continue;
}
if graph.is_adjacent(a, c) {
continue;
}
if graph.get_mark_from(b, c) != Some(EdgeMark::Circle) {
continue;
}
let mark_at_c = graph.get_mark_at(b, c).unwrap_or(EdgeMark::Circle);
graph.set_edge(b, c, EdgeMark::Tail, mark_at_c);
changed = true;
}
}
}
changed
}
fn fci_r8(graph: &mut CausalGraph, p: usize) -> bool {
let mut changed = false;
for a in 0..p {
for c in 0..p {
if a == c || !graph.is_adjacent(a, c) {
continue;
}
if graph.get_mark_from(a, c) != Some(EdgeMark::Circle)
|| graph.get_mark_at(a, c) != Some(EdgeMark::Arrow)
{
continue;
}
for b in 0..p {
if b == a || b == c {
continue;
}
if graph.get_mark_from(b, c) != Some(EdgeMark::Tail)
|| graph.get_mark_at(b, c) != Some(EdgeMark::Arrow)
{
continue;
}
let mark_at_b = graph.get_mark_at(a, b);
let mark_from_a_to_b = graph.get_mark_from(a, b);
let valid = match (mark_from_a_to_b, mark_at_b) {
(Some(EdgeMark::Tail), Some(EdgeMark::Arrow)) => true, (Some(EdgeMark::Tail), Some(EdgeMark::Circle)) => true, _ => false,
};
if valid {
graph.set_edge(a, c, EdgeMark::Tail, EdgeMark::Arrow);
changed = true;
break;
}
}
}
}
changed
}
fn fci_r9(graph: &mut CausalGraph, p: usize) -> bool {
let mut changed = false;
for a in 0..p {
for c in 0..p {
if a == c || !graph.is_adjacent(a, c) {
continue;
}
if graph.get_mark_from(a, c) != Some(EdgeMark::Circle)
|| graph.get_mark_at(a, c) != Some(EdgeMark::Arrow)
{
continue;
}
if has_directed_path_excluding_direct(graph, a, c, p) {
graph.set_edge(a, c, EdgeMark::Tail, EdgeMark::Arrow);
changed = true;
}
}
}
changed
}
fn fci_r10(graph: &mut CausalGraph, p: usize) -> bool {
let mut changed = false;
for a in 0..p {
for c in 0..p {
if a == c || !graph.is_adjacent(a, c) {
continue;
}
if graph.get_mark_from(a, c) != Some(EdgeMark::Circle)
|| graph.get_mark_at(a, c) != Some(EdgeMark::Arrow)
{
continue;
}
let parents_c: Vec<usize> = (0..p)
.filter(|&k| {
k != a
&& k != c
&& graph.get_mark_from(k, c) == Some(EdgeMark::Tail)
&& graph.get_mark_at(k, c) == Some(EdgeMark::Arrow)
})
.collect();
let mut orient = false;
for i in 0..parents_c.len() {
for j in (i + 1)..parents_c.len() {
let b = parents_c[i];
let d = parents_c[j];
let a_oo_b = graph.get_mark_from(a, b) == Some(EdgeMark::Circle)
&& graph.get_mark_at(a, b) == Some(EdgeMark::Circle);
let a_oo_d = graph.get_mark_from(a, d) == Some(EdgeMark::Circle)
&& graph.get_mark_at(a, d) == Some(EdgeMark::Circle);
if !a_oo_b || !a_oo_d {
continue;
}
if has_directed_path_general(graph, b, a, p)
|| has_directed_path_general(graph, d, a, p)
{
orient = true;
break;
}
}
if orient {
break;
}
}
if orient {
graph.set_edge(a, c, EdgeMark::Tail, EdgeMark::Arrow);
changed = true;
}
}
}
changed
}
fn has_uncovered_circle_path(graph: &CausalGraph, src: usize, dst: usize, p: usize) -> bool {
let mut visited = vec![false; p];
visited[src] = true;
let mut queue = VecDeque::new();
for k in 0..p {
if k == dst || k == src {
continue;
}
if graph.is_adjacent(src, k)
&& graph.get_mark_from(src, k) == Some(EdgeMark::Circle)
&& graph.get_mark_at(src, k) == Some(EdgeMark::Circle)
{
queue.push_back((k, 2usize)); }
}
while let Some((cur, len)) = queue.pop_front() {
if visited[cur] {
continue;
}
visited[cur] = true;
if graph.is_adjacent(cur, dst)
&& graph.get_mark_from(cur, dst) == Some(EdgeMark::Circle)
&& graph.get_mark_at(cur, dst) == Some(EdgeMark::Circle)
&& len + 1 >= 3
{
return true;
}
for next in 0..p {
if visited[next] || next == src || next == dst {
continue;
}
if graph.is_adjacent(cur, next)
&& graph.get_mark_from(cur, next) == Some(EdgeMark::Circle)
&& graph.get_mark_at(cur, next) == Some(EdgeMark::Circle)
{
queue.push_back((next, len + 1));
}
}
}
false
}
fn has_directed_path_excluding_direct(
graph: &CausalGraph,
src: usize,
dst: usize,
p: usize,
) -> bool {
let mut visited = vec![false; p];
let mut stack = Vec::new();
for k in 0..p {
if k != dst
&& graph.get_mark_from(src, k) == Some(EdgeMark::Tail)
&& graph.get_mark_at(src, k) == Some(EdgeMark::Arrow)
{
stack.push(k);
}
}
while let Some(cur) = stack.pop() {
if cur == dst {
return true;
}
if visited[cur] {
continue;
}
visited[cur] = true;
for next in 0..p {
if !visited[next]
&& graph.get_mark_from(cur, next) == Some(EdgeMark::Tail)
&& graph.get_mark_at(cur, next) == Some(EdgeMark::Arrow)
{
stack.push(next);
}
}
}
false
}
fn has_directed_path_general(graph: &CausalGraph, src: usize, dst: usize, p: usize) -> bool {
let mut visited = vec![false; p];
let mut stack = vec![src];
while let Some(cur) = stack.pop() {
if cur == dst && cur != src {
return true;
}
if visited[cur] {
continue;
}
visited[cur] = true;
for next in 0..p {
if !visited[next]
&& graph.get_mark_from(cur, next) == Some(EdgeMark::Tail)
&& graph.get_mark_at(cur, next) == Some(EdgeMark::Arrow)
{
stack.push(next);
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn lcg_uniform(s: &mut u64) -> f64 {
*s = s
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((*s >> 11) as f64) / ((1u64 << 53) as f64)
}
fn lcg_normal(s: &mut u64) -> f64 {
let u1 = lcg_uniform(s).max(1e-15);
let u2 = lcg_uniform(s);
(-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
}
fn chain_data(n: usize, seed: u64) -> Array2<f64> {
let mut data = Array2::<f64>::zeros((n, 3));
let mut lcg = seed;
for i in 0..n {
data[[i, 0]] = lcg_normal(&mut lcg);
data[[i, 1]] = 0.9 * data[[i, 0]] + lcg_normal(&mut lcg) * 0.3;
data[[i, 2]] = 0.9 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
}
data
}
fn latent_confounder_data(n: usize, seed: u64) -> Array2<f64> {
let mut data = Array2::<f64>::zeros((n, 3));
let mut lcg = seed;
for i in 0..n {
let latent = lcg_normal(&mut lcg);
data[[i, 0]] = 0.8 * latent + lcg_normal(&mut lcg) * 0.3;
data[[i, 1]] = 0.8 * latent + lcg_normal(&mut lcg) * 0.3;
data[[i, 2]] = 0.5 * data[[i, 0]] + 0.5 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
}
data
}
#[test]
fn test_fci_chain() {
let data = chain_data(300, 12345);
let fci = FciAlgorithm::new(0.05);
let result = fci.fit(data.view(), &["X", "Y", "Z"]).expect("FCI failed");
assert!(
result.graph.is_adjacent(0, 1),
"X-Y should be adjacent in chain"
);
assert!(
result.graph.is_adjacent(1, 2),
"Y-Z should be adjacent in chain"
);
assert!(
!result.graph.is_adjacent(0, 2),
"X-Z should not be adjacent"
);
}
#[test]
fn test_fci_latent_confounder() {
let data = latent_confounder_data(500, 54321);
let fci = FciAlgorithm::new(0.05);
let result = fci.fit(data.view(), &["X", "Y", "Z"]).expect("FCI failed");
assert!(
result.graph.is_adjacent(0, 1) || result.graph.is_adjacent(0, 2),
"Should find some adjacency"
);
assert!(result.n_tests > 0, "Should perform CI tests");
}
#[test]
fn test_fci_produces_pag() {
let data = chain_data(200, 99999);
let fci = FciAlgorithm::new(0.05);
let result = fci.fit(data.view(), &["X", "Y", "Z"]).expect("FCI failed");
assert_eq!(result.graph.n_nodes(), 3);
}
#[test]
fn test_fci_collider_detection() {
let n = 300;
let mut data = Array2::<f64>::zeros((n, 3));
let mut lcg: u64 = 77777;
for i in 0..n {
data[[i, 0]] = lcg_normal(&mut lcg);
data[[i, 1]] = lcg_normal(&mut lcg);
data[[i, 2]] = 0.7 * data[[i, 0]] + 0.7 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
}
let fci = FciAlgorithm::new(0.05);
let result = fci.fit(data.view(), &["X", "Y", "Z"]).expect("FCI failed");
assert!(result.graph.is_adjacent(0, 2), "X-Z should be adjacent");
assert!(result.graph.is_adjacent(1, 2), "Y-Z should be adjacent");
assert!(
!result.graph.is_adjacent(0, 1),
"X-Y should not be adjacent"
);
assert!(
result.graph.get_mark_at(0, 2) == Some(EdgeMark::Arrow)
|| result.graph.get_mark_at(1, 2) == Some(EdgeMark::Arrow),
"Should detect v-structure at Z"
);
}
#[test]
fn test_fci_possible_dsep() {
let mut graph = CausalGraph::new(&["A", "B", "C", "D"]);
graph.set_edge(0, 1, EdgeMark::Circle, EdgeMark::Arrow);
graph.set_edge(1, 2, EdgeMark::Circle, EdgeMark::Circle);
graph.set_edge(2, 3, EdgeMark::Circle, EdgeMark::Arrow);
let pdsep = possible_dsep(&graph, 0, 3, 4);
assert!(
pdsep.contains(&1) || pdsep.contains(&2),
"Possible-D-SEP should contain intermediate nodes"
);
}
#[test]
fn test_fci_r1_orientation() {
let mut graph = CausalGraph::new(&["A", "B", "C"]);
graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrow); graph.set_edge(1, 2, EdgeMark::Circle, EdgeMark::Circle);
let changed = fci_r1(&mut graph, 3);
assert!(changed, "R1 should make a change");
assert_eq!(
graph.get_mark_from(1, 2),
Some(EdgeMark::Tail),
"R1: b side should be tail"
);
}
#[test]
fn test_fci_edge_marks() {
let mut graph = CausalGraph::new(&["A", "B", "C"]);
graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrow);
graph.set_edge(1, 2, EdgeMark::Arrow, EdgeMark::Arrow);
assert!(graph.is_directed(0, 1), "A -> B");
assert!(graph.is_bidirected(1, 2), "B <-> C");
assert!(!graph.is_undirected(0, 1), "A -> B is not undirected");
}
#[test]
fn test_fci_empty_graph() {
let data = Array2::<f64>::zeros((10, 0));
let fci = FciAlgorithm::new(0.05);
let result = fci.fit(data.view(), &[]).expect("FCI should handle empty");
assert_eq!(result.graph.n_nodes(), 0);
assert_eq!(result.n_tests, 0);
}
#[test]
fn test_fci_two_vars() {
let n = 200;
let mut data = Array2::<f64>::zeros((n, 2));
let mut lcg: u64 = 11111;
for i in 0..n {
data[[i, 0]] = lcg_normal(&mut lcg);
data[[i, 1]] = 0.9 * data[[i, 0]] + lcg_normal(&mut lcg) * 0.3;
}
let fci = FciAlgorithm::new(0.05);
let result = fci.fit(data.view(), &["X", "Y"]).expect("FCI with 2 vars");
assert!(result.graph.is_adjacent(0, 1), "X-Y should be adjacent");
}
}