1use std::{
2 iter::Sum,
3 marker::PhantomData,
4 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
5};
6
7use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
8
9use super::HeapArray;
10use crate::{
11 izip_eq,
12 random::{CryptoRngCore, Random},
13 types::Positive,
14};
15
16impl<T: Sized + Random, M: Positive> Random for HeapArray<T, M> {
21 fn random(mut rng: impl CryptoRngCore) -> Self {
23 Self {
24 data: (0..M::to_usize()).map(|_| T::random(&mut rng)).collect(),
25 _len: PhantomData,
26 }
27 }
28}
29
30impl<T: Sized, M: Positive> Add for HeapArray<T, M>
37where
38 T: Add<Output = T>,
39{
40 type Output = HeapArray<T, M>;
41
42 fn add(self, other: Self) -> Self::Output {
43 Self::Output {
44 data: izip_eq!(self, other)
45 .map(|(lhs_value, rhs_value)| lhs_value + rhs_value)
46 .collect(),
47 _len: PhantomData,
48 }
49 }
50}
51
52impl<T: Sized, M: Positive> Add<HeapArray<T, M>> for &HeapArray<T, M>
53where
54 T: for<'b> Add<&'b T, Output = T>,
55{
56 type Output = HeapArray<T, M>;
57
58 fn add(self, other: HeapArray<T, M>) -> Self::Output {
59 Self::Output {
60 data: izip_eq!(self, other)
61 .map(|(lhs_value, rhs_value)| rhs_value + lhs_value)
62 .collect(),
63 _len: PhantomData,
64 }
65 }
66}
67
68impl<T: Sized, M: Positive> Add<&HeapArray<T, M>> for HeapArray<T, M>
69where
70 T: for<'b> Add<&'b T, Output = T>,
71{
72 type Output = HeapArray<T, M>;
73
74 fn add(self, other: &HeapArray<T, M>) -> Self::Output {
75 Self::Output {
76 data: izip_eq!(self, other)
77 .map(|(lhs_value, rhs_value)| lhs_value + rhs_value)
78 .collect(),
79 _len: PhantomData,
80 }
81 }
82}
83
84impl<T: Sized + Copy, M: Positive> Add for &HeapArray<T, M>
85where
86 T: for<'b> Add<&'b T, Output = T>,
87{
88 type Output = HeapArray<T, M>;
89
90 fn add(self, other: Self) -> Self::Output {
91 Self::Output {
92 data: izip_eq!(self, other)
93 .map(|(lhs_value, rhs_value)| *lhs_value + rhs_value)
94 .collect(),
95 _len: PhantomData,
96 }
97 }
98}
99
100impl<T: Sized, M: Positive> Add<T> for HeapArray<T, M>
103where
104 T: for<'a> Add<&'a T, Output = T>,
105{
106 type Output = HeapArray<T, M>;
107
108 fn add(self, other: T) -> Self::Output {
109 Self::Output {
110 data: self.into_iter().map(|value| value + &other).collect(),
111 _len: PhantomData,
112 }
113 }
114}
115
116impl<T: Sized, M: Positive> Add<&T> for HeapArray<T, M>
117where
118 T: for<'a> AddAssign<&'a T>,
119{
120 type Output = HeapArray<T, M>;
121
122 fn add(mut self, other: &T) -> Self::Output {
123 self.iter_mut().for_each(|value| *value += other);
124 self
125 }
126}
127
128impl<'a, T: Sized, M: Positive> Add<T> for &'a mut HeapArray<T, M>
129where
130 T: for<'b> AddAssign<&'b T>,
131{
132 type Output = &'a mut HeapArray<T, M>;
133
134 fn add(self, other: T) -> Self::Output {
135 self.iter_mut().for_each(|value| *value += &other);
136 self
137 }
138}
139
140impl<'a, T: Sized, M: Positive> Add<&T> for &'a mut HeapArray<T, M>
141where
142 T: for<'b> AddAssign<&'b T>,
143{
144 type Output = &'a mut HeapArray<T, M>;
145
146 fn add(self, other: &T) -> Self::Output {
147 self.iter_mut().for_each(|value| *value += other);
148 self
149 }
150}
151
152impl<T: Sized, M: Positive> AddAssign for HeapArray<T, M>
155where
156 T: for<'a> AddAssign<&'a T>,
157{
158 fn add_assign(&mut self, other: Self) {
159 izip_eq!(self, &other).for_each(|(lhs_value, rhs_value)| lhs_value.add_assign(rhs_value));
160 }
161}
162
163impl<T: Sized, M: Positive> AddAssign<&HeapArray<T, M>> for HeapArray<T, M>
164where
165 T: for<'b> AddAssign<&'b T>,
166{
167 fn add_assign(&mut self, other: &Self) {
168 izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.add_assign(rhs_value));
169 }
170}
171
172impl<T: Sized, M: Positive> AddAssign<T> for HeapArray<T, M>
175where
176 T: for<'a> AddAssign<&'a T>,
177{
178 fn add_assign(&mut self, other: T) {
179 self.iter_mut()
180 .for_each(|lhs_value| lhs_value.add_assign(&other));
181 }
182}
183
184impl<T: Sized, M: Positive> AddAssign<&T> for HeapArray<T, M>
185where
186 T: for<'a> AddAssign<&'a T>,
187{
188 fn add_assign(&mut self, other: &T) {
189 self.iter_mut()
190 .for_each(|lhs_value| lhs_value.add_assign(other));
191 }
192}
193
194impl<T: Sized, M: Positive> Sub for HeapArray<T, M>
197where
198 T: Sub<Output = T>,
199{
200 type Output = HeapArray<T, M>;
201
202 fn sub(self, other: Self) -> Self::Output {
203 Self::Output {
204 data: izip_eq!(self, other)
205 .map(|(lhs_value, rhs_value)| lhs_value - rhs_value)
206 .collect(),
207 _len: PhantomData,
208 }
209 }
210}
211
212impl<T: Sized, M: Positive> Sub<&HeapArray<T, M>> for HeapArray<T, M>
213where
214 T: for<'b> Sub<&'b T, Output = T>,
215{
216 type Output = HeapArray<T, M>;
217
218 fn sub(self, other: &HeapArray<T, M>) -> Self::Output {
219 Self::Output {
220 data: izip_eq!(self, other)
221 .map(|(lhs_value, rhs_value)| lhs_value - rhs_value)
222 .collect(),
223 _len: PhantomData,
224 }
225 }
226}
227
228impl<T: Sized, M: Positive> Sub<&mut HeapArray<T, M>> for HeapArray<T, M>
229where
230 T: for<'b> Sub<&'b T, Output = T>,
231{
232 type Output = HeapArray<T, M>;
233
234 fn sub(self, other: &mut HeapArray<T, M>) -> Self::Output {
235 Self::Output {
236 data: izip_eq!(self, other)
237 .map(|(lhs_value, rhs_value)| lhs_value - rhs_value)
238 .collect(),
239 _len: PhantomData,
240 }
241 }
242}
243
244impl<T: Sized + Copy, M: Positive> Sub<HeapArray<T, M>> for &HeapArray<T, M>
245where
246 T: Sub<Output = T>,
247{
248 type Output = HeapArray<T, M>;
249
250 fn sub(self, other: HeapArray<T, M>) -> Self::Output {
251 Self::Output {
252 data: izip_eq!(self, other)
253 .map(|(lhs_value, rhs_value)| *lhs_value - rhs_value)
254 .collect(),
255 _len: PhantomData,
256 }
257 }
258}
259
260impl<T: Sized + Copy, M: Positive> Sub for &HeapArray<T, M>
261where
262 T: for<'b> Sub<&'b T, Output = T>,
263{
264 type Output = HeapArray<T, M>;
265
266 fn sub(self, other: &HeapArray<T, M>) -> Self::Output {
267 Self::Output {
268 data: izip_eq!(self, other)
269 .map(|(lhs_value, rhs_value)| *lhs_value - rhs_value)
270 .collect(),
271 _len: PhantomData,
272 }
273 }
274}
275
276impl<'a, T: Sized + Copy, M: Positive> Sub<&'a mut HeapArray<T, M>> for &'a HeapArray<T, M>
277where
278 T: for<'b> Sub<&'b T, Output = T>,
279{
280 type Output = HeapArray<T, M>;
281
282 fn sub(self, other: &mut HeapArray<T, M>) -> Self::Output {
283 Self::Output {
284 data: izip_eq!(self, other)
285 .map(|(lhs_value, rhs_value)| *lhs_value - rhs_value)
286 .collect(),
287 _len: PhantomData,
288 }
289 }
290}
291
292impl<T: Sized, M: Positive> Sub<T> for HeapArray<T, M>
295where
296 T: for<'a> SubAssign<&'a T>,
297{
298 type Output = HeapArray<T, M>;
299
300 fn sub(mut self, other: T) -> Self::Output {
301 self.iter_mut().for_each(|lhs_value| *lhs_value -= &other);
302 self
303 }
304}
305
306impl<T: Sized, M: Positive> Sub<&T> for HeapArray<T, M>
307where
308 T: for<'a> SubAssign<&'a T>,
309{
310 type Output = HeapArray<T, M>;
311
312 fn sub(mut self, other: &T) -> Self::Output {
313 self.iter_mut().for_each(|lhs_value| *lhs_value -= other);
314 self
315 }
316}
317
318impl<T: Sized + Copy, M: Positive> Sub<T> for &HeapArray<T, M>
319where
320 T: for<'b> Sub<&'b T, Output = T>,
321{
322 type Output = HeapArray<T, M>;
323
324 fn sub(self, other: T) -> Self::Output {
325 self.iter().map(|lhs_value| *lhs_value - &other).collect()
326 }
327}
328
329impl<T: Sized + Copy, M: Positive> Sub<&T> for &HeapArray<T, M>
330where
331 T: for<'b> Sub<&'b T, Output = T>,
332{
333 type Output = HeapArray<T, M>;
334
335 fn sub(self, other: &T) -> Self::Output {
336 self.iter().map(|lhs_value| *lhs_value - other).collect()
337 }
338}
339
340impl<T: Sized, M: Positive> SubAssign for HeapArray<T, M>
343where
344 T: for<'a> SubAssign<&'a T>,
345{
346 fn sub_assign(&mut self, other: Self) {
347 izip_eq!(self, &other).for_each(|(lhs_value, rhs_value)| lhs_value.sub_assign(rhs_value));
348 }
349}
350
351impl<T: Sized, M: Positive> SubAssign<&HeapArray<T, M>> for HeapArray<T, M>
352where
353 T: for<'b> SubAssign<&'b T>,
354{
355 fn sub_assign(&mut self, other: &Self) {
356 izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.sub_assign(rhs_value));
357 }
358}
359
360impl<T: Sized, M: Positive> SubAssign<T> for HeapArray<T, M>
363where
364 T: for<'a> SubAssign<&'a T>,
365{
366 fn sub_assign(&mut self, other: T) {
367 self.iter_mut()
368 .for_each(|lhs_value| lhs_value.sub_assign(&other));
369 }
370}
371
372impl<T: Sized, M: Positive> SubAssign<&T> for HeapArray<T, M>
373where
374 T: for<'a> SubAssign<&'a T>,
375{
376 fn sub_assign(&mut self, other: &T) {
377 self.iter_mut()
378 .for_each(|lhs_value| lhs_value.sub_assign(other));
379 }
380}
381
382impl<T: Sized, M: Positive> Mul for HeapArray<T, M>
385where
386 T: for<'a> Mul<Output = T>,
387{
388 type Output = HeapArray<T, M>;
389
390 fn mul(self, other: Self) -> Self::Output {
391 Self::Output {
392 data: izip_eq!(self, other)
393 .map(|(lhs_value, rhs_value)| lhs_value * rhs_value)
394 .collect(),
395 _len: PhantomData,
396 }
397 }
398}
399
400impl<T: Sized + Copy, M: Positive> Mul for &HeapArray<T, M>
401where
402 T: for<'b> Mul<&'b T, Output = T>,
403{
404 type Output = HeapArray<T, M>;
405
406 fn mul(self, other: Self) -> Self::Output {
407 Self::Output {
408 data: izip_eq!(self, other)
409 .map(|(lhs_value, rhs_value)| *lhs_value * rhs_value)
410 .collect(),
411 _len: PhantomData,
412 }
413 }
414}
415
416impl<T: Sized, M: Positive> Mul<HeapArray<T, M>> for &HeapArray<T, M>
417where
418 T: for<'b> MulAssign<&'b T>,
419{
420 type Output = HeapArray<T, M>;
421
422 fn mul(self, mut other: HeapArray<T, M>) -> Self::Output {
423 izip_eq!(self, &mut other).for_each(|(lhs, rhs)| *rhs *= lhs);
424 other
425 }
426}
427
428impl<T: Sized, M: Positive> Mul<&HeapArray<T, M>> for HeapArray<T, M>
429where
430 T: for<'b> MulAssign<&'b T>,
431{
432 type Output = HeapArray<T, M>;
433
434 fn mul(mut self, other: &HeapArray<T, M>) -> Self::Output {
435 izip_eq!(&mut self, other).for_each(|(lhs, rhs)| *lhs *= rhs);
436 self
437 }
438}
439
440impl<T: Sized, M: Positive> Mul<T> for HeapArray<T, M>
443where
444 T: for<'a> MulAssign<&'a T>,
445{
446 type Output = HeapArray<T, M>;
447
448 fn mul(mut self, other: T) -> Self::Output {
449 self.iter_mut().for_each(|value| *value *= &other);
450 self
451 }
452}
453
454impl<T: Sized, M: Positive> Mul<&T> for HeapArray<T, M>
455where
456 T: for<'a> MulAssign<&'a T>,
457{
458 type Output = HeapArray<T, M>;
459
460 fn mul(mut self, other: &T) -> Self::Output {
461 self.iter_mut().for_each(|value| *value *= other);
462 self
463 }
464}
465
466impl<T: Sized + Copy, M: Positive> Mul<T> for &HeapArray<T, M>
467where
468 T: for<'b> Mul<&'b T, Output = T>,
469{
470 type Output = HeapArray<T, M>;
471
472 fn mul(self, other: T) -> Self::Output {
473 Self::Output {
474 data: self.iter().map(|value| *value * &other).collect(),
475 _len: PhantomData,
476 }
477 }
478}
479
480impl<T: Sized + Copy, M: Positive> Mul<&T> for &HeapArray<T, M>
481where
482 T: for<'b> Mul<&'b T, Output = T>,
483{
484 type Output = HeapArray<T, M>;
485
486 fn mul(self, other: &T) -> Self::Output {
487 Self::Output {
488 data: self.iter().map(|value| *value * other).collect(),
489 _len: PhantomData,
490 }
491 }
492}
493
494impl<T: Sized, M: Positive> MulAssign for HeapArray<T, M>
497where
498 T: for<'a> MulAssign<&'a T>,
499{
500 fn mul_assign(&mut self, other: Self) {
501 izip_eq!(self, &other).for_each(|(lhs, rhs)| *lhs *= rhs);
502 }
503}
504
505impl<T: Sized, M: Positive> MulAssign<&HeapArray<T, M>> for HeapArray<T, M>
506where
507 T: for<'b> MulAssign<&'b T>,
508{
509 fn mul_assign(&mut self, other: &Self) {
510 izip_eq!(self, other).for_each(|(lhs, rhs)| *lhs *= rhs);
511 }
512}
513
514impl<T: Sized, M: Positive> MulAssign<T> for HeapArray<T, M>
517where
518 T: for<'a> MulAssign<&'a T>,
519{
520 fn mul_assign(&mut self, other: T) {
521 self.iter_mut().for_each(|lhs| *lhs *= &other);
522 }
523}
524
525impl<T: Sized, M: Positive> MulAssign<&T> for HeapArray<T, M>
526where
527 T: for<'a> MulAssign<&'a T>,
528{
529 fn mul_assign(&mut self, other: &T) {
530 self.iter_mut().for_each(|lhs| *lhs *= other);
531 }
532}
533
534impl<T: Sized, M: Positive> Neg for HeapArray<T, M>
537where
538 T: Neg<Output = T>,
539{
540 type Output = HeapArray<T, M>;
541
542 fn neg(self) -> Self::Output {
543 Self::Output {
544 data: self.into_iter().map(|value| value.neg()).collect(),
545 _len: PhantomData,
546 }
547 }
548}
549
550impl<T: Sized + Copy, M: Positive> Neg for &HeapArray<T, M>
551where
552 T: Neg<Output = T>,
553{
554 type Output = HeapArray<T, M>;
555
556 fn neg(self) -> Self::Output {
557 Self::Output {
558 data: self.iter().map(|value| value.neg()).collect(),
559 _len: PhantomData,
560 }
561 }
562}
563
564impl<T: Sized + ConditionallySelectable, M: Positive> HeapArray<T, M> {
567 pub fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
574 Self {
575 data: izip_eq!(a, b)
576 .map(|(lhs, rhs)| T::conditional_select(lhs, rhs, choice))
577 .collect(),
578 _len: PhantomData,
579 }
580 }
581}
582
583impl<T: Sized + ConditionallySelectable, M: Positive, L: Positive> HeapArray<HeapArray<T, M>, L> {
584 pub fn conditional_select_2d(a: &Self, b: &Self, choice: Choice) -> Self {
591 Self {
592 data: izip_eq!(a, b)
593 .map(|(lhs, rhs)| HeapArray::conditional_select(lhs, rhs, choice))
594 .collect(),
595 _len: PhantomData,
596 }
597 }
598}
599
600impl<T: Sized, M: Positive> ConstantTimeEq for HeapArray<T, M>
603where
604 T: ConstantTimeEq,
605{
606 fn ct_eq(&self, other: &Self) -> Choice {
607 izip_eq!(self, other).fold(Choice::from(1u8), |acc, (lhs_value, rhs_value)| {
608 acc & lhs_value.ct_eq(rhs_value)
609 })
610 }
611}
612
613impl<T: Sized, M: Positive> PartialEq for HeapArray<T, M>
616where
617 T: PartialEq,
618{
619 fn eq(&self, other: &Self) -> bool {
620 izip_eq!(self, other).all(|(lhs_value, rhs_value)| lhs_value == rhs_value)
621 }
622}
623
624impl<T: Sized + Default + Add<Output = T>, M: Positive> Sum for HeapArray<T, M> {
627 fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
628 let first = iter.next().unwrap_or_default();
629 iter.fold(first, |acc, value| acc + value)
630 }
631}
632
633impl<T: Sized + num_traits::Zero + Eq, M: Positive + Eq> num_traits::Zero for HeapArray<T, M> {
635 fn zero() -> Self {
636 Self {
637 data: (0..M::to_usize()).map(|_| T::zero()).collect(),
638 _len: PhantomData,
639 }
640 }
641
642 fn is_zero(&self) -> bool {
643 let zero = T::zero();
644 self.iter().all(|value| value == &zero)
645 }
646}
647
648impl<T: Sized + Default + num_traits::One + Eq, M: Positive + Eq> num_traits::One
649 for HeapArray<T, M>
650{
651 fn one() -> Self {
652 Self {
653 data: (0..M::to_usize()).map(|_| T::one()).collect(),
654 _len: PhantomData,
655 }
656 }
657
658 fn is_one(&self) -> bool {
659 let one = T::one();
660 self.iter().all(|value| value == &one)
661 }
662}
663
664#[cfg(test)]
665mod tests {
666 use typenum::U10;
667
668 use super::*;
669 use crate::types::heap_array::HeapArray;
670
671 #[test]
672 fn test_addition() {
673 let a = HeapArray::<usize, U10>::from_fn(|i| i);
674 let b = HeapArray::from_fn(|i| i + 1);
675 let c = HeapArray::from_fn(|i| i + i + 1);
676
677 assert_eq!(a + b, c);
678 }
679
680 #[test]
681 fn test_broadcast_addition() {
682 let a = HeapArray::<usize, U10>::from_fn(|i| i);
683 let b = 1;
684 let c = HeapArray::from_fn(|i| i + 1);
685
686 assert_eq!(a + b, c);
687 }
688
689 #[test]
690 fn test_add_assign() {
691 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
692 let b = HeapArray::from_fn(|i| i + 1);
693 let c = HeapArray::from_fn(|i| i + i + 1);
694
695 a += b;
696 assert_eq!(a, c);
697 }
698
699 #[test]
700 fn test_broadcast_add_assign() {
701 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
702 let b = 1;
703 let c = HeapArray::from_fn(|i| i + 1);
704
705 a += b;
706 assert_eq!(a, c);
707 }
708
709 #[test]
710 fn test_subtraction() {
711 let a = HeapArray::<usize, U10>::from_fn(|i| 2 * i + 1);
712 let b = HeapArray::from_fn(|i| i);
713 let c = HeapArray::from_fn(|i| i + 1);
714
715 assert_eq!(a - b, c);
716 }
717
718 #[test]
719 fn test_broadcast_subtraction() {
720 let a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
721 let b = 1;
722 let c = HeapArray::from_fn(|i| i);
723
724 assert_eq!(a - b, c);
725 }
726
727 #[test]
728 fn test_sub_assign() {
729 let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
730 let b = HeapArray::from_fn(|i| i);
731 let c = HeapArray::from_fn(|_| 1);
732
733 a -= b;
734 assert_eq!(a, c);
735 }
736
737 #[test]
738 fn test_broadcast_sub_assign() {
739 let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
740 let b = 1;
741 let c = HeapArray::from_fn(|i| i);
742
743 a -= b;
744 assert_eq!(a, c);
745 }
746
747 #[test]
748 fn test_multiplication() {
749 let a = HeapArray::<usize, U10>::from_fn(|i| i);
750 let b = HeapArray::from_fn(|i| i + 1);
751 let c = HeapArray::from_fn(|i| i * (i + 1));
752
753 assert_eq!(a * b, c);
754 }
755
756 #[test]
757 fn test_broadcast_multiplication() {
758 let a = HeapArray::<usize, U10>::from_fn(|i| i);
759 let b = 1;
760 let c = HeapArray::from_fn(|i| i);
761
762 assert_eq!(a * b, c);
763 }
764
765 #[test]
766 fn test_mul_assign() {
767 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
768 let b = HeapArray::from_fn(|i| i + 1);
769 let c = HeapArray::from_fn(|i| i * (i + 1));
770
771 a *= b;
772 assert_eq!(a, c);
773 }
774
775 #[test]
776 fn test_broadcast_mul_assign() {
777 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
778 let b = 1;
779 let c = HeapArray::from_fn(|i| i);
780
781 a *= b;
782 assert_eq!(a, c);
783 }
784
785 #[test]
786 fn test_negation() {
787 let a = HeapArray::<i64, U10>::from_fn(|i| i as i64);
788 let b = HeapArray::from_fn(|i| -(i as i64));
789
790 assert_eq!(-a, b);
791 }
792
793 #[test]
794 fn test_conditional_select() {
795 let a = HeapArray::<u32, U10>::from_fn(|i| i as u32);
796 let b = HeapArray::<u32, U10>::from_fn(|i| i as u32 + 1);
797 let choice = Choice::from(1u8);
798
799 let selected = HeapArray::conditional_select(&a, &b, choice);
800 let non_selected = HeapArray::conditional_select(&a, &b, !choice);
801 assert_eq!(selected, b);
802 assert_eq!(non_selected, a);
803 }
804
805 #[test]
806 fn test_constant_time_eq() {
807 let a = HeapArray::<usize, U10>::from_fn(|i| i);
808 let b = HeapArray::from_fn(|i| i);
809 let c = HeapArray::from_fn(|i| i + 1);
810
811 assert!(a.ct_eq(&b).unwrap_u8() == 1);
812 assert!(a.ct_eq(&c).unwrap_u8() == 0);
813 }
814}