use crate::{Error, One, Zero, field::Field};
use crypto_bigint::NonZero;
pub use crypto_bigint::Uint;
use num::traits::{ToBytes, WrappingAdd, WrappingMul, WrappingNeg, WrappingSub};
use serde::{Deserialize, Deserializer, Serialize};
use std::{
marker::PhantomData,
ops::{Add, Mul, Neg, Sub},
};
use subtle::{Choice, ConditionallySelectable};
use sunscreen_math_macros::refify_binary_op;
mod barrett;
pub use barrett::*;
pub trait Ring:
std::fmt::Debug
+ Clone
+ Mul<Self, Output = Self>
+ for<'a> Mul<&'a Self, Output = Self>
+ Add<Self, Output = Self>
+ for<'a> Add<&'a Self, Output = Self>
+ Sub<Self, Output = Self>
+ for<'a> Sub<&'a Self, Output = Self>
+ Zero
+ One
+ Eq
+ PartialEq
+ Neg<Output = Self>
+ Sync
+ Send
{
}
pub trait RingModulus<const N: usize> {
fn field_modulus() -> Uint<N>;
fn field_modulus_div_2() -> Uint<N>;
}
pub trait WrappingSemantics:
Copy
+ Clone
+ std::fmt::Debug
+ WrappingAdd
+ WrappingMul
+ WrappingSub
+ WrappingNeg
+ Zero
+ One
+ Eq
+ Sync
+ Send
+ ToBytes
{
}
impl WrappingSemantics for u8 {}
impl WrappingSemantics for u16 {}
impl WrappingSemantics for u32 {}
impl WrappingSemantics for u64 {}
impl WrappingSemantics for u128 {}
impl Zero for u8 {
#[inline(always)]
fn zero() -> Self {
0
}
#[inline(always)]
fn vartime_is_zero(&self) -> bool {
self == &0
}
}
impl Zero for u16 {
#[inline(always)]
fn zero() -> Self {
0
}
#[inline(always)]
fn vartime_is_zero(&self) -> bool {
self == &0
}
}
impl Zero for u32 {
#[inline(always)]
fn zero() -> Self {
0
}
#[inline(always)]
fn vartime_is_zero(&self) -> bool {
self == &0
}
}
impl Zero for u64 {
#[inline(always)]
fn zero() -> Self {
0
}
#[inline(always)]
fn vartime_is_zero(&self) -> bool {
self == &0
}
}
impl Zero for u128 {
#[inline(always)]
fn zero() -> Self {
0
}
#[inline(always)]
fn vartime_is_zero(&self) -> bool {
self == &0
}
}
impl One for u8 {
#[inline(always)]
fn one() -> Self {
1
}
}
impl One for u16 {
#[inline(always)]
fn one() -> Self {
1
}
}
impl One for u32 {
#[inline(always)]
fn one() -> Self {
1
}
}
impl One for u64 {
#[inline(always)]
fn one() -> Self {
1
}
}
impl One for u128 {
#[inline(always)]
fn one() -> Self {
1
}
}
#[repr(transparent)]
#[derive(Clone, Copy, Debug)]
pub struct ZInt<T>(pub T)
where
T: WrappingSemantics;
impl<T> ZInt<T>
where
T: WrappingSemantics,
{
#[inline(always)]
pub fn new(val: T) -> Self {
Self(val)
}
}
impl<T> From<T> for ZInt<T>
where
T: WrappingSemantics,
{
#[inline(always)]
fn from(value: T) -> Self {
Self(value)
}
}
impl<T> ToBytes for ZInt<T>
where
T: WrappingSemantics,
{
type Bytes = T::Bytes;
fn to_ne_bytes(&self) -> Self::Bytes {
self.0.to_ne_bytes()
}
fn to_be_bytes(&self) -> Self::Bytes {
self.0.to_be_bytes()
}
fn to_le_bytes(&self) -> Self::Bytes {
self.0.to_le_bytes()
}
}
#[refify_binary_op]
impl<T> Sub<&ZInt<T>> for &ZInt<T>
where
T: WrappingSemantics,
{
type Output = ZInt<T>;
fn sub(self, rhs: &ZInt<T>) -> Self::Output {
self.0.wrapping_sub(&rhs.0).into()
}
}
#[refify_binary_op]
impl<T> Add<&ZInt<T>> for &ZInt<T>
where
T: WrappingSemantics,
{
type Output = ZInt<T>;
fn add(self, rhs: &ZInt<T>) -> Self::Output {
self.0.wrapping_add(&rhs.0).into()
}
}
#[refify_binary_op]
impl<T> Mul<&ZInt<T>> for &ZInt<T>
where
T: WrappingSemantics,
{
type Output = ZInt<T>;
fn mul(self, rhs: &ZInt<T>) -> Self::Output {
self.0.wrapping_mul(&rhs.0).into()
}
}
impl<T> Zero for ZInt<T>
where
T: WrappingSemantics,
{
#[inline(always)]
fn zero() -> Self {
Self(T::zero())
}
#[inline(always)]
fn vartime_is_zero(&self) -> bool {
self.0.vartime_is_zero()
}
}
impl<T> One for ZInt<T>
where
T: WrappingSemantics,
{
#[inline(always)]
fn one() -> Self {
Self(T::one())
}
}
impl<T> Neg for ZInt<T>
where
T: WrappingSemantics,
{
type Output = ZInt<T>;
fn neg(self) -> Self::Output {
Self::zero() - self
}
}
impl<T> Neg for &ZInt<T>
where
T: WrappingSemantics,
{
type Output = ZInt<T>;
fn neg(self) -> Self::Output {
ZInt::zero() - self
}
}
impl<T> PartialEq for ZInt<T>
where
T: WrappingSemantics,
{
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<T> Eq for ZInt<T> where T: WrappingSemantics {}
impl<T> Ring for ZInt<T> where T: WrappingSemantics {}
impl<T> PartialOrd for ZInt<T>
where
T: WrappingSemantics + PartialOrd,
{
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.0.partial_cmp(&other.0)
}
}
impl<T> Ord for ZInt<T>
where
T: WrappingSemantics + Ord,
{
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.cmp(&other.0)
}
}
pub trait ArithmeticBackend<const N: usize>: Sync + Send {
const MODULUS: Uint<N>;
const MODULUS_DIV_2: Uint<N>;
const ZERO: Uint<N>;
const ONE: Uint<N>;
fn add_mod(lhs: &Uint<N>, rhs: &Uint<N>) -> Uint<N> {
lhs.add_mod(rhs, &Self::MODULUS)
}
fn sub_mod(lhs: &Uint<N>, rhs: &Uint<N>) -> Uint<N> {
lhs.sub_mod(rhs, &Self::MODULUS)
}
fn mul_mod(lhs: &Uint<N>, rhs: &Uint<N>) -> Uint<N>;
fn encode(val: &Uint<N>) -> Uint<N>;
fn decode(val: &Uint<N>) -> Uint<N>;
}
pub trait FieldBackend {}
pub struct Zq<const N: usize, B: ArithmeticBackend<N>> {
pub val: Uint<N>,
_phantom: PhantomData<B>,
}
impl<const N: usize, B: ArithmeticBackend<N>> Zq<N, B> {
pub fn into_bigint(self) -> Uint<N> {
B::decode(&self.val)
}
}
impl<const N: usize, B: ArithmeticBackend<N>> PartialEq for Zq<N, B> {
fn eq(&self, other: &Self) -> bool {
self.val.eq(&other.val)
}
}
impl<const N: usize, B: ArithmeticBackend<N>> std::fmt::Debug for Zq<N, B> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Zq = {{ val = {:#?} }}", self.val)
}
}
impl<const N: usize, B: ArithmeticBackend<N>> Clone for Zq<N, B> {
fn clone(&self) -> Self {
*self
}
}
impl<const N: usize, B: ArithmeticBackend<N>> Copy for Zq<N, B> {}
impl<const N: usize, B: ArithmeticBackend<N>> Serialize for Zq<N, B> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
fn cast<const M: usize, const N: usize, B: ArithmeticBackend<N>>(x: &Zq<N, B>) -> Uint<M> {
Uint::<M>::new(x.into_bigint().to_limbs()[..].try_into().unwrap())
}
match N {
#[cfg(target_pointer_width = "64")]
1 => cast::<1, N, B>(self).serialize(serializer),
2 => cast::<2, N, B>(self).serialize(serializer),
#[cfg(target_pointer_width = "64")]
3 => cast::<3, N, B>(self).serialize(serializer),
4 => cast::<4, N, B>(self).serialize(serializer),
#[cfg(target_pointer_width = "64")]
5 => cast::<5, N, B>(self).serialize(serializer),
6 => cast::<6, N, B>(self).serialize(serializer),
#[cfg(target_pointer_width = "64")]
7 => cast::<7, N, B>(self).serialize(serializer),
8 => cast::<8, N, B>(self).serialize(serializer),
#[cfg(target_pointer_width = "64")]
9 => cast::<9, N, B>(self).serialize(serializer),
10 => cast::<10, N, B>(self).serialize(serializer),
#[cfg(target_pointer_width = "64")]
11 => cast::<11, N, B>(self).serialize(serializer),
12 => cast::<12, N, B>(self).serialize(serializer),
#[cfg(target_pointer_width = "64")]
13 => cast::<13, N, B>(self).serialize(serializer),
14 => cast::<14, N, B>(self).serialize(serializer),
#[cfg(target_pointer_width = "64")]
15 => cast::<15, N, B>(self).serialize(serializer),
16 => cast::<16, N, B>(self).serialize(serializer),
_ => unimplemented!(),
}
}
}
impl<'de, const N: usize, B: ArithmeticBackend<N>> Deserialize<'de> for Zq<N, B> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
fn cast<'de, const M: usize, const N: usize, B, D>(x: Uint<M>) -> Result<Zq<N, B>, D::Error>
where
B: ArithmeticBackend<N>,
D: Deserializer<'de>,
{
Zq::try_from(Uint::<N>::new(x.to_limbs()[..].try_into().unwrap()))
.map_err(serde::de::Error::custom)
}
match N {
#[cfg(target_pointer_width = "64")]
1 => cast::<'_, 1, N, B, D>(Uint::deserialize(deserializer)?),
2 => cast::<'_, 2, N, B, D>(Uint::deserialize(deserializer)?),
#[cfg(target_pointer_width = "64")]
3 => cast::<'_, 3, N, B, D>(Uint::deserialize(deserializer)?),
4 => cast::<'_, 4, N, B, D>(Uint::deserialize(deserializer)?),
#[cfg(target_pointer_width = "64")]
5 => cast::<'_, 5, N, B, D>(Uint::deserialize(deserializer)?),
6 => cast::<'_, 6, N, B, D>(Uint::deserialize(deserializer)?),
#[cfg(target_pointer_width = "64")]
7 => cast::<'_, 7, N, B, D>(Uint::deserialize(deserializer)?),
8 => cast::<'_, 8, N, B, D>(Uint::deserialize(deserializer)?),
#[cfg(target_pointer_width = "64")]
9 => cast::<'_, 9, N, B, D>(Uint::deserialize(deserializer)?),
10 => cast::<'_, 10, N, B, D>(Uint::deserialize(deserializer)?),
#[cfg(target_pointer_width = "64")]
11 => cast::<'_, 11, N, B, D>(Uint::deserialize(deserializer)?),
12 => cast::<'_, 12, N, B, D>(Uint::deserialize(deserializer)?),
#[cfg(target_pointer_width = "64")]
13 => cast::<'_, 13, N, B, D>(Uint::deserialize(deserializer)?),
14 => cast::<'_, 14, N, B, D>(Uint::deserialize(deserializer)?),
#[cfg(target_pointer_width = "64")]
15 => cast::<'_, 15, N, B, D>(Uint::deserialize(deserializer)?),
16 => cast::<'_, 16, N, B, D>(Uint::deserialize(deserializer)?),
_ => unimplemented!(),
}
}
}
#[refify_binary_op]
impl<const N: usize, B> Add<&Zq<N, B>> for &Zq<N, B>
where
B: ArithmeticBackend<N>,
{
type Output = Zq<N, B>;
fn add(self, rhs: &Zq<N, B>) -> Self::Output {
Self::Output {
val: B::add_mod(&self.val, &rhs.val),
_phantom: PhantomData,
}
}
}
#[refify_binary_op]
impl<const N: usize, B> Sub<&Zq<N, B>> for &Zq<N, B>
where
B: ArithmeticBackend<N>,
{
type Output = Zq<N, B>;
fn sub(self, rhs: &Zq<N, B>) -> Self::Output {
Self::Output {
val: B::sub_mod(&self.val, &rhs.val),
_phantom: PhantomData,
}
}
}
#[refify_binary_op]
impl<const N: usize, B: ArithmeticBackend<N>> Mul<&Zq<N, B>> for &Zq<N, B> {
type Output = Zq<N, B>;
fn mul(self, rhs: &Zq<N, B>) -> Self::Output {
Self::Output {
val: B::mul_mod(&self.val, &rhs.val),
_phantom: PhantomData,
}
}
}
impl<const N: usize, B: ArithmeticBackend<N>> Zero for Zq<N, B> {
#[inline(always)]
fn zero() -> Self {
Self {
val: B::ZERO,
_phantom: PhantomData,
}
}
#[inline(always)]
fn vartime_is_zero(&self) -> bool {
self.val == Uint::ZERO
}
}
impl<const N: usize, B: ArithmeticBackend<N>> One for Zq<N, B> {
#[inline(always)]
fn one() -> Self {
Self {
val: B::ONE,
_phantom: PhantomData,
}
}
}
impl<const N: usize, B: ArithmeticBackend<N>> Neg for Zq<N, B> {
type Output = Zq<N, B>;
fn neg(self) -> Self::Output {
Zq::zero() - self
}
}
impl<const N: usize, B: ArithmeticBackend<N>> Neg for &Zq<N, B> {
type Output = Zq<N, B>;
fn neg(self) -> Self::Output {
Zq::zero() - self
}
}
impl<const N: usize, B: ArithmeticBackend<N>> Eq for Zq<N, B> {}
impl<const N: usize, B: ArithmeticBackend<N>> Ring for Zq<N, B> {}
impl<const N: usize, B: ArithmeticBackend<N>> Zq<N, B> {}
impl<const N: usize, B: ArithmeticBackend<N>> TryFrom<Uint<N>> for Zq<N, B> {
type Error = crate::Error;
fn try_from(value: Uint<N>) -> Result<Self, Self::Error> {
if value.ge(&B::MODULUS) {
return Err(Error::OutOfRange);
}
Ok(Zq {
val: B::encode(&value),
_phantom: PhantomData,
})
}
}
impl<const N: usize, B: ArithmeticBackend<N>> From<u32> for Zq<N, B> {
fn from(value: u32) -> Self {
(value as u64).into()
}
}
impl<const N: usize, B: ArithmeticBackend<N>> From<u64> for Zq<N, B> {
fn from(value: u64) -> Self {
let modulus = NonZero::new(B::MODULUS).unwrap();
let value = Uint::from_u64(value).rem(&modulus);
Self {
val: B::encode(&value),
_phantom: PhantomData,
}
}
}
impl<const N: usize, B: ArithmeticBackend<N>> From<i32> for Zq<N, B> {
fn from(value: i32) -> Self {
Self::from(value as i64)
}
}
impl<const N: usize, B: ArithmeticBackend<N>> From<i64> for Zq<N, B> {
fn from(value: i64) -> Self {
let modulus = NonZero::new(B::MODULUS).unwrap();
let abs = Uint::from_u64(value.unsigned_abs());
let neg = if value.is_negative() {
Choice::from(1)
} else {
Choice::from(0)
};
let neg_val = modulus.wrapping_sub(&abs);
let value = Uint::conditional_select(&abs, &neg_val, neg);
let value = value.rem(&modulus);
Self {
val: B::encode(&value),
_phantom: PhantomData,
}
}
}
impl<const N: usize, B: ArithmeticBackend<N>> PartialOrd for Zq<N, B> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<const N: usize, B: ArithmeticBackend<N>> Ord for Zq<N, B> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.val.cmp(&other.val)
}
}
impl<B: ArithmeticBackend<M>, const M: usize, const N: usize> RingModulus<N> for Zq<M, B> {
fn field_modulus() -> Uint<N> {
extend_bigint(&B::MODULUS)
}
fn field_modulus_div_2() -> Uint<N> {
extend_bigint(&B::MODULUS_DIV_2)
}
}
impl<const N: usize, B: FieldBackend + ArithmeticBackend<N>> Field for Zq<N, B> {
fn inverse(&self) -> Self {
let val = self.into_bigint();
let inv = val.inv_odd_mod(&B::MODULUS).0;
Self {
val: B::encode(&inv),
_phantom: PhantomData,
}
}
}
pub fn extend_bigint<const N: usize, const M: usize>(x: &Uint<M>) -> Uint<N> {
assert!(M <= N);
let mut val = Uint::<N>::ZERO;
for (i, limb) in x.as_words().iter().enumerate() {
val.as_words_mut()[i] = *limb;
}
val
}