Skip to main content

primitives/types/heap_array/
ops.rs

1use std::{
2    iter::Sum,
3    marker::PhantomData,
4    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
5};
6
7use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
8
9use super::HeapArray;
10use crate::{
11    izip_eq,
12    random::{CryptoRngCore, Random},
13    types::Positive,
14};
15
16// ----------
17// | Random |
18// ----------
19
20impl<T: Sized + Random, M: Positive> Random for HeapArray<T, M> {
21    /// Generate a random array of length `M`.
22    fn random(mut rng: impl CryptoRngCore) -> Self {
23        Self {
24            data: (0..M::to_usize()).map(|_| T::random(&mut rng)).collect(),
25            _len: PhantomData,
26        }
27    }
28}
29
30// --------------
31// | Arithmetic |
32// --------------
33
34// === Addition === //
35
36#[macros::op_variants(owned, borrowed, flipped_commutative)]
37impl<T: Sized, M: Positive> Add<&HeapArray<T, M>> for HeapArray<T, M>
38where
39    T: for<'b> Add<&'b T, Output = T>,
40{
41    type Output = HeapArray<T, M>;
42
43    fn add(self, other: &HeapArray<T, M>) -> Self::Output {
44        Self::Output {
45            data: izip_eq!(self.data, &other.data)
46                .map(|(lhs_value, rhs_value)| lhs_value + rhs_value)
47                .collect(),
48            _len: PhantomData,
49        }
50    }
51}
52
53#[macros::op_variants(owned)]
54impl<T: Sized, M: Positive> Add<&T> for HeapArray<T, M>
55where
56    T: for<'a> Add<&'a T, Output = T>,
57{
58    type Output = HeapArray<T, M>;
59
60    fn add(self, other: &T) -> Self::Output {
61        Self::Output {
62            data: IntoIterator::into_iter(self.data)
63                .map(|value| value + other)
64                .collect(),
65            _len: PhantomData,
66        }
67    }
68}
69
70#[macros::op_variants(owned)]
71impl<T: Sized, M: Positive> AddAssign<&HeapArray<T, M>> for HeapArray<T, M>
72where
73    T: for<'b> AddAssign<&'b T>,
74{
75    fn add_assign(&mut self, other: &Self) {
76        izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.add_assign(rhs_value));
77    }
78}
79
80#[macros::op_variants(owned)]
81impl<T: Sized, M: Positive> AddAssign<&T> for HeapArray<T, M>
82where
83    T: for<'a> AddAssign<&'a T>,
84{
85    fn add_assign(&mut self, other: &T) {
86        self.iter_mut()
87            .for_each(|lhs_value| lhs_value.add_assign(other));
88    }
89}
90
91// === Subtraction === //
92
93#[macros::op_variants(owned, borrowed, flipped)]
94impl<T: Sized, M: Positive> Sub<&HeapArray<T, M>> for HeapArray<T, M>
95where
96    T: for<'b> Sub<&'b T, Output = T>,
97{
98    type Output = HeapArray<T, M>;
99
100    fn sub(self, other: &HeapArray<T, M>) -> Self::Output {
101        Self::Output {
102            data: izip_eq!(self, other)
103                .map(|(lhs_value, rhs_value)| lhs_value - rhs_value)
104                .collect(),
105            _len: PhantomData,
106        }
107    }
108}
109
110#[macros::op_variants(owned, borrowed, flipped)]
111impl<T: Sized, M: Positive> Sub<&T> for HeapArray<T, M>
112where
113    T: for<'a> Sub<&'a T, Output = T>,
114{
115    type Output = HeapArray<T, M>;
116
117    fn sub(self, other: &T) -> Self::Output {
118        Self::Output {
119            data: IntoIterator::into_iter(self.data)
120                .map(|value| value - other)
121                .collect(),
122            _len: PhantomData,
123        }
124    }
125}
126
127#[macros::op_variants(owned)]
128impl<T: Sized, M: Positive> SubAssign<&HeapArray<T, M>> for HeapArray<T, M>
129where
130    T: for<'b> SubAssign<&'b T>,
131{
132    fn sub_assign(&mut self, other: &Self) {
133        izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.sub_assign(rhs_value));
134    }
135}
136
137#[macros::op_variants(owned)]
138impl<T: Sized, M: Positive> SubAssign<&T> for HeapArray<T, M>
139where
140    T: for<'a> SubAssign<&'a T>,
141{
142    fn sub_assign(&mut self, other: &T) {
143        self.iter_mut()
144            .for_each(|lhs_value| lhs_value.sub_assign(other));
145    }
146}
147
148// === Multiplication === //
149
150#[macros::op_variants(owned, borrowed, flipped_commutative)]
151impl<T: Sized, M: Positive> Mul<&HeapArray<T, M>> for HeapArray<T, M>
152where
153    T: for<'b> Mul<&'b T, Output = T>,
154{
155    type Output = HeapArray<T, M>;
156
157    fn mul(self, other: &HeapArray<T, M>) -> Self::Output {
158        Self::Output {
159            data: izip_eq!(self, other)
160                .map(|(lhs_value, rhs_value)| lhs_value * rhs_value)
161                .collect(),
162            _len: PhantomData,
163        }
164    }
165}
166
167#[macros::op_variants(owned, borrowed, flipped)]
168impl<T: Sized, M: Positive> Mul<&T> for HeapArray<T, M>
169where
170    T: for<'a> Mul<&'a T, Output = T>,
171{
172    type Output = HeapArray<T, M>;
173
174    fn mul(self, other: &T) -> Self::Output {
175        Self::Output {
176            data: IntoIterator::into_iter(self.data)
177                .map(|value| value * other)
178                .collect(),
179            _len: PhantomData,
180        }
181    }
182}
183
184// === MulAssign === //
185
186#[macros::op_variants(owned)]
187impl<T: Sized, M: Positive> MulAssign<&HeapArray<T, M>> for HeapArray<T, M>
188where
189    T: for<'b> MulAssign<&'b T>,
190{
191    fn mul_assign(&mut self, other: &Self) {
192        izip_eq!(self, other).for_each(|(lhs, rhs)| *lhs *= rhs);
193    }
194}
195
196#[macros::op_variants(owned)]
197impl<T: Sized, M: Positive> MulAssign<&T> for HeapArray<T, M>
198where
199    T: for<'a> MulAssign<&'a T>,
200{
201    fn mul_assign(&mut self, other: &T) {
202        self.iter_mut().for_each(|lhs| *lhs *= other);
203    }
204}
205
206// === Negation === //
207
208#[macros::op_variants(borrowed)]
209impl<T: Sized, M: Positive> Neg for HeapArray<T, M>
210where
211    T: Neg<Output = T>,
212{
213    type Output = HeapArray<T, M>;
214
215    fn neg(self) -> Self::Output {
216        HeapArray::<T, M> {
217            data: IntoIterator::into_iter(self.data)
218                .map(|value| value.neg())
219                .collect(),
220            _len: PhantomData,
221        }
222    }
223}
224
225// === ConditionallySelectable === //
226
227impl<T: Sized + ConditionallySelectable, M: Positive> HeapArray<T, M> {
228    /// Select `a` or `b` according to `choice`.
229    ///
230    /// # Returns
231    ///
232    /// * `a` if `choice == Choice(0)`;
233    /// * `b` if `choice == Choice(1)`.
234    pub fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
235        Self {
236            data: izip_eq!(a, b)
237                .map(|(lhs, rhs)| T::conditional_select(lhs, rhs, choice))
238                .collect(),
239            _len: PhantomData,
240        }
241    }
242}
243
244impl<T: Sized + ConditionallySelectable, M: Positive, L: Positive> HeapArray<HeapArray<T, M>, L> {
245    /// Select `a` or `b` according to `choice`.
246    ///
247    /// # Returns
248    ///
249    /// * `a` if `choice == Choice(0)`;
250    /// * `b` if `choice == Choice(1)`.
251    pub fn conditional_select_2d(a: &Self, b: &Self, choice: Choice) -> Self {
252        Self {
253            data: izip_eq!(a, b)
254                .map(|(lhs, rhs)| HeapArray::conditional_select(lhs, rhs, choice))
255                .collect(),
256            _len: PhantomData,
257        }
258    }
259}
260
261// === ConstantTimeEq === //
262
263impl<T: Sized, M: Positive> ConstantTimeEq for HeapArray<T, M>
264where
265    T: ConstantTimeEq,
266{
267    fn ct_eq(&self, other: &Self) -> Choice {
268        izip_eq!(self, other).fold(Choice::from(1u8), |acc, (lhs_value, rhs_value)| {
269            acc & lhs_value.ct_eq(rhs_value)
270        })
271    }
272}
273
274// === PartialEq === //
275
276impl<T: Sized, M: Positive> PartialEq for HeapArray<T, M>
277where
278    T: PartialEq,
279{
280    fn eq(&self, other: &Self) -> bool {
281        izip_eq!(self, other).all(|(lhs_value, rhs_value)| lhs_value == rhs_value)
282    }
283}
284
285// === Sum === //
286
287impl<T: Clone + Default + for<'b> Add<&'b T, Output = T>, M: Positive> Sum for HeapArray<T, M> {
288    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
289        let first = iter.next().unwrap_or_default();
290        iter.fold(first, |acc, value| acc + &value)
291    }
292}
293
294// === Zero and One === //
295impl<T: Clone + num_traits::Zero + Eq + for<'b> Add<&'b T, Output = T>, M: Positive + Eq>
296    num_traits::Zero for HeapArray<T, M>
297{
298    fn zero() -> Self {
299        Self {
300            data: (0..M::to_usize()).map(|_| T::zero()).collect(),
301            _len: PhantomData,
302        }
303    }
304
305    fn is_zero(&self) -> bool {
306        let zero = T::zero();
307        self.iter().all(|value| value == &zero)
308    }
309}
310
311impl<
312        T: Sized + Default + num_traits::One + Eq + for<'b> Mul<&'b T, Output = T>,
313        M: Positive + Eq,
314    > num_traits::One for HeapArray<T, M>
315{
316    fn one() -> Self {
317        Self {
318            data: (0..M::to_usize()).map(|_| T::one()).collect(),
319            _len: PhantomData,
320        }
321    }
322
323    fn is_one(&self) -> bool {
324        let one = T::one();
325        self.iter().all(|value| value == &one)
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use typenum::U10;
332
333    use super::*;
334    use crate::types::heap_array::HeapArray;
335
336    #[test]
337    fn test_addition() {
338        let a = HeapArray::<usize, U10>::from_fn(|i| i);
339        let b = HeapArray::from_fn(|i| i + 1);
340        let c = HeapArray::from_fn(|i| i + i + 1);
341
342        assert_eq!(a.clone() + &b, c);
343        assert_eq!(&a + b.clone(), c);
344        assert_eq!(&a + &b, c);
345        assert_eq!(a + b, c);
346    }
347
348    #[test]
349    fn test_broadcast_addition() {
350        let a = HeapArray::<usize, U10>::from_fn(|i| i);
351        let b = 1;
352        let c = HeapArray::from_fn(|i| i + 1);
353
354        assert_eq!(a + b, c);
355    }
356
357    #[test]
358    fn test_add_assign() {
359        let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
360        let b = HeapArray::from_fn(|i| i + 1);
361        let c = HeapArray::from_fn(|i| i + i + 1);
362
363        a += b;
364        assert_eq!(a, c);
365    }
366
367    #[test]
368    fn test_broadcast_add_assign() {
369        let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
370        let b = 1;
371        let c = HeapArray::from_fn(|i| i + 1);
372
373        a += b;
374        assert_eq!(a, c);
375    }
376
377    #[test]
378    fn test_subtraction() {
379        let a = HeapArray::<usize, U10>::from_fn(|i| 2 * i + 1);
380        let b = HeapArray::from_fn(|i| i);
381        let c = HeapArray::from_fn(|i| i + 1);
382
383        assert_eq!(a - b, c);
384    }
385
386    #[test]
387    fn test_broadcast_subtraction() {
388        let a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
389        let b = 1;
390        let c = HeapArray::from_fn(|i| i);
391
392        assert_eq!(a - b, c);
393    }
394
395    #[test]
396    fn test_sub_assign() {
397        let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
398        let b = HeapArray::from_fn(|i| i);
399        let c = HeapArray::from_fn(|_| 1);
400
401        a -= b;
402        assert_eq!(a, c);
403    }
404
405    #[test]
406    fn test_broadcast_sub_assign() {
407        let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
408        let b = 1;
409        let c = HeapArray::from_fn(|i| i);
410
411        a -= b;
412        assert_eq!(a, c);
413    }
414
415    #[test]
416    fn test_multiplication() {
417        let a = HeapArray::<usize, U10>::from_fn(|i| i);
418        let b = HeapArray::from_fn(|i| i + 1);
419        let c = HeapArray::from_fn(|i| i * (i + 1));
420
421        assert_eq!(a * b, c);
422    }
423
424    #[test]
425    fn test_broadcast_multiplication() {
426        let a = HeapArray::<usize, U10>::from_fn(|i| i);
427        let b = 1;
428        let c = HeapArray::from_fn(|i| i);
429
430        assert_eq!(a * b, c);
431    }
432
433    #[test]
434    fn test_mul_assign() {
435        let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
436        let b = HeapArray::from_fn(|i| i + 1);
437        let c = HeapArray::from_fn(|i| i * (i + 1));
438
439        a *= b;
440        assert_eq!(a, c);
441    }
442
443    #[test]
444    fn test_broadcast_mul_assign() {
445        let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
446        let b = 1;
447        let c = HeapArray::from_fn(|i| i);
448
449        a *= b;
450        assert_eq!(a, c);
451    }
452
453    #[test]
454    fn test_negation() {
455        let a = HeapArray::<i64, U10>::from_fn(|i| i as i64);
456        let b = HeapArray::from_fn(|i| -(i as i64));
457
458        assert_eq!(-a, b);
459    }
460
461    #[test]
462    fn test_conditional_select() {
463        let a = HeapArray::<u32, U10>::from_fn(|i| i as u32);
464        let b = HeapArray::<u32, U10>::from_fn(|i| i as u32 + 1);
465        let choice = Choice::from(1u8);
466
467        let selected = HeapArray::conditional_select(&a, &b, choice);
468        let non_selected = HeapArray::conditional_select(&a, &b, !choice);
469        assert_eq!(selected, b);
470        assert_eq!(non_selected, a);
471    }
472
473    #[test]
474    fn test_constant_time_eq() {
475        let a = HeapArray::<usize, U10>::from_fn(|i| i);
476        let b = HeapArray::from_fn(|i| i);
477        let c = HeapArray::from_fn(|i| i + 1);
478
479        assert!(a.ct_eq(&b).unwrap_u8() == 1);
480        assert!(a.ct_eq(&c).unwrap_u8() == 0);
481    }
482}