use super::arithmetic::Polynomial;
use crate::error::{FFTError, FFTResult};
fn build_subproduct_tree(xs: &[f64]) -> FFTResult<Vec<Vec<Polynomial>>> {
let n = xs.len();
if n == 0 {
return Ok(vec![vec![Polynomial::one()]]);
}
let mut level: Vec<Polynomial> = xs
.iter()
.map(|&xi| Polynomial::new(vec![-xi, 1.0]))
.collect();
let mut tree: Vec<Vec<Polynomial>> = vec![level.clone()];
while level.len() > 1 {
let mut next_level: Vec<Polynomial> = Vec::with_capacity((level.len() + 1) / 2);
let mut i = 0;
while i < level.len() {
if i + 1 < level.len() {
let prod = level[i].mul_fft(&level[i + 1])?;
next_level.push(prod);
} else {
next_level.push(level[i].clone());
}
i += 2;
}
level = next_level;
tree.push(level.clone());
}
tree.reverse();
Ok(tree)
}
pub fn multipoint_eval(poly: &Polynomial, points: &[f64]) -> FFTResult<Vec<f64>> {
if points.is_empty() {
return Err(FFTError::ValueError("no evaluation points given".into()));
}
if points.len() <= 8 || poly.degree() <= 8 {
return Ok(points.iter().map(|&x| poly.eval(x)).collect());
}
let tree = build_subproduct_tree(points)?;
let remainders = multipoint_eval_tree(poly, points, &tree, 0, 0, points.len())?;
Ok(remainders)
}
fn multipoint_eval_tree(
poly: &Polynomial,
points: &[f64],
tree: &[Vec<Polynomial>],
depth: usize,
lo: usize,
hi: usize,
) -> FFTResult<Vec<f64>> {
let n = hi - lo;
if n == 0 {
return Ok(vec![]);
}
if n == 1 {
return Ok(vec![poly.eval(points[lo])]);
}
let tree_depth = tree.len();
let node_poly = get_tree_node(tree, depth, lo, hi, points.len())?;
let (_, rem) = poly.div_rem(&node_poly)?;
let mid = lo + (hi - lo) / 2;
let left = multipoint_eval_tree(&rem, points, tree, depth + 1, lo, mid)?;
let right = multipoint_eval_tree(&rem, points, tree, depth + 1, mid, hi)?;
let mut result = left;
result.extend(right);
Ok(result)
}
fn get_tree_node(
tree: &[Vec<Polynomial>],
depth: usize,
lo: usize,
hi: usize,
total: usize,
) -> FFTResult<Polynomial> {
let levels = tree.len();
if depth >= levels {
return Ok(tree
.last()
.and_then(|lvl| {
let idx = lo;
lvl.get(idx).cloned()
})
.unwrap_or_else(Polynomial::one));
}
let level = &tree[depth];
let level_size = level.len();
let blocks_at_depth = 1_usize << depth; let block_size = (total + blocks_at_depth - 1) / blocks_at_depth;
let node_idx = lo / block_size;
if node_idx < level_size {
Ok(level[node_idx].clone())
} else {
build_product_polynomial(&[], lo, hi)
}
}
fn build_product_polynomial(xs: &[f64], lo: usize, hi: usize) -> FFTResult<Polynomial> {
if lo >= hi {
return Ok(Polynomial::one());
}
if hi - lo == 1 {
if lo < xs.len() {
return Ok(Polynomial::new(vec![-xs[lo], 1.0]));
} else {
return Ok(Polynomial::one());
}
}
let mid = lo + (hi - lo) / 2;
let left = build_product_polynomial(xs, lo, mid)?;
let right = build_product_polynomial(xs, mid, hi)?;
left.mul_fft(&right)
}
pub fn interpolate(points: &[f64], values: &[f64]) -> FFTResult<Polynomial> {
if points.len() != values.len() {
return Err(FFTError::ValueError(format!(
"points ({}) and values ({}) must have the same length",
points.len(),
values.len()
)));
}
if points.is_empty() {
return Err(FFTError::ValueError("no interpolation points".into()));
}
let n = points.len();
if n == 1 {
return Ok(Polynomial::new(vec![values[0]]));
}
if n == 2 {
let slope = (values[1] - values[0]) / (points[1] - points[0]);
let intercept = values[0] - slope * points[0];
return Ok(Polynomial::new(vec![intercept, slope]));
}
if n <= 32 {
return interpolate_newton(points, values);
}
let tree = build_subproduct_tree(points)?;
let m_poly = tree[0][0].clone();
let m_deriv = m_poly.derivative();
let m_deriv_vals = multipoint_eval(&m_deriv, points)?;
let weights: Vec<f64> = values
.iter()
.zip(m_deriv_vals.iter())
.enumerate()
.map(|(i, (&y, &md))| {
if md.abs() < f64::EPSILON * 1e6 {
let _ = i;
0.0
} else {
y / md
}
})
.collect();
interpolate_from_tree(&tree, &weights, points)
}
fn interpolate_from_tree(
tree: &[Vec<Polynomial>],
weights: &[f64],
_points: &[f64],
) -> FFTResult<Polynomial> {
let levels = tree.len();
if levels == 0 {
return Ok(Polynomial::zero());
}
let leaf_level = levels - 1;
let leaves = &tree[leaf_level];
let n = weights.len();
let mut nodes: Vec<(Polynomial, Polynomial)> = (0..leaves.len())
.map(|i| {
let q = leaves[i].clone();
let p = if i < n {
Polynomial::new(vec![weights[i]])
} else {
Polynomial::zero()
};
(p, q)
})
.collect();
for d in (0..leaf_level).rev() {
let level = &tree[d];
let mut next_nodes: Vec<(Polynomial, Polynomial)> = Vec::with_capacity(level.len());
let mut i = 0;
while i < nodes.len() {
if i + 1 < nodes.len() {
let (p_left, q_left) = &nodes[i];
let (p_right, q_right) = &nodes[i + 1];
let pl_qr = p_left.mul_fft(q_right)?;
let pr_ql = p_right.mul_fft(q_left)?;
let p_merge = pl_qr.add(&pr_ql);
let q_merge = if i / 2 < level.len() {
level[i / 2].clone()
} else {
q_left.mul_fft(q_right)?
};
next_nodes.push((p_merge, q_merge));
} else {
next_nodes.push(nodes[i].clone());
}
i += 2;
}
nodes = next_nodes;
}
if nodes.is_empty() {
Ok(Polynomial::zero())
} else {
Ok(nodes[0].0.clone())
}
}
fn interpolate_newton(points: &[f64], values: &[f64]) -> FFTResult<Polynomial> {
let n = points.len();
let mut dd = values.to_vec();
for j in 1..n {
for i in (j..n).rev() {
let denom = points[i] - points[i - j];
if denom.abs() < f64::EPSILON {
return Err(FFTError::ValueError(format!(
"duplicate interpolation points at index {i}"
)));
}
dd[i] = (dd[i] - dd[i - 1]) / denom;
}
}
let mut result = Polynomial::new(vec![dd[n - 1]]);
for i in (0..n - 1).rev() {
let shift = Polynomial::new(vec![-points[i], 1.0]);
result = result.mul_naive(&shift);
result.coeffs[0] += dd[i];
}
Ok(result)
}
pub fn partial_fraction_decomp(
numerator: &Polynomial,
poles: &[f64],
) -> FFTResult<Vec<f64>> {
if poles.is_empty() {
return Err(FFTError::ValueError("no poles provided".into()));
}
let n = poles.len();
if numerator.degree() >= n {
return Err(FFTError::ValueError(format!(
"numerator degree {} must be < number of poles {}",
numerator.degree(),
n
)));
}
let m_poly = build_product_poly(poles)?;
let m_deriv = m_poly.derivative();
let num_vals = multipoint_eval(numerator, poles)?;
let deriv_vals = multipoint_eval(&m_deriv, poles)?;
let residues: FFTResult<Vec<f64>> = num_vals
.iter()
.zip(deriv_vals.iter())
.enumerate()
.map(|(i, (&pv, &dv))| {
if dv.abs() < f64::EPSILON * 1e6 * (1.0 + poles[i].abs()) {
Err(FFTError::ValueError(format!(
"M'(pole[{i}]) ≈ 0; poles may not be distinct"
)))
} else {
Ok(pv / dv)
}
})
.collect();
residues
}
pub fn build_product_poly(roots: &[f64]) -> FFTResult<Polynomial> {
if roots.is_empty() {
return Ok(Polynomial::one());
}
if roots.len() == 1 {
return Ok(Polynomial::new(vec![-roots[0], 1.0]));
}
let mid = roots.len() / 2;
let left = build_product_poly(&roots[..mid])?;
let right = build_product_poly(&roots[mid..])?;
left.mul_fft(&right)
}
pub fn chebyshev_nodes_first(n: usize) -> Vec<f64> {
use std::f64::consts::PI;
(0..n)
.map(|k| (PI * (2 * k + 1) as f64 / (2 * n) as f64).cos())
.collect()
}
pub fn chebyshev_nodes_second(n: usize) -> Vec<f64> {
use std::f64::consts::PI;
if n <= 1 {
return vec![0.0];
}
(0..n)
.map(|k| (PI * k as f64 / (n - 1) as f64).cos())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn p(c: Vec<f64>) -> Polynomial {
Polynomial::new(c)
}
#[test]
fn test_multipoint_eval_constant() {
let poly = p(vec![7.0]);
let xs = vec![0.0, 1.0, 2.0, -1.0, 100.0];
let ys = multipoint_eval(&poly, &xs).expect("eval");
for y in ys {
assert_relative_eq!(y, 7.0, epsilon = 1e-10);
}
}
#[test]
fn test_multipoint_eval_linear() {
let poly = p(vec![2.0, 3.0]);
let xs: Vec<f64> = (0..10).map(|i| i as f64).collect();
let ys = multipoint_eval(&poly, &xs).expect("eval");
for (i, y) in ys.iter().enumerate() {
let expected = 2.0 + 3.0 * i as f64;
assert_relative_eq!(y, &expected, epsilon = 1e-10);
}
}
#[test]
fn test_multipoint_eval_quadratic() {
let poly = p(vec![1.0, 0.0, 1.0]);
let xs = vec![0.0, 1.0, 2.0, 3.0];
let ys = multipoint_eval(&poly, &xs).expect("eval");
let expected = vec![1.0, 2.0, 5.0, 10.0];
for (y, e) in ys.iter().zip(expected.iter()) {
assert_relative_eq!(y, e, epsilon = 1e-10);
}
}
#[test]
fn test_multipoint_eval_many_points() {
let coeffs: Vec<f64> = (0..9).map(|i| i as f64 + 1.0).collect();
let poly = p(coeffs);
let xs: Vec<f64> = (0..50).map(|i| i as f64 * 0.1 - 2.5).collect();
let ys_mp = multipoint_eval(&poly, &xs).expect("multipoint");
let ys_direct: Vec<f64> = xs.iter().map(|&x| poly.eval(x)).collect();
for (a, b) in ys_mp.iter().zip(ys_direct.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-8);
}
}
#[test]
fn test_interpolate_constant() {
let xs = vec![0.0, 1.0, 2.0];
let ys = vec![5.0, 5.0, 5.0];
let q = interpolate(&xs, &ys).expect("interp");
assert_relative_eq!(q.eval(3.0), 5.0, epsilon = 1e-8);
}
#[test]
fn test_interpolate_linear() {
let xs = vec![0.0, 1.0];
let ys = vec![1.0, 3.0];
let q = interpolate(&xs, &ys).expect("interp");
assert_relative_eq!(q.eval(0.5), 2.0, epsilon = 1e-10);
assert_relative_eq!(q.eval(2.0), 5.0, epsilon = 1e-10);
}
#[test]
fn test_interpolate_through_known_polynomial() {
let poly = p(vec![1.0, -2.0, 0.0, 1.0]);
let xs: Vec<f64> = vec![-2.0, -1.0, 0.0, 1.0];
let ys: Vec<f64> = xs.iter().map(|&x| poly.eval(x)).collect();
let q = interpolate(&xs, &ys).expect("interp");
assert_relative_eq!(q.eval(2.0), poly.eval(2.0), epsilon = 1e-6);
assert_relative_eq!(q.eval(0.5), poly.eval(0.5), epsilon = 1e-6);
}
#[test]
fn test_interpolate_mismatched_lengths_error() {
assert!(interpolate(&[0.0, 1.0], &[1.0]).is_err());
}
#[test]
fn test_interpolate_empty_error() {
assert!(interpolate(&[], &[]).is_err());
}
#[test]
fn test_pfd_simple() {
let num = p(vec![1.0]);
let poles = vec![1.0, 2.0];
let res = partial_fraction_decomp(&num, &poles).expect("pfd");
assert_eq!(res.len(), 2);
assert_relative_eq!(res[0], -1.0, epsilon = 1e-10);
assert_relative_eq!(res[1], 1.0, epsilon = 1e-10);
}
#[test]
fn test_pfd_three_poles() {
let num = p(vec![1.0]);
let poles = vec![1.0, 2.0, 3.0];
let res = partial_fraction_decomp(&num, &poles).expect("pfd");
assert_eq!(res.len(), 3);
assert_relative_eq!(res[0], 0.5, epsilon = 1e-8);
assert_relative_eq!(res[1], -1.0, epsilon = 1e-8);
assert_relative_eq!(res[2], 0.5, epsilon = 1e-8);
}
#[test]
fn test_pfd_numerator_too_high_error() {
let num = p(vec![1.0, 0.0, 1.0]); let poles = vec![1.0, 2.0]; assert!(partial_fraction_decomp(&num, &poles).is_err());
}
#[test]
fn test_pfd_single_pole() {
let num = p(vec![3.0]); let poles = vec![5.0];
let res = partial_fraction_decomp(&num, &poles).expect("pfd");
assert_eq!(res.len(), 1);
assert_relative_eq!(res[0], 3.0, epsilon = 1e-10); }
#[test]
fn test_chebyshev_nodes_first_symmetry() {
let nodes = chebyshev_nodes_first(4);
assert_eq!(nodes.len(), 4);
for &x in &nodes {
assert!(x >= -1.0 - 1e-12 && x <= 1.0 + 1e-12);
}
}
#[test]
fn test_chebyshev_nodes_second_endpoints() {
let nodes = chebyshev_nodes_second(5);
assert_eq!(nodes.len(), 5);
assert_relative_eq!(nodes[0].abs(), 1.0, epsilon = 1e-12);
assert_relative_eq!(nodes[4].abs(), 1.0, epsilon = 1e-12);
}
#[test]
fn test_build_product_poly() {
let roots = vec![1.0, 2.0, 3.0];
let poly = build_product_poly(&roots).expect("product poly");
assert_relative_eq!(poly.eval(1.0), 0.0, epsilon = 1e-10);
assert_relative_eq!(poly.eval(2.0), 0.0, epsilon = 1e-10);
assert_relative_eq!(poly.eval(3.0), 0.0, epsilon = 1e-10);
assert_relative_eq!(poly.eval(4.0), 6.0, epsilon = 1e-10); }
}