use alloc::vec;
use alloc::vec::Vec;
use crate::curve::OsstScalar;
use crate::error::OsstError;
pub fn compute_lagrange_coefficients<S: OsstScalar>(indices: &[u32]) -> Result<Vec<S>, OsstError> {
let k = indices.len();
if k == 0 {
return Err(OsstError::EmptyContributions);
}
for &idx in indices {
if idx == 0 {
return Err(OsstError::InvalidIndex);
}
}
let mut sorted = indices.to_vec();
sorted.sort();
for i in 1..sorted.len() {
if sorted[i] == sorted[i - 1] {
return Err(OsstError::DuplicateIndex(sorted[i]));
}
}
if k == 1 {
return Ok(vec![S::one()]);
}
let scalars: Vec<S> = indices.iter().map(|&i| S::from_u32(i)).collect();
let xi: S = scalars.iter().fold(S::one(), |acc, x| acc.mul(x));
let mut d_values: Vec<S> = Vec::with_capacity(k);
for i in 0..k {
let mut d = scalars[i].clone();
for j in 0..k {
if i != j {
let diff = scalars[j].sub(&scalars[i]);
d = d.mul(&diff);
}
}
d_values.push(d);
}
let mut rho: Vec<S> = vec![S::one(); k];
for i in 1..k {
rho[i] = rho[i - 1].mul(&d_values[i - 1]);
}
let mut suffix = S::one();
for i in (0..k).rev() {
rho[i] = rho[i].mul(&suffix);
suffix = suffix.mul(&d_values[i]);
}
let d_bar = suffix;
let d_bar_inv = d_bar.invert();
let delta = xi.mul(&d_bar_inv);
let coefficients: Vec<S> = rho.iter().map(|rho_i| delta.mul(rho_i)).collect();
Ok(coefficients)
}
#[cfg(all(test, feature = "ristretto255"))]
mod tests {
use super::*;
use curve25519_dalek::scalar::Scalar;
#[test]
fn test_lagrange_single() {
let coeffs = compute_lagrange_coefficients::<Scalar>(&[1]).unwrap();
assert_eq!(coeffs.len(), 1);
assert_eq!(coeffs[0], Scalar::ONE);
}
#[test]
fn test_lagrange_two_points() {
let coeffs = compute_lagrange_coefficients::<Scalar>(&[1, 2]).unwrap();
assert_eq!(coeffs[0], Scalar::from(2u32));
assert_eq!(coeffs[1], -Scalar::ONE);
}
#[test]
fn test_lagrange_three_points() {
let coeffs = compute_lagrange_coefficients::<Scalar>(&[1, 2, 3]).unwrap();
assert_eq!(coeffs[0], Scalar::from(3u32));
assert_eq!(coeffs[1], -Scalar::from(3u32));
assert_eq!(coeffs[2], Scalar::ONE);
}
#[test]
fn test_lagrange_non_consecutive() {
let coeffs = compute_lagrange_coefficients::<Scalar>(&[1, 3, 5]).unwrap();
let f_1 = Scalar::from(6u32);
let f_3 = Scalar::from(34u32);
let f_5 = Scalar::from(86u32);
let interpolated = coeffs[0] * f_1 + coeffs[1] * f_3 + coeffs[2] * f_5;
assert_eq!(interpolated, Scalar::ONE); }
#[test]
fn test_lagrange_duplicate_error() {
let result = compute_lagrange_coefficients::<Scalar>(&[1, 2, 2]);
assert!(matches!(result, Err(OsstError::DuplicateIndex(2))));
}
#[test]
fn test_lagrange_zero_index_error() {
let result = compute_lagrange_coefficients::<Scalar>(&[0, 1, 2]);
assert!(matches!(result, Err(OsstError::InvalidIndex)));
}
#[test]
fn test_lagrange_empty_error() {
let result = compute_lagrange_coefficients::<Scalar>(&[]);
assert!(matches!(result, Err(OsstError::EmptyContributions)));
}
#[test]
fn test_lagrange_interpolation_property() {
for k in 2..=10 {
let indices: Vec<u32> = (1..=k).collect();
let coeffs = compute_lagrange_coefficients::<Scalar>(&indices).unwrap();
let sum: Scalar = coeffs.iter().fold(Scalar::ZERO, |acc, c| acc + c);
assert_eq!(
sum,
Scalar::ONE,
"sum of lagrange coeffs should be 1 for k={}",
k
);
}
}
#[test]
fn test_lagrange_large_set() {
let indices: Vec<u32> = (1..=100).collect();
let coeffs = compute_lagrange_coefficients::<Scalar>(&indices).unwrap();
assert_eq!(coeffs.len(), 100);
let sum: Scalar = coeffs.iter().fold(Scalar::ZERO, |acc, c| acc + c);
assert_eq!(sum, Scalar::ONE);
}
}