arcium-core-utils 0.4.2

Arcium core utils
Documentation
use ff::{Field, PrimeField};
use primitives::algebra::{
    elliptic_curve::{BaseField, BaseFieldElement, Curve},
    field::FieldExtension,
};

use crate::key_recovery::{KeyRecoveryError, MXE_KEY_RECOVERY_D, MXE_KEY_RECOVERY_N};

/// Compute the error locator polynomial Lambda using the Berlekamp-Massey algorithm.
#[allow(non_snake_case)]
fn compute_error_locator_polynomial<C: Curve>(
    d_minus_one: BaseFieldElement<C>,
    syndromes: &[BaseFieldElement<C>],
) -> Result<(usize, Vec<BaseFieldElement<C>>), KeyRecoveryError> {
    let d_minus_one_bytes = d_minus_one.to_le_bytes();
    if d_minus_one_bytes.iter().skip(8).any(|byte| *byte != 0u8) {
        return Err(KeyRecoveryError::InvalidInput(String::from(
            "Key recovery compute errors expects d_minus_one to be at most 8 bytes long.",
        )));
    }
    let d_minus_one = usize::from_le_bytes(
        d_minus_one_bytes
            .into_iter()
            .take(8)
            .collect::<Vec<u8>>()
            .try_into()
            .unwrap(),
    );
    if d_minus_one > MXE_KEY_RECOVERY_D - 1 {
        return Err(KeyRecoveryError::InvalidInput(String::from(
            "Key recovery compute errors d_minus_one too large.",
        )));
    }
    if syndromes.len() != MXE_KEY_RECOVERY_D - 1 {
        return Err(KeyRecoveryError::InvalidInput(String::from(
            "Key recovery compute errors syndromes length mismatch.",
        )));
    }
    if syndromes
        .iter()
        .skip(d_minus_one)
        .any(|syndrome| *syndrome != BaseFieldElement::<C>::from(0u8))
    {
        return Err(KeyRecoveryError::InvalidInput(String::from(
            "Key recovery compute errors trailing syndromes should be zero.",
        )));
    }
    let syndromes = syndromes
        .iter()
        .take(d_minus_one)
        .copied()
        .collect::<Vec<BaseFieldElement<C>>>();

    // Notation-wise we follow [this](https://en.wikipedia.org/wiki/Berlekamp%E2%80%93Massey_algorithm#Pseudocode).
    let mut C = vec![BaseFieldElement::<C>::from(1u8)];
    let mut B = vec![BaseFieldElement::<C>::from(1u8)];
    let mut L = 0usize;
    let mut m = 1usize;
    let mut b = BaseFieldElement::<C>::from(1u8);
    for n in 0..syndromes.len() {
        let mut d = syndromes[n];
        for i in 1..=L {
            d += C[i] * syndromes[n - i];
        }

        if d.eq(&BaseFieldElement::<C>::from(0u8)) {
            m += 1;
        } else if 2 * L <= n {
            let T = C.clone();
            C.resize((m + B.len()).max(C.len()), BaseFieldElement::<C>::from(0u8));
            let db_inv = d * b.invert().unwrap_or(BaseFieldElement::<C>::from(0u8));
            for i in 0..B.len() {
                C[m + i] -= db_inv * B[i];
            }
            L = n + 1 - L;
            B = T;
            b = d;
            m = 1usize;
        } else {
            C.resize((m + B.len()).max(C.len()), BaseFieldElement::<C>::from(0u8));
            let db_inv = d * b.invert().unwrap_or(BaseFieldElement::<C>::from(0u8));
            for i in 0..B.len() {
                C[m + i] -= db_inv * B[i];
            }
            m += 1;
        }
    }

    // L now corresponds to the number of errors and C corresponds to Lambda.
    Ok((L, C))
}

