use std::collections::HashMap;
use crate::feffdat::{FeffDatFile, PathParams, gnxas, sigma2_debye, sigma2_eins};
use crate::lm::{LmConfig, lmdif};
use crate::params::{Expr, ExprError, FuncCtx, ParamError, Parameters, parse};
use crate::dataset::DataSet;
#[derive(Debug, Clone)]
pub enum Spec {
Const(f64),
Expr(String),
}
#[derive(Debug, Clone)]
pub struct PathSpec {
pub degen: Spec,
pub s02: Spec,
pub e0: Spec,
pub ei: Spec,
pub deltar: Spec,
pub sigma2: Spec,
pub third: Spec,
pub fourth: Spec,
}
impl PathSpec {
pub fn defaults(file_degen: f64) -> Self {
PathSpec {
degen: Spec::Const(file_degen),
s02: Spec::Const(1.0),
e0: Spec::Const(0.0),
ei: Spec::Const(0.0),
deltar: Spec::Const(0.0),
sigma2: Spec::Const(0.0),
third: Spec::Const(0.0),
fourth: Spec::Const(0.0),
}
}
}
pub struct FitDataSet {
pub dataset: DataSet,
pub specs: Vec<PathSpec>,
pub epsilon_k: Option<f64>,
}
#[derive(Debug, Clone)]
pub struct Best {
pub name: String,
pub value: f64,
pub stderr: f64,
}
pub const PATH_PNAMES: [&str; 8] = [
"degen", "s02", "e0", "ei", "deltar", "sigma2", "third", "fourth",
];
#[derive(Debug, Clone)]
pub struct PathParam {
pub dataset: usize,
pub path: usize,
pub name: String,
pub value: f64,
pub stderr: f64,
}
#[derive(Debug, Clone)]
pub struct FeffitResult {
pub best: Vec<Best>,
pub derived: Vec<Best>,
pub path_params: Vec<PathParam>,
pub covar: Option<Vec<Vec<f64>>>,
pub nvarys: usize,
pub nfree: usize,
pub ndata: usize,
pub n_idp: f64,
pub nfev: i32,
pub info: i32,
pub chi_square: f64,
pub chi2_reduced: f64,
pub rfactor: f64,
pub aic: f64,
pub bic: f64,
}
#[derive(Debug)]
pub enum FitError {
Param(ParamError),
Expr(ExprError),
Shape(String),
}
impl std::fmt::Display for FitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FitError::Param(e) => write!(f, "{e}"),
FitError::Expr(e) => write!(f, "path parameter expression: {e}"),
FitError::Shape(s) => write!(f, "{s}"),
}
}
}
impl std::error::Error for FitError {}
impl From<ParamError> for FitError {
fn from(e: ParamError) -> Self {
FitError::Param(e)
}
}
impl From<ExprError> for FitError {
fn from(e: ExprError) -> Self {
FitError::Expr(e)
}
}
enum CompiledSpec {
Const(f64),
Expr(Expr),
}
struct CompiledPathSpec {
degen: CompiledSpec,
s02: CompiledSpec,
e0: CompiledSpec,
ei: CompiledSpec,
deltar: CompiledSpec,
sigma2: CompiledSpec,
third: CompiledSpec,
fourth: CompiledSpec,
}
fn compile(spec: &Spec) -> Result<CompiledSpec, ExprError> {
Ok(match spec {
Spec::Const(v) => CompiledSpec::Const(*v),
Spec::Expr(s) => CompiledSpec::Expr(parse(s)?),
})
}
impl CompiledPathSpec {
fn from(spec: &PathSpec) -> Result<Self, ExprError> {
Ok(CompiledPathSpec {
degen: compile(&spec.degen)?,
s02: compile(&spec.s02)?,
e0: compile(&spec.e0)?,
ei: compile(&spec.ei)?,
deltar: compile(&spec.deltar)?,
sigma2: compile(&spec.sigma2)?,
third: compile(&spec.third)?,
fourth: compile(&spec.fourth)?,
})
}
fn eval(
&self,
base: &HashMap<String, f64>,
fdat: &FeffDatFile,
) -> Result<PathParams, ExprError> {
let sym = path_symbols(base, fdat);
let ctx = PathFuncCtx { fdat };
let ev = |c: &CompiledSpec| -> Result<f64, ExprError> {
match c {
CompiledSpec::Const(v) => Ok(*v),
CompiledSpec::Expr(e) => e.eval_ctx(&sym, &ctx),
}
};
Ok(PathParams {
degen: ev(&self.degen)?,
s02: ev(&self.s02)?,
e0: ev(&self.e0)?,
ei: ev(&self.ei)?,
deltar: ev(&self.deltar)?,
sigma2: ev(&self.sigma2)?,
third: ev(&self.third)?,
fourth: ev(&self.fourth)?,
})
}
fn eval_dual(
&self,
base: &HashMap<String, f64>,
grads: &HashMap<String, Vec<f64>>,
nvar: usize,
fdat: &FeffDatFile,
) -> Result<[(f64, Vec<f64>); 8], ExprError> {
let sym = path_symbols(base, fdat);
let ctx = PathFuncCtx { fdat };
let ev = |c: &CompiledSpec| -> Result<(f64, Vec<f64>), ExprError> {
match c {
CompiledSpec::Const(v) => Ok((*v, vec![0.0; nvar])),
CompiledSpec::Expr(e) => e.eval_dual_ctx(&sym, grads, nvar, &ctx),
}
};
Ok([
ev(&self.degen)?,
ev(&self.s02)?,
ev(&self.e0)?,
ev(&self.ei)?,
ev(&self.deltar)?,
ev(&self.sigma2)?,
ev(&self.third)?,
ev(&self.fourth)?,
])
}
}
fn path_symbols(base: &HashMap<String, f64>, fdat: &FeffDatFile) -> HashMap<String, f64> {
let mut sym = base.clone();
sym.insert("reff".to_string(), fdat.reff);
sym.insert("nleg".to_string(), fdat.nleg as f64);
sym.insert("degen".to_string(), fdat.degen);
sym.insert("rmass".to_string(), fdat.rmass());
sym.insert("rnorman".to_string(), fdat.rnorman);
sym.insert("gam_ch".to_string(), fdat.gam_ch);
sym.insert("rs_int".to_string(), fdat.rs_int);
sym.insert("vint".to_string(), fdat.vint);
sym.insert("vmu".to_string(), fdat.vmu);
sym.insert("vfermi".to_string(), fdat.vfermi);
sym
}
struct PathFuncCtx<'a> {
fdat: &'a FeffDatFile,
}
impl FuncCtx for PathFuncCtx<'_> {
fn call(&self, name: &str, args: &[f64]) -> Option<Result<f64, ExprError>> {
let arity2 = |f: &dyn Fn(f64, f64) -> f64| {
if args.len() == 2 {
Ok(f(args[0], args[1]))
} else {
Err(ExprError::Arity(name.to_string()))
}
};
let arity3 = |f: &dyn Fn(f64, f64, f64) -> f64| {
if args.len() == 3 {
Ok(f(args[0], args[1], args[2]))
} else {
Err(ExprError::Arity(name.to_string()))
}
};
match name {
"sigma2_eins" => Some(arity2(&|t, th| sigma2_eins(t, th, &self.fdat.geom))),
"sigma2_debye" => Some(arity2(&|t, th| {
sigma2_debye(t, th, self.fdat.rnorman, &self.fdat.geom)
})),
"gnxas" => Some(arity3(&|r0, sigma, beta| {
gnxas(r0, sigma, beta, self.fdat.reff)
})),
_ => None,
}
}
}
fn apply_params(
params: &mut Parameters,
datasets: &mut [FitDataSet],
compiled: &[Vec<CompiledPathSpec>],
bkg_names: &[Vec<String>],
) -> Result<(), FitError> {
params.update_constraints()?;
let base = params.symbols();
for (di, (fds, cds)) in datasets.iter_mut().zip(compiled).enumerate() {
for (path, cspec) in fds.dataset.paths.iter_mut().zip(cds) {
path.params = cspec.eval(&base, &path.feffdat)?;
}
if !bkg_names[di].is_empty() {
let coefs: Vec<f64> = bkg_names[di].iter().map(|n| base[n]).collect();
fds.dataset.set_bkg_coefs(&coefs);
}
}
Ok(())
}
pub fn feffit(
params: &mut Parameters,
datasets: &mut [FitDataSet],
) -> Result<FeffitResult, FitError> {
let mut compiled: Vec<Vec<CompiledPathSpec>> = Vec::with_capacity(datasets.len());
for fds in datasets.iter_mut() {
if fds.dataset.paths.len() != fds.specs.len() {
return Err(FitError::Shape(format!(
"dataset has {} paths but {} specs",
fds.dataset.paths.len(),
fds.specs.len()
)));
}
fds.dataset.prepare_fit(fds.epsilon_k);
let cspecs = fds
.specs
.iter()
.map(CompiledPathSpec::from)
.collect::<Result<Vec<_>, _>>()?;
compiled.push(cspecs);
}
let mut bkg_names: Vec<Vec<String>> = Vec::with_capacity(datasets.len());
for (di, fds) in datasets.iter().enumerate() {
if fds.dataset.refine_bkg() {
let names: Vec<String> = (0..fds.dataset.bkg_nspline())
.map(|i| format!("bkg{i:02}_ds{di}"))
.collect();
for name in &names {
params.add_var(name, 0.0);
}
bkg_names.push(names);
} else {
bkg_names.push(Vec::new());
}
}
let var_names = params.var_names();
let nvarys = var_names.len();
params.update_constraints()?;
let x0 = params.internal_x0();
apply_params(params, datasets, &compiled, &bkg_names)?;
let cfg = LmConfig {
ftol: 1.0e-6,
xtol: 1.0e-6,
gtol: 1.0e-6,
maxfev: 4000 * (nvarys as i32 + 1),
epsfcn: 1.0e-10,
factor: 100.0,
};
let result = {
let params = &mut *params;
let datasets = &mut *datasets;
let compiled = &compiled;
let bkg_names = &bkg_names;
let fcn = |vars: &[f64]| -> Vec<f64> {
params.set_var_internal(vars);
apply_params(params, datasets, compiled, bkg_names)
.expect("constraint/expression resolution failed mid-fit");
let mut out = Vec::new();
for fds in datasets.iter_mut() {
out.extend(fds.dataset.residual(false));
}
out
};
lmdif(fcn, &x0, &cfg)
};
params.set_var_internal(&result.x);
apply_params(params, datasets, &compiled, &bkg_names)?;
let ndata = result.fvec.len();
let nfree = ndata.saturating_sub(nvarys);
let chisqr = result.fnorm * result.fnorm; let n_idp: f64 = datasets.iter().map(|d| d.dataset.n_idp()).sum();
let chi_square = chisqr * n_idp / ndata as f64;
let chi2_reduced = chi_square / (n_idp - nvarys as f64);
let mut dat_ss = 0.0;
for fds in datasets.iter_mut() {
for v in fds.dataset.residual(true) {
dat_ss += v * v;
}
}
let rfactor = chisqr / dat_ss;
let neg2_loglikel = n_idp * (chi_square / n_idp).ln();
let aic = neg2_loglikel + 2.0 * nvarys as f64;
let bic = neg2_loglikel + n_idp.ln() * nvarys as f64;
let err_scale = chisqr / (n_idp - nvarys as f64);
let grad = params.var_scale_gradients(&result.x);
let cov_ext: Option<Vec<Vec<f64>>> = result.covar().as_ref().map(|c| {
(0..nvarys)
.map(|i| (0..nvarys).map(|j| c[i][j] * grad[i] * grad[j]).collect())
.collect()
});
let covar: Option<Vec<Vec<f64>>> = cov_ext.as_ref().map(|c| {
c.iter()
.map(|row| row.iter().map(|v| v * err_scale).collect())
.collect()
});
let best: Vec<Best> = var_names
.iter()
.enumerate()
.map(|(i, name)| {
let stderr = cov_ext
.as_ref()
.map(|c| (c[i][i] * err_scale).sqrt())
.unwrap_or(f64::NAN);
Best {
name: name.clone(),
value: params.value(name).unwrap(),
stderr,
}
})
.collect();
let propagate = |g: &[f64]| -> f64 {
match &covar {
Some(c) => {
let mut s = 0.0;
for i in 0..nvarys {
for j in 0..nvarys {
s += g[i] * c[i][j] * g[j];
}
}
s.max(0.0).sqrt()
}
None => f64::NAN,
}
};
let value_grads = params.value_grads()?;
let grads: HashMap<String, Vec<f64>> = value_grads
.iter()
.map(|(k, (_, g))| (k.clone(), g.clone()))
.collect();
let derived: Vec<Best> = params
.expr_names()
.into_iter()
.map(|name| {
let (value, g) = &value_grads[&name];
Best {
name,
value: *value,
stderr: propagate(g),
}
})
.collect();
let base = params.symbols();
let mut path_params = Vec::new();
for (di, (fds, cds)) in datasets.iter().zip(&compiled).enumerate() {
for (pi, (path, cspec)) in fds.dataset.paths.iter().zip(cds).enumerate() {
let vgs = cspec.eval_dual(&base, &grads, nvarys, &path.feffdat)?;
for (k, (value, g)) in vgs.iter().enumerate() {
path_params.push(PathParam {
dataset: di,
path: pi,
name: PATH_PNAMES[k].to_string(),
value: *value,
stderr: propagate(g),
});
}
}
}
Ok(FeffitResult {
best,
derived,
path_params,
covar,
nvarys,
nfree,
ndata,
n_idp,
nfev: result.nfev,
info: result.info,
chi_square,
chi2_reduced,
rfactor,
aic,
bic,
})
}
#[cfg(test)]
mod gnxas_wiring_tests {
use super::*;
use std::path::PathBuf;
fn test_fdat() -> FeffDatFile {
FeffDatFile::from_path(
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/data/feff0001.dat"),
)
.unwrap()
}
#[test]
fn gnxas_routed_with_path_reff() {
let fdat = test_fdat();
let ctx = PathFuncCtx { fdat: &fdat };
let got = ctx.call("gnxas", &[2.5, 0.05, 0.30]).unwrap().unwrap();
let want = gnxas(2.5, 0.05, 0.30, fdat.reff);
assert_eq!(got, want);
}
#[test]
fn gnxas_wrong_arity_is_error() {
let fdat = test_fdat();
let ctx = PathFuncCtx { fdat: &fdat };
assert!(matches!(
ctx.call("gnxas", &[2.5, 0.05]),
Some(Err(ExprError::Arity(_)))
));
assert!(matches!(
ctx.call("gnxas", &[2.5, 0.05, 0.30, 2.55]),
Some(Err(ExprError::Arity(_)))
));
}
}