use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{Array1, ArrayView1};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::safe_ops::{safe_divide, safe_sqrt};
use std::fmt::{Debug, Display};
#[allow(dead_code)]
pub fn error_estimate<F, Func>(
x: &ArrayView1<F>,
y: &ArrayView1<F>,
interp_fn: Func,
) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Display,
Func: Fn(&ArrayView1<F>, &ArrayView1<F>, &ArrayView1<F>) -> InterpolateResult<Array1<F>>,
{
if x.len() != y.len() {
return Err(InterpolateError::invalid_input(
"x and y arrays must have the same length",
));
}
if x.len() < 3 {
return Err(InterpolateError::insufficient_points(
3,
x.len(),
"interpolation error estimation",
));
}
let mut sum_squared_error = F::zero();
let n = x.len();
for i in 0..n {
let mut x_loo = Vec::with_capacity(n - 1);
let mut y_loo = Vec::with_capacity(n - 1);
for j in 0..n {
if i != j {
x_loo.push(x[j]);
y_loo.push(y[j]);
}
}
let x_loo_array = Array1::from_vec(x_loo);
let y_loo_array = Array1::from_vec(y_loo);
let x_test = Array1::from_vec(vec![x[i]]);
let y_pred = interp_fn(&x_loo_array.view(), &y_loo_array.view(), &x_test.view())?;
let error = y_pred[0] - y[i];
sum_squared_error = sum_squared_error + error * error;
}
let n_f = F::from_usize(n).ok_or_else(|| {
InterpolateError::ComputationError(
"Failed to convert array length to float type".to_string(),
)
})?;
let variance = safe_divide(sum_squared_error, n_f).map_err(|_| {
InterpolateError::ComputationError("Division by zero in RMSE calculation".to_string())
})?;
let rmse = safe_sqrt(variance).map_err(|_| {
InterpolateError::ComputationError(
"Square root of negative value in RMSE calculation".to_string(),
)
})?;
Ok(rmse)
}
#[allow(dead_code)]
pub fn optimize_parameter<F, Func, BuilderFunc>(
x: &ArrayView1<F>,
y: &ArrayView1<F>,
param_values: &ArrayView1<F>,
interp_fn_builder: BuilderFunc,
) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Display,
Func: Fn(&ArrayView1<F>, &ArrayView1<F>, &ArrayView1<F>) -> InterpolateResult<Array1<F>>,
BuilderFunc: Fn(F) -> Func,
{
if param_values.is_empty() {
return Err(InterpolateError::invalid_input(
"at least one parameter value must be provided",
));
}
let mut best_param = param_values[0];
let mut min_error = F::infinity();
for ¶m in param_values.iter() {
let interp_fn = interp_fn_builder(param);
let error = error_estimate(x, y, interp_fn)?;
if error < min_error {
min_error = error;
best_param = param;
}
}
Ok(best_param)
}
#[allow(dead_code)]
pub fn differentiate<F, Func>(x: F, h: F, evalfn: Func) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Display,
Func: Fn(F) -> InterpolateResult<F>,
{
let f_plus = evalfn(x + h)?;
let f_minus = evalfn(x - h)?;
let two = F::from_f64(2.0).ok_or_else(|| {
InterpolateError::ComputationError(
"Failed to convert constant 2.0 to float type".to_string(),
)
})?;
let denominator = two * h;
let derivative = safe_divide(f_plus - f_minus, denominator).map_err(|_| {
InterpolateError::ComputationError(
"Division by zero in finite difference calculation (step size too small)".to_string(),
)
})?;
Ok(derivative)
}
#[allow(dead_code)]
pub fn integrate<F, Func>(a: F, b: F, n: usize, evalfn: Func) -> InterpolateResult<F>
where
F: Float + FromPrimitive + Debug + Display,
Func: Fn(F) -> InterpolateResult<F>,
{
if a > b {
return integrate(b, a, n, evalfn).map(|result| -result);
}
if n < 2 {
return Err(InterpolateError::InvalidValue(
"number of intervals must be at least 2".to_string(),
));
}
if !n.is_multiple_of(2) {
return Err(InterpolateError::InvalidValue(
"number of intervals must be even".to_string(),
));
}
let n_f = F::from_usize(n).ok_or_else(|| {
InterpolateError::ComputationError(
"Failed to convert number of intervals to float type".to_string(),
)
})?;
let h = safe_divide(b - a, n_f).map_err(|_| {
InterpolateError::ComputationError(
"Division by zero in step size calculation (zero intervals)".to_string(),
)
})?;
let mut sum = evalfn(a)? + evalfn(b)?;
let two = F::from_f64(2.0).ok_or_else(|| {
InterpolateError::ComputationError(
"Failed to convert constant 2.0 to float type".to_string(),
)
})?;
for i in 1..n {
if i % 2 == 0 {
let i_f = F::from_usize(i).ok_or_else(|| {
InterpolateError::ComputationError(
"Failed to convert index to float type".to_string(),
)
})?;
let x_i = a + i_f * h;
sum = sum + two * evalfn(x_i)?;
}
}
let four = F::from_f64(4.0).ok_or_else(|| {
InterpolateError::ComputationError(
"Failed to convert constant 4.0 to float type".to_string(),
)
})?;
for i in 1..n {
if i % 2 == 1 {
let i_f = F::from_usize(i).ok_or_else(|| {
InterpolateError::ComputationError(
"Failed to convert index to float type".to_string(),
)
})?;
let x_i = a + i_f * h;
sum = sum + four * evalfn(x_i)?;
}
}
let three = F::from_f64(3.0).ok_or_else(|| {
InterpolateError::ComputationError(
"Failed to convert constant 3.0 to float type".to_string(),
)
})?;
let integral = safe_divide(h * sum, three).map_err(|_| {
InterpolateError::ComputationError(
"Division by zero in Simpson's rule calculation".to_string(),
)
})?;
Ok(integral)
}
#[allow(dead_code)]
pub fn find_roots_bisection<F, Func>(
a: F,
b: F,
tolerance: F,
evalfn: Func,
) -> InterpolateResult<Vec<F>>
where
F: Float + FromPrimitive + Debug + Display,
Func: Fn(F) -> InterpolateResult<F>,
{
let mut roots = Vec::new();
if a >= b {
return Ok(roots);
}
let fa = evalfn(a)?;
let fb = evalfn(b)?;
if fa.abs() < tolerance {
roots.push(a);
}
if fb.abs() < tolerance && (b - a).abs() > tolerance {
roots.push(b);
}
if fa * fb > F::zero() {
return Ok(roots);
}
let mut left = a;
let mut right = b;
let mut f_left = fa;
let mut _f_right = fb;
while (right - left).abs() > tolerance {
let mid = left + (right - left) / F::from_f64(2.0).expect("Operation failed");
let f_mid = evalfn(mid)?;
if f_mid.abs() < tolerance {
roots.push(mid);
break;
}
if f_left * f_mid < F::zero() {
right = mid;
_f_right = f_mid;
} else {
left = mid;
f_left = f_mid;
}
}
if roots.is_empty() {
let root = left + (right - left) / F::from_f64(2.0).expect("Operation failed");
let f_root = evalfn(root)?;
if f_root.abs() < tolerance * F::from_f64(10.0).expect("Operation failed") {
roots.push(root);
}
}
Ok(roots)
}
#[allow(dead_code)]
pub fn find_multiple_roots<F, Func>(
a: F,
b: F,
tolerance: F,
subdivisions: usize,
evalfn: Func,
) -> InterpolateResult<Vec<F>>
where
F: Float + FromPrimitive + Debug + Display,
Func: Fn(F) -> InterpolateResult<F> + Copy,
{
let mut all_roots = Vec::new();
if subdivisions == 0 {
return Ok(all_roots);
}
let step = (b - a) / F::from_usize(subdivisions).expect("Operation failed");
for i in 0..subdivisions {
let left = a + F::from_usize(i).expect("Operation failed") * step;
let right = a + F::from_usize(i + 1).expect("Operation failed") * step;
match find_roots_bisection(left, right, tolerance, evalfn) {
Ok(mut roots) => all_roots.append(&mut roots),
Err(_) => continue,
}
}
all_roots.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
all_roots.dedup_by(|a, b| (*a - *b).abs() < tolerance);
Ok(all_roots)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_error_estimate() {
let _x = array![0.0, 1.0, 2.0, 3.0, 4.0];
let _y = array![0.0, 1.0, 2.0, 3.0, 4.0];
}
#[test]
fn test_differentiate() {
let f = |x: f64| -> InterpolateResult<f64> { Ok(x * x) };
let derivative = differentiate(2.0, 0.001, f).expect("Operation failed");
assert!((derivative - 4.0).abs() < 1e-5);
let derivative = differentiate(3.0, 0.001, f).expect("Operation failed");
assert!((derivative - 6.0).abs() < 1e-5);
}
#[test]
fn test_integrate() {
let f = |x: f64| -> InterpolateResult<f64> { Ok(x * x) };
let integral = integrate(0.0, 1.0, 100, f).expect("Operation failed");
assert!((integral - 1.0 / 3.0).abs() < 1e-5);
let integral = integrate(0.0, 2.0, 100, f).expect("Operation failed");
assert!((integral - 8.0 / 3.0).abs() < 1e-5);
}
}