commonware_utils/
lib.rs

1//! Leverage common functionality across multiple primitives.
2
3use bytes::{BufMut, BytesMut};
4use commonware_codec::{EncodeSize, Write};
5use std::time::Duration;
6
7pub mod sequence;
8pub use sequence::{Array, Span};
9mod bitvec;
10pub use bitvec::{BitIterator, BitVec};
11pub mod channels;
12mod time;
13pub use time::{parse_duration, SystemTimeExt};
14mod priority_set;
15pub use priority_set::PrioritySet;
16pub mod futures;
17mod stable_buf;
18pub use stable_buf::StableBuf;
19
20/// Converts bytes to a hexadecimal string.
21pub fn hex(bytes: &[u8]) -> String {
22    let mut hex = String::new();
23    for byte in bytes.iter() {
24        hex.push_str(&format!("{byte:02x}"));
25    }
26    hex
27}
28
29/// Converts a hexadecimal string to bytes.
30pub fn from_hex(hex: &str) -> Option<Vec<u8>> {
31    if !hex.len().is_multiple_of(2) {
32        return None;
33    }
34
35    (0..hex.len())
36        .step_by(2)
37        .map(|i| u8::from_str_radix(hex.get(i..i + 2)?, 16).ok())
38        .collect()
39}
40
41/// Converts a hexadecimal string to bytes, stripping whitespace and/or a `0x` prefix. Commonly used
42/// in testing to encode external test vectors without modification.
43pub fn from_hex_formatted(hex: &str) -> Option<Vec<u8>> {
44    let hex = hex.replace(['\t', '\n', '\r', ' '], "");
45    let res = hex.strip_prefix("0x").unwrap_or(&hex);
46    from_hex(res)
47}
48
49/// Compute the maximum number of `f` (faults) that can be tolerated for a given set of `n`
50/// participants. This is the maximum integer `f` such that `n >= 3*f + 1`. `f` may be zero.
51pub fn max_faults(n: u32) -> u32 {
52    n.saturating_sub(1) / 3
53}
54
55/// Compute the quorum size for a given set of `n` participants. This is the minimum integer `q`
56/// such that `3*q >= 2*n + 1`. It is also equal to `n - f`, where `f` is the maximum number of
57/// faults.
58///
59/// # Panics
60///
61/// Panics if `n` is zero.
62pub fn quorum(n: u32) -> u32 {
63    assert!(n > 0, "n must not be zero");
64    n - max_faults(n)
65}
66
67/// Compute the quorum size for a given slice.
68///
69/// # Panics
70///
71/// Panics if the slice length is greater than [u32::MAX].
72pub fn quorum_from_slice<T>(slice: &[T]) -> u32 {
73    let n: u32 = slice
74        .len()
75        .try_into()
76        .expect("slice length must be less than u32::MAX");
77    quorum(n)
78}
79
80/// Computes the union of two byte slices.
81pub fn union(a: &[u8], b: &[u8]) -> Vec<u8> {
82    let mut union = Vec::with_capacity(a.len() + b.len());
83    union.extend_from_slice(a);
84    union.extend_from_slice(b);
85    union
86}
87
88/// Concatenate a namespace and a message, prepended by a varint encoding of the namespace length.
89///
90/// This produces a unique byte sequence (i.e. no collisions) for each `(namespace, msg)` pair.
91pub fn union_unique(namespace: &[u8], msg: &[u8]) -> Vec<u8> {
92    let len_prefix = namespace.len();
93    let mut buf = BytesMut::with_capacity(len_prefix.encode_size() + namespace.len() + msg.len());
94    len_prefix.write(&mut buf);
95    BufMut::put_slice(&mut buf, namespace);
96    BufMut::put_slice(&mut buf, msg);
97    buf.into()
98}
99
100/// Compute the modulo of bytes interpreted as a big-endian integer.
101///
102/// This function is used to select a random entry from an array when the bytes are a random seed.
103///
104/// # Panics
105///
106/// Panics if `n` is zero.
107pub fn modulo(bytes: &[u8], n: u64) -> u64 {
108    assert_ne!(n, 0, "modulus must be non-zero");
109
110    let n = n as u128;
111    let mut result = 0u128;
112    for &byte in bytes {
113        result = (result << 8) | (byte as u128);
114        result %= n;
115    }
116
117    // Result is either 0 or modulo `n`, so we can safely cast to u64
118    result as u64
119}
120
121/// A macro to create a `NonZeroUsize` from a value, panicking if the value is zero.
122#[macro_export]
123macro_rules! NZUsize {
124    ($val:expr) => {
125        // This will panic at runtime if $val is zero.
126        // For literals, the compiler *might* optimize, but the check is still conceptually there.
127        std::num::NonZeroUsize::new($val).expect("value must be non-zero")
128    };
129}
130
131/// A macro to create a `NonZeroU8` from a value, panicking if the value is zero.
132#[macro_export]
133macro_rules! NZU8 {
134    ($val:expr) => {
135        std::num::NonZeroU8::new($val).expect("value must be non-zero")
136    };
137}
138
139/// A macro to create a `NonZeroU16` from a value, panicking if the value is zero.
140#[macro_export]
141macro_rules! NZU16 {
142    ($val:expr) => {
143        std::num::NonZeroU16::new($val).expect("value must be non-zero")
144    };
145}
146
147/// A macro to create a `NonZeroU32` from a value, panicking if the value is zero.
148#[macro_export]
149macro_rules! NZU32 {
150    ($val:expr) => {
151        // This will panic at runtime if $val is zero.
152        // For literals, the compiler *might* optimize, but the check is still conceptually there.
153        std::num::NonZeroU32::new($val).expect("value must be non-zero")
154    };
155}
156
157/// A macro to create a `NonZeroU64` from a value, panicking if the value is zero.
158#[macro_export]
159macro_rules! NZU64 {
160    ($val:expr) => {
161        // This will panic at runtime if $val is zero.
162        // For literals, the compiler *might* optimize, but the check is still conceptually there.
163        std::num::NonZeroU64::new($val).expect("value must be non-zero")
164    };
165}
166
167/// A wrapper around `Duration` that guarantees the duration is non-zero.
168#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
169pub struct NonZeroDuration(Duration);
170
171impl NonZeroDuration {
172    /// Creates a `NonZeroDuration` if the given duration is non-zero.
173    pub fn new(duration: Duration) -> Option<Self> {
174        if duration == Duration::ZERO {
175            None
176        } else {
177            Some(Self(duration))
178        }
179    }
180
181    /// Creates a `NonZeroDuration` from the given duration, panicking if it's zero.
182    pub fn new_panic(duration: Duration) -> Self {
183        Self::new(duration).expect("duration must be non-zero")
184    }
185
186    /// Returns the wrapped `Duration`.
187    pub fn get(self) -> Duration {
188        self.0
189    }
190}
191
192impl From<NonZeroDuration> for Duration {
193    fn from(nz_duration: NonZeroDuration) -> Self {
194        nz_duration.0
195    }
196}
197
198/// A macro to create a `NonZeroDuration` from a duration, panicking if the duration is zero.
199#[macro_export]
200macro_rules! NZDuration {
201    ($val:expr) => {
202        // This will panic at runtime if $val is zero.
203        $crate::NonZeroDuration::new_panic($val)
204    };
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use num_bigint::BigUint;
211    use rand::{rngs::StdRng, Rng, SeedableRng};
212
213    #[test]
214    fn test_hex() {
215        // Test case 0: empty bytes
216        let b = &[];
217        let h = hex(b);
218        assert_eq!(h, "");
219        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
220
221        // Test case 1: single byte
222        let b = &[0x01];
223        let h = hex(b);
224        assert_eq!(h, "01");
225        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
226
227        // Test case 2: multiple bytes
228        let b = &[0x01, 0x02, 0x03];
229        let h = hex(b);
230        assert_eq!(h, "010203");
231        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
232
233        // Test case 3: odd number of bytes
234        let h = "0102030";
235        assert!(from_hex(h).is_none());
236
237        // Test case 4: invalid hexadecimal character
238        let h = "01g3";
239        assert!(from_hex(h).is_none());
240    }
241
242    #[test]
243    fn test_from_hex_formatted() {
244        // Test case 0: empty bytes
245        let b = &[];
246        let h = hex(b);
247        assert_eq!(h, "");
248        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
249
250        // Test case 1: single byte
251        let b = &[0x01];
252        let h = hex(b);
253        assert_eq!(h, "01");
254        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
255
256        // Test case 2: multiple bytes
257        let b = &[0x01, 0x02, 0x03];
258        let h = hex(b);
259        assert_eq!(h, "010203");
260        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
261
262        // Test case 3: odd number of bytes
263        let h = "0102030";
264        assert!(from_hex_formatted(h).is_none());
265
266        // Test case 4: invalid hexadecimal character
267        let h = "01g3";
268        assert!(from_hex_formatted(h).is_none());
269
270        // Test case 5: whitespace
271        let h = "01 02 03";
272        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
273
274        // Test case 6: 0x prefix
275        let h = "0x010203";
276        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
277
278        // Test case 7: 0x prefix + different whitespace chars
279        let h = "    \n\n0x\r\n01
280                            02\t03\n";
281        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
282    }
283
284    #[test]
285    fn test_from_hex_utf8_char_boundaries() {
286        const MISALIGNMENT_CASE: &str = "쀘\n";
287
288        // Ensure that `from_hex` can handle misaligned UTF-8 character boundaries.
289        let b = from_hex(MISALIGNMENT_CASE);
290        assert!(b.is_none());
291    }
292
293    #[test]
294    fn test_max_faults_zero() {
295        assert_eq!(max_faults(0), 0);
296    }
297
298    #[test]
299    #[should_panic]
300    fn test_quorum_zero() {
301        quorum(0);
302    }
303
304    #[test]
305    fn test_quorum_and_max_faults() {
306        // n, expected_f, expected_q
307        let test_cases = [
308            (1, 0, 1),
309            (2, 0, 2),
310            (3, 0, 3),
311            (4, 1, 3),
312            (5, 1, 4),
313            (6, 1, 5),
314            (7, 2, 5),
315            (8, 2, 6),
316            (9, 2, 7),
317            (10, 3, 7),
318            (11, 3, 8),
319            (12, 3, 9),
320            (13, 4, 9),
321            (14, 4, 10),
322            (15, 4, 11),
323            (16, 5, 11),
324            (17, 5, 12),
325            (18, 5, 13),
326            (19, 6, 13),
327            (20, 6, 14),
328            (21, 6, 15),
329        ];
330
331        for (n, ef, eq) in test_cases {
332            assert_eq!(max_faults(n), ef);
333            assert_eq!(quorum(n), eq);
334            assert_eq!(n, ef + eq);
335        }
336    }
337
338    #[test]
339    fn test_union() {
340        // Test case 0: empty slices
341        assert_eq!(union(&[], &[]), []);
342
343        // Test case 1: empty and non-empty slices
344        assert_eq!(union(&[], &[0x01, 0x02, 0x03]), [0x01, 0x02, 0x03]);
345
346        // Test case 2: non-empty and non-empty slices
347        assert_eq!(
348            union(&[0x01, 0x02, 0x03], &[0x04, 0x05, 0x06]),
349            [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]
350        );
351    }
352
353    #[test]
354    fn test_union_unique() {
355        let namespace = b"namespace";
356        let msg = b"message";
357
358        let length_encoding = vec![0b0000_1001];
359        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
360        expected.extend_from_slice(&length_encoding);
361        expected.extend_from_slice(namespace);
362        expected.extend_from_slice(msg);
363
364        let result = union_unique(namespace, msg);
365        assert_eq!(result, expected);
366        assert_eq!(result.len(), result.capacity());
367    }
368
369    #[test]
370    fn test_union_unique_zero_length() {
371        let namespace = b"";
372        let msg = b"message";
373
374        let length_encoding = vec![0];
375        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
376        expected.extend_from_slice(&length_encoding);
377        expected.extend_from_slice(msg);
378
379        let result = union_unique(namespace, msg);
380        assert_eq!(result, expected);
381        assert_eq!(result.len(), result.capacity());
382    }
383
384    #[test]
385    fn test_union_unique_long_length() {
386        // Use a namespace of over length 127.
387        let namespace = &b"n".repeat(256);
388        let msg = b"message";
389
390        let length_encoding = vec![0b1000_0000, 0b0000_0010];
391        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
392        expected.extend_from_slice(&length_encoding);
393        expected.extend_from_slice(namespace);
394        expected.extend_from_slice(msg);
395
396        let result = union_unique(namespace, msg);
397        assert_eq!(result, expected);
398        assert_eq!(result.len(), result.capacity());
399    }
400
401    #[test]
402    fn test_modulo() {
403        // Test case 0: empty bytes
404        assert_eq!(modulo(&[], 1), 0);
405
406        // Test case 1: single byte
407        assert_eq!(modulo(&[0x01], 1), 0);
408
409        // Test case 2: multiple bytes
410        assert_eq!(modulo(&[0x01, 0x02, 0x03], 10), 1);
411
412        // Test case 3: check equivalence with BigUint
413        for i in 0..100 {
414            let mut rng = StdRng::seed_from_u64(i);
415            let bytes: [u8; 32] = rng.gen();
416
417            // 1-byte modulus
418            let n = 11u64;
419            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
420            let utils_modulo = modulo(&bytes, n);
421            assert_eq!(big_modulo, BigUint::from(utils_modulo));
422
423            // 2-byte modulus
424            let n = 11_111u64;
425            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
426            let utils_modulo = modulo(&bytes, n);
427            assert_eq!(big_modulo, BigUint::from(utils_modulo));
428
429            // 8-byte modulus
430            let n = 0xDFFFFFFFFFFFFFFD;
431            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
432            let utils_modulo = modulo(&bytes, n);
433            assert_eq!(big_modulo, BigUint::from(utils_modulo));
434        }
435    }
436
437    #[test]
438    #[should_panic]
439    fn test_modulo_zero_panics() {
440        modulo(&[0x01, 0x02, 0x03], 0);
441    }
442
443    #[test]
444    fn test_non_zero_macros() {
445        // Test case 0: zero value
446        assert!(std::panic::catch_unwind(|| NZUsize!(0)).is_err());
447        assert!(std::panic::catch_unwind(|| NZU32!(0)).is_err());
448        assert!(std::panic::catch_unwind(|| NZU64!(0)).is_err());
449        assert!(std::panic::catch_unwind(|| NZDuration!(Duration::ZERO)).is_err());
450
451        // Test case 1: non-zero value
452        assert_eq!(NZUsize!(1).get(), 1);
453        assert_eq!(NZU32!(2).get(), 2);
454        assert_eq!(NZU64!(3).get(), 3);
455        assert_eq!(
456            NZDuration!(Duration::from_secs(1)).get(),
457            Duration::from_secs(1)
458        );
459    }
460
461    #[test]
462    fn test_non_zero_duration() {
463        // Test case 0: zero duration
464        assert!(NonZeroDuration::new(Duration::ZERO).is_none());
465
466        // Test case 1: non-zero duration
467        let duration = Duration::from_millis(100);
468        let nz_duration = NonZeroDuration::new(duration).unwrap();
469        assert_eq!(nz_duration.get(), duration);
470        assert_eq!(Duration::from(nz_duration), duration);
471
472        // Test case 2: panic on zero
473        assert!(std::panic::catch_unwind(|| NonZeroDuration::new_panic(Duration::ZERO)).is_err());
474
475        // Test case 3: ordering
476        let d1 = NonZeroDuration::new(Duration::from_millis(100)).unwrap();
477        let d2 = NonZeroDuration::new(Duration::from_millis(200)).unwrap();
478        assert!(d1 < d2);
479    }
480}