Skip to main content

primitives/types/
mod.rs

1pub mod heap_array;
2pub mod identifiers;
3
4use std::{error::Error, fmt::Debug, ops::Add};
5
6/// Marker trait for scalar element types used in broadcast arithmetic on [`HeapArray`].
7///
8/// This trait is intentionally **not** implemented for `HeapArray` itself, which allows
9/// both element-wise (`HeapArray * HeapArray`) and broadcast (`HeapArray * scalar`)
10/// operator impls to coexist without coherence conflicts.
11///
12/// Blanket-implemented for all `Copy` types. Since `HeapArray` contains a `Box<[T]>`,
13/// it can never be `Copy` and therefore can never satisfy `Element`.
14pub trait Element {}
15impl<T: Copy> Element for T {}
16
17pub use heap_array::{HeapArray, HeapMatrix, RowMajorHeapMatrix};
18pub use identifiers::{FaultyPeer, PeerId, PeerIndex, PeerNumber, ProtocolInfo, SessionId};
19use subtle::Choice;
20use typenum::{NonZero, Unsigned, B1};
21
22use crate::utils::IntoExactSizeIterator;
23
24// ---------- Typenum ---------- //
25
26/// A trait to represent positive nonzero unsigned integer typenum constants.
27pub trait Positive: Unsigned + NonZero + Debug + Eq + Send + Clone {
28    const SIZE: usize;
29}
30impl<T: Unsigned + NonZero + Debug + Eq + Send + Clone> Positive for T {
31    const SIZE: usize = <T as Unsigned>::USIZE;
32}
33
34/// A trait to represent nonnegative (zero or positive) unsigned integer typenum constants.
35pub trait NonNegative: Unsigned + Debug + Eq + Send + Clone {}
36impl<T: Unsigned + Debug + Eq + Send + Clone> NonNegative for T {}
37
38/// A trait to represent positive nonzero unsigned integer typenum constants which accept "+1"
39/// operation.
40pub trait PositivePlusOne: Positive + Add<B1, Output: Positive> {}
41impl<T: Positive + Add<B1, Output: Positive>> PositivePlusOne for T {}
42
43// ---------- Traits ---------- //
44
45/// A trait to define batching types that contain a compile-time known size of elements.
46pub trait Batched: Sized + IntoExactSizeIterator<Item = <Self as Batched>::Item> {
47    /// The type of elements contained in the batch.
48    type Item;
49
50    /// The size of the batch.
51    type Size: Positive;
52
53    /// Returns the size of the batch as a usize.
54    fn batch_size() -> usize {
55        Self::Size::SIZE
56    }
57}
58
59/// A type which can be conditionally selected with a loose promise of constant time.
60/// Compared to `subtle::ConditionallySelectable`, this trait does not require
61/// Copy, but the internal elements should be Copy for the promise to loosely hold.
62pub trait ConditionallySelectable: Sized {
63    /// Select `a` or `b` according to `choice`.
64    ///
65    /// # Returns
66    ///
67    /// * `a` if `choice == Choice(0)`;
68    /// * `b` if `choice == Choice(1)`.
69    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self;
70}
71
72impl<T: subtle::ConditionallySelectable> ConditionallySelectable for T {
73    #[inline]
74    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
75        <T as subtle::ConditionallySelectable>::conditional_select(a, b, choice)
76    }
77}
78
79/// Collect an iterator of Results<T, E> into a container of T and a vector of E,
80/// then pack all the errors in a single error of the same type. Requires that E can be
81/// constructed from a Vec<E>.
82pub trait CollectAll<T, E: Error + From<Vec<E>>>: Iterator<Item = Result<T, E>> + Sized {
83    /// Collect the results into a container of successes, or pack all errors.
84    /// Contrary to `collect<Result<Vec<T>, E>>`, this method collects all errors
85    /// instead of stopping at the first one.
86    fn collect_all<TC: FromIterator<T> + Extend<T> + Default>(self) -> Result<TC, E> {
87        let (values, errors): (TC, Vec<E>) =
88            itertools::Itertools::partition_map(self, |v| match v {
89                Ok(v) => itertools::Either::Left(v),
90                Err(e) => itertools::Either::Right(e),
91            });
92        match errors.len() {
93            0 => Ok(values),
94            1 => Err(errors.into_iter().next().unwrap()),
95            _ => Err(E::from(errors)),
96        }
97    }
98
99    /// Collect the results into a container of successes, or pack all errors.
100    /// Contrary to `collect<Result<Vec<T>, E>>`, this method collects all errors
101    /// instead of stopping at the first one.
102    fn collect_all_vec(self) -> Result<Vec<T>, E> {
103        self.collect_all::<Vec<T>>()
104    }
105
106    /// Collect the results, returning Ok(()) if all succeeded, or packing all errors
107    fn collect_errors(self) -> Result<(), E> {
108        let errors: Vec<E> = self.filter_map(Result::err).collect();
109        match errors.len() {
110            0 => Ok(()),
111            1 => Err(errors.into_iter().next().unwrap()),
112            _ => Err(E::from(errors)),
113        }
114    }
115}
116
117impl<T, E: Error + From<Vec<E>>, I: Iterator<Item = Result<T, E>>> CollectAll<T, E> for I {}
118
119pub trait TryFoldAll<T>: Iterator<Item = T> + Sized {
120    /// Fold the results, returning the accumulated value if all succeeded,
121    /// or packing all errors if any failed.
122    fn try_fold_all<Acc, E: Error + From<Vec<E>>, F>(self, init: Acc, mut f: F) -> Result<Acc, E>
123    where
124        F: FnMut(Acc, T) -> (Acc, Option<E>),
125    {
126        let (acc, errors) = self.fold((init, Vec::new()), |(acc, mut errors), element| {
127            let (new_acc, opt_err) = f(acc, element);
128            if let Some(e) = opt_err {
129                errors.push(e);
130            }
131            (new_acc, errors)
132        });
133        match errors.len() {
134            0 => Ok(acc),
135            1 => Err(errors.into_iter().next().unwrap()),
136            _ => Err(E::from(errors)),
137        }
138    }
139}
140
141impl<T, I: Iterator<Item = T>> TryFoldAll<T> for I {}
142
143#[cfg(test)]
144mod tests {
145    use std::fmt;
146
147    use super::*;
148
149    #[derive(Debug, PartialEq)]
150    enum TestError {
151        Single(u32),
152        Multiple(Vec<TestError>),
153    }
154
155    impl fmt::Display for TestError {
156        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157            match self {
158                TestError::Single(n) => write!(f, "error {n}"),
159                TestError::Multiple(errs) => write!(f, "multiple errors: {}", errs.len()),
160            }
161        }
162    }
163
164    impl Error for TestError {}
165
166    impl From<Vec<TestError>> for TestError {
167        fn from(errors: Vec<TestError>) -> Self {
168            TestError::Multiple(errors)
169        }
170    }
171
172    // --- collect_all ---
173
174    #[test]
175    fn collect_all_all_ok() {
176        let result: Result<Vec<_>, TestError> = vec![Ok(1), Ok(2), Ok(3)].into_iter().collect_all();
177        assert_eq!(result, Ok(vec![1, 2, 3]));
178    }
179
180    #[test]
181    fn collect_all_empty_iterator() {
182        let result: Result<Vec<i32>, TestError> = std::iter::empty().collect_all();
183        assert_eq!(result, Ok(vec![]));
184    }
185
186    #[test]
187    fn collect_all_single_error() {
188        let result: Result<Vec<_>, TestError> = vec![Ok(1), Err(TestError::Single(42)), Ok(3)]
189            .into_iter()
190            .collect_all();
191        assert_eq!(result, Err(TestError::Single(42)));
192    }
193
194    #[test]
195    fn collect_all_multiple_errors() {
196        let result: Result<Vec<i32>, TestError> =
197            vec![Err(TestError::Single(1)), Ok(2), Err(TestError::Single(3))]
198                .into_iter()
199                .collect_all();
200        assert!(matches!(result, Err(TestError::Multiple(errs)) if errs.len() == 2));
201    }
202
203    // --- collect_errors ---
204
205    #[test]
206    fn collect_errors_all_ok() {
207        let result: Result<(), TestError> = vec![Ok::<i32, TestError>(1), Ok(2), Ok(3)]
208            .into_iter()
209            .collect_errors();
210        assert_eq!(result, Ok(()));
211    }
212
213    #[test]
214    fn collect_errors_empty_iterator() {
215        let result: Result<(), TestError> =
216            std::iter::empty::<Result<i32, TestError>>().collect_errors();
217        assert_eq!(result, Ok(()));
218    }
219
220    #[test]
221    fn collect_errors_single_error() {
222        let result: Result<(), TestError> = vec![Ok(1), Err(TestError::Single(7)), Ok(3)]
223            .into_iter()
224            .collect_errors();
225        assert_eq!(result, Err(TestError::Single(7)));
226    }
227
228    #[test]
229    fn collect_errors_multiple_errors() {
230        let result: Result<(), TestError> =
231            vec![Err(TestError::Single(1)), Ok(2), Err(TestError::Single(3))]
232                .into_iter()
233                .collect_errors();
234        assert!(matches!(result, Err(TestError::Multiple(errs)) if errs.len() == 2));
235    }
236
237    // --- try_fold_all ---
238
239    #[test]
240    fn try_fold_all_all_ok() {
241        let result: Result<i32, TestError> = vec![1, 2, 3]
242            .into_iter()
243            .try_fold_all(0, |acc, x| (acc + x, None));
244        assert_eq!(result, Ok(6));
245    }
246
247    #[test]
248    fn try_fold_all_empty_iterator() {
249        let result: Result<i32, TestError> =
250            std::iter::empty::<i32>().try_fold_all(42, |acc, x| (acc + x, None));
251        assert_eq!(result, Ok(42));
252    }
253
254    #[test]
255    fn try_fold_all_single_fn_error() {
256        // f returns an error for value 2
257        let result: Result<i32, TestError> = vec![1, 2, 3].into_iter().try_fold_all(0, |acc, x| {
258            if x == 2 {
259                (acc, Some(TestError::Single(x as u32)))
260            } else {
261                (acc + x, None)
262            }
263        });
264        assert_eq!(result, Err(TestError::Single(2)));
265    }
266
267    #[test]
268    fn try_fold_all_multiple_fn_errors() {
269        // f returns errors for values 1 and 3
270        let result: Result<i32, TestError> = vec![1, 2, 3].into_iter().try_fold_all(0, |acc, x| {
271            if x == 1 || x == 3 {
272                (acc, Some(TestError::Single(x as u32)))
273            } else {
274                (acc + x, None)
275            }
276        });
277        assert!(matches!(result, Err(TestError::Multiple(errs)) if errs.len() == 2));
278    }
279}