lmm 0.1.2

A language agnostic framework for emulating reality.
Documentation
use crate::equation::Expression;
use crate::tensor::Tensor;
use std::collections::HashMap;

pub fn symbolic_diff(expr: &Expression, var: &str) -> Expression {
    expr.symbolic_diff(var)
}

pub fn simplify(expr: &Expression) -> Expression {
    expr.simplify()
}

pub fn complexity_score(expr: &Expression) -> usize {
    expr.complexity()
}

pub fn format_expr(expr: &Expression) -> String {
    expr.to_string()
}

pub fn numerical_gradient(
    expr: &Expression,
    var: &str,
    data: &[(&str, f64)],
    h: f64,
) -> Option<f64> {
    let mut bindings: HashMap<String, f64> =
        data.iter().map(|(k, v)| ((*k).to_string(), *v)).collect();
    let x = *bindings.get(var)?;
    bindings.insert(var.to_string(), x + h);
    let f_plus = expr.evaluate(&bindings).ok()?;
    bindings.insert(var.to_string(), x - h);
    let f_minus = expr.evaluate(&bindings).ok()?;
    Some((f_plus - f_minus) / (2.0 * h))
}

pub fn jacobian(exprs: &[Expression], vars: &[&str], point: &Tensor) -> Vec<Vec<f64>> {
    let bindings: HashMap<String, f64> = vars
        .iter()
        .zip(point.data.iter())
        .map(|(k, v)| ((*k).to_string(), *v))
        .collect();
    exprs
        .iter()
        .map(|expr| {
            vars.iter()
                .map(|var| {
                    let diff = expr.symbolic_diff(var).simplify();
                    diff.evaluate(&bindings).unwrap_or(0.0)
                })
                .collect()
        })
        .collect()
}

pub fn compose(outer: &Expression, inner: &Expression, var: &str) -> Expression {
    substitute(outer, var, inner)
}

fn substitute(expr: &Expression, var: &str, replacement: &Expression) -> Expression {
    match expr {
        Expression::Variable(name) if name == var => replacement.clone(),
        Expression::Variable(_) | Expression::Constant(_) => expr.clone(),
        Expression::Add(l, r) => Expression::Add(
            Box::new(substitute(l, var, replacement)),
            Box::new(substitute(r, var, replacement)),
        ),
        Expression::Sub(l, r) => Expression::Sub(
            Box::new(substitute(l, var, replacement)),
            Box::new(substitute(r, var, replacement)),
        ),
        Expression::Mul(l, r) => Expression::Mul(
            Box::new(substitute(l, var, replacement)),
            Box::new(substitute(r, var, replacement)),
        ),
        Expression::Div(l, r) => Expression::Div(
            Box::new(substitute(l, var, replacement)),
            Box::new(substitute(r, var, replacement)),
        ),
        Expression::Pow(b, e) => Expression::Pow(
            Box::new(substitute(b, var, replacement)),
            Box::new(substitute(e, var, replacement)),
        ),
        Expression::Neg(e) => Expression::Neg(Box::new(substitute(e, var, replacement))),
        Expression::Abs(e) => Expression::Abs(Box::new(substitute(e, var, replacement))),
        Expression::Sin(e) => Expression::Sin(Box::new(substitute(e, var, replacement))),
        Expression::Cos(e) => Expression::Cos(Box::new(substitute(e, var, replacement))),
        Expression::Exp(e) => Expression::Exp(Box::new(substitute(e, var, replacement))),
        Expression::Log(e) => Expression::Log(Box::new(substitute(e, var, replacement))),
    }
}