use std::fmt;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum ProbExpr {
Joint(Vec<String>),
Conditional {
numerator: Box<ProbExpr>,
denominator: Box<ProbExpr>,
},
Marginal {
expr: Box<ProbExpr>,
summand_vars: Vec<String>,
},
Interventional {
expr: Box<ProbExpr>,
do_vars: Vec<String>,
},
Product(Vec<ProbExpr>),
Quotient {
num: Box<ProbExpr>,
den: Box<ProbExpr>,
},
}
impl fmt::Display for ProbExpr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ProbExpr::Joint(vars) => {
if vars.is_empty() {
write!(f, "P()")
} else {
write!(f, "P({})", vars.join(", "))
}
}
ProbExpr::Conditional {
numerator,
denominator,
} => {
match (numerator.as_ref(), denominator.as_ref()) {
(ProbExpr::Joint(num_vars), ProbExpr::Joint(den_vars)) => {
write!(f, "P({} | {})", num_vars.join(", "), den_vars.join(", "))
}
_ => {
write!(f, "[{numerator}] / [{denominator}]")
}
}
}
ProbExpr::Marginal { expr, summand_vars } => {
if summand_vars.is_empty() {
write!(f, "{expr}")
} else {
write!(f, "Σ_{{{vars}}} {expr}", vars = summand_vars.join(", "))
}
}
ProbExpr::Interventional { expr, do_vars } => {
if do_vars.is_empty() {
write!(f, "{expr}")
} else {
match expr.as_ref() {
ProbExpr::Joint(outcome_vars) => {
write!(
f,
"P({} | do({}))",
outcome_vars.join(", "),
do_vars.join(", ")
)
}
_ => {
write!(f, "P({expr} | do({do}))", do = do_vars.join(", "))
}
}
}
}
ProbExpr::Product(exprs) => {
if exprs.is_empty() {
write!(f, "1")
} else if exprs.len() == 1 {
write!(f, "{}", exprs[0])
} else {
let parts: Vec<String> = exprs.iter().map(|e| format!("{e}")).collect();
write!(f, "{}", parts.join(" · "))
}
}
ProbExpr::Quotient { num, den } => {
write!(f, "[{num}] / [{den}]")
}
}
}
}
impl ProbExpr {
pub fn simplify(&self) -> ProbExpr {
match self {
ProbExpr::Marginal { expr, summand_vars } => {
let inner = expr.simplify();
if summand_vars.is_empty() {
inner
} else {
ProbExpr::Marginal {
expr: Box::new(inner),
summand_vars: summand_vars.clone(),
}
}
}
ProbExpr::Product(exprs) => {
let mut flat: Vec<ProbExpr> = Vec::new();
for e in exprs {
let s = e.simplify();
match s {
ProbExpr::Product(sub) => flat.extend(sub),
other => flat.push(other),
}
}
if flat.is_empty() {
ProbExpr::Joint(Vec::new())
} else if flat.len() == 1 {
flat.remove(0)
} else {
ProbExpr::Product(flat)
}
}
ProbExpr::Conditional {
numerator,
denominator,
} => ProbExpr::Conditional {
numerator: Box::new(numerator.simplify()),
denominator: Box::new(denominator.simplify()),
},
ProbExpr::Interventional { expr, do_vars } => {
let inner = expr.simplify();
if do_vars.is_empty() {
inner
} else {
ProbExpr::Interventional {
expr: Box::new(inner),
do_vars: do_vars.clone(),
}
}
}
ProbExpr::Quotient { num, den } => ProbExpr::Quotient {
num: Box::new(num.simplify()),
den: Box::new(den.simplify()),
},
ProbExpr::Joint(_) => self.clone(),
}
}
pub fn p_do(y_vars: Vec<String>, x_vars: Vec<String>) -> Self {
ProbExpr::Interventional {
expr: Box::new(ProbExpr::Joint(y_vars)),
do_vars: x_vars,
}
}
pub fn p(vars: Vec<String>) -> Self {
ProbExpr::Joint(vars)
}
pub fn marginal(expr: ProbExpr, summand_vars: Vec<String>) -> Self {
ProbExpr::Marginal {
expr: Box::new(expr),
summand_vars,
}
}
pub fn conditional(y_vars: Vec<String>, z_vars: Vec<String>) -> Self {
ProbExpr::Conditional {
numerator: Box::new(ProbExpr::Joint(y_vars)),
denominator: Box::new(ProbExpr::Joint(z_vars)),
}
}
pub fn product(exprs: Vec<ProbExpr>) -> Self {
ProbExpr::Product(exprs)
}
pub fn joint_vars(&self) -> Option<&[String]> {
match self {
ProbExpr::Joint(vars) => Some(vars),
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn s(s: &str) -> String {
s.to_owned()
}
#[test]
fn test_display_joint() {
let e = ProbExpr::p(vec![s("Y")]);
assert_eq!(format!("{e}"), "P(Y)");
}
#[test]
fn test_display_joint_multi() {
let e = ProbExpr::p(vec![s("X"), s("Y"), s("Z")]);
assert_eq!(format!("{e}"), "P(X, Y, Z)");
}
#[test]
fn test_display_conditional() {
let e = ProbExpr::conditional(vec![s("Y")], vec![s("X")]);
assert_eq!(format!("{e}"), "P(Y | X)");
}
#[test]
fn test_display_interventional() {
let e = ProbExpr::p_do(vec![s("Y")], vec![s("X")]);
assert_eq!(format!("{e}"), "P(Y | do(X))");
}
#[test]
fn test_display_interventional_multiple_do() {
let e = ProbExpr::p_do(vec![s("Y")], vec![s("X1"), s("X2")]);
assert_eq!(format!("{e}"), "P(Y | do(X1, X2))");
}
#[test]
fn test_display_marginal() {
let inner = ProbExpr::p(vec![s("Y"), s("Z")]);
let e = ProbExpr::marginal(inner, vec![s("Z")]);
assert_eq!(format!("{e}"), "Σ_{Z} P(Y, Z)");
}
#[test]
fn test_display_product() {
let e1 = ProbExpr::p(vec![s("X")]);
let e2 = ProbExpr::p(vec![s("Y")]);
let prod = ProbExpr::product(vec![e1, e2]);
assert_eq!(format!("{prod}"), "P(X) · P(Y)");
}
#[test]
fn test_display_empty_product() {
let prod = ProbExpr::product(vec![]);
let simplified = prod.simplify();
assert!(matches!(simplified, ProbExpr::Joint(ref v) if v.is_empty()));
}
#[test]
fn test_simplify_trivial_marginal() {
let inner = ProbExpr::p(vec![s("Y")]);
let marginal = ProbExpr::marginal(inner.clone(), vec![]);
let simplified = marginal.simplify();
assert_eq!(simplified, inner);
}
#[test]
fn test_simplify_nested_product_flattening() {
let e1 = ProbExpr::p(vec![s("X")]);
let e2 = ProbExpr::p(vec![s("Y")]);
let e3 = ProbExpr::p(vec![s("Z")]);
let inner_prod = ProbExpr::product(vec![e1.clone(), e2.clone()]);
let outer_prod = ProbExpr::product(vec![inner_prod, e3.clone()]);
let simplified = outer_prod.simplify();
match simplified {
ProbExpr::Product(ref terms) => {
assert_eq!(terms.len(), 3);
}
_ => panic!("Expected Product with 3 terms, got: {simplified:?}"),
}
}
#[test]
fn test_simplify_single_element_product() {
let e = ProbExpr::p(vec![s("Y")]);
let prod = ProbExpr::product(vec![e.clone()]);
let simplified = prod.simplify();
assert_eq!(simplified, e);
}
#[test]
fn test_conditional_display_complex() {
let e = ProbExpr::conditional(vec![s("Y")], vec![s("X"), s("Z")]);
assert_eq!(format!("{e}"), "P(Y | X, Z)");
}
#[test]
fn test_interventional_with_marginal() {
let inner = ProbExpr::Interventional {
expr: Box::new(ProbExpr::Joint(vec![s("Y")])),
do_vars: vec![s("X")],
};
let marg = ProbExpr::marginal(inner, vec![s("M")]);
let disp = format!("{marg}");
assert!(disp.contains("Σ_{M}"), "Should contain summation: {disp}");
assert!(disp.contains("do(X)"), "Should contain do(X): {disp}");
}
}