1use std::{
2 iter::Sum,
3 marker::PhantomData,
4 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
5};
6
7use ff::Field;
8use subtle::{Choice, ConstantTimeEq};
9
10use super::HeapArray;
11use crate::{
12 izip_eq,
13 random::{CryptoRngCore, Random},
14 types::{ConditionallySelectable, Element, Positive},
15};
16
17impl<T: Sized + Random, M: Positive> Random for HeapArray<T, M> {
22 fn random(mut rng: impl CryptoRngCore) -> Self {
24 Self {
25 data: T::random_n::<Box<[_]>>(&mut rng, M::SIZE),
26 _len: PhantomData,
27 }
28 }
29}
30
31#[macros::op_variants(owned, borrowed, flipped_commutative)]
38impl<T: Sized, M: Positive> Add<&HeapArray<T, M>> for HeapArray<T, M>
39where
40 T: for<'b> Add<&'b T, Output = T>,
41{
42 type Output = HeapArray<T, M>;
43
44 fn add(self, other: &HeapArray<T, M>) -> Self::Output {
45 Self::Output {
46 data: izip_eq!(self.data, &other.data)
47 .map(|(lhs_value, rhs_value)| lhs_value + rhs_value)
48 .collect(),
49 _len: PhantomData,
50 }
51 }
52}
53
54#[macros::op_variants(owned)]
55impl<T: Sized, M: Positive> Add<&T> for HeapArray<T, M>
56where
57 T: for<'a> Add<&'a T, Output = T>,
58{
59 type Output = HeapArray<T, M>;
60
61 fn add(self, other: &T) -> Self::Output {
62 Self::Output {
63 data: IntoIterator::into_iter(self.data)
64 .map(|value| value + other)
65 .collect(),
66 _len: PhantomData,
67 }
68 }
69}
70
71#[macros::op_variants(owned)]
72impl<T: Sized, M: Positive> AddAssign<&HeapArray<T, M>> for HeapArray<T, M>
73where
74 T: for<'b> AddAssign<&'b T>,
75{
76 fn add_assign(&mut self, other: &Self) {
77 izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.add_assign(rhs_value));
78 }
79}
80
81#[macros::op_variants(owned)]
82impl<T: Sized, M: Positive> AddAssign<&T> for HeapArray<T, M>
83where
84 T: for<'a> AddAssign<&'a T>,
85{
86 fn add_assign(&mut self, other: &T) {
87 self.iter_mut()
88 .for_each(|lhs_value| lhs_value.add_assign(other));
89 }
90}
91
92#[macros::op_variants(owned, borrowed, flipped)]
95impl<T: Sized, M: Positive> Sub<&HeapArray<T, M>> for HeapArray<T, M>
96where
97 T: for<'b> Sub<&'b T, Output = T>,
98{
99 type Output = HeapArray<T, M>;
100
101 fn sub(self, other: &HeapArray<T, M>) -> Self::Output {
102 Self::Output {
103 data: izip_eq!(self, other)
104 .map(|(lhs_value, rhs_value)| lhs_value - rhs_value)
105 .collect(),
106 _len: PhantomData,
107 }
108 }
109}
110
111#[macros::op_variants(owned, borrowed, flipped)]
112impl<T: Sized, M: Positive> Sub<&T> for HeapArray<T, M>
113where
114 T: for<'a> Sub<&'a T, Output = T>,
115{
116 type Output = HeapArray<T, M>;
117
118 fn sub(self, other: &T) -> Self::Output {
119 Self::Output {
120 data: IntoIterator::into_iter(self.data)
121 .map(|value| value - other)
122 .collect(),
123 _len: PhantomData,
124 }
125 }
126}
127
128#[macros::op_variants(owned)]
129impl<T: Sized, M: Positive> SubAssign<&HeapArray<T, M>> for HeapArray<T, M>
130where
131 T: for<'b> SubAssign<&'b T>,
132{
133 fn sub_assign(&mut self, other: &Self) {
134 izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.sub_assign(rhs_value));
135 }
136}
137
138#[macros::op_variants(owned)]
139impl<T: Sized, M: Positive> SubAssign<&T> for HeapArray<T, M>
140where
141 T: for<'a> SubAssign<&'a T>,
142{
143 fn sub_assign(&mut self, other: &T) {
144 self.iter_mut()
145 .for_each(|lhs_value| lhs_value.sub_assign(other));
146 }
147}
148
149#[macros::op_variants(owned, borrowed, flipped)]
152impl<T1: Sized, T2: Sized, T3: Sized, M: Positive> Mul<&HeapArray<T2, M>> for HeapArray<T1, M>
153where
154 T1: for<'b> Mul<&'b T2, Output = T3>,
155{
156 type Output = HeapArray<T3, M>;
157 fn mul(self, other: &HeapArray<T2, M>) -> Self::Output {
158 Self::Output {
159 data: izip_eq!(self, other)
160 .map(|(lhs_value, rhs_value)| lhs_value * rhs_value)
161 .collect(),
162 _len: PhantomData,
163 }
164 }
165}
166
167#[macros::op_variants(borrowed)]
168impl<T1: Sized, T2: Sized + Element, T3: Sized, M: Positive> Mul<&T2> for HeapArray<T1, M>
169where
170 T1: for<'a> Mul<&'a T2, Output = T3>,
171{
172 type Output = HeapArray<T3, M>;
173 fn mul(self, other: &T2) -> Self::Output {
174 Self::Output {
175 data: IntoIterator::into_iter(self.data)
176 .map(|value| value * other)
177 .collect(),
178 _len: PhantomData,
179 }
180 }
181}
182
183#[macros::op_variants(owned)]
186impl<T1: Sized, T2: Sized, M: Positive> MulAssign<&HeapArray<T2, M>> for HeapArray<T1, M>
187where
188 T1: for<'b> MulAssign<&'b T2>,
189{
190 fn mul_assign(&mut self, other: &HeapArray<T2, M>) {
191 izip_eq!(self, other).for_each(|(lhs, rhs)| *lhs *= rhs);
192 }
193}
194
195impl<T: Sized, T2: Sized + Element, M: Positive> MulAssign<&T2> for HeapArray<T, M>
196where
197 T: for<'a> MulAssign<&'a T2>,
198{
199 fn mul_assign(&mut self, other: &T2) {
200 self.iter_mut().for_each(|lhs| *lhs *= other);
201 }
202}
203
204#[macros::op_variants(borrowed)]
207impl<T: Sized, M: Positive> Neg for HeapArray<T, M>
208where
209 T: Neg<Output = T>,
210{
211 type Output = HeapArray<T, M>;
212
213 fn neg(self) -> Self::Output {
214 HeapArray::<T, M> {
215 data: IntoIterator::into_iter(self.data)
216 .map(|value| value.neg())
217 .collect(),
218 _len: PhantomData,
219 }
220 }
221}
222
223impl<T: Field, M: Positive> HeapArray<T, M> {
226 pub fn square(&self) -> Self {
228 Self {
229 data: self.iter().map(|value| value.square()).collect(),
230 _len: PhantomData,
231 }
232 }
233
234 pub fn double(&self) -> Self {
236 Self {
237 data: self.iter().map(|value| value.double()).collect(),
238 _len: PhantomData,
239 }
240 }
241
242 pub fn invert(&self) -> Option<Self> {
244 let inverted_data: Option<Vec<T>> =
245 self.iter().map(|value| value.invert().into()).collect();
246
247 inverted_data.map(|data| HeapArray {
248 data: data.into(),
249 _len: PhantomData,
250 })
251 }
252
253 pub fn pow<S: AsRef<[u64]> + Clone>(&self, exp: S) -> Self {
256 Self {
257 data: self.iter().map(|value| value.pow(exp.clone())).collect(),
258 _len: PhantomData,
259 }
260 }
261}
262
263impl<T: Sized + ConditionallySelectable, M: Positive> ConditionallySelectable for HeapArray<T, M> {
266 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
267 Self {
268 data: izip_eq!(a, b)
269 .map(|(lhs, rhs)| T::conditional_select(lhs, rhs, choice))
270 .collect(),
271 _len: PhantomData,
272 }
273 }
274}
275
276impl<T: Sized + ConditionallySelectable, M: Positive, L: Positive> HeapArray<HeapArray<T, M>, L> {
277 pub fn conditional_select_2d(a: &Self, b: &Self, choice: Choice) -> Self {
284 Self {
285 data: izip_eq!(a, b)
286 .map(|(lhs, rhs)| HeapArray::conditional_select(lhs, rhs, choice))
287 .collect(),
288 _len: PhantomData,
289 }
290 }
291}
292
293impl<T: Sized, M: Positive> ConstantTimeEq for HeapArray<T, M>
296where
297 T: ConstantTimeEq,
298{
299 fn ct_eq(&self, other: &Self) -> Choice {
300 izip_eq!(self, other).fold(Choice::from(1u8), |acc, (lhs_value, rhs_value)| {
301 acc & lhs_value.ct_eq(rhs_value)
302 })
303 }
304}
305
306impl<T: Sized, M: Positive> PartialEq for HeapArray<T, M>
309where
310 T: PartialEq,
311{
312 fn eq(&self, other: &Self) -> bool {
313 izip_eq!(self, other).all(|(lhs_value, rhs_value)| lhs_value == rhs_value)
314 }
315}
316
317impl<T: Default, M: Positive> Sum for HeapArray<T, M>
320where
321 HeapArray<T, M>: Add<Output = HeapArray<T, M>>,
322{
323 fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
324 let first = iter.next().unwrap_or_default();
325 iter.fold(first, |acc, value| acc + value)
326 }
327}
328
329impl<'a, T: Default + Clone, M: Positive> Sum<&'a HeapArray<T, M>> for HeapArray<T, M>
330where
331 HeapArray<T, M>: for<'b> Add<&'b HeapArray<T, M>, Output = HeapArray<T, M>>,
332{
333 fn sum<I: Iterator<Item = &'a Self>>(mut iter: I) -> Self {
334 let first = iter.next().cloned().unwrap_or_default();
335 iter.fold(first, |acc, value| acc + value)
336 }
337}
338
339impl<T: Clone + num_traits::Zero + Eq + for<'b> Add<&'b T, Output = T>, M: Positive + Eq>
341 num_traits::Zero for HeapArray<T, M>
342{
343 fn zero() -> Self {
344 Self {
345 data: (0..M::to_usize()).map(|_| T::zero()).collect(),
346 _len: PhantomData,
347 }
348 }
349
350 fn is_zero(&self) -> bool {
351 let zero = T::zero();
352 self.iter().all(|value| value == &zero)
353 }
354}
355
356impl<
357 T: Sized + Default + num_traits::One + Eq + for<'b> Mul<&'b T, Output = T>,
358 M: Positive + Eq,
359 > num_traits::One for HeapArray<T, M>
360{
361 fn one() -> Self {
362 Self {
363 data: (0..M::to_usize()).map(|_| T::one()).collect(),
364 _len: PhantomData,
365 }
366 }
367
368 fn is_one(&self) -> bool {
369 let one = T::one();
370 self.iter().all(|value| value == &one)
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use typenum::U10;
377
378 use super::*;
379 use crate::types::heap_array::HeapArray;
380
381 #[test]
382 fn test_addition() {
383 let a = HeapArray::<usize, U10>::from_fn(|i| i);
384 let b = HeapArray::from_fn(|i| i + 1);
385 let c = HeapArray::from_fn(|i| i + i + 1);
386
387 assert_eq!(a.clone() + &b, c);
388 assert_eq!(&a + b.clone(), c);
389 assert_eq!(&a + &b, c);
390 assert_eq!(a + b, c);
391 }
392
393 #[test]
394 fn test_broadcast_addition() {
395 let a = HeapArray::<usize, U10>::from_fn(|i| i);
396 let b = 1;
397 let c = HeapArray::from_fn(|i| i + 1);
398
399 assert_eq!(a + b, c);
400 }
401
402 #[test]
403 fn test_add_assign() {
404 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
405 let b = HeapArray::from_fn(|i| i + 1);
406 let c = HeapArray::from_fn(|i| i + i + 1);
407
408 a += b;
409 assert_eq!(a, c);
410 }
411
412 #[test]
413 fn test_broadcast_add_assign() {
414 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
415 let b = 1;
416 let c = HeapArray::from_fn(|i| i + 1);
417
418 a += b;
419 assert_eq!(a, c);
420 }
421
422 #[test]
423 fn test_subtraction() {
424 let a = HeapArray::<usize, U10>::from_fn(|i| 2 * i + 1);
425 let b = HeapArray::from_fn(|i| i);
426 let c = HeapArray::from_fn(|i| i + 1);
427
428 assert_eq!(a - b, c);
429 }
430
431 #[test]
432 fn test_broadcast_subtraction() {
433 let a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
434 let b = 1;
435 let c = HeapArray::from_fn(|i| i);
436
437 assert_eq!(a - b, c);
438 }
439
440 #[test]
441 fn test_sub_assign() {
442 let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
443 let b = HeapArray::from_fn(|i| i);
444 let c = HeapArray::from_fn(|_| 1);
445
446 a -= b;
447 assert_eq!(a, c);
448 }
449
450 #[test]
451 fn test_broadcast_sub_assign() {
452 let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
453 let b = 1;
454 let c = HeapArray::from_fn(|i| i);
455
456 a -= b;
457 assert_eq!(a, c);
458 }
459
460 #[test]
461 fn test_multiplication() {
462 let a = HeapArray::<usize, U10>::from_fn(|i| i);
463 let b = HeapArray::from_fn(|i| i + 1);
464 let c = HeapArray::from_fn(|i| i * (i + 1));
465
466 assert_eq!(a * b, c);
467 }
468
469 #[test]
470 fn test_broadcast_multiplication() {
471 let a = HeapArray::<usize, U10>::from_fn(|i| i);
472 let b = 1;
473 let c = HeapArray::from_fn(|i| i);
474
475 assert_eq!(a * &b, c);
476 }
477
478 #[test]
479 fn test_mul_assign() {
480 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
481 let b = HeapArray::from_fn(|i| i + 1);
482 let c = HeapArray::from_fn(|i| i * (i + 1));
483
484 a *= b;
485 assert_eq!(a, c);
486 }
487
488 #[test]
489 fn test_broadcast_mul_assign() {
490 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
491 let b = 1;
492 let c = HeapArray::from_fn(|i| i);
493
494 a *= &b;
495 assert_eq!(a, c);
496 }
497
498 #[test]
499 fn test_negation() {
500 let a = HeapArray::<i64, U10>::from_fn(|i| i as i64);
501 let b = HeapArray::from_fn(|i| -(i as i64));
502
503 assert_eq!(-a, b);
504 }
505
506 #[test]
507 fn test_conditional_select() {
508 let a = HeapArray::<u32, U10>::from_fn(|i| i as u32);
509 let b = HeapArray::<u32, U10>::from_fn(|i| i as u32 + 1);
510 let choice = Choice::from(1u8);
511
512 let selected = HeapArray::conditional_select(&a, &b, choice);
513 let non_selected = HeapArray::conditional_select(&a, &b, !choice);
514 assert_eq!(selected, b);
515 assert_eq!(non_selected, a);
516 }
517
518 #[test]
519 fn test_constant_time_eq() {
520 let a = HeapArray::<usize, U10>::from_fn(|i| i);
521 let b = HeapArray::from_fn(|i| i);
522 let c = HeapArray::from_fn(|i| i + 1);
523
524 assert!(a.ct_eq(&b).unwrap_u8() == 1);
525 assert!(a.ct_eq(&c).unwrap_u8() == 0);
526 }
527}