#[allow(clippy::needless_range_loop)] pub fn lagrange_interpolate(matrix: &[Vec<f64>], x: f64) -> Vec<f64> {
let n_constraints = matrix.len();
let n_vars = if n_constraints > 0 {
matrix[0].len()
} else {
0
};
let mut result = vec![0.0; n_vars];
for j in 0..n_vars {
let mut p = 0.0;
for i in 0..n_constraints {
let mut lagrange_basis = 1.0;
for c in 1..=n_constraints {
if c != i + 1 {
lagrange_basis *= (x - c as f64) / ((i + 1) as f64 - c as f64);
}
}
p += lagrange_basis * matrix[i][j];
}
result[j] = p;
}
result
}
pub struct LagrangePolynomial {
matrix: Vec<Vec<f64>>,
}
impl LagrangePolynomial {
pub fn new(matrix: Vec<Vec<f64>>) -> Self {
LagrangePolynomial { matrix }
}
pub fn evaluate(&self, x: f64) -> Vec<f64> {
lagrange_interpolate(&self.matrix, x)
}
}
pub fn vanishing_polynomial(n_constraints: usize, x: f64) -> f64 {
let mut result = 1.0;
for i in 1..=n_constraints {
result *= x - i as f64;
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lagrange_interpolation() {
let matrix = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
let result = lagrange_interpolate(&matrix, 1.0);
assert!((result[0] - 1.0).abs() < 1e-10);
assert!((result[1] - 0.0).abs() < 1e-10);
let result = lagrange_interpolate(&matrix, 2.0);
assert!((result[0] - 0.0).abs() < 1e-10);
assert!((result[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_lagrange_matches_python() {
let matrix = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let result = lagrange_interpolate(&matrix, 4.0);
println!("Rust LP(M, 4) = {:?}", result);
assert!(
(result[0] - 1.0).abs() < 1e-10,
"Expected result[0] = 1.0, got {}",
result[0]
);
assert!(
(result[1] - (-3.0)).abs() < 1e-10,
"Expected result[1] = -3.0, got {}",
result[1]
);
assert!(
(result[2] - 3.0).abs() < 1e-10,
"Expected result[2] = 3.0, got {}",
result[2]
);
}
#[test]
fn test_vanishing_polynomial() {
let z = vanishing_polynomial(3, 4.0);
assert!((z - 6.0).abs() < 1e-10);
assert!(vanishing_polynomial(3, 1.0).abs() < 1e-10);
assert!(vanishing_polynomial(3, 2.0).abs() < 1e-10);
assert!(vanishing_polynomial(3, 3.0).abs() < 1e-10);
}
}