use std::{
fmt::{Display, Formatter, Result as FmtResult},
iter::{Product, Sum},
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};
use ark_ec::CurveGroup;
use ark_ff::{batch_inversion, Field, PrimeField};
use ark_std::UniformRand;
use itertools::Itertools;
use num_bigint::BigUint;
use rand::{CryptoRng, RngCore};
use serde::{Deserialize, Serialize};
use crate::fabric::{ResultHandle, ResultValue};
use super::macros::{impl_borrow_variants, impl_commutative};
#[inline]
pub const fn n_bytes_field<F: PrimeField>() -> usize {
let n_bits = F::MODULUS_BIT_SIZE as usize;
(n_bits + 7) / 8
}
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
pub struct Scalar<C: CurveGroup>(pub(crate) C::ScalarField);
impl<C: CurveGroup> Scalar<C> {
pub type Field = C::ScalarField;
pub fn new(inner: C::ScalarField) -> Self {
Scalar(inner)
}
pub fn zero() -> Self {
Scalar(C::ScalarField::from(0u8))
}
pub fn one() -> Self {
Scalar(C::ScalarField::from(1u8))
}
pub fn inner(&self) -> C::ScalarField {
self.0
}
pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Self {
Self(C::ScalarField::rand(rng))
}
pub fn inverse(&self) -> Self {
Scalar(self.0.inverse().unwrap())
}
pub fn batch_inverse(vals: &mut [Self]) {
let mut values = vals.iter().map(|x| x.0).collect_vec();
batch_inversion(&mut values);
for (i, val) in vals.iter_mut().enumerate() {
*val = Scalar(values[i]);
}
}
pub fn from_be_bytes_mod_order(bytes: &[u8]) -> Self {
let inner = C::ScalarField::from_be_bytes_mod_order(bytes);
Scalar(inner)
}
pub fn to_bytes_be(&self) -> Vec<u8> {
let val_biguint = self.to_biguint();
let mut bytes = val_biguint.to_bytes_be();
let n_bytes = n_bytes_field::<C::ScalarField>();
let mut padding = vec![0u8; n_bytes - bytes.len()];
padding.append(&mut bytes);
padding
}
pub fn to_biguint(&self) -> BigUint {
self.0.into()
}
pub fn from_biguint(val: &BigUint) -> Self {
let le_bytes = val.to_bytes_le();
let inner = C::ScalarField::from_le_bytes_mod_order(&le_bytes);
Scalar(inner)
}
}
impl<C: CurveGroup> Display for Scalar<C> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "{}", self.to_biguint())
}
}
impl<C: CurveGroup> Serialize for Scalar<C> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let bytes = self.to_bytes_be();
bytes.serialize(serializer)
}
}
impl<'de, C: CurveGroup> Deserialize<'de> for Scalar<C> {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let bytes = <Vec<u8>>::deserialize(deserializer)?;
let scalar = Scalar::from_be_bytes_mod_order(&bytes);
Ok(scalar)
}
}
pub type ScalarResult<C> = ResultHandle<C, Scalar<C>>;
pub type BatchScalarResult<C> = ResultHandle<C, Vec<Scalar<C>>>;
impl<C: CurveGroup> ScalarResult<C> {
pub fn inverse(&self) -> ScalarResult<C> {
self.fabric.new_gate_op(vec![self.id], |mut args| {
let val: Scalar<C> = args.remove(0).into();
ResultValue::Scalar(Scalar(val.0.inverse().unwrap()))
})
}
}
impl<C: CurveGroup> Add<&Scalar<C>> for &Scalar<C> {
type Output = Scalar<C>;
fn add(self, rhs: &Scalar<C>) -> Self::Output {
let rhs = *rhs;
Scalar(self.0 + rhs.0)
}
}
impl_borrow_variants!(Scalar<C>, Add, add, +, Scalar<C>, C: CurveGroup);
impl<C: CurveGroup> Add<&Scalar<C>> for &ScalarResult<C> {
type Output = ScalarResult<C>;
fn add(self, rhs: &Scalar<C>) -> Self::Output {
let rhs = *rhs;
self.fabric.new_gate_op(vec![self.id], move |args| {
let lhs: Scalar<C> = args[0].to_owned().into();
ResultValue::Scalar(Scalar(lhs.0 + rhs.0))
})
}
}
impl_borrow_variants!(ScalarResult<C>, Add, add, +, Scalar<C>, C: CurveGroup);
impl_commutative!(ScalarResult<C>, Add, add, +, Scalar<C>, C: CurveGroup);
impl<C: CurveGroup> Add<&ScalarResult<C>> for &ScalarResult<C> {
type Output = ScalarResult<C>;
fn add(self, rhs: &ScalarResult<C>) -> Self::Output {
self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
let lhs: Scalar<C> = args[0].to_owned().into();
let rhs: Scalar<C> = args[1].to_owned().into();
ResultValue::Scalar(Scalar(lhs.0 + rhs.0))
})
}
}
impl_borrow_variants!(ScalarResult<C>, Add, add, +, ScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> ScalarResult<C> {
pub fn batch_add(a: &[ScalarResult<C>], b: &[ScalarResult<C>]) -> Vec<ScalarResult<C>> {
assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
let n = a.len();
let fabric = &a[0].fabric;
let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
fabric.new_batch_gate_op(ids, n , move |args| {
let mut res = Vec::with_capacity(n);
for i in 0..n {
let lhs: Scalar<C> = args[i].to_owned().into();
let rhs: Scalar<C> = args[i + n].to_owned().into();
res.push(ResultValue::Scalar(Scalar(lhs.0 + rhs.0)));
}
res
})
}
}
impl<C: CurveGroup> AddAssign for Scalar<C> {
fn add_assign(&mut self, rhs: Scalar<C>) {
*self = *self + rhs;
}
}
impl<C: CurveGroup> Sub<&Scalar<C>> for &Scalar<C> {
type Output = Scalar<C>;
fn sub(self, rhs: &Scalar<C>) -> Self::Output {
let rhs = *rhs;
Scalar(self.0 - rhs.0)
}
}
impl_borrow_variants!(Scalar<C>, Sub, sub, -, Scalar<C>, C: CurveGroup);
impl<C: CurveGroup> Sub<&Scalar<C>> for &ScalarResult<C> {
type Output = ScalarResult<C>;
fn sub(self, rhs: &Scalar<C>) -> Self::Output {
let rhs = *rhs;
self.fabric.new_gate_op(vec![self.id], move |args| {
let lhs: Scalar<C> = args[0].to_owned().into();
ResultValue::Scalar(Scalar(lhs.0 - rhs.0))
})
}
}
impl_borrow_variants!(ScalarResult<C>, Sub, sub, -, Scalar<C>, C: CurveGroup);
impl<C: CurveGroup> Sub<&ScalarResult<C>> for &Scalar<C> {
type Output = ScalarResult<C>;
fn sub(self, rhs: &ScalarResult<C>) -> Self::Output {
let lhs = *self;
rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
let rhs: Scalar<C> = args[0].to_owned().into();
ResultValue::Scalar(lhs - rhs)
})
}
}
impl_borrow_variants!(Scalar<C>, Sub, sub, -, ScalarResult<C>, Output=ScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> Sub<&ScalarResult<C>> for &ScalarResult<C> {
type Output = ScalarResult<C>;
fn sub(self, rhs: &ScalarResult<C>) -> Self::Output {
self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
let lhs: Scalar<C> = args[0].to_owned().into();
let rhs: Scalar<C> = args[1].to_owned().into();
ResultValue::Scalar(Scalar(lhs.0 - rhs.0))
})
}
}
impl_borrow_variants!(ScalarResult<C>, Sub, sub, -, ScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> ScalarResult<C> {
pub fn batch_sub(a: &[ScalarResult<C>], b: &[ScalarResult<C>]) -> Vec<ScalarResult<C>> {
assert_eq!(a.len(), b.len(), "Batch sub requires equal length inputs");
let n = a.len();
let fabric = &a[0].fabric;
let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
fabric.new_batch_gate_op(ids, n , move |args| {
let mut res = Vec::with_capacity(n);
for i in 0..n {
let lhs: Scalar<C> = args[i].to_owned().into();
let rhs: Scalar<C> = args[i + n].to_owned().into();
res.push(ResultValue::Scalar(Scalar(lhs.0 - rhs.0)));
}
res
})
}
}
impl<C: CurveGroup> SubAssign for Scalar<C> {
fn sub_assign(&mut self, rhs: Scalar<C>) {
*self = *self - rhs;
}
}
impl<C: CurveGroup> Mul<&Scalar<C>> for &Scalar<C> {
type Output = Scalar<C>;
fn mul(self, rhs: &Scalar<C>) -> Self::Output {
let rhs = *rhs;
Scalar(self.0 * rhs.0)
}
}
impl_borrow_variants!(Scalar<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
impl<C: CurveGroup> Mul<&Scalar<C>> for &ScalarResult<C> {
type Output = ScalarResult<C>;
fn mul(self, rhs: &Scalar<C>) -> Self::Output {
let rhs = *rhs;
self.fabric.new_gate_op(vec![self.id], move |args| {
let lhs: Scalar<C> = args[0].to_owned().into();
ResultValue::Scalar(Scalar(lhs.0 * rhs.0))
})
}
}
impl_borrow_variants!(ScalarResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
impl_commutative!(ScalarResult<C>, Mul, mul, *, Scalar<C>, C: CurveGroup);
impl<C: CurveGroup> Mul<&ScalarResult<C>> for &ScalarResult<C> {
type Output = ScalarResult<C>;
fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
let lhs: Scalar<C> = args[0].to_owned().into();
let rhs: Scalar<C> = args[1].to_owned().into();
ResultValue::Scalar(Scalar(lhs.0 * rhs.0))
})
}
}
impl_borrow_variants!(ScalarResult<C>, Mul, mul, *, ScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> ScalarResult<C> {
pub fn batch_mul(a: &[ScalarResult<C>], b: &[ScalarResult<C>]) -> Vec<ScalarResult<C>> {
assert_eq!(a.len(), b.len(), "Batch mul requires equal length inputs");
let n = a.len();
let fabric = &a[0].fabric;
let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
fabric.new_batch_gate_op(ids, n , move |args| {
let mut res = Vec::with_capacity(n);
for i in 0..n {
let lhs: Scalar<C> = args[i].to_owned().into();
let rhs: Scalar<C> = args[i + n].to_owned().into();
res.push(ResultValue::Scalar(Scalar(lhs.0 * rhs.0)));
}
res
})
}
}
impl<C: CurveGroup> Neg for &Scalar<C> {
type Output = Scalar<C>;
fn neg(self) -> Self::Output {
Scalar(-self.0)
}
}
impl_borrow_variants!(Scalar<C>, Neg, neg, -, C: CurveGroup);
impl<C: CurveGroup> Neg for &ScalarResult<C> {
type Output = ScalarResult<C>;
fn neg(self) -> Self::Output {
self.fabric.new_gate_op(vec![self.id], |args| {
let lhs: Scalar<C> = args[0].to_owned().into();
ResultValue::Scalar(Scalar(-lhs.0))
})
}
}
impl_borrow_variants!(ScalarResult<C>, Neg, neg, -, C: CurveGroup);
impl<C: CurveGroup> ScalarResult<C> {
pub fn batch_neg(a: &[ScalarResult<C>]) -> Vec<ScalarResult<C>> {
let n = a.len();
let fabric = &a[0].fabric;
let ids = a.iter().map(|v| v.id).collect_vec();
fabric.new_batch_gate_op(ids, n , move |args| {
args.into_iter()
.map(Scalar::from)
.map(|x| -x)
.map(ResultValue::Scalar)
.collect_vec()
})
}
}
impl<C: CurveGroup> MulAssign for Scalar<C> {
fn mul_assign(&mut self, rhs: Scalar<C>) {
*self = *self * rhs;
}
}
impl<C: CurveGroup> From<bool> for Scalar<C> {
fn from(value: bool) -> Self {
Scalar(C::ScalarField::from(value))
}
}
impl<C: CurveGroup> From<u8> for Scalar<C> {
fn from(value: u8) -> Self {
Scalar(C::ScalarField::from(value))
}
}
impl<C: CurveGroup> From<u16> for Scalar<C> {
fn from(value: u16) -> Self {
Scalar(C::ScalarField::from(value))
}
}
impl<C: CurveGroup> From<u32> for Scalar<C> {
fn from(value: u32) -> Self {
Scalar(C::ScalarField::from(value))
}
}
impl<C: CurveGroup> From<u64> for Scalar<C> {
fn from(value: u64) -> Self {
Scalar(C::ScalarField::from(value))
}
}
impl<C: CurveGroup> From<u128> for Scalar<C> {
fn from(value: u128) -> Self {
Scalar(C::ScalarField::from(value))
}
}
impl<C: CurveGroup> From<usize> for Scalar<C> {
fn from(value: usize) -> Self {
Scalar(C::ScalarField::from(value as u64))
}
}
impl<C: CurveGroup> Sum for Scalar<C> {
fn sum<I: Iterator<Item = Scalar<C>>>(iter: I) -> Self {
iter.fold(Scalar::zero(), |acc, x| acc + x)
}
}
impl<C: CurveGroup> Product for Scalar<C> {
fn product<I: Iterator<Item = Scalar<C>>>(iter: I) -> Self {
iter.fold(Scalar::one(), |acc, x| acc * x)
}
}
#[cfg(test)]
mod test {
use crate::{algebra::scalar::Scalar, test_helpers::mock_fabric};
use rand::thread_rng;
#[tokio::test]
async fn test_scalar_add() {
let mut rng = thread_rng();
let a = Scalar::random(&mut rng);
let b = Scalar::random(&mut rng);
let expected_res = a + b;
let fabric = mock_fabric();
let a_alloc = fabric.allocate_scalar(a);
let b_alloc = fabric.allocate_scalar(b);
let res = &a_alloc + &b_alloc;
let res_final = res.await;
assert_eq!(res_final, expected_res);
fabric.shutdown();
}
#[tokio::test]
async fn test_scalar_sub() {
let mut rng = thread_rng();
let a = Scalar::random(&mut rng);
let b = Scalar::random(&mut rng);
let expected_res = a - b;
let fabric = mock_fabric();
let a_alloc = fabric.allocate_scalar(a);
let b_alloc = fabric.allocate_scalar(b);
let res = a_alloc - b_alloc;
let res_final = res.await;
assert_eq!(res_final, expected_res);
fabric.shutdown();
}
#[tokio::test]
async fn test_scalar_neg() {
let mut rng = thread_rng();
let a = Scalar::random(&mut rng);
let expected_res = -a;
let fabric = mock_fabric();
let a_alloc = fabric.allocate_scalar(a);
let res = -a_alloc;
let res_final = res.await;
assert_eq!(res_final, expected_res);
fabric.shutdown();
}
#[tokio::test]
async fn test_scalar_mul() {
let mut rng = thread_rng();
let a = Scalar::random(&mut rng);
let b = Scalar::random(&mut rng);
let expected_res = a * b;
let fabric = mock_fabric();
let a_alloc = fabric.allocate_scalar(a);
let b_alloc = fabric.allocate_scalar(b);
let res = a_alloc * b_alloc;
let res_final = res.await;
assert_eq!(res_final, expected_res);
fabric.shutdown();
}
}