use nalgebra::{DMatrix, DVector};
use std::error::Error; use std::f64;
use std::fmt;
#[derive(Debug)]
struct FittingError(String);
impl fmt::Display for FittingError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Fitting Error: {}", self.0)
}
}
impl Error for FittingError {}
#[derive(Debug, Default)]
#[allow(dead_code)]
pub struct QuadraticFitResult {
pub peak_x: f64, pub a: f64, pub b: f64, pub c: f64, }
pub fn fit_quadratic_least_squares(
x_coords: &[f64],
y_values: &[f64],
) -> Result<QuadraticFitResult, Box<dyn Error>> {
let n = x_coords.len();
let epsilon = 1e-9;
if n < 3 || n != y_values.len() {
return Err(Box::new(FittingError("Input vectors must be of the same size and contain at least 3 points for least squares.".to_string())));
}
let x_center = x_coords[n / 2];
let shifted_x_coords: Vec<f64> = x_coords.iter().map(|&x| x - x_center).collect();
let s0 = n as f64;
let mut s1 = 0.0;
let mut s2 = 0.0;
let mut s3 = 0.0;
let mut s4 = 0.0;
let mut t0 = 0.0;
let mut t1 = 0.0;
let mut t2 = 0.0;
for i in 0..n {
let x = shifted_x_coords[i]; let y = y_values[i];
let x_sq = x * x;
s1 += x;
s2 += x_sq;
s3 += x_sq * x;
s4 += x_sq * x_sq;
t0 += y;
t1 += x * y;
t2 += x_sq * y;
}
let d = s0 * (s2 * s4 - s3 * s3) - s1 * (s1 * s4 - s2 * s3) + s2 * (s1 * s3 - s2 * s2);
if d.abs() < epsilon {
return Err(Box::new(FittingError(format!(
"Denominator D ({}) is almost zero. Matrix is singular or ill-conditioned.",
d
)))); }
let dc_num = t0 * (s2 * s4 - s3 * s3) - s1 * (t1 * s4 - t2 * s3) + s2 * (t1 * s3 - t2 * s2);
let db_num = s0 * (t1 * s4 - t2 * s3) - t0 * (s1 * s4 - s2 * s3) + s2 * (s1 * t2 - s2 * t1);
let da_num = s0 * (s2 * t2 - s3 * t1) - s1 * (s1 * t2 - s2 * t1) + t0 * (s1 * s3 - s2 * s2);
let a = da_num / d;
let b = db_num / d;
let c = dc_num / d;
if a.abs() < epsilon {
return Err(Box::new(FittingError(
"Coefficient 'a' is almost zero. Quadratic function is degenerate.".to_string(),
))); }
if a > 0.0 {
return Err(Box::new(FittingError("Coefficient 'a' is positive. Quadratic function is convex downwards, no maximum exists.".to_string())));
}
let peak_x_shifted = -b / (2.0 * a);
let peak_x = peak_x_shifted + x_center;
Ok(QuadraticFitResult { peak_x, a, b, c })
}
#[allow(dead_code)]
pub fn fit_linear_least_squares(
x_coords: &[f64],
y_values: &[f64],
) -> Result<(f64, f64), Box<dyn Error>> {
if x_coords.len() != y_values.len() || x_coords.len() < 2 {
return Err("Input vectors must be of the same size and contain at least 2 points.".into());
}
let n = x_coords.len() as f64;
let sum_x: f64 = x_coords.iter().sum();
let sum_y: f64 = y_values.iter().sum();
let sum_xy: f64 = x_coords
.iter()
.zip(y_values.iter())
.map(|(x, y)| x * y)
.sum();
let sum_x_sq: f64 = x_coords.iter().map(|x| x * x).sum();
let denominator = n * sum_x_sq - sum_x * sum_x;
if denominator.abs() < 1e-9 {
return Err("Denominator is zero, cannot fit a line.".into());
}
let m = (n * sum_xy - sum_x * sum_y) / denominator;
let c = (sum_y * sum_x_sq - sum_x * sum_xy) / denominator;
Ok((m, c)) }
pub fn fit_polynomial_least_squares(
x_coords: &[f64],
y_values: &[f64],
degree: usize,
) -> Result<Vec<f64>, Box<dyn Error>> {
let n = x_coords.len();
if n <= degree {
return Err(Box::new(FittingError(format!(
"Not enough data points ({}) for a polynomial of degree {}. Need at least {} points.",
n,
degree,
degree + 1
))));
}
let mut a_data = Vec::with_capacity(n * (degree + 1));
for &x in x_coords {
for i in 0..=degree {
a_data.push(x.powi(i as i32));
}
}
let a = DMatrix::from_row_slice(n, degree + 1, &a_data);
let y = DVector::from_vec(y_values.to_vec());
let ata = a.transpose() * &a;
let aty = a.transpose() * y;
let lu = ata.lu();
let coeffs = lu.solve(&aty).ok_or_else(|| {
Box::new(FittingError(
"Failed to solve linear system for polynomial fitting. Matrix might be singular."
.to_string(),
))
})?;
Ok(coeffs.iter().cloned().collect())
}
pub fn fit_polynomial_plus_sinusoid_least_squares(
x_coords: &[f64],
y_values: &[f64],
degree: usize,
period_sec: f64,
) -> Result<Vec<f64>, Box<dyn Error>> {
let n = x_coords.len();
let cols = degree + 3; if n < cols {
return Err(Box::new(FittingError(format!(
"Not enough data points ({}) for polynomial+sin fitting with {} parameters. Need at least {} points.",
n, cols, cols
))));
}
if !period_sec.is_finite() || period_sec <= 0.0 {
return Err(Box::new(FittingError(format!(
"Invalid sinusoid period: {}",
period_sec
))));
}
let omega = 2.0 * f64::consts::PI / period_sec;
let mut a_data = Vec::with_capacity(n * cols);
for &x in x_coords {
for i in 0..=degree {
a_data.push(x.powi(i as i32));
}
a_data.push((omega * x).sin());
a_data.push((omega * x).cos());
}
let a = DMatrix::from_row_slice(n, cols, &a_data);
let y = DVector::from_vec(y_values.to_vec());
let ata = a.transpose() * &a;
let aty = a.transpose() * y;
let lu = ata.lu();
let coeffs = lu.solve(&aty).ok_or_else(|| {
Box::new(FittingError(
"Failed to solve linear system for polynomial+sin fitting. Matrix might be singular."
.to_string(),
))
})?;
Ok(coeffs.iter().cloned().collect())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fit_quadratic_least_squares_basic() {
let x_coords = vec![0.0, 1.0, 2.0];
let y_values = vec![3.0, 4.0, 3.0];
let result = fit_quadratic_least_squares(&x_coords, &y_values);
assert!(result.is_ok());
let fit_result = result.unwrap();
assert!((fit_result.peak_x - 1.0).abs() < 1e-9);
assert!((fit_result.a - (-1.0)).abs() < 1e-9);
assert!((fit_result.b - 0.0).abs() < 1e-9);
assert!((fit_result.c - 4.0).abs() < 1e-9);
}
#[test]
fn test_fit_quadratic_least_squares_more_points() {
let x_coords = vec![0.0, 1.0, 2.0, 3.0, 4.0];
let y_values = vec![1.0, 7.0, 9.0, 7.0, 1.0];
let result = fit_quadratic_least_squares(&x_coords, &y_values);
assert!(result.is_ok());
let fit_result = result.unwrap();
assert!((fit_result.peak_x - 2.0).abs() < 1e-9);
assert!((fit_result.a - (-2.0)).abs() < 1e-9);
assert!((fit_result.b - 0.0).abs() < 1e-9);
assert!((fit_result.c - 9.0).abs() < 1e-9);
}
#[test]
fn test_fit_quadratic_least_squares_positive_a() {
let x_coords = vec![-1.0, 0.0, 1.0];
let y_values = vec![1.0, 0.0, 1.0];
let result = fit_quadratic_least_squares(&x_coords, &y_values);
assert!(result.is_err());
if let Err(e) = result {
let err_msg = e.to_string();
assert!(err_msg.contains("Coefficient 'a' is positive"));
} else {
panic!("予期しないエラータイプ: {:?}", result);
}
}
#[test]
fn test_fit_quadratic_least_squares_degenerate() {
let x_coords = vec![0.0, 1.0, 2.0];
let y_values = vec![0.0, 1.0, 2.0];
let result = fit_quadratic_least_squares(&x_coords, &y_values);
assert!(result.is_err());
if let Err(e) = result {
let err_msg = e.to_string();
assert!(err_msg.contains("Coefficient 'a' is almost zero"));
} else {
panic!("予期しないエラータイプ: {:?}", result);
}
}
#[test]
fn test_fit_quadratic_least_squares_insufficient_points() {
let x_coords = vec![0.0, 1.0];
let y_values = vec![0.0, 1.0];
let result = fit_quadratic_least_squares(&x_coords, &y_values);
assert!(result.is_err());
if let Err(e) = result {
let err_msg = e.to_string();
assert!(err_msg.contains("at least 3 points"));
} else {
panic!("予期しないエラータイプ: {:?}", result);
}
}
#[test]
fn test_fit_polynomial_least_squares_linear() {
let x_coords = vec![0.0, 1.0, 2.0, 3.0];
let y_values = vec![1.0, 3.0, 5.0, 7.0];
let degree = 1;
let result = fit_polynomial_least_squares(&x_coords, &y_values, degree);
assert!(result.is_ok());
let coeffs = result.unwrap();
assert_eq!(coeffs.len(), 2);
assert!((coeffs[0] - 1.0).abs() < 1e-9); assert!((coeffs[1] - 2.0).abs() < 1e-9); }
#[test]
fn test_fit_polynomial_least_squares_quadratic() {
let x_coords = vec![0.0, 1.0, 2.0, 3.0, 4.0];
let y_values = vec![3.0, 2.0, 3.0, 6.0, 11.0];
let degree = 2;
let result = fit_polynomial_least_squares(&x_coords, &y_values, degree);
assert!(result.is_ok());
let coeffs = result.unwrap();
assert_eq!(coeffs.len(), 3);
assert!((coeffs[0] - 3.0).abs() < 1e-9); assert!((coeffs[1] - (-2.0)).abs() < 1e-9); assert!((coeffs[2] - 1.0).abs() < 1e-9); }
#[test]
fn test_fit_polynomial_least_squares_insufficient_points() {
let x_coords = vec![0.0, 1.0];
let y_values = vec![0.0, 1.0];
let degree = 2;
let result = fit_polynomial_least_squares(&x_coords, &y_values, degree);
assert!(result.is_err());
if let Err(e) = result {
let err_msg = e.to_string();
assert!(err_msg.contains("Not enough data points"));
}
}
}