1use std::{
8 fmt::{Display, Formatter, Result as FmtResult},
9 iter::{Product, Sum},
10 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
11};
12
13use ark_ff::{batch_inversion, Field, Fp256, MontBackend, MontConfig, PrimeField};
14use itertools::Itertools;
15use num_bigint::BigUint;
16use rand::{CryptoRng, Rng, RngCore};
17use serde::{Deserialize, Serialize};
18
19use crate::fabric::{ResultHandle, ResultValue};
20
21use super::macros::{impl_borrow_variants, impl_commutative};
22
23pub const BASE_FIELD_BYTES: usize = 32;
25pub const SCALAR_BYTES: usize = 32;
27
28#[derive(MontConfig)]
30#[modulus = "3618502788666131213697322783095070105623107215331596699973092056135872020481"]
31#[generator = "3"]
32pub struct StarknetFqConfig;
33pub type StarknetBaseFelt = Fp256<MontBackend<StarknetFqConfig, 4>>;
35
36#[derive(MontConfig)]
38#[modulus = "3618502788666131213697322783095070105526743751716087489154079457884512865583"]
39#[generator = "3"]
40pub struct StarknetFrConfig;
41pub(crate) type ScalarInner = Fp256<MontBackend<StarknetFrConfig, 4>>;
47#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
48pub struct Scalar(pub(crate) ScalarInner);
50
51impl Scalar {
56 pub type Field = ScalarInner;
58
59 pub fn zero() -> Scalar {
61 Scalar(ScalarInner::from(0))
62 }
63
64 pub fn one() -> Scalar {
66 Scalar(ScalarInner::from(1))
67 }
68
69 pub fn inner(&self) -> ScalarInner {
71 self.0
72 }
73
74 pub fn random<R: RngCore + CryptoRng>(rng: &mut R) -> Scalar {
79 let inner: ScalarInner = rng.sample(rand::distributions::Standard);
80 Scalar(inner)
81 }
82
83 pub fn inverse(&self) -> Scalar {
85 Scalar(self.0.inverse().unwrap())
86 }
87
88 pub fn batch_inverse(vals: &mut [Scalar]) {
90 let mut values = vals.iter().map(|x| x.0).collect_vec();
91 batch_inversion(&mut values);
92
93 for (i, val) in vals.iter_mut().enumerate() {
94 *val = Scalar(values[i]);
95 }
96 }
97
98 pub fn from_be_bytes_mod_order(bytes: &[u8]) -> Scalar {
100 let inner = ScalarInner::from_be_bytes_mod_order(bytes);
101 Scalar(inner)
102 }
103
104 pub fn to_bytes_be(&self) -> Vec<u8> {
109 let val_biguint = self.to_biguint();
110 let mut bytes = val_biguint.to_bytes_be();
111
112 let mut padding = vec![0u8; SCALAR_BYTES - bytes.len()];
113 padding.append(&mut bytes);
114
115 padding
116 }
117
118 pub fn to_biguint(&self) -> BigUint {
120 self.0.into()
121 }
122
123 pub fn from_biguint(val: &BigUint) -> Scalar {
125 let le_bytes = val.to_bytes_le();
126 let inner = ScalarInner::from_le_bytes_mod_order(&le_bytes);
127 Scalar(inner)
128 }
129}
130
131impl Display for Scalar {
132 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
133 write!(f, "{}", self.to_biguint())
134 }
135}
136
137impl Serialize for Scalar {
138 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
139 let bytes = self.to_bytes_be();
140 bytes.serialize(serializer)
141 }
142}
143
144impl<'de> Deserialize<'de> for Scalar {
145 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
146 let bytes = <Vec<u8>>::deserialize(deserializer)?;
147 let scalar = Scalar::from_be_bytes_mod_order(&bytes);
148 Ok(scalar)
149 }
150}
151
152pub type ScalarResult = ResultHandle<Scalar>;
160pub type BatchScalarResult = ResultHandle<Vec<Scalar>>;
162impl ScalarResult {
163 pub fn inverse(&self) -> ScalarResult {
165 self.fabric.new_gate_op(vec![self.id], |mut args| {
166 let val: Scalar = args.remove(0).into();
167 ResultValue::Scalar(Scalar(val.0.inverse().unwrap()))
168 })
169 }
170}
171
172impl Add<&Scalar> for &Scalar {
173 type Output = Scalar;
174
175 fn add(self, rhs: &Scalar) -> Self::Output {
176 let rhs = *rhs;
177 Scalar(self.0 + rhs.0)
178 }
179}
180impl_borrow_variants!(Scalar, Add, add, +, Scalar);
181
182impl Add<&Scalar> for &ScalarResult {
183 type Output = ScalarResult;
184
185 fn add(self, rhs: &Scalar) -> Self::Output {
186 let rhs = *rhs;
187 self.fabric.new_gate_op(vec![self.id], move |args| {
188 let lhs: Scalar = args[0].to_owned().into();
189 ResultValue::Scalar(Scalar(lhs.0 + rhs.0))
190 })
191 }
192}
193impl_borrow_variants!(ScalarResult, Add, add, +, Scalar);
194impl_commutative!(ScalarResult, Add, add, +, Scalar);
195
196impl Add<&ScalarResult> for &ScalarResult {
197 type Output = ScalarResult;
198
199 fn add(self, rhs: &ScalarResult) -> Self::Output {
200 self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
201 let lhs: Scalar = args[0].to_owned().into();
202 let rhs: Scalar = args[1].to_owned().into();
203 ResultValue::Scalar(Scalar(lhs.0 + rhs.0))
204 })
205 }
206}
207impl_borrow_variants!(ScalarResult, Add, add, +, ScalarResult);
208
209impl ScalarResult {
210 pub fn batch_add(a: &[ScalarResult], b: &[ScalarResult]) -> Vec<ScalarResult> {
212 assert_eq!(a.len(), b.len(), "Batch add requires equal length inputs");
213
214 let n = a.len();
215 let fabric = &a[0].fabric;
216 let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
217 fabric.new_batch_gate_op(ids, n , move |args| {
218 let mut res = Vec::with_capacity(n);
219 for i in 0..n {
220 let lhs: Scalar = args[i].to_owned().into();
221 let rhs: Scalar = args[i + n].to_owned().into();
222 res.push(ResultValue::Scalar(Scalar(lhs.0 + rhs.0)));
223 }
224
225 res
226 })
227 }
228}
229
230impl AddAssign for Scalar {
233 fn add_assign(&mut self, rhs: Scalar) {
234 *self = *self + rhs;
235 }
236}
237
238impl Sub<&Scalar> for &Scalar {
241 type Output = Scalar;
242
243 fn sub(self, rhs: &Scalar) -> Self::Output {
244 let rhs = *rhs;
245 Scalar(self.0 - rhs.0)
246 }
247}
248impl_borrow_variants!(Scalar, Sub, sub, -, Scalar);
249
250impl Sub<&Scalar> for &ScalarResult {
251 type Output = ScalarResult;
252
253 fn sub(self, rhs: &Scalar) -> Self::Output {
254 let rhs = *rhs;
255 self.fabric.new_gate_op(vec![self.id], move |args| {
256 let lhs: Scalar = args[0].to_owned().into();
257 ResultValue::Scalar(Scalar(lhs.0 - rhs.0))
258 })
259 }
260}
261impl_borrow_variants!(ScalarResult, Sub, sub, -, Scalar);
262
263impl Sub<&ScalarResult> for &Scalar {
264 type Output = ScalarResult;
265
266 fn sub(self, rhs: &ScalarResult) -> Self::Output {
267 let lhs = *self;
268 rhs.fabric.new_gate_op(vec![rhs.id], move |args| {
269 let rhs: Scalar = args[0].to_owned().into();
270 ResultValue::Scalar(lhs - rhs)
271 })
272 }
273}
274impl_borrow_variants!(Scalar, Sub, sub, -, ScalarResult, Output=ScalarResult);
275
276impl Sub<&ScalarResult> for &ScalarResult {
277 type Output = ScalarResult;
278
279 fn sub(self, rhs: &ScalarResult) -> Self::Output {
280 self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
281 let lhs: Scalar = args[0].to_owned().into();
282 let rhs: Scalar = args[1].to_owned().into();
283 ResultValue::Scalar(Scalar(lhs.0 - rhs.0))
284 })
285 }
286}
287impl_borrow_variants!(ScalarResult, Sub, sub, -, ScalarResult);
288
289impl ScalarResult {
290 pub fn batch_sub(a: &[ScalarResult], b: &[ScalarResult]) -> Vec<ScalarResult> {
292 assert_eq!(a.len(), b.len(), "Batch sub requires equal length inputs");
293
294 let n = a.len();
295 let fabric = &a[0].fabric;
296 let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
297 fabric.new_batch_gate_op(ids, n , move |args| {
298 let mut res = Vec::with_capacity(n);
299 for i in 0..n {
300 let lhs: Scalar = args[i].to_owned().into();
301 let rhs: Scalar = args[i + n].to_owned().into();
302 res.push(ResultValue::Scalar(Scalar(lhs.0 - rhs.0)));
303 }
304
305 res
306 })
307 }
308}
309
310impl SubAssign for Scalar {
313 fn sub_assign(&mut self, rhs: Scalar) {
314 *self = *self - rhs;
315 }
316}
317
318impl Mul<&Scalar> for &Scalar {
321 type Output = Scalar;
322
323 fn mul(self, rhs: &Scalar) -> Self::Output {
324 let rhs = *rhs;
325 Scalar(self.0 * rhs.0)
326 }
327}
328impl_borrow_variants!(Scalar, Mul, mul, *, Scalar);
329
330impl Mul<&Scalar> for &ScalarResult {
331 type Output = ScalarResult;
332
333 fn mul(self, rhs: &Scalar) -> Self::Output {
334 let rhs = *rhs;
335 self.fabric.new_gate_op(vec![self.id], move |args| {
336 let lhs: Scalar = args[0].to_owned().into();
337 ResultValue::Scalar(Scalar(lhs.0 * rhs.0))
338 })
339 }
340}
341impl_borrow_variants!(ScalarResult, Mul, mul, *, Scalar);
342impl_commutative!(ScalarResult, Mul, mul, *, Scalar);
343
344impl Mul<&ScalarResult> for &ScalarResult {
345 type Output = ScalarResult;
346
347 fn mul(self, rhs: &ScalarResult) -> Self::Output {
348 self.fabric.new_gate_op(vec![self.id, rhs.id], |args| {
349 let lhs: Scalar = args[0].to_owned().into();
350 let rhs: Scalar = args[1].to_owned().into();
351 ResultValue::Scalar(Scalar(lhs.0 * rhs.0))
352 })
353 }
354}
355impl_borrow_variants!(ScalarResult, Mul, mul, *, ScalarResult);
356
357impl ScalarResult {
358 pub fn batch_mul(a: &[ScalarResult], b: &[ScalarResult]) -> Vec<ScalarResult> {
360 assert_eq!(a.len(), b.len(), "Batch mul requires equal length inputs");
361
362 let n = a.len();
363 let fabric = &a[0].fabric;
364 let ids = a.iter().chain(b.iter()).map(|v| v.id).collect_vec();
365 fabric.new_batch_gate_op(ids, n , move |args| {
366 let mut res = Vec::with_capacity(n);
367 for i in 0..n {
368 let lhs: Scalar = args[i].to_owned().into();
369 let rhs: Scalar = args[i + n].to_owned().into();
370 res.push(ResultValue::Scalar(Scalar(lhs.0 * rhs.0)));
371 }
372
373 res
374 })
375 }
376}
377
378impl Neg for &Scalar {
379 type Output = Scalar;
380
381 fn neg(self) -> Self::Output {
382 Scalar(-self.0)
383 }
384}
385impl_borrow_variants!(Scalar, Neg, neg, -);
386
387impl Neg for &ScalarResult {
388 type Output = ScalarResult;
389
390 fn neg(self) -> Self::Output {
391 self.fabric.new_gate_op(vec![self.id], |args| {
392 let lhs: Scalar = args[0].to_owned().into();
393 ResultValue::Scalar(Scalar(-lhs.0))
394 })
395 }
396}
397impl_borrow_variants!(ScalarResult, Neg, neg, -);
398
399impl ScalarResult {
400 pub fn batch_neg(a: &[ScalarResult]) -> Vec<ScalarResult> {
402 let n = a.len();
403 let fabric = &a[0].fabric;
404 let ids = a.iter().map(|v| v.id).collect_vec();
405 fabric.new_batch_gate_op(ids, n , move |args| {
406 args.into_iter()
407 .map(Scalar::from)
408 .map(|x| -x)
409 .map(ResultValue::Scalar)
410 .collect_vec()
411 })
412 }
413}
414
415impl MulAssign for Scalar {
418 fn mul_assign(&mut self, rhs: Scalar) {
419 *self = *self * rhs;
420 }
421}
422
423impl<T: Into<ScalarInner>> From<T> for Scalar {
428 fn from(val: T) -> Self {
429 Scalar(val.into())
430 }
431}
432
433impl Sum for Scalar {
438 fn sum<I: Iterator<Item = Scalar>>(iter: I) -> Self {
439 iter.fold(Scalar::zero(), |acc, x| acc + x)
440 }
441}
442
443impl Product for Scalar {
444 fn product<I: Iterator<Item = Scalar>>(iter: I) -> Self {
445 iter.fold(Scalar::one(), |acc, x| acc * x)
446 }
447}
448
449#[cfg(test)]
450mod test {
451 use crate::{
452 algebra::scalar::{Scalar, SCALAR_BYTES},
453 test_helpers::mock_fabric,
454 };
455 use rand::thread_rng;
456
457 #[test]
459 fn test_scalar_serialize() {
460 let mut rng = thread_rng();
462 let scalar = Scalar::random(&mut rng);
463 let bytes = scalar.to_bytes_be();
464
465 assert_eq!(bytes.len(), SCALAR_BYTES);
466
467 let scalar_deserialized = Scalar::from_be_bytes_mod_order(&bytes);
469 assert_eq!(scalar, scalar_deserialized);
470 }
471
472 #[tokio::test]
474 async fn test_scalar_add() {
475 let mut rng = thread_rng();
476 let a = Scalar::random(&mut rng);
477 let b = Scalar::random(&mut rng);
478
479 let expected_res = a + b;
480
481 let fabric = mock_fabric();
483 let a_alloc = fabric.allocate_scalar(a);
484 let b_alloc = fabric.allocate_scalar(b);
485
486 let res = &a_alloc + &b_alloc;
487 let res_final = res.await;
488
489 assert_eq!(res_final, expected_res);
490 fabric.shutdown();
491 }
492
493 #[tokio::test]
495 async fn test_scalar_sub() {
496 let mut rng = thread_rng();
497 let a = Scalar::random(&mut rng);
498 let b = Scalar::random(&mut rng);
499
500 let expected_res = a - b;
501
502 let fabric = mock_fabric();
504 let a_alloc = fabric.allocate_scalar(a);
505 let b_alloc = fabric.allocate_scalar(b);
506
507 let res = a_alloc - b_alloc;
508 let res_final = res.await;
509
510 assert_eq!(res_final, expected_res);
511 fabric.shutdown();
512 }
513
514 #[tokio::test]
516 async fn test_scalar_neg() {
517 let mut rng = thread_rng();
518 let a = Scalar::random(&mut rng);
519
520 let expected_res = -a;
521
522 let fabric = mock_fabric();
524 let a_alloc = fabric.allocate_scalar(a);
525
526 let res = -a_alloc;
527 let res_final = res.await;
528
529 assert_eq!(res_final, expected_res);
530 fabric.shutdown();
531 }
532
533 #[tokio::test]
535 async fn test_scalar_mul() {
536 let mut rng = thread_rng();
537 let a = Scalar::random(&mut rng);
538 let b = Scalar::random(&mut rng);
539
540 let expected_res = a * b;
541
542 let fabric = mock_fabric();
544 let a_alloc = fabric.allocate_scalar(a);
545 let b_alloc = fabric.allocate_scalar(b);
546
547 let res = a_alloc * b_alloc;
548 let res_final = res.await;
549
550 assert_eq!(res_final, expected_res);
551 fabric.shutdown();
552 }
553}