use std::fmt;
use argmin::core::CostFunction;
use argmin::core::Executor;
use argmin::core::Gradient;
use argmin::core::State;
use argmin::solver::linesearch::MoreThuenteLineSearch;
use argmin::solver::quasinewton::LBFGS;
use ndarray::Array1;
use ndarray::ArrayView1;
use parking_lot::Mutex;
use super::DiffusionModel;
use super::density::DensityApprox;
#[derive(Clone, Debug)]
pub struct MleResult {
pub params: Array1<f64>,
pub param_names: Vec<String>,
pub log_likelihood: f64,
pub sample_size: usize,
pub aic: f64,
pub bic: f64,
}
impl fmt::Display for MleResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "MLE Result")?;
writeln!(f, "----------")?;
for (name, val) in self.param_names.iter().zip(&self.params) {
writeln!(f, " {:<12} = {:.6}", name, val)?;
}
writeln!(f, " log-lik = {:.4}", self.log_likelihood)?;
writeln!(f, " AIC = {:.4}", self.aic)?;
writeln!(f, " BIC = {:.4}", self.bic)?;
writeln!(f, " sample size = {}", self.sample_size)?;
Ok(())
}
}
struct MleProblem<'a> {
model: Mutex<&'a mut dyn DiffusionModel>,
sample: ArrayView1<'a, f64>,
dt: f64,
density: DensityApprox,
bounds: Vec<(f64, f64)>,
}
impl MleProblem<'_> {
fn clamp(&self, params: &[f64]) -> Vec<f64> {
params
.iter()
.enumerate()
.map(|(i, &x)| x.clamp(self.bounds[i].0, self.bounds[i].1))
.collect()
}
fn eval_nll(&self, params: &[f64]) -> f64 {
let clamped = self.clamp(params);
let mut model = self.model.lock();
model.set_params(&clamped);
let mut sum = 0.0;
for i in 1..self.sample.len() {
let t0 = (i - 1) as f64 * self.dt;
let d = self
.density
.density(&**model, self.sample[i - 1], self.sample[i], t0, self.dt);
sum -= d.max(1e-30).ln();
}
if sum.is_finite() { sum } else { 1e30 }
}
}
impl CostFunction for MleProblem<'_> {
type Param = Vec<f64>;
type Output = f64;
fn cost(&self, params: &Self::Param) -> Result<Self::Output, argmin::core::Error> {
Ok(self.eval_nll(params))
}
}
impl Gradient for MleProblem<'_> {
type Param = Vec<f64>;
type Gradient = Vec<f64>;
fn gradient(&self, params: &Self::Param) -> Result<Self::Gradient, argmin::core::Error> {
let clamped = self.clamp(params);
let n = clamped.len();
let mut grad = vec![0.0; n];
for i in 0..n {
let h = 1e-7 * (1.0 + clamped[i].abs());
let mut p_plus = clamped.clone();
let mut p_minus = clamped.clone();
p_plus[i] = (clamped[i] + h).min(self.bounds[i].1);
p_minus[i] = (clamped[i] - h).max(self.bounds[i].0);
let actual_2h = p_plus[i] - p_minus[i];
if actual_2h > 0.0 {
let fp = self.eval_nll(&p_plus);
let fm = self.eval_nll(&p_minus);
grad[i] = (fp - fm) / actual_2h;
}
}
Ok(grad)
}
}
pub fn fit_mle(
model: &mut dyn DiffusionModel,
sample: ArrayView1<f64>,
dt: f64,
density: DensityApprox,
param_bounds: Option<Vec<(f64, f64)>>,
) -> MleResult {
let bounds = param_bounds.unwrap_or_else(|| model.param_bounds());
let n_params = model.num_params();
let n_transitions = sample.len() - 1;
assert!(
sample.len() >= 2,
"sample must contain at least 2 observations"
);
assert_eq!(
bounds.len(),
n_params,
"bounds length must match number of parameters"
);
let x0 = model.params();
let best_params = if n_params == 0 {
x0.to_vec()
} else {
let init: Vec<f64> = x0.to_vec();
{
let problem = MleProblem {
model: Mutex::new(&mut *model),
sample,
dt,
density,
bounds: bounds.clone(),
};
let linesearch = MoreThuenteLineSearch::new();
let solver = LBFGS::new(linesearch, 10);
let result = Executor::new(problem, solver)
.configure(|state| state.param(init.clone()).max_iters(200))
.run();
match result {
Ok(res) => res.state.get_best_param().cloned().unwrap_or(init),
Err(_) => init,
}
}
};
let clamped: Vec<f64> = best_params
.iter()
.enumerate()
.map(|(i, &x)| x.clamp(bounds[i].0, bounds[i].1))
.collect();
model.set_params(&clamped);
let mut log_lik = 0.0;
for i in 1..sample.len() {
let t0 = (i - 1) as f64 * dt;
let d = density.density(model, sample[i - 1], sample[i], t0, dt);
log_lik += d.max(1e-30).ln();
}
let k = n_params as f64;
let n = n_transitions as f64;
let aic = 2.0 * k - 2.0 * log_lik;
let bic = k * n.ln() - 2.0 * log_lik;
MleResult {
params: Array1::from_vec(clamped),
param_names: model.param_names().into_iter().map(String::from).collect(),
log_likelihood: log_lik,
sample_size: n_transitions,
aic,
bic,
}
}