Skip to main content

primitives/types/
mod.rs

1pub mod heap_array;
2pub mod identifiers;
3
4use std::{error::Error, fmt::Debug, ops::Add};
5
6pub use heap_array::{HeapArray, HeapMatrix, RowMajorHeapMatrix};
7pub use identifiers::{FaultyPeer, PeerId, PeerIndex, PeerNumber, ProtocolInfo, SessionId};
8use subtle::Choice;
9use typenum::{NonZero, Unsigned, B1};
10
11use crate::utils::IntoExactSizeIterator;
12
13// ---------- Typenum ---------- //
14
15/// A trait to represent positive nonzero unsigned integer typenum constants.
16pub trait Positive: Unsigned + NonZero + Debug + Eq + Send + Clone {
17    const SIZE: usize;
18}
19impl<T: Unsigned + NonZero + Debug + Eq + Send + Clone> Positive for T {
20    const SIZE: usize = <T as Unsigned>::USIZE;
21}
22
23/// A trait to represent nonnegative (zero or positive) unsigned integer typenum constants.
24pub trait NonNegative: Unsigned + Debug + Eq + Send + Clone {}
25impl<T: Unsigned + Debug + Eq + Send + Clone> NonNegative for T {}
26
27/// A trait to represent positive nonzero unsigned integer typenum constants which accept "+1"
28/// operation.
29pub trait PositivePlusOne: Positive + Add<B1, Output: Positive> {}
30impl<T: Positive + Add<B1, Output: Positive>> PositivePlusOne for T {}
31
32// ---------- Traits ---------- //
33
34/// Marker trait for scalar element types used in broadcast arithmetic on [`HeapArray`].
35///
36/// This trait is intentionally **not** implemented for `HeapArray` itself, which allows
37/// both element-wise (`HeapArray * HeapArray`) and broadcast (`HeapArray * scalar`)
38/// operator impls to coexist without coherence conflicts.
39///
40/// Blanket-implemented for all `Copy` types. Since `HeapArray` contains a `Box<[T]>`,
41/// it can never be `Copy` and therefore can never satisfy `Element`.
42pub trait Element {}
43impl<T: Copy> Element for T {}
44
45/// A trait to define batching types that contain a compile-time known number of elements.
46pub trait Batched: Sized + IntoExactSizeIterator<Item: Send> + FromIterator<Self::Item> {
47    /// The size of the batch.
48    type Size: Positive;
49
50    /// Returns the size of the batch as a usize.
51    fn batch_size() -> usize {
52        Self::Size::SIZE
53    }
54}
55
56/// A type which can be conditionally selected with a loose promise of constant time.
57/// Compared to `subtle::ConditionallySelectable`, this trait does not require
58/// Copy, but the internal elements should be Copy for the promise to loosely hold.
59pub trait ConditionallySelectable: Sized {
60    /// Select `a` or `b` according to `choice`.
61    ///
62    /// # Returns
63    ///
64    /// * `a` if `choice == Choice(0)`;
65    /// * `b` if `choice == Choice(1)`.
66    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self;
67}
68
69impl<T: subtle::ConditionallySelectable> ConditionallySelectable for T {
70    #[inline]
71    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
72        <T as subtle::ConditionallySelectable>::conditional_select(a, b, choice)
73    }
74}
75
76/// Collect an iterator of Results<T, E> into a container of T and a vector of E,
77/// then pack all the errors in a single error of the same type. Requires that E can be
78/// constructed from a Vec<E>.
79pub trait CollectAll<T, E: Error + From<Vec<E>>>: Iterator<Item = Result<T, E>> + Sized {
80    /// Collect the results into a container of successes, or pack all errors.
81    /// Contrary to `collect<Result<Vec<T>, E>>`, this method collects all errors
82    /// instead of stopping at the first one.
83    fn collect_all<TC: FromIterator<T> + Extend<T> + Default>(self) -> Result<TC, E> {
84        let (values, errors): (TC, Vec<E>) =
85            itertools::Itertools::partition_map(self, |v| match v {
86                Ok(v) => itertools::Either::Left(v),
87                Err(e) => itertools::Either::Right(e),
88            });
89        match errors.len() {
90            0 => Ok(values),
91            1 => Err(errors.into_iter().next().unwrap()),
92            _ => Err(E::from(errors)),
93        }
94    }
95
96    /// Collect the results into a container of successes, or pack all errors.
97    /// Contrary to `collect<Result<Vec<T>, E>>`, this method collects all errors
98    /// instead of stopping at the first one.
99    fn collect_all_vec(self) -> Result<Vec<T>, E> {
100        self.collect_all::<Vec<T>>()
101    }
102
103    /// Collect the results, returning Ok(()) if all succeeded, or packing all errors
104    fn collect_errors(self) -> Result<(), E> {
105        let errors: Vec<E> = self.filter_map(Result::err).collect();
106        match errors.len() {
107            0 => Ok(()),
108            1 => Err(errors.into_iter().next().unwrap()),
109            _ => Err(E::from(errors)),
110        }
111    }
112}
113
114impl<T, E: Error + From<Vec<E>>, I: Iterator<Item = Result<T, E>>> CollectAll<T, E> for I {}
115
116pub trait TryFoldAll<T>: Iterator<Item = T> + Sized {
117    /// Fold the results, returning the accumulated value if all succeeded,
118    /// or packing all errors if any failed.
119    fn try_fold_all<Acc, E: Error + From<Vec<E>>, F>(self, init: Acc, mut f: F) -> Result<Acc, E>
120    where
121        F: FnMut(Acc, T) -> (Acc, Option<E>),
122    {
123        let (acc, errors) = self.fold((init, Vec::new()), |(acc, mut errors), element| {
124            let (new_acc, opt_err) = f(acc, element);
125            if let Some(e) = opt_err {
126                errors.push(e);
127            }
128            (new_acc, errors)
129        });
130        match errors.len() {
131            0 => Ok(acc),
132            1 => Err(errors.into_iter().next().unwrap()),
133            _ => Err(E::from(errors)),
134        }
135    }
136}
137
138impl<T, I: Iterator<Item = T>> TryFoldAll<T> for I {}
139
140#[cfg(test)]
141mod tests {
142    use std::fmt;
143
144    use super::*;
145
146    #[derive(Debug, PartialEq)]
147    enum TestError {
148        Single(u32),
149        Multiple(Vec<TestError>),
150    }
151
152    impl fmt::Display for TestError {
153        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
154            match self {
155                TestError::Single(n) => write!(f, "error {n}"),
156                TestError::Multiple(errs) => write!(f, "multiple errors: {}", errs.len()),
157            }
158        }
159    }
160
161    impl Error for TestError {}
162
163    impl From<Vec<TestError>> for TestError {
164        fn from(errors: Vec<TestError>) -> Self {
165            TestError::Multiple(errors)
166        }
167    }
168
169    // --- collect_all ---
170
171    #[test]
172    fn collect_all_all_ok() {
173        let result: Result<Vec<_>, TestError> = vec![Ok(1), Ok(2), Ok(3)].into_iter().collect_all();
174        assert_eq!(result, Ok(vec![1, 2, 3]));
175    }
176
177    #[test]
178    fn collect_all_empty_iterator() {
179        let result: Result<Vec<i32>, TestError> = std::iter::empty().collect_all();
180        assert_eq!(result, Ok(vec![]));
181    }
182
183    #[test]
184    fn collect_all_single_error() {
185        let result: Result<Vec<_>, TestError> = vec![Ok(1), Err(TestError::Single(42)), Ok(3)]
186            .into_iter()
187            .collect_all();
188        assert_eq!(result, Err(TestError::Single(42)));
189    }
190
191    #[test]
192    fn collect_all_multiple_errors() {
193        let result: Result<Vec<i32>, TestError> =
194            vec![Err(TestError::Single(1)), Ok(2), Err(TestError::Single(3))]
195                .into_iter()
196                .collect_all();
197        assert!(matches!(result, Err(TestError::Multiple(errs)) if errs.len() == 2));
198    }
199
200    // --- collect_errors ---
201
202    #[test]
203    fn collect_errors_all_ok() {
204        let result: Result<(), TestError> = vec![Ok::<i32, TestError>(1), Ok(2), Ok(3)]
205            .into_iter()
206            .collect_errors();
207        assert_eq!(result, Ok(()));
208    }
209
210    #[test]
211    fn collect_errors_empty_iterator() {
212        let result: Result<(), TestError> =
213            std::iter::empty::<Result<i32, TestError>>().collect_errors();
214        assert_eq!(result, Ok(()));
215    }
216
217    #[test]
218    fn collect_errors_single_error() {
219        let result: Result<(), TestError> = vec![Ok(1), Err(TestError::Single(7)), Ok(3)]
220            .into_iter()
221            .collect_errors();
222        assert_eq!(result, Err(TestError::Single(7)));
223    }
224
225    #[test]
226    fn collect_errors_multiple_errors() {
227        let result: Result<(), TestError> =
228            vec![Err(TestError::Single(1)), Ok(2), Err(TestError::Single(3))]
229                .into_iter()
230                .collect_errors();
231        assert!(matches!(result, Err(TestError::Multiple(errs)) if errs.len() == 2));
232    }
233
234    // --- try_fold_all ---
235
236    #[test]
237    fn try_fold_all_all_ok() {
238        let result: Result<i32, TestError> = vec![1, 2, 3]
239            .into_iter()
240            .try_fold_all(0, |acc, x| (acc + x, None));
241        assert_eq!(result, Ok(6));
242    }
243
244    #[test]
245    fn try_fold_all_empty_iterator() {
246        let result: Result<i32, TestError> =
247            std::iter::empty::<i32>().try_fold_all(42, |acc, x| (acc + x, None));
248        assert_eq!(result, Ok(42));
249    }
250
251    #[test]
252    fn try_fold_all_single_fn_error() {
253        // f returns an error for value 2
254        let result: Result<i32, TestError> = vec![1, 2, 3].into_iter().try_fold_all(0, |acc, x| {
255            if x == 2 {
256                (acc, Some(TestError::Single(x as u32)))
257            } else {
258                (acc + x, None)
259            }
260        });
261        assert_eq!(result, Err(TestError::Single(2)));
262    }
263
264    #[test]
265    fn try_fold_all_multiple_fn_errors() {
266        // f returns errors for values 1 and 3
267        let result: Result<i32, TestError> = vec![1, 2, 3].into_iter().try_fold_all(0, |acc, x| {
268            if x == 1 || x == 3 {
269                (acc, Some(TestError::Single(x as u32)))
270            } else {
271                (acc + x, None)
272            }
273        });
274        assert!(matches!(result, Err(TestError::Multiple(errs)) if errs.len() == 2));
275    }
276}