arcium-primitives 0.4.1

Arcium primitives
Documentation
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};

use serde::{Deserialize, Serialize};
use subtle::{Choice, ConstantTimeEq};

use super::{GlobalFieldKey, ScalarKey};
use crate::{
    algebra::elliptic_curve::{Curve, Point, Scalar, ScalarAsExtension, ScalarField},
    errors::PrimitiveError,
    random::{CryptoRngCore, Random, RandomWith},
    sharing::OpenPointShare,
    types::ConditionallySelectable,
};

pub type GlobalCurveKey<C> = GlobalFieldKey<ScalarField<C>>;

// α and β, such that MAC(x) = α · x + β
// In the context of VOLE, this corresponds to w = Δ · u + v
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(bound = "C: Curve")]
pub struct CurveKey<C: Curve> {
    pub(crate) alpha: GlobalCurveKey<C>,
    pub(crate) beta: Point<C>,
}

impl<C: Curve> CurveKey<C> {
    pub fn new(alpha: GlobalCurveKey<C>, beta: Point<C>) -> Self {
        CurveKey { alpha, beta }
    }

    pub fn compute_mac(&self, value: &Point<C>) -> Point<C> {
        self.beta + value * *self.alpha
    }

    #[inline]
    pub fn verify_mac(&self, open_share: &OpenPointShare<C>) -> Result<(), PrimitiveError> {
        let expected_mac = self.compute_mac(&open_share.value);
        bool::from(expected_mac.ct_eq(&open_share.mac))
            .then_some(())
            .ok_or_else(|| {
                PrimitiveError::WrongMAC(serde_json::to_string(&open_share.mac).unwrap())
            })
    }

    pub fn get_alpha(&self) -> GlobalCurveKey<C> {
        self.alpha.clone()
    }

    pub fn get_alpha_value(&self) -> ScalarAsExtension<C> {
        *self.alpha
    }

    pub fn get_beta(&self) -> Point<C> {
        self.beta
    }

    pub fn zero_batch(alphas: Vec<GlobalCurveKey<C>>) -> Vec<CurveKey<C>> {
        alphas
            .iter()
            .map(|alpha| CurveKey::new(alpha.clone(), Point::<C>::identity()))
            .collect()
    }
}

impl<C: Curve> ConditionallySelectable for CurveKey<C> {
    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
        CurveKey {
            alpha: GlobalCurveKey::<C>::conditional_select(&a.alpha, &b.alpha, choice),
            beta: Point::conditional_select(&a.beta, &b.beta, choice),
        }
    }
}

// -------------------------
// |   Random Generation   |
// -------------------------

impl<C: Curve> Random for CurveKey<C> {
    fn random(mut rng: impl CryptoRngCore) -> Self {
        let alpha = GlobalCurveKey::<C>::random(&mut rng);
        let beta = Point::<C>::random(&mut rng);
        CurveKey { alpha, beta }
    }
}

impl<C: Curve> RandomWith<GlobalCurveKey<C>> for CurveKey<C> {
    fn random_with(mut rng: impl CryptoRngCore, alpha: GlobalCurveKey<C>) -> Self {
        CurveKey {
            alpha,
            beta: Point::random(&mut rng),
        }
    }
}

// ------------------------------------
// | Curve Arithmetic Implementations |
// ------------------------------------

// === Addition === //
#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl<'a, C: Curve> Add<&'a CurveKey<C>> for CurveKey<C> {
    type Output = CurveKey<C>;

    #[inline]
    fn add(self, other: &'a CurveKey<C>) -> Self::Output {
        assert_eq!(self.alpha, other.alpha);
        CurveKey {
            beta: self.beta + other.beta,
            ..self
        }
    }
}

#[macros::op_variants(owned)]
impl<'a, C: Curve> AddAssign<&'a CurveKey<C>> for CurveKey<C> {
    #[inline]
    fn add_assign(&mut self, rhs: &'a CurveKey<C>) {
        assert_eq!(self.alpha, rhs.alpha);
        self.beta += rhs.beta;
    }
}

// === Subtraction === //

#[macros::op_variants(owned, borrowed, flipped)]
impl<'a, C: Curve> Sub<&'a CurveKey<C>> for CurveKey<C> {
    type Output = CurveKey<C>;

    #[inline]
    fn sub(self, other: &'a CurveKey<C>) -> Self::Output {
        assert_eq!(self.alpha, other.alpha);
        CurveKey {
            beta: self.beta - other.beta,
            ..self
        }
    }
}

// === SubAssign === //

