use std::collections::HashMap;
use std::path::PathBuf;
use feffit::params::{Parameters, parse};
fn data_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/data")
}
fn read(name: &str) -> String {
std::fs::read_to_string(data_dir().join(name)).unwrap()
}
#[test]
fn expressions_match_asteval() {
let text = read("ref_params_expr.txt");
let mut sym: HashMap<String, f64> = HashMap::new();
let mut n_checked = 0;
for line in text.lines() {
if let Some(rest) = line.strip_prefix("#sym ") {
let mut it = rest.split_whitespace();
let name = it.next().unwrap().to_string();
let val: f64 = it.next().unwrap().parse().unwrap();
sym.insert(name, val);
}
}
for line in text.lines() {
if let Some(rest) = line.strip_prefix("#expr ") {
let (val_str, expr) = rest.split_once(" :: ").expect("malformed #expr line");
let want: f64 = val_str.trim().parse().unwrap();
let ast = parse(expr.trim()).unwrap_or_else(|e| panic!("parse '{expr}': {e}"));
let got = ast
.eval(&sym)
.unwrap_or_else(|e| panic!("eval '{expr}': {e}"));
let d = (got - want).abs();
let rel = d / want.abs().max(1e-300);
println!("{expr:>32} got={got:.12e} want={want:.12e} rel={rel:.2e}");
assert!(
rel < 1e-12 || d < 1e-15,
"expr '{expr}': got {got}, want {want} (rel {rel:.2e})"
);
n_checked += 1;
}
}
assert!(n_checked >= 20, "expected at least 20 expression cases");
}
#[test]
fn constraints_match_lmfit() {
let text = read("ref_params_resolve.txt");
let mut p = Parameters::new();
let mut expect: Vec<(String, f64)> = Vec::new();
for line in text.lines() {
if let Some(rest) = line.strip_prefix("#const ") {
let mut it = rest.split_whitespace();
let name = it.next().unwrap();
let val: f64 = it.next().unwrap().parse().unwrap();
p.set_const(name, val);
} else if let Some(rest) = line.strip_prefix("#var ") {
let mut it = rest.split_whitespace();
let name = it.next().unwrap();
let val: f64 = it.next().unwrap().parse().unwrap();
p.add_var(name, val);
} else if let Some(rest) = line.strip_prefix("#varb ") {
let mut it = rest.split_whitespace();
let name = it.next().unwrap();
let val: f64 = it.next().unwrap().parse().unwrap();
let min: f64 = it.next().unwrap().parse().unwrap();
let max: f64 = it.next().unwrap().parse().unwrap();
p.add_var_bounded(name, val, min, max);
} else if let Some(rest) = line.strip_prefix("#fix ") {
let mut it = rest.split_whitespace();
let name = it.next().unwrap();
let val: f64 = it.next().unwrap().parse().unwrap();
p.add_fixed(name, val);
} else if let Some(rest) = line.strip_prefix("#expr ") {
let (name, expr) = rest.split_once(char::is_whitespace).unwrap();
p.add_expr(name, expr.trim());
} else if let Some(rest) = line.strip_prefix("#expect ") {
let (name, val) = rest.split_once(char::is_whitespace).unwrap();
expect.push((name.to_string(), val.trim().parse().unwrap()));
}
}
p.update_constraints()
.expect("constraint resolution failed");
for (name, want) in &expect {
let got = p
.value(name)
.unwrap_or_else(|| panic!("missing param {name}"));
let d = (got - want).abs();
println!("{name:>10} got={got:.12e} want={want:.12e} |Δ|={d:.2e}");
assert!(
d < 1e-12,
"param '{name}': got {got}, want {want} (|Δ| {d:.2e})"
);
}
assert_eq!(
p.n_vary(),
4,
"amp, alpha, s02, tight are the free variables"
);
}
#[test]
fn gradients_match_finite_difference() {
let build = || {
let mut p = Parameters::new();
p.add_var("x", 0.7);
p.add_var("y", 1.3);
p.add_var("z", 0.4);
p.add_expr("u", "sin(x) * exp(y)");
p.add_expr("v", "x**y + atan2(z, x)");
p.add_expr("w", "u * v + log(x + 2.0)"); p.add_expr("pp", "sqrt(x*x + y*y)");
p
};
let targets = ["u", "v", "w", "pp"];
let var_names = ["x", "y", "z"];
let x0 = [0.7_f64, 1.3, 0.4];
let h = 1e-6;
let mut p = build();
p.update_constraints().unwrap();
let vg = p.value_grads().unwrap();
for t in &targets {
let (val, _) = &vg[*t];
let want = p.value(t).unwrap();
assert!((val - want).abs() < 1e-12, "{t}: AD value {val} vs {want}");
}
let fd_value = |xs: &[f64], t: &str| -> f64 {
let mut q = build();
q.set_var_values(xs);
q.update_constraints().unwrap();
q.value(t).unwrap()
};
for t in &targets {
let (_, g) = &vg[*t];
for (i, vn) in var_names.iter().enumerate() {
let mut xp = x0;
xp[i] += h;
let mut xm = x0;
xm[i] -= h;
let fd = (fd_value(&xp, t) - fd_value(&xm, t)) / (2.0 * h);
let d = (g[i] - fd).abs();
let rel = d / fd.abs().max(1e-12);
println!("d{t}/d{vn} ad={:.10e} fd={:.10e} rel={rel:.2e}", g[i], fd);
assert!(rel < 1e-6 || d < 1e-8, "d{t}/d{vn}: ad {} vs fd {fd}", g[i]);
}
}
}
#[test]
fn cycle_is_detected() {
let mut p = Parameters::new();
p.add_expr("a", "b + 1");
p.add_expr("b", "a * 2");
assert!(p.update_constraints().is_err(), "expected a cycle error");
}