use std::{
fmt::{Display, Formatter, Result as FmtResult},
iter::{Product, Sum},
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};
use ark_ff::{batch_inversion, Field, Fp256, MontBackend, MontConfig, PrimeField};
use itertools::Itertools;
use num_bigint::BigUint;
use rand::{CryptoRng, Rng, RngCore};
use serde::{Deserialize, Serialize};
use crate::fabric::{ResultHandle, ResultValue};
use super::macros::{impl_borrow_variants, impl_commutative};
pub const BASE_FIELD_BYTES: usize = 32;
pub const SCALAR_BYTES: usize = 32;
#[derive(MontConfig)]
#[modulus = "3618502788666131213697322783095070105623107215331596699973092056135872020481"]
#[generator = "3"]
pub struct StarknetFqConfig;
pub type StarknetBaseFelt = Fp256<MontBackend<StarknetFqConfig, 4>>;
#[derive(MontConfig)]
#[modulus = "3618502788666131213697322783095070105526743751716087489154079457884512865583"]
#[generator = "3"]
pub struct StarknetFrConfig;
pub(crate) type ScalarInner = Fp256<MontBackend<StarknetFrConfig, 4>>;
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
pub struct Scalar(pub(crate) ScalarInner);
impl Scalar {
pub type Field = ScalarInner;
pub fn zero() -> Scalar {
Scalar(ScalarInner::from(0))
}
pub fn one() -> Scalar {
Scalar(ScalarInner::from(1))
}
pub fn inner(&self) -> ScalarInner {
self.0
}
pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Scalar {
let inner: ScalarInner = rng.sample(rand::distributions::Standard);
Scalar(inner)
}
pub fn inverse(&self) -> Scalar {
Scalar(self.0.inverse().unwrap())
}
pub fn batch_inverse(vals: &mut [Scalar]) {
let mut values = vals.iter().map(|x| x.0).collect_vec();
batch_inversion(&mut values);
for (i, val) in vals.iter_mut().enumerate() {
*val = Scalar(values[i]);
}
}
pub fn from_be_bytes_mod_order(bytes: &[u8]) -> Scalar {
let inner = ScalarInner::from_be_bytes_mod_order(bytes);
Scalar(inner)
}
pub fn to_bytes_be(&self) -> Vec<u8> {
let val_biguint = self.to_biguint();
let mut bytes = val_biguint.to_bytes_be();
let mut padding = vec![0u8; SCALAR_BYTES - bytes.len()];
padding.append(&mut bytes);
padding
}
pub fn to_biguint(&self) -> BigUint {
self.0.into()
}
pub fn from_biguint(val: &BigUint) -> Scalar {
let le_bytes = val.to_bytes_le();
let inner = ScalarInner::from_le_bytes_mod_order(&le_bytes);
Scalar(inner)
}
}
impl Display for Scalar {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "{}", self.to_biguint())
}
}
impl Serialize for Scalar {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let bytes = self.to_bytes_be();
bytes.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Scalar {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let bytes = <Vec<u8>>::deserialize(deserializer)?;
let scalar = Scalar::from_be_bytes_mod_order(&bytes);
Ok(scalar)
}
}
pub type ScalarResult = ResultHandle<Scalar>;
pub type BatchScalarResult = ResultHandle<Vec<Scalar>>;
impl ScalarResult {
pub fn inverse(&self) -> ScalarResult {
self.fabric.new_gate_op(vec![self.id], |mut args| {
let val: Scalar = args.remove(0).into();
ResultValue::Scalar(Scalar(val.0.inverse().unwrap()))
})
}
}
impl Add<&Scalar> for &Scalar {
type Output = Scalar;
fn add(self, rhs: &Scalar) -> Self::Output {
let rhs = *rhs;
Scalar(self.0 + rhs.0)
}
}
impl_borrow_variants!(Scalar, Add, add, +, Scalar);
impl Add<&Scalar> for &ScalarResult {
type Output = ScalarResult;
fn add(self, rhs: &Scalar) -> Self::Output {
let rhs = *rhs;
self.fabric.new_gate_op(vec![self.id], move |args| {
let lhs: Scalar = args[0].to_owned().into();
ResultValue::Scalar(Scalar(lhs.0 + rhs.0))
})
}
}
impl_borrow_variants!(ScalarResult, Add, add, +, Scalar);
impl_commutative!(ScalarResult, Add, add, +, Scalar);
impl Add<&ScalarResult> for &ScalarResult {
type Output = ScalarResult;
fn add(self, rhs: &ScalarResult) -> Self::Output {
self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
let lhs: Scalar = args[0].to_owned().into();
let rhs: Scalar = args[1].to_owned().into();
ResultValue::Scalar(Scalar(lhs.0 + rhs.0))
})
}
}
impl_borrow_variants!(ScalarResult, Add, add, +, ScalarResult);
impl ScalarResult {
pub fn batch_add(a: &[ScalarResult], b: &[ScalarResult]) -> Vec<ScalarResult> {
assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
let n = a.len();
let fabric = &a[0].fabric;
let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
fabric.new_batch_gate_op(ids, n , move |args| {
let mut res = Vec::with_capacity(n);
for i in 0..n {
let lhs: Scalar = args[i].to_owned().into();
let rhs: Scalar = args[i + n].to_owned().into();
res.push(ResultValue::Scalar(Scalar(lhs.0 + rhs.0)));
}
res
})
}
}
impl AddAssign for Scalar {
fn add_assign(&mut self, rhs: Scalar) {
*self = *self + rhs;
}
}
impl Sub<&Scalar> for &Scalar {
type Output = Scalar;
fn sub(self, rhs: &Scalar) -> Self::Output {
let rhs = *rhs;
Scalar(self.0 - rhs.0)
}
}
impl_borrow_variants!(Scalar, Sub, sub, -, Scalar);
impl Sub<&Scalar> for &ScalarResult {
type Output = ScalarResult;
fn sub(self, rhs: &Scalar) -> Self::Output {
let rhs = *rhs;
self.fabric.new_gate_op(vec![self.id], move |args| {
let lhs: Scalar = args[0].to_owned().into();
ResultValue::Scalar(Scalar(lhs.0 - rhs.0))
})
}
}
impl_borrow_variants!(ScalarResult, Sub, sub, -, Scalar);
impl Sub<&ScalarResult> for &Scalar {
type Output = ScalarResult;
fn sub(self, rhs: &ScalarResult) -> Self::Output {
let lhs = *self;
rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
let rhs: Scalar = args[0].to_owned().into();
ResultValue::Scalar(lhs - rhs)
})
}
}
impl_borrow_variants!(Scalar, Sub, sub, -, ScalarResult, Output=ScalarResult);
impl Sub<&ScalarResult> for &ScalarResult {
type Output = ScalarResult;
fn sub(self, rhs: &ScalarResult) -> Self::Output {
self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
let lhs: Scalar = args[0].to_owned().into();
let rhs: Scalar = args[1].to_owned().into();
ResultValue::Scalar(Scalar(lhs.0 - rhs.0))
})
}
}
impl_borrow_variants!(ScalarResult, Sub, sub, -, ScalarResult);
impl ScalarResult {
pub fn batch_sub(a: &[ScalarResult], b: &[ScalarResult]) -> Vec<ScalarResult> {
assert_eq!(a.len(), b.len(), "Batch sub requires equal length inputs");
let n = a.len();
let fabric = &a[0].fabric;
let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
fabric.new_batch_gate_op(ids, n , move |args| {
let mut res = Vec::with_capacity(n);
for i in 0..n {
let lhs: Scalar = args[i].to_owned().into();
let rhs: Scalar = args[i + n].to_owned().into();
res.push(ResultValue::Scalar(Scalar(lhs.0 - rhs.0)));
}
res
})
}
}
impl SubAssign for Scalar {
fn sub_assign(&mut self, rhs: Scalar) {
*self = *self - rhs;
}
}
impl Mul<&Scalar> for &Scalar {
type Output = Scalar;
fn mul(self, rhs: &Scalar) -> Self::Output {
let rhs = *rhs;
Scalar(self.0 * rhs.0)
}
}
impl_borrow_variants!(Scalar, Mul, mul, *, Scalar);
impl Mul<&Scalar> for &ScalarResult {
type Output = ScalarResult;
fn mul(self, rhs: &Scalar) -> Self::Output {
let rhs = *rhs;
self.fabric.new_gate_op(vec![self.id], move |args| {
let lhs: Scalar = args[0].to_owned().into();
ResultValue::Scalar(Scalar(lhs.0 * rhs.0))
})
}
}
impl_borrow_variants!(ScalarResult, Mul, mul, *, Scalar);
impl_commutative!(ScalarResult, Mul, mul, *, Scalar);
impl Mul<&ScalarResult> for &ScalarResult {
type Output = ScalarResult;
fn mul(self, rhs: &ScalarResult) -> Self::Output {
self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
let lhs: Scalar = args[0].to_owned().into();
let rhs: Scalar = args[1].to_owned().into();
ResultValue::Scalar(Scalar(lhs.0 * rhs.0))
})
}
}
impl_borrow_variants!(ScalarResult, Mul, mul, *, ScalarResult);
impl ScalarResult {
pub fn batch_mul(a: &[ScalarResult], b: &[ScalarResult]) -> Vec<ScalarResult> {
assert_eq!(a.len(), b.len(), "Batch mul requires equal length inputs");
let n = a.len();
let fabric = &a[0].fabric;
let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
fabric.new_batch_gate_op(ids, n , move |args| {
let mut res = Vec::with_capacity(n);
for i in 0..n {
let lhs: Scalar = args[i].to_owned().into();
let rhs: Scalar = args[i + n].to_owned().into();
res.push(ResultValue::Scalar(Scalar(lhs.0 * rhs.0)));
}
res
})
}
}
impl Neg for &Scalar {
type Output = Scalar;
fn neg(self) -> Self::Output {
Scalar(-self.0)
}
}
impl_borrow_variants!(Scalar, Neg, neg, -);
impl Neg for &ScalarResult {
type Output = ScalarResult;
fn neg(self) -> Self::Output {
self.fabric.new_gate_op(vec![self.id], |args| {
let lhs: Scalar = args[0].to_owned().into();
ResultValue::Scalar(Scalar(-lhs.0))
})
}
}
impl_borrow_variants!(ScalarResult, Neg, neg, -);
impl ScalarResult {
pub fn batch_neg(a: &[ScalarResult]) -> Vec<ScalarResult> {
let n = a.len();
let fabric = &a[0].fabric;
let ids = a.iter().map(|v| v.id).collect_vec();
fabric.new_batch_gate_op(ids, n , move |args| {
args.into_iter()
.map(Scalar::from)
.map(|x| -x)
.map(ResultValue::Scalar)
.collect_vec()
})
}
}
impl MulAssign for Scalar {
fn mul_assign(&mut self, rhs: Scalar) {
*self = *self * rhs;
}
}
impl<T: Into<ScalarInner>> From<T> for Scalar {
fn from(val: T) -> Self {
Scalar(val.into())
}
}
impl Sum for Scalar {
fn sum<I: Iterator<Item = Scalar>>(iter: I) -> Self {
iter.fold(Scalar::zero(), |acc, x| acc + x)
}
}
impl Product for Scalar {
fn product<I: Iterator<Item = Scalar>>(iter: I) -> Self {
iter.fold(Scalar::one(), |acc, x| acc * x)
}
}
#[cfg(test)]
mod test {
use crate::{
algebra::scalar::{Scalar, SCALAR_BYTES},
test_helpers::mock_fabric,
};
use rand::thread_rng;
#[test]
fn test_scalar_serialize() {
let mut rng = thread_rng();
let scalar = Scalar::random(&mut rng);
let bytes = scalar.to_bytes_be();
assert_eq!(bytes.len(), SCALAR_BYTES);
let scalar_deserialized = Scalar::from_be_bytes_mod_order(&bytes);
assert_eq!(scalar, scalar_deserialized);
}
#[tokio::test]
async fn test_scalar_add() {
let mut rng = thread_rng();
let a = Scalar::random(&mut rng);
let b = Scalar::random(&mut rng);
let expected_res = a + b;
let fabric = mock_fabric();
let a_alloc = fabric.allocate_scalar(a);
let b_alloc = fabric.allocate_scalar(b);
let res = &a_alloc + &b_alloc;
let res_final = res.await;
assert_eq!(res_final, expected_res);
fabric.shutdown();
}
#[tokio::test]
async fn test_scalar_sub() {
let mut rng = thread_rng();
let a = Scalar::random(&mut rng);
let b = Scalar::random(&mut rng);
let expected_res = a - b;
let fabric = mock_fabric();
let a_alloc = fabric.allocate_scalar(a);
let b_alloc = fabric.allocate_scalar(b);
let res = a_alloc - b_alloc;
let res_final = res.await;
assert_eq!(res_final, expected_res);
fabric.shutdown();
}
#[tokio::test]
async fn test_scalar_neg() {
let mut rng = thread_rng();
let a = Scalar::random(&mut rng);
let expected_res = -a;
let fabric = mock_fabric();
let a_alloc = fabric.allocate_scalar(a);
let res = -a_alloc;
let res_final = res.await;
assert_eq!(res_final, expected_res);
fabric.shutdown();
}
#[tokio::test]
async fn test_scalar_mul() {
let mut rng = thread_rng();
let a = Scalar::random(&mut rng);
let b = Scalar::random(&mut rng);
let expected_res = a * b;
let fabric = mock_fabric();
let a_alloc = fabric.allocate_scalar(a);
let b_alloc = fabric.allocate_scalar(b);
let res = a_alloc * b_alloc;
let res_final = res.await;
assert_eq!(res_final, expected_res);
fabric.shutdown();
}
}