use crate::DType;
use numr::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct LeastSquaresOptions {
pub max_iter: usize,
pub f_tol: f64,
pub x_tol: f64,
pub g_tol: f64,
pub eps: f64,
}
impl Default for LeastSquaresOptions {
fn default() -> Self {
Self {
max_iter: 100,
f_tol: 1e-8,
x_tol: 1e-8,
g_tol: 1e-8,
eps: 1e-8,
}
}
}
#[derive(Debug, Clone)]
pub struct LeastSquaresTensorResult<R: Runtime<DType = DType>> {
pub x: Tensor<R>,
pub residuals: Tensor<R>,
pub cost: f64,
pub iterations: usize,
pub nfev: usize,
pub converged: bool,
}
pub trait LeastSquaresAlgorithms<R: Runtime<DType = DType>> {
fn leastsq<F>(
&self,
f: F,
x0: &Tensor<R>,
options: &LeastSquaresOptions,
) -> Result<LeastSquaresTensorResult<R>>
where
F: Fn(&Tensor<R>) -> Result<Tensor<R>>;
fn least_squares<F>(
&self,
f: F,
x0: &Tensor<R>,
bounds: Option<(&Tensor<R>, &Tensor<R>)>,
options: &LeastSquaresOptions,
) -> Result<LeastSquaresTensorResult<R>>
where
F: Fn(&Tensor<R>) -> Result<Tensor<R>>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_options() {
let opts = LeastSquaresOptions::default();
assert_eq!(opts.max_iter, 100);
assert!((opts.f_tol - 1e-8).abs() < 1e-12);
}
}