use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use serde::{Deserialize, Serialize};
use subtle::{Choice, ConstantTimeEq};
use super::{GlobalFieldKey, ScalarKey};
use crate::{
algebra::elliptic_curve::{Curve, Point, Scalar, ScalarAsExtension, ScalarField},
errors::PrimitiveError,
random::{CryptoRngCore, Random, RandomWith},
sharing::OpenPointShare,
types::ConditionallySelectable,
};
pub type GlobalCurveKey<C> = GlobalFieldKey<ScalarField<C>>;
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(bound = "C: Curve")]
pub struct CurveKey<C: Curve> {
pub(crate) alpha: GlobalCurveKey<C>,
pub(crate) beta: Point<C>,
}
impl<C: Curve> CurveKey<C> {
pub fn new(alpha: GlobalCurveKey<C>, beta: Point<C>) -> Self {
CurveKey { alpha, beta }
}
pub fn compute_mac(&self, value: &Point<C>) -> Point<C> {
self.beta + value * *self.alpha
}
#[inline]
pub fn verify_mac(&self, open_share: &OpenPointShare<C>) -> Result<(), PrimitiveError> {
let expected_mac = self.compute_mac(&open_share.value);
bool::from(expected_mac.ct_eq(&open_share.mac))
.then_some(())
.ok_or_else(|| {
PrimitiveError::WrongMAC(serde_json::to_string(&open_share.mac).unwrap())
})
}
pub fn get_alpha(&self) -> GlobalCurveKey<C> {
self.alpha.clone()
}
pub fn get_alpha_value(&self) -> ScalarAsExtension<C> {
*self.alpha
}
pub fn get_beta(&self) -> Point<C> {
self.beta
}
pub fn zero_batch(alphas: Vec<GlobalCurveKey<C>>) -> Vec<CurveKey<C>> {
alphas
.iter()
.map(|alpha| CurveKey::new(alpha.clone(), Point::<C>::identity()))
.collect()
}
}
impl<C: Curve> ConditionallySelectable for CurveKey<C> {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
CurveKey {
alpha: GlobalCurveKey::<C>::conditional_select(&a.alpha, &b.alpha, choice),
beta: Point::conditional_select(&a.beta, &b.beta, choice),
}
}
}
impl<C: Curve> Random for CurveKey<C> {
fn random(mut rng: impl CryptoRngCore) -> Self {
let alpha = GlobalCurveKey::<C>::random(&mut rng);
let beta = Point::<C>::random(&mut rng);
CurveKey { alpha, beta }
}
}
impl<C: Curve> RandomWith<GlobalCurveKey<C>> for CurveKey<C> {
fn random_with(mut rng: impl CryptoRngCore, alpha: GlobalCurveKey<C>) -> Self {
CurveKey {
alpha,
beta: Point::random(&mut rng),
}
}
}
#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl<'a, C: Curve> Add<&'a CurveKey<C>> for CurveKey<C> {
type Output = CurveKey<C>;
#[inline]
fn add(self, other: &'a CurveKey<C>) -> Self::Output {
assert_eq!(self.alpha, other.alpha);
CurveKey {
beta: self.beta + other.beta,
..self
}
}
}
#[macros::op_variants(owned)]
impl<'a, C: Curve> AddAssign<&'a CurveKey<C>> for CurveKey<C> {
#[inline]
fn add_assign(&mut self, rhs: &'a CurveKey<C>) {
assert_eq!(self.alpha, rhs.alpha);
self.beta += rhs.beta;
}
}
#[macros::op_variants(owned, borrowed, flipped)]
impl<'a, C: Curve> Sub<&'a CurveKey<C>> for CurveKey<C> {
type Output = CurveKey<C>;
#[inline]
fn sub(self, other: &'a CurveKey<C>) -> Self::Output {
assert_eq!(self.alpha, other.alpha);
CurveKey {
beta: self.beta - other.beta,
..self
}
}
}
#[macros::op_variants(owned)]
impl<'a, C: Curve> SubAssign<&'a CurveKey<C>> for CurveKey<C> {
#[inline]
fn sub_assign(&mut self, rhs: &'a CurveKey<C>) {
assert_eq!(self.alpha, rhs.alpha);
self.beta -= rhs.beta;
}
}
#[macros::op_variants(owned, borrowed, flipped)]
impl<'a, C: Curve> Mul<&'a ScalarAsExtension<C>> for CurveKey<C> {
type Output = CurveKey<C>;
#[inline]
fn mul(self, other: &'a ScalarAsExtension<C>) -> Self::Output {
CurveKey {
beta: self.beta * other,
..self
}
}
}
#[macros::op_variants(owned, borrowed, flipped)]
impl<'a, C: Curve> Mul<&'a Scalar<C>> for CurveKey<C> {
type Output = CurveKey<C>;
#[inline]
fn mul(self, other: &'a Scalar<C>) -> Self::Output {
CurveKey {
beta: self.beta * other,
..self
}
}
}
#[macros::op_variants(owned)]
impl<'a, C: Curve> MulAssign<&'a ScalarAsExtension<C>> for CurveKey<C> {
#[inline]
fn mul_assign(&mut self, rhs: &'a ScalarAsExtension<C>) {
self.beta *= rhs;
}
}
#[macros::op_variants(owned)]
impl<'a, C: Curve> MulAssign<&'a Scalar<C>> for CurveKey<C> {
#[inline]
fn mul_assign(&mut self, rhs: &'a Scalar<C>) {
self.beta *= rhs;
}
}
#[macros::op_variants(borrowed)]
impl<C: Curve> Neg for CurveKey<C> {
type Output = CurveKey<C>;
#[inline]
fn neg(self) -> Self::Output {
CurveKey {
alpha: self.alpha,
beta: -self.beta,
}
}
}
impl<C: Curve> ConstantTimeEq for CurveKey<C> {
#[inline]
fn ct_eq(&self, other: &Self) -> Choice {
self.alpha.ct_eq(&other.alpha) & self.beta.ct_eq(&other.beta)
}
}
impl<C: Curve> From<ScalarKey<C>> for CurveKey<C> {
#[inline]
fn from(scalar_key: ScalarKey<C>) -> Self {
CurveKey {
alpha: scalar_key.alpha,
beta: scalar_key.beta * Point::<C>::generator(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algebra::elliptic_curve::Curve25519Ristretto as C;
pub type FrExt = ScalarAsExtension<C>;
pub type P = Point<C>;
#[test]
fn test_addition() {
let mut rng = rand::thread_rng();
let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
let beta1 = P::random(&mut rng);
let beta2 = P::random(&mut rng);
let key1 = CurveKey {
alpha: alpha.clone(),
beta: beta1,
};
let key2 = CurveKey {
alpha: alpha.clone(),
beta: beta2,
};
let expected_result = CurveKey {
alpha,
beta: beta1 + beta2,
};
assert_eq!(key1 + key2, expected_result);
}
#[test]
fn test_subtraction() {
let mut rng = rand::thread_rng();
let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
let beta1 = P::random(&mut rng);
let beta2 = P::random(&mut rng);
let key1 = CurveKey {
alpha: alpha.clone(),
beta: beta1,
};
let key2 = CurveKey {
alpha: alpha.clone(),
beta: beta2,
};
let expected_result = CurveKey {
alpha,
beta: beta1 - beta2,
};
assert_eq!(key1 - key2, expected_result);
}
#[test]
fn test_multiplication() {
let mut rng = rand::thread_rng();
let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
let beta1 = P::random(&mut rng);
let key = CurveKey {
alpha: alpha.clone(),
beta: beta1,
};
let scalar = FrExt::from(3u32);
let expected_result = CurveKey {
alpha,
beta: beta1 * scalar,
};
assert_eq!(key * scalar, expected_result);
}
#[test]
fn test_negation() {
let mut rng = rand::thread_rng();
let alpha = GlobalCurveKey::<C>::new(FrExt::random(&mut rng));
let beta1 = P::random(&mut rng);
let key = CurveKey {
alpha: alpha.clone(),
beta: beta1,
};
let expected_result = CurveKey {
alpha,
beta: -beta1,
};
assert_eq!(-key, expected_result);
}
}