use scirs2_core::ndarray::ArrayView1;
use scirs2_symbolic::{
cas::{grad_canonical, solve_system, SystemSolveError},
eml::{eval_real, EvalCtx, LoweredOp},
};
#[derive(Debug, Clone)]
pub enum DeriveError {
EmptyParams,
ZeroSamples,
VarIdCollision {
id: usize,
},
InternalError(String),
}
impl std::fmt::Display for DeriveError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DeriveError::EmptyParams => write!(f, "params slice must not be empty"),
DeriveError::ZeroSamples => write!(f, "n_samples must be ≥ 1"),
DeriveError::VarIdCollision { id } => write!(
f,
"Var({id}) is used by both a data slot and a parameter — ids must not overlap"
),
DeriveError::InternalError(msg) => write!(f, "internal error in mle::derive: {msg}"),
}
}
}
impl std::error::Error for DeriveError {}
#[derive(Debug, Clone)]
pub enum FitError {
DataLengthMismatch {
expected: usize,
got: usize,
},
NumericalFailed(String),
NoEstimator,
}
impl std::fmt::Display for FitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FitError::DataLengthMismatch { expected, got } => write!(
f,
"data length mismatch: estimator built for {expected} samples, got {got}"
),
FitError::NumericalFailed(msg) => write!(f, "Newton MLE failed: {msg}"),
FitError::NoEstimator => write!(f, "no estimator path available"),
}
}
}
impl std::error::Error for FitError {}
#[derive(Debug)]
pub struct Estimator {
pub closed_form: Option<Vec<LoweredOp>>,
pub score_equations: Vec<LoweredOp>,
pub falls_back_to_numeric: bool,
pub(crate) params: Vec<usize>,
pub(crate) data_var: usize,
pub(crate) n_samples: usize,
}
pub fn derive(
pdf: &LoweredOp,
params: &[usize],
data_var: usize,
n_samples: usize,
) -> Result<Estimator, DeriveError> {
if params.is_empty() {
return Err(DeriveError::EmptyParams);
}
if n_samples == 0 {
return Err(DeriveError::ZeroSamples);
}
for &p in params {
if p >= data_var && p < data_var + n_samples {
return Err(DeriveError::VarIdCollision { id: p });
}
}
let log_terms: Vec<LoweredOp> = (0..n_samples)
.map(|i| {
let pdf_i = substitute_var(pdf, data_var, data_var + i);
LoweredOp::Ln(Box::new(pdf_i))
})
.collect();
let log_likelihood = balanced_sum(&log_terms);
let score_equations: Vec<LoweredOp> = params
.iter()
.map(|&p| grad_canonical(&log_likelihood, p))
.collect();
let eqs: Vec<(LoweredOp, LoweredOp)> = score_equations
.iter()
.map(|s| (s.clone(), LoweredOp::Const(0.0)))
.collect();
let (closed_form, falls_back_to_numeric) = match solve_system(&eqs, params) {
Ok(result) if !result.solutions.is_empty() => {
let sol = &result.solutions[0];
let cf: Vec<LoweredOp> = params
.iter()
.map(|&p| sol.get(&p).cloned().unwrap_or(LoweredOp::Const(f64::NAN)))
.collect();
(Some(cf), false)
}
Ok(_) => (None, true), Err(SystemSolveError::CannotEliminateTranscendental) => (None, true),
Err(SystemSolveError::EmptyVars | SystemSolveError::EmptyEquations) => (None, true),
Err(SystemSolveError::InternalError(msg)) => {
return Err(DeriveError::InternalError(msg));
}
};
Ok(Estimator {
closed_form,
score_equations,
falls_back_to_numeric,
params: params.to_vec(),
data_var,
n_samples,
})
}
impl Estimator {
pub fn fit(&self, data: ArrayView1<f64>) -> Result<Vec<f64>, FitError> {
if data.len() != self.n_samples {
return Err(FitError::DataLengthMismatch {
expected: self.n_samples,
got: data.len(),
});
}
if let Some(ref cf) = self.closed_form {
self.fit_closed_form(data, cf)
} else {
self.fit_newton(data)
}
}
fn fit_closed_form(
&self,
data: ArrayView1<f64>,
cf: &[LoweredOp],
) -> Result<Vec<f64>, FitError> {
let max_data_id = self.data_var + self.n_samples;
let max_param_id = self.params.iter().copied().max().unwrap_or(0) + 1;
let binding_len = max_data_id.max(max_param_id);
let mut bindings = vec![0.0f64; binding_len];
for (i, &xi) in data.iter().enumerate() {
let slot = self.data_var + i;
if slot < bindings.len() {
bindings[slot] = xi;
}
}
let ctx = EvalCtx::new(&bindings);
let mut estimates = Vec::with_capacity(cf.len());
for expr in cf {
let v = eval_real(expr, &ctx)
.map_err(|e| FitError::NumericalFailed(format!("closed-form eval: {e}")))?;
estimates.push(v);
}
Ok(estimates)
}
fn fit_newton(&self, data: ArrayView1<f64>) -> Result<Vec<f64>, FitError> {
let n = self.params.len();
let max_data_id = self.data_var + self.n_samples;
let max_param_id = self.params.iter().copied().max().unwrap_or(0) + 1;
let binding_len = max_data_id.max(max_param_id);
let mut base_bindings = vec![0.0f64; binding_len];
for (i, &xi) in data.iter().enumerate() {
let slot = self.data_var + i;
if slot < base_bindings.len() {
base_bindings[slot] = xi;
}
}
let mut theta: Vec<f64> = vec![0.5; n];
let max_iter = 200_usize;
let eps = 1e-5_f64; let tol = 1e-8_f64;
for _iter in 0..max_iter {
let mut bindings = base_bindings.clone();
for (k, &p) in self.params.iter().enumerate() {
if p < bindings.len() {
bindings[p] = theta[k];
}
}
let ctx = EvalCtx::new(&bindings);
let mut score = Vec::with_capacity(n);
for s in &self.score_equations {
let sv = eval_real(s, &ctx)
.map_err(|e| FitError::NumericalFailed(format!("score eval: {e}")))?;
if !sv.is_finite() {
return Err(FitError::NumericalFailed(format!(
"non-finite score at iteration {_iter}: {sv}"
)));
}
score.push(sv);
}
let norm: f64 = score.iter().map(|s| s * s).sum::<f64>().sqrt();
if norm < tol {
break;
}
let mut hessian = vec![vec![0.0f64; n]; n];
for j in 0..n {
let param_id = self.params[j];
let mut bindings_plus = base_bindings.clone();
for (k, &p) in self.params.iter().enumerate() {
if p < bindings_plus.len() {
bindings_plus[p] = if k == j { theta[k] + eps } else { theta[k] };
}
}
let mut bindings_minus = base_bindings.clone();
for (k, &p) in self.params.iter().enumerate() {
if p < bindings_minus.len() {
bindings_minus[p] = if k == j { theta[k] - eps } else { theta[k] };
}
}
let _ = param_id;
let ctx_plus = EvalCtx::new(&bindings_plus);
let ctx_minus = EvalCtx::new(&bindings_minus);
for i in 0..n {
let sp = eval_real(&self.score_equations[i], &ctx_plus).unwrap_or(f64::NAN);
let sm = eval_real(&self.score_equations[i], &ctx_minus).unwrap_or(f64::NAN);
if !sp.is_finite() || !sm.is_finite() {
return Err(FitError::NumericalFailed(format!(
"non-finite finite-difference at H[{i}][{j}]"
)));
}
hessian[i][j] = (sp - sm) / (2.0 * eps);
}
}
let delta = solve_linear_system_f64(&hessian, &score)
.map_err(|e| FitError::NumericalFailed(format!("Hessian solve: {e}")))?;
for k in 0..n {
theta[k] -= delta[k];
}
}
Ok(theta)
}
}
fn substitute_var(op: &LoweredOp, old_id: usize, new_id: usize) -> LoweredOp {
if old_id == new_id {
return op.clone();
}
enum Frame<'a> {
Enter(&'a LoweredOp),
Build(&'a LoweredOp),
}
let mut frame_stack: Vec<Frame> = vec![Frame::Enter(op)];
let mut result_stack: Vec<LoweredOp> = Vec::new();
while let Some(frame) = frame_stack.pop() {
match frame {
Frame::Enter(node) => match node {
LoweredOp::Const(_) | LoweredOp::Var(_) => {
frame_stack.push(Frame::Build(node));
}
LoweredOp::Add(a, b)
| LoweredOp::Sub(a, b)
| LoweredOp::Mul(a, b)
| LoweredOp::Div(a, b)
| LoweredOp::Pow(a, b) => {
frame_stack.push(Frame::Build(node));
frame_stack.push(Frame::Enter(b));
frame_stack.push(Frame::Enter(a));
}
LoweredOp::Neg(c)
| LoweredOp::Exp(c)
| LoweredOp::Ln(c)
| LoweredOp::Sin(c)
| LoweredOp::Cos(c)
| LoweredOp::Tan(c)
| LoweredOp::Sinh(c)
| LoweredOp::Cosh(c)
| LoweredOp::Tanh(c)
| LoweredOp::Arcsin(c)
| LoweredOp::Arccos(c)
| LoweredOp::Arctan(c)
| LoweredOp::Arcsinh(c)
| LoweredOp::Arccosh(c)
| LoweredOp::Arctanh(c)
| LoweredOp::Sqrt(c)
| LoweredOp::Abs(c) => {
frame_stack.push(Frame::Build(node));
frame_stack.push(Frame::Enter(c));
}
},
Frame::Build(node) => {
let built = match node {
LoweredOp::Const(c) => LoweredOp::Const(*c),
LoweredOp::Var(v) => {
if *v == old_id {
LoweredOp::Var(new_id)
} else {
LoweredOp::Var(*v)
}
}
LoweredOp::Add(_, _) => {
let b = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
let a = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Add(Box::new(a), Box::new(b))
}
LoweredOp::Sub(_, _) => {
let b = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
let a = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Sub(Box::new(a), Box::new(b))
}
LoweredOp::Mul(_, _) => {
let b = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
let a = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Mul(Box::new(a), Box::new(b))
}
LoweredOp::Div(_, _) => {
let b = result_stack.pop().unwrap_or(LoweredOp::Const(1.0));
let a = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Div(Box::new(a), Box::new(b))
}
LoweredOp::Pow(_, _) => {
let b = result_stack.pop().unwrap_or(LoweredOp::Const(1.0));
let a = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Pow(Box::new(a), Box::new(b))
}
LoweredOp::Neg(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Neg(Box::new(c))
}
LoweredOp::Exp(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Exp(Box::new(c))
}
LoweredOp::Ln(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Ln(Box::new(c))
}
LoweredOp::Sin(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Sin(Box::new(c))
}
LoweredOp::Cos(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Cos(Box::new(c))
}
LoweredOp::Tan(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Tan(Box::new(c))
}
LoweredOp::Sinh(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Sinh(Box::new(c))
}
LoweredOp::Cosh(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Cosh(Box::new(c))
}
LoweredOp::Tanh(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Tanh(Box::new(c))
}
LoweredOp::Arcsin(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Arcsin(Box::new(c))
}
LoweredOp::Arccos(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Arccos(Box::new(c))
}
LoweredOp::Arctan(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Arctan(Box::new(c))
}
LoweredOp::Arcsinh(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Arcsinh(Box::new(c))
}
LoweredOp::Arccosh(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Arccosh(Box::new(c))
}
LoweredOp::Arctanh(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Arctanh(Box::new(c))
}
LoweredOp::Sqrt(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Sqrt(Box::new(c))
}
LoweredOp::Abs(_) => {
let c = result_stack.pop().unwrap_or(LoweredOp::Const(0.0));
LoweredOp::Abs(Box::new(c))
}
};
result_stack.push(built);
}
}
}
result_stack.pop().unwrap_or(LoweredOp::Const(0.0))
}
fn balanced_sum(ops: &[LoweredOp]) -> LoweredOp {
match ops.len() {
0 => LoweredOp::Const(0.0),
1 => ops[0].clone(),
_ => {
let mid = ops.len() / 2;
let left = balanced_sum(&ops[..mid]);
let right = balanced_sum(&ops[mid..]);
LoweredOp::Add(Box::new(left), Box::new(right))
}
}
}
fn solve_linear_system_f64(a: &[Vec<f64>], b: &[f64]) -> Result<Vec<f64>, String> {
let n = b.len();
if a.len() != n {
return Err(format!("matrix rows ({}) != rhs length ({})", a.len(), n));
}
let mut mat: Vec<Vec<f64>> = a
.iter()
.zip(b.iter())
.map(|(row, &bi)| {
let mut r = row.clone();
r.push(bi);
r
})
.collect();
for col in 0..n {
let pivot = (col..n)
.max_by(|&i, &j| {
mat[i][col]
.abs()
.partial_cmp(&mat[j][col].abs())
.unwrap_or(std::cmp::Ordering::Equal)
})
.ok_or_else(|| "empty pivot range".to_string())?;
mat.swap(col, pivot);
let diag = mat[col][col];
if diag.abs() < 1e-14 {
return Err(format!("singular matrix at column {col}"));
}
let inv_diag = 1.0 / diag;
for k in col..=n {
mat[col][k] *= inv_diag;
}
for row in 0..n {
if row == col {
continue;
}
let factor = mat[row][col];
if factor.abs() < 1e-15 {
continue;
}
for k in col..=n {
let v = mat[col][k];
mat[row][k] -= factor * v;
}
}
}
Ok(mat.iter().map(|row| row[n]).collect())
}