use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
use scirs2_symbolic::eml::eval::{eval_real, EvalCtx};
use scirs2_symbolic::eml::LoweredOp;
use std::sync::Arc;
#[derive(Debug)]
pub enum InitFromFormulaError {
EvalError(String),
EmptyGrid,
SingularDesignMatrix,
ShapeError(String),
}
impl std::fmt::Display for InitFromFormulaError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InitFromFormulaError::EvalError(msg) => {
write!(f, "formula evaluation error: {msg}")
}
InitFromFormulaError::EmptyGrid => {
write!(f, "sample_grid has zero rows or zero columns")
}
InitFromFormulaError::SingularDesignMatrix => {
write!(
f,
"design matrix is singular — cannot compute least-squares weights"
)
}
InitFromFormulaError::ShapeError(msg) => {
write!(f, "internal shape error: {msg}")
}
}
}
}
impl std::error::Error for InitFromFormulaError {}
pub fn init_weights_from_formula(
formula: &Arc<LoweredOp>,
sample_grid: ArrayView2<'_, f64>,
) -> Result<(Array2<f64>, Array1<f64>), InitFromFormulaError> {
let n_samples = sample_grid.shape()[0];
let n_inputs = sample_grid.shape()[1];
if n_samples == 0 || n_inputs == 0 {
return Err(InitFromFormulaError::EmptyGrid);
}
let mut y_vec: Vec<f64> = Vec::with_capacity(n_samples);
for i in 0..n_samples {
let row: Vec<f64> = (0..n_inputs).map(|j| sample_grid[(i, j)]).collect();
let ctx = EvalCtx::new(&row);
let val = eval_real(formula, &ctx)
.map_err(|e| InitFromFormulaError::EvalError(format!("{e:?}")))?;
y_vec.push(val);
}
let n_aug = n_inputs + 1;
let mut x_data: Vec<f64> = Vec::with_capacity(n_samples * n_aug);
for i in 0..n_samples {
for j in 0..n_inputs {
x_data.push(sample_grid[(i, j)]);
}
x_data.push(1.0); }
let x_aug = Array2::from_shape_vec((n_samples, n_aug), x_data)
.map_err(|e| InitFromFormulaError::ShapeError(e.to_string()))?;
let y_arr = Array1::from_vec(y_vec);
let lstsq_result = scirs2_linalg::lstsq(&x_aug.view(), &y_arr.view(), None)
.map_err(|_e| InitFromFormulaError::SingularDesignMatrix)?;
let theta: Array1<f64> = lstsq_result.x;
let w_slice: Vec<f64> = (0..n_inputs).map(|j| theta[j]).collect();
let b_val = theta[n_inputs];
let w = Array2::from_shape_vec((1, n_inputs), w_slice)
.map_err(|e| InitFromFormulaError::ShapeError(e.to_string()))?;
let b = Array1::from_vec(vec![b_val]);
Ok((w, b))
}