use crate::faer_ndarray::FaerCholesky;
use faer::Side;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SindyPenaltyKind {
Ridge,
Scad,
Mcp,
}
#[inline]
fn scad_grad(abs_xi: f64, lam: f64, a: f64) -> f64 {
if abs_xi <= lam {
lam
} else if abs_xi <= a * lam {
((a * lam - abs_xi) / (a - 1.0)).max(0.0)
} else {
0.0
}
}
#[inline]
fn mcp_grad(abs_xi: f64, lam: f64, gamma: f64) -> f64 {
if abs_xi <= gamma * lam {
(lam - abs_xi / gamma).max(0.0)
} else {
0.0
}
}
#[inline]
fn lqa_weight(kind: SindyPenaltyKind, abs_xi: f64, lam: f64, a: f64, eps: f64) -> f64 {
match kind {
SindyPenaltyKind::Ridge => lam,
SindyPenaltyKind::Scad => scad_grad(abs_xi, lam, a) / abs_xi.max(eps),
SindyPenaltyKind::Mcp => mcp_grad(abs_xi, lam, a) / abs_xi.max(eps),
}
}
#[derive(Debug, Clone)]
pub struct SindyStlsqResult {
pub coefficients: Array2<f64>,
pub rounds_used: usize,
pub converged: bool,
}
pub fn sindy_stlsq_solve(
theta: ArrayView2<'_, f64>,
dz_dt: ArrayView2<'_, f64>,
tol: f64,
max_rounds: usize,
lam: f64,
kind: SindyPenaltyKind,
concave_a: f64,
) -> Result<SindyStlsqResult, String> {
let (n, p) = theta.dim();
let (n_dz, d) = dz_dt.dim();
if n == 0 || p == 0 || d == 0 {
return Err(format!(
"sindy_stlsq_solve requires non-empty theta and dz_dt; got theta=({n},{p}), dz_dt=({n_dz},{d})"
));
}
if n_dz != n {
return Err(format!(
"sindy_stlsq_solve requires theta.nrows == dz_dt.nrows; got {n} vs {n_dz}"
));
}
if !(tol.is_finite() && tol >= 0.0) {
return Err(format!(
"sindy_stlsq_solve requires finite tol >= 0, got {tol}"
));
}
if max_rounds == 0 {
return Err("sindy_stlsq_solve requires max_rounds >= 1".to_string());
}
if !(lam.is_finite() && lam >= 0.0) {
return Err(format!(
"sindy_stlsq_solve requires finite lam >= 0, got {lam}"
));
}
if matches!(kind, SindyPenaltyKind::Scad) && !(concave_a.is_finite() && concave_a > 2.0) {
return Err(format!(
"sindy_stlsq_solve SCAD requires concave_a > 2, got {concave_a}"
));
}
if matches!(kind, SindyPenaltyKind::Mcp) && !(concave_a.is_finite() && concave_a > 1.0) {
return Err(format!(
"sindy_stlsq_solve MCP requires concave_a > 1, got {concave_a}"
));
}
let lam_seed = if lam > 0.0 { lam } else { 1.0e-12 };
let mut xi = ridge_full_solve(theta, dz_dt, lam_seed)?;
let lqa_eps = (tol * 1.0e-2).max(1.0e-10);
let mut active = vec![true; p];
let mut prev_active = vec![false; p];
let mut rounds_used = 0usize;
let mut converged = false;
for round in 0..max_rounds {
rounds_used = round + 1;
for j in 0..p {
let mut keep = false;
for c in 0..d {
if active[j] && xi[(j, c)].abs() >= tol {
keep = true;
} else {
xi[(j, c)] = 0.0;
}
}
active[j] = keep;
}
if active.iter().all(|x| !*x) {
converged = prev_active == active;
if converged {
break;
}
prev_active.copy_from_slice(&active);
continue;
}
if active == prev_active {
converged = true;
break;
}
prev_active.copy_from_slice(&active);
let active_idx: Vec<usize> = active
.iter()
.enumerate()
.filter_map(|(j, &on)| if on { Some(j) } else { None })
.collect();
let p_act = active_idx.len();
let mut theta_act = Array2::<f64>::zeros((n, p_act));
for (k, &j) in active_idx.iter().enumerate() {
theta_act.column_mut(k).assign(&theta.column(j));
}
let mut diag = Array1::<f64>::zeros(p_act);
for (k, &j) in active_idx.iter().enumerate() {
let mut mag = 0.0_f64;
for c in 0..d {
let v = xi[(j, c)].abs();
if v > mag {
mag = v;
}
}
diag[k] = lqa_weight(kind, mag, lam, concave_a, lqa_eps);
}
let xi_act = ridge_diag_solve(theta_act.view(), dz_dt, diag.view())?;
xi.fill(0.0);
for (k, &j) in active_idx.iter().enumerate() {
for c in 0..d {
xi[(j, c)] = xi_act[(k, c)];
}
}
}
for j in 0..p {
for c in 0..d {
if xi[(j, c)].abs() < tol {
xi[(j, c)] = 0.0;
}
}
}
Ok(SindyStlsqResult {
coefficients: xi,
rounds_used,
converged,
})
}
fn ridge_full_solve(
theta: ArrayView2<'_, f64>,
dz_dt: ArrayView2<'_, f64>,
lam: f64,
) -> Result<Array2<f64>, String> {
let p = theta.ncols();
let diag = Array1::<f64>::from_elem(p, lam.max(1.0e-12));
ridge_diag_solve(theta, dz_dt, diag.view())
}
fn ridge_diag_solve(
theta: ArrayView2<'_, f64>,
dz_dt: ArrayView2<'_, f64>,
diag: ArrayView1<'_, f64>,
) -> Result<Array2<f64>, String> {
let p = theta.ncols();
let mut gram = theta.t().dot(&theta);
for i in 0..p {
let d = diag[i].max(1.0e-12);
gram[(i, i)] += d;
}
let rhs = theta.t().dot(&dz_dt);
let chol = gram
.cholesky(Side::Lower)
.map_err(|err| format!("sindy_stlsq_solve ridge Cholesky failed: {err}"))?;
let mut sol = rhs;
chol.solve_mat_in_place(&mut sol);
Ok(sol)
}
pub fn sindy_stlsq_auto_lam(
theta: ArrayView2<'_, f64>,
dz_dt: ArrayView2<'_, f64>,
tol: f64,
max_rounds: usize,
kind: SindyPenaltyKind,
concave_a: f64,
) -> Result<(f64, SindyStlsqResult), String> {
let (n, _p) = theta.dim();
if n == 0 {
return Err("sindy_stlsq_auto_lam requires n > 0".to_string());
}
let grid: Vec<f64> = (0..9)
.map(|i| tol.max(1.0e-6) * 10f64.powf((i as f64) - 4.0))
.collect();
let mut best: Option<(f64, f64, SindyStlsqResult)> = None;
for &lam in &grid {
let res = sindy_stlsq_solve(theta, dz_dt, tol, max_rounds, lam, kind, concave_a)?;
let bic = bic_score(theta, dz_dt, &res.coefficients);
let pick = match &best {
None => true,
Some((_, b, _)) => bic < *b,
};
if pick {
best = Some((lam, bic, res));
}
}
let (lam, _bic, res) =
best.ok_or_else(|| "sindy_stlsq_auto_lam: empty grid produced no candidates".to_string())?;
Ok((lam, res))
}
fn bic_score(theta: ArrayView2<'_, f64>, dz_dt: ArrayView2<'_, f64>, xi: &Array2<f64>) -> f64 {
let (n, _p) = theta.dim();
let d = dz_dt.ncols();
let resid = &theta.dot(xi) - &dz_dt;
let mut bic = 0.0_f64;
let n_f = n as f64;
for c in 0..d {
let r = resid.index_axis(Axis(1), c);
let rss = r.iter().map(|&x| x * x).sum::<f64>().max(1.0e-300);
let k_active = xi.column(c).iter().filter(|&&v| v != 0.0).count() as f64;
bic += n_f * (rss / n_f).ln() + k_active * n_f.ln();
}
bic
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
#[test]
fn stlsq_recovers_pure_linear_system() {
let n = 200;
let mut rng_state = 0u64;
let mut rand = || {
rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
((rng_state >> 33) as f64) / (u32::MAX as f64) - 0.5
};
let mut theta = Array2::<f64>::zeros((n, 3));
let mut dz = Array2::<f64>::zeros((n, 2));
for i in 0..n {
let x = rand();
let y = rand();
theta[(i, 0)] = 1.0;
theta[(i, 1)] = x;
theta[(i, 2)] = y;
dz[(i, 0)] = 2.0 * x - 3.0 * y;
dz[(i, 1)] = -x;
}
let res = sindy_stlsq_solve(
theta.view(),
dz.view(),
0.05,
10,
1.0e-3,
SindyPenaltyKind::Ridge,
3.7,
)
.expect("stlsq must succeed");
assert!(res.converged);
assert!((res.coefficients[(1, 0)] - 2.0).abs() < 0.05);
assert!((res.coefficients[(2, 0)] + 3.0).abs() < 0.05);
assert!((res.coefficients[(1, 1)] + 1.0).abs() < 0.05);
assert!(res.coefficients[(0, 0)].abs() < 1.0e-6);
assert!(res.coefficients[(2, 1)].abs() < 1.0e-6);
}
}