Skip to main content

primitives/types/heap_array/
ops.rs

1use std::{
2    iter::Sum,
3    marker::PhantomData,
4    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
5};
6
7use ff::Field;
8use subtle::{Choice, ConstantTimeEq};
9
10use super::HeapArray;
11use crate::{
12    izip_eq,
13    random::{CryptoRngCore, Random},
14    types::{ConditionallySelectable, Element, Positive},
15};
16
17// ----------
18// | Random |
19// ----------
20
21impl<T: Sized + Random, M: Positive> Random for HeapArray<T, M> {
22    /// Generate a random array of length `M`.
23    fn random(mut rng: impl CryptoRngCore) -> Self {
24        Self {
25            data: T::random_n::<Box<[_]>>(&mut rng, M::SIZE),
26            _len: PhantomData,
27        }
28    }
29}
30
31// --------------
32// | Arithmetic |
33// --------------
34
35// === Addition === //
36
37#[macros::op_variants(owned, borrowed, flipped_commutative)]
38impl<T: Sized, M: Positive> Add<&HeapArray<T, M>> for HeapArray<T, M>
39where
40    T: for<'b> Add<&'b T, Output = T>,
41{
42    type Output = HeapArray<T, M>;
43
44    fn add(self, other: &HeapArray<T, M>) -> Self::Output {
45        Self::Output {
46            data: izip_eq!(self.data, &other.data)
47                .map(|(lhs_value, rhs_value)| lhs_value + rhs_value)
48                .collect(),
49            _len: PhantomData,
50        }
51    }
52}
53
54#[macros::op_variants(owned)]
55impl<T: Sized, M: Positive> Add<&T> for HeapArray<T, M>
56where
57    T: for<'a> Add<&'a T, Output = T>,
58{
59    type Output = HeapArray<T, M>;
60
61    fn add(self, other: &T) -> Self::Output {
62        Self::Output {
63            data: IntoIterator::into_iter(self.data)
64                .map(|value| value + other)
65                .collect(),
66            _len: PhantomData,
67        }
68    }
69}
70
71#[macros::op_variants(owned)]
72impl<T: Sized, M: Positive> AddAssign<&HeapArray<T, M>> for HeapArray<T, M>
73where
74    T: for<'b> AddAssign<&'b T>,
75{
76    fn add_assign(&mut self, other: &Self) {
77        izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.add_assign(rhs_value));
78    }
79}
80
81#[macros::op_variants(owned)]
82impl<T: Sized, M: Positive> AddAssign<&T> for HeapArray<T, M>
83where
84    T: for<'a> AddAssign<&'a T>,
85{
86    fn add_assign(&mut self, other: &T) {
87        self.iter_mut()
88            .for_each(|lhs_value| lhs_value.add_assign(other));
89    }
90}
91
92// === Subtraction === //
93
94#[macros::op_variants(owned, borrowed, flipped)]
95impl<T: Sized, M: Positive> Sub<&HeapArray<T, M>> for HeapArray<T, M>
96where
97    T: for<'b> Sub<&'b T, Output = T>,
98{
99    type Output = HeapArray<T, M>;
100
101    fn sub(self, other: &HeapArray<T, M>) -> Self::Output {
102        Self::Output {
103            data: izip_eq!(self, other)
104                .map(|(lhs_value, rhs_value)| lhs_value - rhs_value)
105                .collect(),
106            _len: PhantomData,
107        }
108    }
109}
110
111#[macros::op_variants(owned, borrowed, flipped)]
112impl<T: Sized, M: Positive> Sub<&T> for HeapArray<T, M>
113where
114    T: for<'a> Sub<&'a T, Output = T>,
115{
116    type Output = HeapArray<T, M>;
117
118    fn sub(self, other: &T) -> Self::Output {
119        Self::Output {
120            data: IntoIterator::into_iter(self.data)
121                .map(|value| value - other)
122                .collect(),
123            _len: PhantomData,
124        }
125    }
126}
127
128#[macros::op_variants(owned)]
129impl<T: Sized, M: Positive> SubAssign<&HeapArray<T, M>> for HeapArray<T, M>
130where
131    T: for<'b> SubAssign<&'b T>,
132{
133    fn sub_assign(&mut self, other: &Self) {
134        izip_eq!(self, other).for_each(|(lhs_value, rhs_value)| lhs_value.sub_assign(rhs_value));
135    }
136}
137
138#[macros::op_variants(owned)]
139impl<T: Sized, M: Positive> SubAssign<&T> for HeapArray<T, M>
140where
141    T: for<'a> SubAssign<&'a T>,
142{
143    fn sub_assign(&mut self, other: &T) {
144        self.iter_mut()
145            .for_each(|lhs_value| lhs_value.sub_assign(other));
146    }
147}
148
149// === Multiplication === //
150
151#[macros::op_variants(owned, borrowed, flipped)]
152impl<T1: Sized, T2: Sized, T3: Sized, M: Positive> Mul<&HeapArray<T2, M>> for HeapArray<T1, M>
153where
154    T1: for<'b> Mul<&'b T2, Output = T3>,
155{
156    type Output = HeapArray<T3, M>;
157    fn mul(self, other: &HeapArray<T2, M>) -> Self::Output {
158        Self::Output {
159            data: izip_eq!(self, other)
160                .map(|(lhs_value, rhs_value)| lhs_value * rhs_value)
161                .collect(),
162            _len: PhantomData,
163        }
164    }
165}
166
167#[macros::op_variants(borrowed)]
168impl<T1: Sized, T2: Sized + Element, T3: Sized, M: Positive> Mul<&T2> for HeapArray<T1, M>
169where
170    T1: for<'a> Mul<&'a T2, Output = T3>,
171{
172    type Output = HeapArray<T3, M>;
173    fn mul(self, other: &T2) -> Self::Output {
174        Self::Output {
175            data: IntoIterator::into_iter(self.data)
176                .map(|value| value * other)
177                .collect(),
178            _len: PhantomData,
179        }
180    }
181}
182
183// === MulAssign === //
184
185#[macros::op_variants(owned)]
186impl<T1: Sized, T2: Sized, M: Positive> MulAssign<&HeapArray<T2, M>> for HeapArray<T1, M>
187where
188    T1: for<'b> MulAssign<&'b T2>,
189{
190    fn mul_assign(&mut self, other: &HeapArray<T2, M>) {
191        izip_eq!(self, other).for_each(|(lhs, rhs)| *lhs *= rhs);
192    }
193}
194
195impl<T: Sized, T2: Sized + Element, M: Positive> MulAssign<&T2> for HeapArray<T, M>
196where
197    T: for<'a> MulAssign<&'a T2>,
198{
199    fn mul_assign(&mut self, other: &T2) {
200        self.iter_mut().for_each(|lhs| *lhs *= other);
201    }
202}
203
204// === Negation === //
205
206#[macros::op_variants(borrowed)]
207impl<T: Sized, M: Positive> Neg for HeapArray<T, M>
208where
209    T: Neg<Output = T>,
210{
211    type Output = HeapArray<T, M>;
212
213    fn neg(self) -> Self::Output {
214        HeapArray::<T, M> {
215            data: IntoIterator::into_iter(self.data)
216                .map(|value| value.neg())
217                .collect(),
218            _len: PhantomData,
219        }
220    }
221}
222
223// === Field Ops === //
224
225impl<T: Field, M: Positive> HeapArray<T, M> {
226    /// Squares each element in the array.
227    pub fn square(&self) -> Self {
228        Self {
229            data: self.iter().map(|value| value.square()).collect(),
230            _len: PhantomData,
231        }
232    }
233
234    /// Doubles each element in the array.
235    pub fn double(&self) -> Self {
236        Self {
237            data: self.iter().map(|value| value.double()).collect(),
238            _len: PhantomData,
239        }
240    }
241
242    /// Inverts each element in the array.
243    pub fn invert(&self) -> Option<Self> {
244        let inverted_data: Option<Vec<T>> =
245            self.iter().map(|value| value.invert().into()).collect();
246
247        inverted_data.map(|data| HeapArray {
248            data: data.into(),
249            _len: PhantomData,
250        })
251    }
252
253    /// Exponentiates `self` by `exp`, where `exp` is a little-endian order integer
254    /// exponent.
255    pub fn pow<S: AsRef<[u64]> + Clone>(&self, exp: S) -> Self {
256        Self {
257            data: self.iter().map(|value| value.pow(exp.clone())).collect(),
258            _len: PhantomData,
259        }
260    }
261}
262
263// === ConditionallySelectable === //
264
265impl<T: Sized + ConditionallySelectable, M: Positive> ConditionallySelectable for HeapArray<T, M> {
266    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
267        Self {
268            data: izip_eq!(a, b)
269                .map(|(lhs, rhs)| T::conditional_select(lhs, rhs, choice))
270                .collect(),
271            _len: PhantomData,
272        }
273    }
274}
275
276impl<T: Sized + ConditionallySelectable, M: Positive, L: Positive> HeapArray<HeapArray<T, M>, L> {
277    /// Select `a` or `b` according to `choice`.
278    ///
279    /// # Returns
280    ///
281    /// * `a` if `choice == Choice(0)`;
282    /// * `b` if `choice == Choice(1)`.
283    pub fn conditional_select_2d(a: &Self, b: &Self, choice: Choice) -> Self {
284        Self {
285            data: izip_eq!(a, b)
286                .map(|(lhs, rhs)| HeapArray::conditional_select(lhs, rhs, choice))
287                .collect(),
288            _len: PhantomData,
289        }
290    }
291}
292
293// === ConstantTimeEq === //
294
295impl<T: Sized, M: Positive> ConstantTimeEq for HeapArray<T, M>
296where
297    T: ConstantTimeEq,
298{
299    fn ct_eq(&self, other: &Self) -> Choice {
300        izip_eq!(self, other).fold(Choice::from(1u8), |acc, (lhs_value, rhs_value)| {
301            acc & lhs_value.ct_eq(rhs_value)
302        })
303    }
304}
305
306// === PartialEq === //
307
308impl<T: Sized, M: Positive> PartialEq for HeapArray<T, M>
309where
310    T: PartialEq,
311{
312    fn eq(&self, other: &Self) -> bool {
313        izip_eq!(self, other).all(|(lhs_value, rhs_value)| lhs_value == rhs_value)
314    }
315}
316
317// === Sum === //
318
319impl<T: Default, M: Positive> Sum for HeapArray<T, M>
320where
321    HeapArray<T, M>: Add<Output = HeapArray<T, M>>,
322{
323    fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
324        let first = iter.next().unwrap_or_default();
325        iter.fold(first, |acc, value| acc + value)
326    }
327}
328
329impl<'a, T: Default + Clone, M: Positive> Sum<&'a HeapArray<T, M>> for HeapArray<T, M>
330where
331    HeapArray<T, M>: for<'b> Add<&'b HeapArray<T, M>, Output = HeapArray<T, M>>,
332{
333    fn sum<I: Iterator<Item = &'a Self>>(mut iter: I) -> Self {
334        let first = iter.next().cloned().unwrap_or_default();
335        iter.fold(first, |acc, value| acc + value)
336    }
337}
338
339// === Zero and One === //
340impl<T: Clone + num_traits::Zero + Eq + for<'b> Add<&'b T, Output = T>, M: Positive + Eq>
341    num_traits::Zero for HeapArray<T, M>
342{
343    fn zero() -> Self {
344        Self {
345            data: (0..M::to_usize()).map(|_| T::zero()).collect(),
346            _len: PhantomData,
347        }
348    }
349
350    fn is_zero(&self) -> bool {
351        let zero = T::zero();
352        self.iter().all(|value| value == &zero)
353    }
354}
355
356impl<
357        T: Sized + Default + num_traits::One + Eq + for<'b> Mul<&'b T, Output = T>,
358        M: Positive + Eq,
359    > num_traits::One for HeapArray<T, M>
360{
361    fn one() -> Self {
362        Self {
363            data: (0..M::to_usize()).map(|_| T::one()).collect(),
364            _len: PhantomData,
365        }
366    }
367
368    fn is_one(&self) -> bool {
369        let one = T::one();
370        self.iter().all(|value| value == &one)
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use typenum::U10;
377
378    use super::*;
379    use crate::types::heap_array::HeapArray;
380
381    #[test]
382    fn test_addition() {
383        let a = HeapArray::<usize, U10>::from_fn(|i| i);
384        let b = HeapArray::from_fn(|i| i + 1);
385        let c = HeapArray::from_fn(|i| i + i + 1);
386
387        assert_eq!(a.clone() + &b, c);
388        assert_eq!(&a + b.clone(), c);
389        assert_eq!(&a + &b, c);
390        assert_eq!(a + b, c);
391    }
392
393    #[test]
394    fn test_broadcast_addition() {
395        let a = HeapArray::<usize, U10>::from_fn(|i| i);
396        let b = 1;
397        let c = HeapArray::from_fn(|i| i + 1);
398
399        assert_eq!(a + b, c);
400    }
401
402    #[test]
403    fn test_add_assign() {
404        let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
405        let b = HeapArray::from_fn(|i| i + 1);
406        let c = HeapArray::from_fn(|i| i + i + 1);
407
408        a += b;
409        assert_eq!(a, c);
410    }
411
412    #[test]
413    fn test_broadcast_add_assign() {
414        let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
415        let b = 1;
416        let c = HeapArray::from_fn(|i| i + 1);
417
418        a += b;
419        assert_eq!(a, c);
420    }
421
422    #[test]
423    fn test_subtraction() {
424        let a = HeapArray::<usize, U10>::from_fn(|i| 2 * i + 1);
425        let b = HeapArray::from_fn(|i| i);
426        let c = HeapArray::from_fn(|i| i + 1);
427
428        assert_eq!(a - b, c);
429    }
430
431    #[test]
432    fn test_broadcast_subtraction() {
433        let a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
434        let b = 1;
435        let c = HeapArray::from_fn(|i| i);
436
437        assert_eq!(a - b, c);
438    }
439
440    #[test]
441    fn test_sub_assign() {
442        let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
443        let b = HeapArray::from_fn(|i| i);
444        let c = HeapArray::from_fn(|_| 1);
445
446        a -= b;
447        assert_eq!(a, c);
448    }
449
450    #[test]
451    fn test_broadcast_sub_assign() {
452        let mut a = HeapArray::<usize, U10>::from_fn(|i| i + 1);
453        let b = 1;
454        let c = HeapArray::from_fn(|i| i);
455
456        a -= b;
457        assert_eq!(a, c);
458    }
459
460    #[test]
461    fn test_multiplication() {
462        let a = HeapArray::<usize, U10>::from_fn(|i| i);
463        let b = HeapArray::from_fn(|i| i + 1);
464        let c = HeapArray::from_fn(|i| i * (i + 1));
465
466        assert_eq!(a * b, c);
467    }
468
469    #[test]
470    fn test_broadcast_multiplication() {
471        let a = HeapArray::<usize, U10>::from_fn(|i| i);
472        let b = 1;
473        let c = HeapArray::from_fn(|i| i);
474
475        assert_eq!(a * &b, c);
476    }
477
478    #[test]
479    fn test_mul_assign() {
480        let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
481        let b = HeapArray::from_fn(|i| i + 1);
482        let c = HeapArray::from_fn(|i| i * (i + 1));
483
484        a *= b;
485        assert_eq!(a, c);
486    }
487
488    #[test]
489    fn test_broadcast_mul_assign() {
490        let mut a = HeapArray::<usize, U10>::from_fn(|i| i);
491        let b = 1;
492        let c = HeapArray::from_fn(|i| i);
493
494        a *= &b;
495        assert_eq!(a, c);
496    }
497
498    #[test]
499    fn test_negation() {
500        let a = HeapArray::<i64, U10>::from_fn(|i| i as i64);
501        let b = HeapArray::from_fn(|i| -(i as i64));
502
503        assert_eq!(-a, b);
504    }
505
506    #[test]
507    fn test_conditional_select() {
508        let a = HeapArray::<u32, U10>::from_fn(|i| i as u32);
509        let b = HeapArray::<u32, U10>::from_fn(|i| i as u32 + 1);
510        let choice = Choice::from(1u8);
511
512        let selected = HeapArray::conditional_select(&a, &b, choice);
513        let non_selected = HeapArray::conditional_select(&a, &b, !choice);
514        assert_eq!(selected, b);
515        assert_eq!(non_selected, a);
516    }
517
518    #[test]
519    fn test_constant_time_eq() {
520        let a = HeapArray::<usize, U10>::from_fn(|i| i);
521        let b = HeapArray::from_fn(|i| i);
522        let c = HeapArray::from_fn(|i| i + 1);
523
524        assert!(a.ct_eq(&b).unwrap_u8() == 1);
525        assert!(a.ct_eq(&c).unwrap_u8() == 0);
526    }
527}