Skip to main content

primitives/types/heap_array/
ops.rs

1use std::{
2    iter::Sum,
3    marker::PhantomData,
4    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
5};
6
7use ff::Field;
8use subtle::{Choice, ConstantTimeEq};
9
10use super::HeapArray;
11use crate::{
12    errors::PrimitiveError,
13    izip_eq,
14    random::{CryptoRngCore, Random, RandomNonZero},
15    types::{ConditionallySelectable, Element, Positive},
16};
17
18// ----------
19// | Random |
20// ----------
21
22impl<T: Sized + Random, M: Positive> Random for HeapArray<T, M> {
23    /// Generate a random array of length `M`.
24    fn random(mut rng: impl CryptoRngCore) -> Self {
25        Self {
26            data: T::random_n::<Box<[_]>>(&mut rng, M::SIZE),
27            _len: PhantomData,
28        }
29    }
30}
31
32impl<T: Sized + RandomNonZero, M: Positive> RandomNonZero for HeapArray<T, M> {
33    /// Generate a random array of length `M`.
34    fn random_non_zero(mut rng: impl CryptoRngCore) -> Result<Self, PrimitiveError> {
35        Ok(Self {
36            data: T::random_n_non_zero::<Box<[_]>>(&mut rng, M::SIZE)?,
37            _len: PhantomData,
38        })
39    }
40}
41
42// --------------
43// | Arithmetic |
44// --------------
45
46// === Addition === //
47
48#[macros::op_variants(owned, borrowed, flipped_commutative)]
49impl<T: Sized, M: Positive> Add<&HeapArray<T, M>> for HeapArray<T, M>
50where
51    T: for<'b> Add<&'b T, Output = T>,
52{
53    type Output = HeapArray<T, M>;
54
55    fn add(self, other: &HeapArray<T, M>) -> Self::Output {
56        Self::Output {
57            data: izip_eq!(self.data, &other.data)
58                .map(|(lhs_value, rhs_value)| lhs_value + rhs_value)
59                .collect(),
60            _len: PhantomData,
61        }
62    }
63}
64
65#[macros::op_variants(owned)]
66impl<T: Sized, M: Positive> Add<&T> for HeapArray<T, M>
67where
68    T: for<'a> Add<&'a T, Output = T>,
69{
70    type Output = HeapArray<T, M>;
71
72    fn add(self, other: &T) -> Self::Output {
73        Self::Output {
74            data: IntoIterator::into_iter(self.data)
75                .map(|value| value + other)
76                .collect(),
77            _len: PhantomData,
78        }
79    }
80}
81
82#[macros::op_variants(owned)]
83impl<T: Sized, M: Positive> AddAssign<&HeapArray<T, M>> for HeapArray<T, M>
84where
85    T: for<'b> AddAssign<&'b T>,
86{
87    fn add_assign(&mut self, other: &Self) {
88        izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.add_assign(rhs_value));
89    }
90}
91
92#[macros::op_variants(owned)]
93impl<T: Sized, M: Positive> AddAssign<&T> for HeapArray<T, M>
94where
95    T: for<'a> AddAssign<&'a T>,
96{
97    fn add_assign(&mut self, other: &T) {
98        self.iter_mut()
99            .for_each(|lhs_value| lhs_value.add_assign(other));
100    }
101}
102
103// === Subtraction === //
104
105#[macros::op_variants(owned, borrowed, flipped)]
106impl<T: Sized, M: Positive> Sub<&HeapArray<T, M>> for HeapArray<T, M>
107where
108    T: for<'b> Sub<&'b T, Output = T>,
109{
110    type Output = HeapArray<T, M>;
111
112    fn sub(self, other: &HeapArray<T, M>) -> Self::Output {
113        Self::Output {
114            data: izip_eq!(self, other)
115                .map(|(lhs_value, rhs_value)| lhs_value - rhs_value)
116                .collect(),
117            _len: PhantomData,
118        }
119    }
120}
121
122#[macros::op_variants(owned, borrowed, flipped)]
123impl<T: Sized, M: Positive> Sub<&T> for HeapArray<T, M>
124where
125    T: for<'a> Sub<&'a T, Output = T>,
126{
127    type Output = HeapArray<T, M>;
128
129    fn sub(self, other: &T) -> Self::Output {
130        Self::Output {
131            data: IntoIterator::into_iter(self.data)
132                .map(|value| value - other)
133                .collect(),
134            _len: PhantomData,
135        }
136    }
137}
138
139#[macros::op_variants(owned)]
140impl<T: Sized, M: Positive> SubAssign<&HeapArray<T, M>> for HeapArray<T, M>
141where
142    T: for<'b> SubAssign<&'b T>,
143{
144    fn sub_assign(&mut self, other: &Self) {
145        izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.sub_assign(rhs_value));
146    }
147}
148
149#[macros::op_variants(owned)]
150impl<T: Sized, M: Positive> SubAssign<&T> for HeapArray<T, M>
151where
152    T: for<'a> SubAssign<&'a T>,
153{
154    fn sub_assign(&mut self, other: &T) {
155        self.iter_mut()
156            .for_each(|lhs_value| lhs_value.sub_assign(other));
157    }
158}
159
160// === Multiplication === //
161
162#[macros::op_variants(owned, borrowed, flipped)]
163impl<T1: Sized, T2: Sized, T3: Sized, M: Positive> Mul<&HeapArray<T2, M>> for HeapArray<T1, M>
164where
165    T1: for<'b> Mul<&'b T2, Output = T3>,
166{
167    type Output = HeapArray<T3, M>;
168    fn mul(self, other: &HeapArray<T2, M>) -> Self::Output {
169        Self::Output {
170            data: izip_eq!(self, other)
171                .map(|(lhs_value, rhs_value)| lhs_value * rhs_value)
172                .collect(),
173            _len: PhantomData,
174        }
175    }
176}
177
178#[macros::op_variants(borrowed)]
179impl<T1: Sized, T2: Sized + Element, T3: Sized, M: Positive> Mul<&T2> for HeapArray<T1, M>
180where
181    T1: for<'a> Mul<&'a T2, Output = T3>,
182{
183    type Output = HeapArray<T3, M>;
184    fn mul(self, other: &T2) -> Self::Output {
185        Self::Output {
186            data: IntoIterator::into_iter(self.data)
187                .map(|value| value * other)
188                .collect(),
189            _len: PhantomData,
190        }
191    }
192}
193
194// === MulAssign === //
195
196#[macros::op_variants(owned)]
197impl<T1: Sized, T2: Sized, M: Positive> MulAssign<&HeapArray<T2, M>> for HeapArray<T1, M>
198where
199    T1: for<'b> MulAssign<&'b T2>,
200{
201    fn mul_assign(&mut self, other: &HeapArray<T2, M>) {
202        izip_eq!(self, other).for_each(|(lhs, rhs)| *lhs *= rhs);
203    }
204}
205
206impl<T: Sized, T2: Sized + Element, M: Positive> MulAssign<&T2> for HeapArray<T, M>
207where
208    T: for<'a> MulAssign<&'a T2>,
209{
210    fn mul_assign(&mut self, other: &T2) {
211        self.iter_mut().for_each(|lhs| *lhs *= other);
212    }
213}
214
215// === Negation === //
216
217#[macros::op_variants(borrowed)]
218impl<T: Sized, M: Positive> Neg for HeapArray<T, M>
219where
220    T: Neg<Output = T>,
221{
222    type Output = HeapArray<T, M>;
223
224    fn neg(self) -> Self::Output {
225        HeapArray::<T, M> {
226            data: IntoIterator::into_iter(self.data)
227                .map(|value| value.neg())
228                .collect(),
229            _len: PhantomData,
230        }
231    }
232}
233
234// === Field Ops === //
235
236impl<T: Field, M: Positive> HeapArray<T, M> {
237    /// Squares each element in the array.
238    pub fn square(&self) -> Self {
239        Self {
240            data: self.iter().map(|value| value.square()).collect(),
241            _len: PhantomData,
242        }
243    }
244
245    /// Doubles each element in the array.
246    pub fn double(&self) -> Self {
247        Self {
248            data: self.iter().map(|value| value.double()).collect(),
249            _len: PhantomData,
250        }
251    }
252
253    /// Inverts each element in the array.
254    pub fn invert(&self) -> Option<Self> {
255        let inverted_data: Option<Vec<T>> =
256            self.iter().map(|value| value.invert().into()).collect();
257
258        inverted_data.map(|data| HeapArray {
259            data: data.into(),
260            _len: PhantomData,
261        })
262    }
263
264    /// Exponentiates `self` by `exp`, where `exp` is a little-endian order integer
265    /// exponent.
266    pub fn pow<S: AsRef<[u64]> + Clone>(&self, exp: S) -> Self {
267        Self {
268            data: self.iter().map(|value| value.pow(exp.clone())).collect(),
269            _len: PhantomData,
270        }
271    }
272}
273
274// === ConditionallySelectable === //
275
276impl<T: Sized + ConditionallySelectable, M: Positive> ConditionallySelectable for HeapArray<T, M> {
277    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
278        Self {
279            data: izip_eq!(a, b)
280                .map(|(lhs, rhs)| T::conditional_select(lhs, rhs, choice))
281                .collect(),
282            _len: PhantomData,
283        }
284    }
285}
286
287impl<T: Sized + ConditionallySelectable, M: Positive, L: Positive> HeapArray<HeapArray<T, M>, L> {
288    /// Select `a` or `b` according to `choice`.
289    ///
290    /// # Returns
291    ///
292    /// * `a` if `choice == Choice(0)`;
293    /// * `b` if `choice == Choice(1)`.
294    pub fn conditional_select_2d(a: &Self, b: &Self, choice: Choice) -> Self {
295        Self {
296            data: izip_eq!(a, b)
297                .map(|(lhs, rhs)| HeapArray::conditional_select(lhs, rhs, choice))
298                .collect(),
299            _len: PhantomData,
300        }
301    }
302}
303
304// === ConstantTimeEq === //
305
306impl<T: Sized, M: Positive> ConstantTimeEq for HeapArray<T, M>
307where
308    T: ConstantTimeEq,
309{
310    fn ct_eq(&self, other: &Self) -> Choice {
311        izip_eq!(self, other).fold(Choice::from(1u8), |acc, (lhs_value, rhs_value)| {
312            acc & lhs_value.ct_eq(rhs_value)
313        })
314    }
315}
316
317// === PartialEq === //
318
319impl<T: Sized, M: Positive> PartialEq for HeapArray<T, M>
320where
321    T: PartialEq,
322{
323    fn eq(&self, other: &Self) -> bool {
324        izip_eq!(self, other).all(|(lhs_value, rhs_value)| lhs_value == rhs_value)
325    }
326}
327
328// === Sum === //
329
330impl<T: Default, M: Positive> Sum for HeapArray<T, M>
331where
332    HeapArray<T, M>: Add<Output = HeapArray<T, M>>,
333{
334    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
335        let first = iter.next().unwrap_or_default();
336        iter.fold(first, |acc, value| acc + value)
337    }
338}
339
340impl<'a, T: Default + Clone, M: Positive> Sum<&'a HeapArray<T, M>> for HeapArray<T, M>
341where
342    HeapArray<T, M>: for<'b> Add<&'b HeapArray<T, M>, Output = HeapArray<T, M>>,
343{
344    fn sum<I: Iterator<Item = &'a Self>>(mut iter: I) -> Self {
345        let first = iter.next().cloned().unwrap_or_default();
346        iter.fold(first, |acc, value| acc + value)
347    }
348}
349
350// === Zero and One === //
351impl<T: Clone + num_traits::Zero + Eq + for<'b> Add<&'b T, Output = T>, M: Positive + Eq>
352    num_traits::Zero for HeapArray<T, M>
353{
354    fn zero() -> Self {
355        Self {
356            data: (0..M::to_usize()).map(|_| T::zero()).collect(),
357            _len: PhantomData,
358        }
359    }
360
361    fn is_zero(&self) -> bool {
362        let zero = T::zero();
363        self.iter().all(|value| value == &zero)
364    }
365}
366
367impl<
368        T: Sized + Default + num_traits::One + Eq + for<'b> Mul<&'b T, Output = T>,
369        M: Positive + Eq,
370    > num_traits::One for HeapArray<T, M>
371{
372    fn one() -> Self {
373        Self {
374            data: (0..M::to_usize()).map(|_| T::one()).collect(),
375            _len: PhantomData,
376        }
377    }
378
379    fn is_one(&self) -> bool {
380        let one = T::one();
381        self.iter().all(|value| value == &one)
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use typenum::U10;
388
389    use super::*;
390    use crate::types::heap_array::HeapArray;
391
392    #[test]
393    fn test_addition() {
394        let a = HeapArray::<usize, U10>::from_fn(|i| i);
395        let b = HeapArray::from_fn(|i| i + 1);
396        let c = HeapArray::from_fn(|i| i + i + 1);
397
398        assert_eq!(a.clone() + &b, c);
399        assert_eq!(&a + b.clone(), c);
400        assert_eq!(&a + &b, c);
401        assert_eq!(a + b, c);
402    }
403
404    #[test]
405    fn test_broadcast_addition() {
406        let a = HeapArray::<usize, U10>::from_fn(|i| i);
407        let b = 1;
408        let c = HeapArray::from_fn(|i| i + 1);
409
410        assert_eq!(a + b, c);
411    }
412
413    #[test]
414    fn test_add_assign() {
415        let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
416        let b = HeapArray::from_fn(|i| i + 1);
417        let c = HeapArray::from_fn(|i| i + i + 1);
418
419        a += b;
420        assert_eq!(a, c);
421    }
422
423    #[test]
424    fn test_broadcast_add_assign() {
425        let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
426        let b = 1;
427        let c = HeapArray::from_fn(|i| i + 1);
428
429        a += b;
430        assert_eq!(a, c);
431    }
432
433    #[test]
434    fn test_subtraction() {
435        let a = HeapArray::<usize, U10>::from_fn(|i| 2 * i + 1);
436        let b = HeapArray::from_fn(|i| i);
437        let c = HeapArray::from_fn(|i| i + 1);
438
439        assert_eq!(a - b, c);
440    }
441
442    #[test]
443    fn test_broadcast_subtraction() {
444        let a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
445        let b = 1;
446        let c = HeapArray::from_fn(|i| i);
447
448        assert_eq!(a - b, c);
449    }
450
451    #[test]
452    fn test_sub_assign() {
453        let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
454        let b = HeapArray::from_fn(|i| i);
455        let c = HeapArray::from_fn(|_| 1);
456
457        a -= b;
458        assert_eq!(a, c);
459    }
460
461    #[test]
462    fn test_broadcast_sub_assign() {
463        let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
464        let b = 1;
465        let c = HeapArray::from_fn(|i| i);
466
467        a -= b;
468        assert_eq!(a, c);
469    }
470
471    #[test]
472    fn test_multiplication() {
473        let a = HeapArray::<usize, U10>::from_fn(|i| i);
474        let b = HeapArray::from_fn(|i| i + 1);
475        let c = HeapArray::from_fn(|i| i * (i + 1));
476
477        assert_eq!(a * b, c);
478    }
479
480    #[test]
481    fn test_broadcast_multiplication() {
482        let a = HeapArray::<usize, U10>::from_fn(|i| i);
483        let b = 1;
484        let c = HeapArray::from_fn(|i| i);
485
486        assert_eq!(a * &b, c);
487    }
488
489    #[test]
490    fn test_mul_assign() {
491        let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
492        let b = HeapArray::from_fn(|i| i + 1);
493        let c = HeapArray::from_fn(|i| i * (i + 1));
494
495        a *= b;
496        assert_eq!(a, c);
497    }
498
499    #[test]
500    fn test_broadcast_mul_assign() {
501        let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
502        let b = 1;
503        let c = HeapArray::from_fn(|i| i);
504
505        a *= &b;
506        assert_eq!(a, c);
507    }
508
509    #[test]
510    fn test_negation() {
511        let a = HeapArray::<i64, U10>::from_fn(|i| i as i64);
512        let b = HeapArray::from_fn(|i| -(i as i64));
513
514        assert_eq!(-a, b);
515    }
516
517    #[test]
518    fn test_conditional_select() {
519        let a = HeapArray::<u32, U10>::from_fn(|i| i as u32);
520        let b = HeapArray::<u32, U10>::from_fn(|i| i as u32 + 1);
521        let choice = Choice::from(1u8);
522
523        let selected = HeapArray::conditional_select(&a, &b, choice);
524        let non_selected = HeapArray::conditional_select(&a, &b, !choice);
525        assert_eq!(selected, b);
526        assert_eq!(non_selected, a);
527    }
528
529    #[test]
530    fn test_constant_time_eq() {
531        let a = HeapArray::<usize, U10>::from_fn(|i| i);
532        let b = HeapArray::from_fn(|i| i);
533        let c = HeapArray::from_fn(|i| i + 1);
534
535        assert!(a.ct_eq(&b).unwrap_u8() == 1);
536        assert!(a.ct_eq(&c).unwrap_u8() == 0);
537    }
538}