/// Find the roots of the reversed error locator polynomial Lambda. We perform a [Chien search](https://en.wikipedia.org/wiki/Chien_search).
#[allow(non_snake_case)]
fn find_roots<C: Curve>(L: usize, C: &[BaseFieldElement<C>]) -> Vec<bool> {
    let multiplicative_generator = BaseFieldElement::<C>::from_le_bytes(
        &BaseField::<C>::MULTIPLICATIVE_GENERATOR
            .to_le_bytes()
            .into_iter()
            .collect::<Vec<u8>>(),
    )
    .unwrap();

    let mut alpha_pows = vec![BaseFieldElement::<C>::from(1u8)];
    for i in 0..L {
        alpha_pows.push(alpha_pows[i] * multiplicative_generator);
    }
    let mut lambdas = C
        .iter()
        .copied()
        .rev()
        .collect::<Vec<BaseFieldElement<C>>>();
    let mut error_locations = vec![lambdas
        .iter()
        .copied()
        .reduce(|a, b| a + b)
        .unwrap()
        .eq(&BaseFieldElement::<C>::from(0u8))];
    for _ in 0..MXE_KEY_RECOVERY_N - 1 {
        lambdas = lambdas
            .iter()
            .zip(alpha_pows.clone())
            .map(|(lambda, pow)| *lambda * pow)
            .collect::<Vec<BaseFieldElement<C>>>();
        error_locations.push(
            lambdas
                .iter()
                .copied()
                .reduce(|a, b| a + b)
                .unwrap()
                .eq(&BaseFieldElement::<C>::from(0u8)),
        );
    }

    error_locations
}

/// Compute the error values at known error locations, using [Forney's algorithm](https://en.wikipedia.org/wiki/Forney_algorithm).
#[allow(non_snake_case)]
fn compute_error_values<C: Curve>(
    syndromes: &[BaseFieldElement<C>],
    L: usize,
    C: &[BaseFieldElement<C>],
    error_locations: Vec<bool>,
) -> [BaseFieldElement<C>; MXE_KEY_RECOVERY_N] {
    let multiplicative_generator = BaseFieldElement::<C>::from_le_bytes(
        &BaseField::<C>::MULTIPLICATIVE_GENERATOR
            .to_le_bytes()
            .into_iter()
            .collect::<Vec<u8>>(),
    )
    .unwrap();

    let Lambda_prime = C
        .iter()
        .enumerate()
        .skip(1)
        .map(|(i, lambda)| BaseFieldElement::<C>::from(i as u64) * lambda)
        .collect::<Vec<BaseFieldElement<C>>>();
    let mut Omega = Vec::new();
    for i in 0..L {
        let mut omega_i = BaseFieldElement::<C>::from(0u8);
        for j in 0..=i {
            if i - j < syndromes.len() && j < L + 1 {
                omega_i += syndromes[i - j] * C[j];
            }
        }
        Omega.push(omega_i);
    }

    fn eval_poly<C: Curve>(
        poly: Vec<BaseFieldElement<C>>,
        x: BaseFieldElement<C>,
    ) -> BaseFieldElement<C> {
        let mut res = poly[0];
        let mut pow = BaseFieldElement::<C>::from(1u8);
        for c in poly.into_iter().skip(1) {
            pow *= x;
            res += c * pow;
        }
        res
    }

    let mut errors = [BaseFieldElement::<C>::from(0u8); MXE_KEY_RECOVERY_N];
    let mut pow = BaseFieldElement::<C>::from(1u8);
    for (i, e) in error_locations.into_iter().enumerate() {
        if e {
            let X_inv = pow.invert().unwrap_or(BaseFieldElement::<C>::from(0u8));
            let num = eval_poly::<C>(Omega.clone(), X_inv);
            let denom = eval_poly::<C>(Lambda_prime.clone(), X_inv);
            errors[i] = -num * denom.invert().unwrap_or(BaseFieldElement::<C>::from(0u8));
        }
        pow *= multiplicative_generator;
    }

    errors
}

/// Compute the errors from plaintext syndromes, using Berlekamp-Massey, Chien search and
/// Forney's algorithm.
#[allow(non_snake_case)]
pub fn compute_errors<C: Curve>(
    d_minus_one: BaseFieldElement<C>,
    syndromes: &[BaseFieldElement<C>],
) -> Result<[BaseFieldElement<C>; MXE_KEY_RECOVERY_N], KeyRecoveryError> {
    let (L, C) = compute_error_locator_polynomial::<C>(d_minus_one, syndromes)?;
    let error_locations = find_roots::<C>(L, &C);
    let errors = compute_error_values::<C>(syndromes, L, &C, error_locations);
    Ok(errors)
}