use super::error::CausalError;
#[derive(Debug, Clone)]
pub struct ObservationalData {
variables: Vec<String>,
samples: Vec<Vec<f64>>,
}
impl ObservationalData {
pub fn new(variables: Vec<String>) -> Self {
Self {
variables,
samples: Vec::new(),
}
}
pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<(), CausalError> {
if sample.len() != self.variables.len() {
return Err(CausalError::DimensionMismatch);
}
self.samples.push(sample);
Ok(())
}
pub fn n_samples(&self) -> usize {
self.samples.len()
}
pub fn n_variables(&self) -> usize {
self.variables.len()
}
pub(super) fn var_index(&self, var: &str) -> Option<usize> {
self.variables.iter().position(|v| v == var)
}
pub fn column(&self, var: &str) -> Option<Vec<f64>> {
let idx = self.var_index(var)?;
Some(self.samples.iter().map(|s| s[idx]).collect())
}
pub fn mean(&self, var: &str) -> Option<f64> {
let col = self.column(var)?;
if col.is_empty() {
return None;
}
Some(col.iter().sum::<f64>() / col.len() as f64)
}
pub fn conditional_mean(
&self,
outcome: &str,
condition_var: &str,
condition_val: f64,
) -> Option<f64> {
let out_idx = self.var_index(outcome)?;
let cond_idx = self.var_index(condition_var)?;
let filtered: Vec<f64> = self
.samples
.iter()
.filter(|s| (s[cond_idx] - condition_val).abs() < 1e-9)
.map(|s| s[out_idx])
.collect();
if filtered.is_empty() {
return None;
}
Some(filtered.iter().sum::<f64>() / filtered.len() as f64)
}
pub fn variables(&self) -> &[String] {
&self.variables
}
pub fn samples(&self) -> &[Vec<f64>] {
&self.samples
}
}
#[derive(Debug, Clone)]
pub struct Intervention {
pub variable: String,
pub value: f64,
}
#[derive(Debug, Clone)]
pub struct TreatmentEffect {
pub ate: f64,
pub ate_treated: f64,
pub ate_control: f64,
pub estimator: String,
pub n_samples: usize,
pub confidence_interval: Option<(f64, f64)>,
}
#[derive(Debug, Clone)]
pub struct BackdoorAdjustment {
pub adjustment_set: Vec<String>,
pub valid: bool,
}