feffit 0.1.0

Pure-Rust EXAFS toolkit — data reduction (pre-edge/normalize/AUTOBK), Fourier transforms, FEFF path fitting (feffit), and feff.inp build/run; a port of larch.xafs
//! Parity tests for `params` against `asteval` (expression evaluation) and
//! `lmfit.Parameters` (constraint resolution), via references generated by
//! `scripts/ref_params.py`.

use std::collections::HashMap;
use std::path::PathBuf;

use feffit::params::{Parameters, parse};

fn data_dir() -> PathBuf {
    PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/data")
}

fn read(name: &str) -> String {
    std::fs::read_to_string(data_dir().join(name)).unwrap()
}

#[test]
fn expressions_match_asteval() {
    let text = read("ref_params_expr.txt");
    let mut sym: HashMap<String, f64> = HashMap::new();
    let mut n_checked = 0;
    // first pass: symbols
    for line in text.lines() {
        if let Some(rest) = line.strip_prefix("#sym ") {
            let mut it = rest.split_whitespace();
            let name = it.next().unwrap().to_string();
            let val: f64 = it.next().unwrap().parse().unwrap();
            sym.insert(name, val);
        }
    }
    // second pass: expressions
    for line in text.lines() {
        if let Some(rest) = line.strip_prefix("#expr ") {
            let (val_str, expr) = rest.split_once(" :: ").expect("malformed #expr line");
            let want: f64 = val_str.trim().parse().unwrap();
            let ast = parse(expr.trim()).unwrap_or_else(|e| panic!("parse '{expr}': {e}"));
            let got = ast
                .eval(&sym)
                .unwrap_or_else(|e| panic!("eval '{expr}': {e}"));
            let d = (got - want).abs();
            let rel = d / want.abs().max(1e-300);
            println!("{expr:>32}  got={got:.12e} want={want:.12e} rel={rel:.2e}");
            assert!(
                rel < 1e-12 || d < 1e-15,
                "expr '{expr}': got {got}, want {want} (rel {rel:.2e})"
            );
            n_checked += 1;
        }
    }
    assert!(n_checked >= 20, "expected at least 20 expression cases");
}

#[test]
fn constraints_match_lmfit() {
    let text = read("ref_params_resolve.txt");
    let mut p = Parameters::new();
    let mut expect: Vec<(String, f64)> = Vec::new();

    for line in text.lines() {
        if let Some(rest) = line.strip_prefix("#const ") {
            let mut it = rest.split_whitespace();
            let name = it.next().unwrap();
            let val: f64 = it.next().unwrap().parse().unwrap();
            p.set_const(name, val);
        } else if let Some(rest) = line.strip_prefix("#var ") {
            let mut it = rest.split_whitespace();
            let name = it.next().unwrap();
            let val: f64 = it.next().unwrap().parse().unwrap();
            p.add_var(name, val);
        } else if let Some(rest) = line.strip_prefix("#varb ") {
            let mut it = rest.split_whitespace();
            let name = it.next().unwrap();
            let val: f64 = it.next().unwrap().parse().unwrap();
            let min: f64 = it.next().unwrap().parse().unwrap();
            let max: f64 = it.next().unwrap().parse().unwrap();
            p.add_var_bounded(name, val, min, max);
        } else if let Some(rest) = line.strip_prefix("#fix ") {
            let mut it = rest.split_whitespace();
            let name = it.next().unwrap();
            let val: f64 = it.next().unwrap().parse().unwrap();
            p.add_fixed(name, val);
        } else if let Some(rest) = line.strip_prefix("#expr ") {
            // "name <expr to end of line>"
            let (name, expr) = rest.split_once(char::is_whitespace).unwrap();
            p.add_expr(name, expr.trim());
        } else if let Some(rest) = line.strip_prefix("#expect ") {
            let (name, val) = rest.split_once(char::is_whitespace).unwrap();
            expect.push((name.to_string(), val.trim().parse().unwrap()));
        }
    }

    p.update_constraints()
        .expect("constraint resolution failed");

    for (name, want) in &expect {
        let got = p
            .value(name)
            .unwrap_or_else(|| panic!("missing param {name}"));
        let d = (got - want).abs();
        println!("{name:>10} got={got:.12e} want={want:.12e} |Δ|={d:.2e}");
        assert!(
            d < 1e-12,
            "param '{name}': got {got}, want {want} (|Δ| {d:.2e})"
        );
    }
    assert_eq!(
        p.n_vary(),
        4,
        "amp, alpha, s02, tight are the free variables"
    );
}

#[test]
fn gradients_match_finite_difference() {
    // `value_grads()` (forward-mode AD via `eval_dual`/`call_func_dual`) against
    // central finite differences — an independent oracle. Exercises nonlinear,
    // multi-variable, chained constraints (`w` depends on `u` and `v`), covering
    // the AD paths the linear feffit fit never touches (sin/exp/pow/atan2/sqrt/
    // log + the chain rule). The end-to-end feffit parity test only has linear,
    // single-variable path expressions, so this is where the AD math is checked.
    let build = || {
        let mut p = Parameters::new();
        p.add_var("x", 0.7);
        p.add_var("y", 1.3);
        p.add_var("z", 0.4);
        p.add_expr("u", "sin(x) * exp(y)");
        p.add_expr("v", "x**y + atan2(z, x)");
        p.add_expr("w", "u * v + log(x + 2.0)"); // chained: depends on u and v
        p.add_expr("pp", "sqrt(x*x + y*y)");
        p
    };

    let targets = ["u", "v", "w", "pp"];
    let var_names = ["x", "y", "z"];
    let x0 = [0.7_f64, 1.3, 0.4];
    let h = 1e-6;

    let mut p = build();
    p.update_constraints().unwrap();
    let vg = p.value_grads().unwrap();

    // AD value must agree with the resolved constraint value
    for t in &targets {
        let (val, _) = &vg[*t];
        let want = p.value(t).unwrap();
        assert!((val - want).abs() < 1e-12, "{t}: AD value {val} vs {want}");
    }

    let fd_value = |xs: &[f64], t: &str| -> f64 {
        let mut q = build();
        q.set_var_values(xs);
        q.update_constraints().unwrap();
        q.value(t).unwrap()
    };

    for t in &targets {
        let (_, g) = &vg[*t];
        for (i, vn) in var_names.iter().enumerate() {
            let mut xp = x0;
            xp[i] += h;
            let mut xm = x0;
            xm[i] -= h;
            let fd = (fd_value(&xp, t) - fd_value(&xm, t)) / (2.0 * h);
            let d = (g[i] - fd).abs();
            let rel = d / fd.abs().max(1e-12);
            println!("d{t}/d{vn} ad={:.10e} fd={:.10e} rel={rel:.2e}", g[i], fd);
            assert!(rel < 1e-6 || d < 1e-8, "d{t}/d{vn}: ad {} vs fd {fd}", g[i]);
        }
    }
}

#[test]
fn cycle_is_detected() {
    let mut p = Parameters::new();
    p.add_expr("a", "b + 1");
    p.add_expr("b", "a * 2");
    assert!(p.update_constraints().is_err(), "expected a cycle error");
}