#[macros::op_variants(owned)]
impl<'a, C: Curve> SubAssign<&'a CurveKey<C>> for CurveKey<C> {
    #[inline]
    fn sub_assign(&mut self, rhs: &'a CurveKey<C>) {
        assert_eq!(self.alpha, rhs.alpha);
        self.beta -= rhs.beta;
    }
}

// === Multiplication === //

#[macros::op_variants(owned, borrowed, flipped)]
impl<'a, C: Curve> Mul<&'a ScalarAsExtension<C>> for CurveKey<C> {
    type Output = CurveKey<C>;

    #[inline]
    fn mul(self, other: &'a ScalarAsExtension<C>) -> Self::Output {
        CurveKey {
            beta: self.beta * other,
            ..self
        }
    }
}

#[macros::op_variants(owned, borrowed, flipped)]
impl<'a, C: Curve> Mul<&'a Scalar<C>> for CurveKey<C> {
    type Output = CurveKey<C>;

    #[inline]
    fn mul(self, other: &'a Scalar<C>) -> Self::Output {
        CurveKey {
            beta: self.beta * other,
            ..self
        }
    }
}

// === MulAssign === //

#[macros::op_variants(owned)]
impl<'a, C: Curve> MulAssign<&'a ScalarAsExtension<C>> for CurveKey<C> {
    #[inline]
    fn mul_assign(&mut self, rhs: &'a ScalarAsExtension<C>) {
        self.beta *= rhs;
    }
}

#[macros::op_variants(owned)]
impl<'a, C: Curve> MulAssign<&'a Scalar<C>> for CurveKey<C> {
    #[inline]
    fn mul_assign(&mut self, rhs: &'a Scalar<C>) {
        self.beta *= rhs;
    }
}

// === Negation === //

#[macros::op_variants(borrowed)]
impl<C: Curve> Neg for CurveKey<C> {
    type Output = CurveKey<C>;

    #[inline]
    fn neg(self) -> Self::Output {
        CurveKey {
            alpha: self.alpha,
            beta: -self.beta,
        }
    }
}

// === Equality === //

impl<C: Curve> ConstantTimeEq for CurveKey<C> {
    #[inline]
    fn ct_eq(&self, other: &Self) -> Choice {
        self.alpha.ct_eq(&other.alpha) & self.beta.ct_eq(&other.beta)
    }
}

// === Type conversions === //

impl<C: Curve> From<ScalarKey<C>> for CurveKey<C> {
    #[inline]
    fn from(scalar_key: ScalarKey<C>) -> Self {
        CurveKey {
            alpha: scalar_key.alpha,
            beta: scalar_key.beta * Point::<C>::generator(),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::algebra::elliptic_curve::Curve25519Ristretto as C;

    pub type FrExt = ScalarAsExtension<C>;
    pub type P = Point<C>;

    #[test]
    fn test_addition() {
        let mut rng = rand::thread_rng();
        let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
        let beta1 = P::random(&mut rng);
        let beta2 = P::random(&mut rng);

        let key1 = CurveKey {
            alpha: alpha.clone(),
            beta: beta1,
        };
        let key2 = CurveKey {
            alpha: alpha.clone(),
            beta: beta2,
        };
        let expected_result = CurveKey {
            alpha,
            beta: beta1 + beta2,
        };
        assert_eq!(key1 + key2, expected_result);
    }

    #[test]
    fn test_subtraction() {
        let mut rng = rand::thread_rng();
        let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
        let beta1 = P::random(&mut rng);
        let beta2 = P::random(&mut rng);

        let key1 = CurveKey {
            alpha: alpha.clone(),
            beta: beta1,
        };
        let key2 = CurveKey {
            alpha: alpha.clone(),
            beta: beta2,
        };
        let expected_result = CurveKey {
            alpha,
            beta: beta1 - beta2,
        };
        assert_eq!(key1 - key2, expected_result);
    }

    #[test]
    fn test_multiplication() {
        let mut rng = rand::thread_rng();
        let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
        let beta1 = P::random(&mut rng);

        let key = CurveKey {
            alpha: alpha.clone(),
            beta: beta1,
        };
        let scalar = FrExt::from(3u32);
        let expected_result = CurveKey {
            alpha,
            beta: beta1 * scalar,
        };
        assert_eq!(key * scalar, expected_result);
    }

    #[test]
    fn test_negation() {
        let mut rng = rand::thread_rng();
        let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
        let beta1 = P::random(&mut rng);

        let key = CurveKey {
            alpha: alpha.clone(),
            beta: beta1,
        };
        let expected_result = CurveKey {
            alpha,
            beta: -beta1,
        };
        assert_eq!(-key, expected_result);
    }
}