1use std::{
2 iter::Sum,
3 marker::PhantomData,
4 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
5};
6
7use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
8
9use super::HeapArray;
10use crate::{
11 izip_eq,
12 random::{CryptoRngCore, Random},
13 types::Positive,
14};
15
16impl<T: Sized + Random, M: Positive> Random for HeapArray<T, M> {
21 fn random(mut rng: impl CryptoRngCore) -> Self {
23 Self {
24 data: (0..M::to_usize()).map(|_| T::random(&mut rng)).collect(),
25 _len: PhantomData,
26 }
27 }
28}
29
30#[macros::op_variants(owned, borrowed, flipped_commutative)]
37impl<T: Sized, M: Positive> Add<&HeapArray<T, M>> for HeapArray<T, M>
38where
39 T: for<'b> Add<&'b T, Output = T>,
40{
41 type Output = HeapArray<T, M>;
42
43 fn add(self, other: &HeapArray<T, M>) -> Self::Output {
44 Self::Output {
45 data: izip_eq!(self.data, &other.data)
46 .map(|(lhs_value, rhs_value)| lhs_value + rhs_value)
47 .collect(),
48 _len: PhantomData,
49 }
50 }
51}
52
53#[macros::op_variants(owned)]
54impl<T: Sized, M: Positive> Add<&T> for HeapArray<T, M>
55where
56 T: for<'a> Add<&'a T, Output = T>,
57{
58 type Output = HeapArray<T, M>;
59
60 fn add(self, other: &T) -> Self::Output {
61 Self::Output {
62 data: IntoIterator::into_iter(self.data)
63 .map(|value| value + other)
64 .collect(),
65 _len: PhantomData,
66 }
67 }
68}
69
70#[macros::op_variants(owned)]
71impl<T: Sized, M: Positive> AddAssign<&HeapArray<T, M>> for HeapArray<T, M>
72where
73 T: for<'b> AddAssign<&'b T>,
74{
75 fn add_assign(&mut self, other: &Self) {
76 izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.add_assign(rhs_value));
77 }
78}
79
80#[macros::op_variants(owned)]
81impl<T: Sized, M: Positive> AddAssign<&T> for HeapArray<T, M>
82where
83 T: for<'a> AddAssign<&'a T>,
84{
85 fn add_assign(&mut self, other: &T) {
86 self.iter_mut()
87 .for_each(|lhs_value| lhs_value.add_assign(other));
88 }
89}
90
91#[macros::op_variants(owned, borrowed, flipped)]
94impl<T: Sized, M: Positive> Sub<&HeapArray<T, M>> for HeapArray<T, M>
95where
96 T: for<'b> Sub<&'b T, Output = T>,
97{
98 type Output = HeapArray<T, M>;
99
100 fn sub(self, other: &HeapArray<T, M>) -> Self::Output {
101 Self::Output {
102 data: izip_eq!(self, other)
103 .map(|(lhs_value, rhs_value)| lhs_value - rhs_value)
104 .collect(),
105 _len: PhantomData,
106 }
107 }
108}
109
110#[macros::op_variants(owned, borrowed, flipped)]
111impl<T: Sized, M: Positive> Sub<&T> for HeapArray<T, M>
112where
113 T: for<'a> Sub<&'a T, Output = T>,
114{
115 type Output = HeapArray<T, M>;
116
117 fn sub(self, other: &T) -> Self::Output {
118 Self::Output {
119 data: IntoIterator::into_iter(self.data)
120 .map(|value| value - other)
121 .collect(),
122 _len: PhantomData,
123 }
124 }
125}
126
127#[macros::op_variants(owned)]
128impl<T: Sized, M: Positive> SubAssign<&HeapArray<T, M>> for HeapArray<T, M>
129where
130 T: for<'b> SubAssign<&'b T>,
131{
132 fn sub_assign(&mut self, other: &Self) {
133 izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.sub_assign(rhs_value));
134 }
135}
136
137#[macros::op_variants(owned)]
138impl<T: Sized, M: Positive> SubAssign<&T> for HeapArray<T, M>
139where
140 T: for<'a> SubAssign<&'a T>,
141{
142 fn sub_assign(&mut self, other: &T) {
143 self.iter_mut()
144 .for_each(|lhs_value| lhs_value.sub_assign(other));
145 }
146}
147
148#[macros::op_variants(owned, borrowed, flipped_commutative)]
151impl<T: Sized, M: Positive> Mul<&HeapArray<T, M>> for HeapArray<T, M>
152where
153 T: for<'b> Mul<&'b T, Output = T>,
154{
155 type Output = HeapArray<T, M>;
156
157 fn mul(self, other: &HeapArray<T, M>) -> Self::Output {
158 Self::Output {
159 data: izip_eq!(self, other)
160 .map(|(lhs_value, rhs_value)| lhs_value * rhs_value)
161 .collect(),
162 _len: PhantomData,
163 }
164 }
165}
166
167#[macros::op_variants(owned, borrowed, flipped)]
168impl<T: Sized, M: Positive> Mul<&T> for HeapArray<T, M>
169where
170 T: for<'a> Mul<&'a T, Output = T>,
171{
172 type Output = HeapArray<T, M>;
173
174 fn mul(self, other: &T) -> Self::Output {
175 Self::Output {
176 data: IntoIterator::into_iter(self.data)
177 .map(|value| value * other)
178 .collect(),
179 _len: PhantomData,
180 }
181 }
182}
183
184#[macros::op_variants(owned)]
187impl<T: Sized, M: Positive> MulAssign<&HeapArray<T, M>> for HeapArray<T, M>
188where
189 T: for<'b> MulAssign<&'b T>,
190{
191 fn mul_assign(&mut self, other: &Self) {
192 izip_eq!(self, other).for_each(|(lhs, rhs)| *lhs *= rhs);
193 }
194}
195
196#[macros::op_variants(owned)]
197impl<T: Sized, M: Positive> MulAssign<&T> for HeapArray<T, M>
198where
199 T: for<'a> MulAssign<&'a T>,
200{
201 fn mul_assign(&mut self, other: &T) {
202 self.iter_mut().for_each(|lhs| *lhs *= other);
203 }
204}
205
206#[macros::op_variants(borrowed)]
209impl<T: Sized, M: Positive> Neg for HeapArray<T, M>
210where
211 T: Neg<Output = T>,
212{
213 type Output = HeapArray<T, M>;
214
215 fn neg(self) -> Self::Output {
216 HeapArray::<T, M> {
217 data: IntoIterator::into_iter(self.data)
218 .map(|value| value.neg())
219 .collect(),
220 _len: PhantomData,
221 }
222 }
223}
224
225impl<T: Sized + ConditionallySelectable, M: Positive> HeapArray<T, M> {
228 pub fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
235 Self {
236 data: izip_eq!(a, b)
237 .map(|(lhs, rhs)| T::conditional_select(lhs, rhs, choice))
238 .collect(),
239 _len: PhantomData,
240 }
241 }
242}
243
244impl<T: Sized + ConditionallySelectable, M: Positive, L: Positive> HeapArray<HeapArray<T, M>, L> {
245 pub fn conditional_select_2d(a: &Self, b: &Self, choice: Choice) -> Self {
252 Self {
253 data: izip_eq!(a, b)
254 .map(|(lhs, rhs)| HeapArray::conditional_select(lhs, rhs, choice))
255 .collect(),
256 _len: PhantomData,
257 }
258 }
259}
260
261impl<T: Sized, M: Positive> ConstantTimeEq for HeapArray<T, M>
264where
265 T: ConstantTimeEq,
266{
267 fn ct_eq(&self, other: &Self) -> Choice {
268 izip_eq!(self, other).fold(Choice::from(1u8), |acc, (lhs_value, rhs_value)| {
269 acc & lhs_value.ct_eq(rhs_value)
270 })
271 }
272}
273
274impl<T: Sized, M: Positive> PartialEq for HeapArray<T, M>
277where
278 T: PartialEq,
279{
280 fn eq(&self, other: &Self) -> bool {
281 izip_eq!(self, other).all(|(lhs_value, rhs_value)| lhs_value == rhs_value)
282 }
283}
284
285impl<T: Clone + Default + for<'b> Add<&'b T, Output = T>, M: Positive> Sum for HeapArray<T, M> {
288 fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
289 let first = iter.next().unwrap_or_default();
290 iter.fold(first, |acc, value| acc + &value)
291 }
292}
293
294impl<T: Clone + num_traits::Zero + Eq + for<'b> Add<&'b T, Output = T>, M: Positive + Eq>
296 num_traits::Zero for HeapArray<T, M>
297{
298 fn zero() -> Self {
299 Self {
300 data: (0..M::to_usize()).map(|_| T::zero()).collect(),
301 _len: PhantomData,
302 }
303 }
304
305 fn is_zero(&self) -> bool {
306 let zero = T::zero();
307 self.iter().all(|value| value == &zero)
308 }
309}
310
311impl<
312 T: Sized + Default + num_traits::One + Eq + for<'b> Mul<&'b T, Output = T>,
313 M: Positive + Eq,
314 > num_traits::One for HeapArray<T, M>
315{
316 fn one() -> Self {
317 Self {
318 data: (0..M::to_usize()).map(|_| T::one()).collect(),
319 _len: PhantomData,
320 }
321 }
322
323 fn is_one(&self) -> bool {
324 let one = T::one();
325 self.iter().all(|value| value == &one)
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use typenum::U10;
332
333 use super::*;
334 use crate::types::heap_array::HeapArray;
335
336 #[test]
337 fn test_addition() {
338 let a = HeapArray::<usize, U10>::from_fn(|i| i);
339 let b = HeapArray::from_fn(|i| i + 1);
340 let c = HeapArray::from_fn(|i| i + i + 1);
341
342 assert_eq!(a.clone() + &b, c);
343 assert_eq!(&a + b.clone(), c);
344 assert_eq!(&a + &b, c);
345 assert_eq!(a + b, c);
346 }
347
348 #[test]
349 fn test_broadcast_addition() {
350 let a = HeapArray::<usize, U10>::from_fn(|i| i);
351 let b = 1;
352 let c = HeapArray::from_fn(|i| i + 1);
353
354 assert_eq!(a + b, c);
355 }
356
357 #[test]
358 fn test_add_assign() {
359 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
360 let b = HeapArray::from_fn(|i| i + 1);
361 let c = HeapArray::from_fn(|i| i + i + 1);
362
363 a += b;
364 assert_eq!(a, c);
365 }
366
367 #[test]
368 fn test_broadcast_add_assign() {
369 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
370 let b = 1;
371 let c = HeapArray::from_fn(|i| i + 1);
372
373 a += b;
374 assert_eq!(a, c);
375 }
376
377 #[test]
378 fn test_subtraction() {
379 let a = HeapArray::<usize, U10>::from_fn(|i| 2 * i + 1);
380 let b = HeapArray::from_fn(|i| i);
381 let c = HeapArray::from_fn(|i| i + 1);
382
383 assert_eq!(a - b, c);
384 }
385
386 #[test]
387 fn test_broadcast_subtraction() {
388 let a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
389 let b = 1;
390 let c = HeapArray::from_fn(|i| i);
391
392 assert_eq!(a - b, c);
393 }
394
395 #[test]
396 fn test_sub_assign() {
397 let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
398 let b = HeapArray::from_fn(|i| i);
399 let c = HeapArray::from_fn(|_| 1);
400
401 a -= b;
402 assert_eq!(a, c);
403 }
404
405 #[test]
406 fn test_broadcast_sub_assign() {
407 let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
408 let b = 1;
409 let c = HeapArray::from_fn(|i| i);
410
411 a -= b;
412 assert_eq!(a, c);
413 }
414
415 #[test]
416 fn test_multiplication() {
417 let a = HeapArray::<usize, U10>::from_fn(|i| i);
418 let b = HeapArray::from_fn(|i| i + 1);
419 let c = HeapArray::from_fn(|i| i * (i + 1));
420
421 assert_eq!(a * b, c);
422 }
423
424 #[test]
425 fn test_broadcast_multiplication() {
426 let a = HeapArray::<usize, U10>::from_fn(|i| i);
427 let b = 1;
428 let c = HeapArray::from_fn(|i| i);
429
430 assert_eq!(a * b, c);
431 }
432
433 #[test]
434 fn test_mul_assign() {
435 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
436 let b = HeapArray::from_fn(|i| i + 1);
437 let c = HeapArray::from_fn(|i| i * (i + 1));
438
439 a *= b;
440 assert_eq!(a, c);
441 }
442
443 #[test]
444 fn test_broadcast_mul_assign() {
445 let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
446 let b = 1;
447 let c = HeapArray::from_fn(|i| i);
448
449 a *= b;
450 assert_eq!(a, c);
451 }
452
453 #[test]
454 fn test_negation() {
455 let a = HeapArray::<i64, U10>::from_fn(|i| i as i64);
456 let b = HeapArray::from_fn(|i| -(i as i64));
457
458 assert_eq!(-a, b);
459 }
460
461 #[test]
462 fn test_conditional_select() {
463 let a = HeapArray::<u32, U10>::from_fn(|i| i as u32);
464 let b = HeapArray::<u32, U10>::from_fn(|i| i as u32 + 1);
465 let choice = Choice::from(1u8);
466
467 let selected = HeapArray::conditional_select(&a, &b, choice);
468 let non_selected = HeapArray::conditional_select(&a, &b, !choice);
469 assert_eq!(selected, b);
470 assert_eq!(non_selected, a);
471 }
472
473 #[test]
474 fn test_constant_time_eq() {
475 let a = HeapArray::<usize, U10>::from_fn(|i| i);
476 let b = HeapArray::from_fn(|i| i);
477 let c = HeapArray::from_fn(|i| i + 1);
478
479 assert!(a.ct_eq(&b).unwrap_u8() == 1);
480 assert!(a.ct_eq(&c).unwrap_u8() == 0);
481 }
482}