1use std::{
2 iter::Sum,
3 marker::PhantomData,
4 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
5};
6
7use ff::Field;
8use subtle::{Choice, ConstantTimeEq};
9
10use super::HeapArray;
11use crate::{
12 errors::PrimitiveError,
13 izip_eq,
14 random::{CryptoRngCore, Random, RandomNonZero},
15 types::{ConditionallySelectable, Element, Positive},
16};
17
18impl<T: Sized + Random, M: Positive> Random for HeapArray<T, M> {
23 fn random(mut rng: impl CryptoRngCore) -> Self {
25 Self {
26 data: T::random_n::<Box<[_]>>(&mut rng, M::SIZE),
27 _len: PhantomData,
28 }
29 }
30}
31
32impl<T: Sized + RandomNonZero, M: Positive> RandomNonZero for HeapArray<T, M> {
33 fn random_non_zero(mut rng: impl CryptoRngCore) -> Result<Self, PrimitiveError> {
35 Ok(Self {
36 data: T::random_n_non_zero::<Box<[_]>>(&mut rng, M::SIZE)?,
37 _len: PhantomData,
38 })
39 }
40}
41
42#[macros::op_variants(owned, borrowed, flipped_commutative)]
49impl<T: Sized, M: Positive> Add<&HeapArray<T, M>> for HeapArray<T, M>
50where
51 T: for<'b> Add<&'b T, Output = T>,
52{
53 type Output = HeapArray<T, M>;
54
55 fn add(self, other: &HeapArray<T, M>) -> Self::Output {
56 Self::Output {
57 data: izip_eq!(self.data, &other.data)
58 .map(|(lhs_value, rhs_value)| lhs_value + rhs_value)
59 .collect(),
60 _len: PhantomData,
61 }
62 }
63}
64
65#[macros::op_variants(owned)]
66impl<T: Sized, M: Positive> Add<&T> for HeapArray<T, M>
67where
68 T: for<'a> Add<&'a T, Output = T>,
69{
70 type Output = HeapArray<T, M>;
71
72 fn add(self, other: &T) -> Self::Output {
73 Self::Output {
74 data: IntoIterator::into_iter(self.data)
75 .map(|value| value + other)
76 .collect(),
77 _len: PhantomData,
78 }
79 }
80}
81
82#[macros::op_variants(owned)]
83impl<T: Sized, M: Positive> AddAssign<&HeapArray<T, M>> for HeapArray<T, M>
84where
85 T: for<'b> AddAssign<&'b T>,
86{
87 fn add_assign(&mut self, other: &Self) {
88 izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.add_assign(rhs_value));
89 }
90}
91
92#[macros::op_variants(owned)]
93impl<T: Sized, M: Positive> AddAssign<&T> for HeapArray<T, M>
94where
95 T: for<'a> AddAssign<&'a T>,
96{
97 fn add_assign(&mut self, other: &T) {
98 self.iter_mut()
99 .for_each(|lhs_value| lhs_value.add_assign(other));
100 }
101}
102
103#[macros::op_variants(owned, borrowed, flipped)]
106impl<T: Sized, M: Positive> Sub<&HeapArray<T, M>> for HeapArray<T, M>
107where
108 T: for<'b> Sub<&'b T, Output = T>,
109{
110 type Output = HeapArray<T, M>;
111
112 fn sub(self, other: &HeapArray<T, M>) -> Self::Output {
113 Self::Output {
114 data: izip_eq!(self, other)
115 .map(|(lhs_value, rhs_value)| lhs_value - rhs_value)
116 .collect(),
117 _len: PhantomData,
118 }
119 }
120}
121
122#[macros::op_variants(owned, borrowed, flipped)]
123impl<T: Sized, M: Positive> Sub<&T> for HeapArray<T, M>
124where
125 T: for<'a> Sub<&'a T, Output = T>,
126{
127 type Output = HeapArray<T, M>;
128
129 fn sub(self, other: &T) -> Self::Output {
130 Self::Output {
131 data: IntoIterator::into_iter(self.data)
132 .map(|value| value - other)
133 .collect(),
134 _len: PhantomData,
135 }
136 }
137}
138
139#[macros::op_variants(owned)]
140impl<T: Sized, M: Positive> SubAssign<&HeapArray<T, M>> for HeapArray<T, M>
141where
142 T: for<'b> SubAssign<&'b T>,
143{
144 fn sub_assign(&mut self, other: &Self) {
145 izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.sub_assign(rhs_value));
146 }
147}
148
149#[macros::op_variants(owned)]
150impl<T: Sized, M: Positive> SubAssign<&T> for HeapArray<T, M>
151where
152 T: for<'a> SubAssign<&'a T>,
153{
154 fn sub_assign(&mut self, other: &T) {
155 self.iter_mut()
156 .for_each(|lhs_value| lhs_value.sub_assign(other));
157 }
158}
159
160#[macros::op_variants(owned, borrowed, flipped)]
163impl<T1: Sized, T2: Sized, T3: Sized, M: Positive> Mul<&HeapArray<T2, M>> for HeapArray<T1, M>
164where
165 T1: for<'b> Mul<&'b T2, Output = T3>,
166{
167 type Output = HeapArray<T3, M>;
168 fn mul(self, other: &HeapArray<T2, M>) -> Self::Output {
169 Self::Output {
170 data: izip_eq!(self, other)
171 .map(|(lhs_value, rhs_value)| lhs_value * rhs_value)
172 .collect(),
173 _len: PhantomData,
174 }
175 }
176}
177
178#[macros::op_variants(borrowed)]
179impl<T1: Sized, T2: Sized + Element, T3: Sized, M: Positive> Mul<&T2> for HeapArray<T1, M>
180where
181 T1: for<'a> Mul<&'a T2, Output = T3>,
182{
183 type Output = HeapArray<T3, M>;
184 fn mul(self, other: &T2) -> Self::Output {
185 Self::Output {
186 data: IntoIterator::into_iter(self.data)
187 .map(|value| value * other)
188 .collect(),
189 _len: PhantomData,
190 }
191 }
192}
193
194#[macros::op_variants(owned)]
197impl<T1: Sized, T2: Sized, M: Positive> MulAssign<&HeapArray<T2, M>> for HeapArray<T1, M>
198where
199 T1: for<'b> MulAssign<&'b T2>,
200{
201 fn mul_assign(&mut self, other: &HeapArray<T2, M>) {
202 izip_eq!(self, other).for_each(|(lhs, rhs)| *lhs *= rhs);
203 }
204}
205
206impl<T: Sized, T2: Sized + Element, M: Positive> MulAssign<&T2> for HeapArray<T, M>
207where
208 T: for<'a> MulAssign<&'a T2>,
209{
210 fn mul_assign(&mut self, other: &T2) {
211 self.iter_mut().for_each(|lhs| *lhs *= other);
212 }
213}
214
215#[macros::op_variants(borrowed)]
218impl<T: Sized, M: Positive> Neg for HeapArray<T, M>
219where
220 T: Neg<Output = T>,
221{
222 type Output = HeapArray<T, M>;
223
224 fn neg(self) -> Self::Output {
225 HeapArray::<T, M> {
226 data: IntoIterator::into_iter(self.data)
227 .map(|value| value.neg())
228 .collect(),
229 _len: PhantomData,
230 }
231 }
232}
233
234impl<T: Field, M: Positive> HeapArray<T, M> {
237 pub fn square(&self) -> Self {
239 Self {
240 data: self.iter().map(|value| value.square()).collect(),
241 _len: PhantomData,
242 }
243 }
244
245 pub fn double(&self) -> Self {
247 Self {
248 data: self.iter().map(|value| value.double()).collect(),
249 _len: PhantomData,
250 }
251 }
252
253 pub fn invert(&self) -> Option<Self> {
255 let inverted_data: Option<Vec<T>> =
256 self.iter().map(|value| value.invert().into()).collect();
257
258 inverted_data.map(|data| HeapArray {
259 data: data.into(),
260 _len: PhantomData,
261 })
262 }
263
264 pub fn pow<S: AsRef<[u64]> + Clone>(&self, exp: S) -> Self {
267 Self {
268 data: self.iter().map(|value| value.pow(exp.clone())).collect(),
269 _len: PhantomData,
270 }
271 }
272}
273
274impl<T: Sized + ConditionallySelectable, M: Positive> ConditionallySelectable for HeapArray<T, M> {
277 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
278 Self {
279 data: izip_eq!(a, b)
280 .map(|(lhs, rhs)| T::conditional_select(lhs, rhs, choice))
281 .collect(),
282 _len: PhantomData,
283 }
284 }
285}
286
287impl<T: Sized + ConditionallySelectable, M: Positive, L: Positive> HeapArray<HeapArray<T, M>, L> {
288 pub fn conditional_select_2d(a: &Self, b: &Self, choice: Choice) -> Self {
295 Self {
296 data: izip_eq!(a, b)
297 .map(|(lhs, rhs)| HeapArray::conditional_select(lhs, rhs, choice))
298 .collect(),
299 _len: PhantomData,
300 }
301 }
302}
303
304impl<T: Sized, M: Positive> ConstantTimeEq for HeapArray<T, M>
307where
308 T: ConstantTimeEq,
309{
310 fn ct_eq(&self, other: &Self) -> Choice {
311 izip_eq!(self, other).fold(Choice::from(1u8), |acc, (lhs_value, rhs_value)| {
312 acc & lhs_value.ct_eq(rhs_value)
313 })
314 }
315}
316
317impl<T: Sized, M: Positive> PartialEq for HeapArray<T, M>
320where
321 T: PartialEq,
322{
323 fn eq(&self, other: &Self) -> bool {
324 izip_eq!(self, other).all(|(lhs_value, rhs_value)| lhs_value == rhs_value)
325 }
326}
327
328impl<T: Default, M: Positive> Sum for HeapArray<T, M>
331where
332 HeapArray<T, M>: Add<Output = HeapArray<T, M>>,
333{
334 fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
335 let first = iter.next().unwrap_or_default();
336 iter.fold(first, |acc, value| acc + value)
337 }
338}
339
340impl<'a, T: Default + Clone, M: Positive> Sum<&'a HeapArray<T, M>> for HeapArray<T, M>
341where
342 HeapArray<T, M>: for<'b> Add<&'b HeapArray<T, M>, Output = HeapArray<T, M>>,
343{
344 fn sum<I: Iterator<Item = &'a Self>>(mut iter: I) -> Self {
345 let first = iter.next().cloned().unwrap_or_default();
346 iter.fold(first, |acc, value| acc + value)
347 }
348}
349
350impl<T: Clone + num_traits::Zero + Eq + for<'b> Add<&'b T, Output = T>, M: Positive + Eq>
352 num_traits::Zero for HeapArray<T, M>
353{
354 fn zero() -> Self {
355 Self {
356 data: (0..M::to_usize()).map(|_| T::zero()).collect(),
357 _len: PhantomData,
358 }
359 }
360
361 fn is_zero(&self) -> bool {
362 let zero = T::zero();
363 self.iter().all(|value| value == &zero)
364 }
365}
366
367impl<
368 T: Sized + Default + num_traits::One + Eq + for<'b> Mul<&'b T, Output = T>,
369 M: Positive + Eq,
370 > num_traits::One for HeapArray<T, M>
371{
372 fn one() -> Self {
373 Self {
374 data: (0..M::to_usize()).map(|_| T::one()).collect(),
375 _len: PhantomData,
376 }
377 }
378
379 fn is_one(&self) -> bool {
380 let one = T::one();
381 self.iter().all(|value| value == &one)
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use typenum::U10;
388
389 use super::*;
390 use crate::types::heap_array::HeapArray;
391
392 #[test]
393 fn test_addition() {
394 let a = HeapArray::<usize, U10>::from_fn(|i| i);
395 let b = HeapArray::from_fn(|i| i + 1);
396 let c = HeapArray::from_fn(|i| i + i + 1);
397
398 assert_eq!(a.clone() + &b, c);
399 assert_eq!(&a + b.clone(), c);
400 assert_eq!(&a + &b, c);
401 assert_eq!(a + b, c);
402 }
403
404 #[test]
405 fn test_broadcast_addition() {
406 let a = HeapArray::<usize, U10>::from_fn(|i| i);
407 let b = 1;
408 let c = HeapArray::from_fn(|i| i + 1);
409
410 assert_eq!(a + b, c);
411 }
412
413 #[test]
414 fn test_add_assign() {
415 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
416 let b = HeapArray::from_fn(|i| i + 1);
417 let c = HeapArray::from_fn(|i| i + i + 1);
418
419 a += b;
420 assert_eq!(a, c);
421 }
422
423 #[test]
424 fn test_broadcast_add_assign() {
425 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
426 let b = 1;
427 let c = HeapArray::from_fn(|i| i + 1);
428
429 a += b;
430 assert_eq!(a, c);
431 }
432
433 #[test]
434 fn test_subtraction() {
435 let a = HeapArray::<usize, U10>::from_fn(|i| 2 * i + 1);
436 let b = HeapArray::from_fn(|i| i);
437 let c = HeapArray::from_fn(|i| i + 1);
438
439 assert_eq!(a - b, c);
440 }
441
442 #[test]
443 fn test_broadcast_subtraction() {
444 let a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
445 let b = 1;
446 let c = HeapArray::from_fn(|i| i);
447
448 assert_eq!(a - b, c);
449 }
450
451 #[test]
452 fn test_sub_assign() {
453 let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
454 let b = HeapArray::from_fn(|i| i);
455 let c = HeapArray::from_fn(|_| 1);
456
457 a -= b;
458 assert_eq!(a, c);
459 }
460
461 #[test]
462 fn test_broadcast_sub_assign() {
463 let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
464 let b = 1;
465 let c = HeapArray::from_fn(|i| i);
466
467 a -= b;
468 assert_eq!(a, c);
469 }
470
471 #[test]
472 fn test_multiplication() {
473 let a = HeapArray::<usize, U10>::from_fn(|i| i);
474 let b = HeapArray::from_fn(|i| i + 1);
475 let c = HeapArray::from_fn(|i| i * (i + 1));
476
477 assert_eq!(a * b, c);
478 }
479
480 #[test]
481 fn test_broadcast_multiplication() {
482 let a = HeapArray::<usize, U10>::from_fn(|i| i);
483 let b = 1;
484 let c = HeapArray::from_fn(|i| i);
485
486 assert_eq!(a * &b, c);
487 }
488
489 #[test]
490 fn test_mul_assign() {
491 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
492 let b = HeapArray::from_fn(|i| i + 1);
493 let c = HeapArray::from_fn(|i| i * (i + 1));
494
495 a *= b;
496 assert_eq!(a, c);
497 }
498
499 #[test]
500 fn test_broadcast_mul_assign() {
501 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
502 let b = 1;
503 let c = HeapArray::from_fn(|i| i);
504
505 a *= &b;
506 assert_eq!(a, c);
507 }
508
509 #[test]
510 fn test_negation() {
511 let a = HeapArray::<i64, U10>::from_fn(|i| i as i64);
512 let b = HeapArray::from_fn(|i| -(i as i64));
513
514 assert_eq!(-a, b);
515 }
516
517 #[test]
518 fn test_conditional_select() {
519 let a = HeapArray::<u32, U10>::from_fn(|i| i as u32);
520 let b = HeapArray::<u32, U10>::from_fn(|i| i as u32 + 1);
521 let choice = Choice::from(1u8);
522
523 let selected = HeapArray::conditional_select(&a, &b, choice);
524 let non_selected = HeapArray::conditional_select(&a, &b, !choice);
525 assert_eq!(selected, b);
526 assert_eq!(non_selected, a);
527 }
528
529 #[test]
530 fn test_constant_time_eq() {
531 let a = HeapArray::<usize, U10>::from_fn(|i| i);
532 let b = HeapArray::from_fn(|i| i);
533 let c = HeapArray::from_fn(|i| i + 1);
534
535 assert!(a.ct_eq(&b).unwrap_u8() == 1);
536 assert!(a.ct_eq(&c).unwrap_u8() == 0);
537 }
538}