use ff::{Field, PrimeField};
use primitives::algebra::{
elliptic_curve::{BaseField, BaseFieldElement, Curve},
field::FieldExtension,
};
use crate::key_recovery::{KeyRecoveryError, MXE_KEY_RECOVERY_D, MXE_KEY_RECOVERY_N};
#[allow(non_snake_case)]
fn compute_error_locator_polynomial<C: Curve>(
d_minus_one: BaseFieldElement<C>,
syndromes: &[BaseFieldElement<C>],
) -> Result<(usize, Vec<BaseFieldElement<C>>), KeyRecoveryError> {
let d_minus_one_bytes = d_minus_one.to_le_bytes();
if d_minus_one_bytes.iter().skip(8).any(|byte| *byte != 0u8) {
return Err(KeyRecoveryError::InvalidInput(String::from(
"Key recovery compute errors expects d_minus_one to be at most 8 bytes long.",
)));
}
let d_minus_one = usize::from_le_bytes(
d_minus_one_bytes
.into_iter()
.take(8)
.collect::<Vec<u8>>()
.try_into()
.unwrap(),
);
if d_minus_one > MXE_KEY_RECOVERY_D - 1 {
return Err(KeyRecoveryError::InvalidInput(String::from(
"Key recovery compute errors d_minus_one too large.",
)));
}
if syndromes.len() != MXE_KEY_RECOVERY_D - 1 {
return Err(KeyRecoveryError::InvalidInput(String::from(
"Key recovery compute errors syndromes length mismatch.",
)));
}
if syndromes
.iter()
.skip(d_minus_one)
.any(|syndrome| *syndrome != BaseFieldElement::<C>::from(0u8))
{
return Err(KeyRecoveryError::InvalidInput(String::from(
"Key recovery compute errors trailing syndromes should be zero.",
)));
}
let syndromes = syndromes
.iter()
.take(d_minus_one)
.copied()
.collect::<Vec<BaseFieldElement<C>>>();
let mut C = vec![BaseFieldElement::<C>::from(1u8)];
let mut B = vec![BaseFieldElement::<C>::from(1u8)];
let mut L = 0usize;
let mut m = 1usize;
let mut b = BaseFieldElement::<C>::from(1u8);
for n in 0..syndromes.len() {
let mut d = syndromes[n];
for i in 1..=L {
d += C[i] * syndromes[n - i];
}
if d.eq(&BaseFieldElement::<C>::from(0u8)) {
m += 1;
} else if 2 * L <= n {
let T = C.clone();
C.resize((m + B.len()).max(C.len()), BaseFieldElement::<C>::from(0u8));
let db_inv = d * b.invert().unwrap_or(BaseFieldElement::<C>::from(0u8));
for i in 0..B.len() {
C[m + i] -= db_inv * B[i];
}
L = n + 1 - L;
B = T;
b = d;
m = 1usize;
} else {
C.resize((m + B.len()).max(C.len()), BaseFieldElement::<C>::from(0u8));
let db_inv = d * b.invert().unwrap_or(BaseFieldElement::<C>::from(0u8));
for i in 0..B.len() {
C[m + i] -= db_inv * B[i];
}
m += 1;
}
}
Ok((L, C))
}
#[allow(non_snake_case)]
fn find_roots<C: Curve>(L: usize, C: &[BaseFieldElement<C>]) -> Vec<bool> {
let multiplicative_generator = BaseFieldElement::<C>::from_le_bytes(
&BaseField::<C>::MULTIPLICATIVE_GENERATOR
.to_le_bytes()
.into_iter()
.collect::<Vec<u8>>(),
)
.unwrap();
let mut alpha_pows = vec![BaseFieldElement::<C>::from(1u8)];
for i in 0..L {
alpha_pows.push(alpha_pows[i] * multiplicative_generator);
}
let mut lambdas = C
.iter()
.copied()
.rev()
.collect::<Vec<BaseFieldElement<C>>>();
let mut error_locations = vec![lambdas
.iter()
.copied()
.reduce(|a, b| a + b)
.unwrap()
.eq(&BaseFieldElement::<C>::from(0u8))];
for _ in 0..MXE_KEY_RECOVERY_N - 1 {
lambdas = lambdas
.iter()
.zip(alpha_pows.clone())
.map(|(lambda, pow)| *lambda * pow)
.collect::<Vec<BaseFieldElement<C>>>();
error_locations.push(
lambdas
.iter()
.copied()
.reduce(|a, b| a + b)
.unwrap()
.eq(&BaseFieldElement::<C>::from(0u8)),
);
}
error_locations
}
#[allow(non_snake_case)]
fn compute_error_values<C: Curve>(
syndromes: &[BaseFieldElement<C>],
L: usize,
C: &[BaseFieldElement<C>],
error_locations: Vec<bool>,
) -> [BaseFieldElement<C>; MXE_KEY_RECOVERY_N] {
let multiplicative_generator = BaseFieldElement::<C>::from_le_bytes(
&BaseField::<C>::MULTIPLICATIVE_GENERATOR
.to_le_bytes()
.into_iter()
.collect::<Vec<u8>>(),
)
.unwrap();
let Lambda_prime = C
.iter()
.enumerate()
.skip(1)
.map(|(i, lambda)| BaseFieldElement::<C>::from(i as u64) * lambda)
.collect::<Vec<BaseFieldElement<C>>>();
let mut Omega = Vec::new();
for i in 0..L {
let mut omega_i = BaseFieldElement::<C>::from(0u8);
for j in 0..=i {
if i - j < syndromes.len() && j < L + 1 {
omega_i += syndromes[i - j] * C[j];
}
}
Omega.push(omega_i);
}
fn eval_poly<C: Curve>(
poly: Vec<BaseFieldElement<C>>,
x: BaseFieldElement<C>,
) -> BaseFieldElement<C> {
let mut res = poly[0];
let mut pow = BaseFieldElement::<C>::from(1u8);
for c in poly.into_iter().skip(1) {
pow *= x;
res += c * pow;
}
res
}
let mut errors = [BaseFieldElement::<C>::from(0u8); MXE_KEY_RECOVERY_N];
let mut pow = BaseFieldElement::<C>::from(1u8);
for (i, e) in error_locations.into_iter().enumerate() {
if e {
let X_inv = pow.invert().unwrap_or(BaseFieldElement::<C>::from(0u8));
let num = eval_poly::<C>(Omega.clone(), X_inv);
let denom = eval_poly::<C>(Lambda_prime.clone(), X_inv);
errors[i] = -num * denom.invert().unwrap_or(BaseFieldElement::<C>::from(0u8));
}
pow *= multiplicative_generator;
}
errors
}
#[allow(non_snake_case)]
pub fn compute_errors<C: Curve>(
d_minus_one: BaseFieldElement<C>,
syndromes: &[BaseFieldElement<C>],
) -> Result<[BaseFieldElement<C>; MXE_KEY_RECOVERY_N], KeyRecoveryError> {
let (L, C) = compute_error_locator_polynomial::<C>(d_minus_one, syndromes)?;
let error_locations = find_roots::<C>(L, &C);
let errors = compute_error_values::<C>(syndromes, L, &C, error_locations);
Ok(errors)
}