commonware_utils/
lib.rs

1//! Leverage common functionality across multiple primitives.
2
3use bytes::{BufMut, BytesMut};
4use commonware_codec::{EncodeSize, Write};
5use std::time::Duration;
6
7pub mod array;
8pub use array::Array;
9mod bitvec;
10pub use bitvec::{BitIterator, BitVec};
11mod time;
12pub use time::SystemTimeExt;
13mod priority_set;
14pub use priority_set::PrioritySet;
15pub mod futures;
16mod stable_buf;
17pub use stable_buf::StableBuf;
18
19/// Converts bytes to a hexadecimal string.
20pub fn hex(bytes: &[u8]) -> String {
21    let mut hex = String::new();
22    for byte in bytes.iter() {
23        hex.push_str(&format!("{byte:02x}"));
24    }
25    hex
26}
27
28/// Converts a hexadecimal string to bytes.
29pub fn from_hex(hex: &str) -> Option<Vec<u8>> {
30    if hex.len() % 2 != 0 {
31        return None;
32    }
33
34    (0..hex.len())
35        .step_by(2)
36        .map(|i| u8::from_str_radix(hex.get(i..i + 2)?, 16).ok())
37        .collect()
38}
39
40/// Converts a hexadecimal string to bytes, stripping whitespace and/or a `0x` prefix. Commonly used
41/// in testing to encode external test vectors without modification.
42pub fn from_hex_formatted(hex: &str) -> Option<Vec<u8>> {
43    let hex = hex.replace(['\t', '\n', '\r', ' '], "");
44    let res = hex.strip_prefix("0x").unwrap_or(&hex);
45    from_hex(res)
46}
47
48/// Compute the maximum number of `f` (faults) that can be tolerated for a given set of `n`
49/// participants. This is the maximum integer `f` such that `n >= 3*f + 1`. `f` may be zero.
50pub fn max_faults(n: u32) -> u32 {
51    n.saturating_sub(1) / 3
52}
53
54/// Compute the quorum size for a given set of `n` participants. This is the minimum integer `q`
55/// such that `3*q >= 2*n + 1`. It is also equal to `n - f`, where `f` is the maximum number of
56/// faults.
57///
58/// # Panics
59///
60/// Panics if `n` is zero.
61pub fn quorum(n: u32) -> u32 {
62    assert!(n > 0, "n must not be zero");
63    n - max_faults(n)
64}
65
66/// Computes the union of two byte slices.
67pub fn union(a: &[u8], b: &[u8]) -> Vec<u8> {
68    let mut union = Vec::with_capacity(a.len() + b.len());
69    union.extend_from_slice(a);
70    union.extend_from_slice(b);
71    union
72}
73
74/// Concatenate a namespace and a message, prepended by a varint encoding of the namespace length.
75///
76/// This produces a unique byte sequence (i.e. no collisions) for each `(namespace, msg)` pair.
77pub fn union_unique(namespace: &[u8], msg: &[u8]) -> Vec<u8> {
78    let len_prefix = namespace.len();
79    let mut buf = BytesMut::with_capacity(len_prefix.encode_size() + namespace.len() + msg.len());
80    len_prefix.write(&mut buf);
81    BufMut::put_slice(&mut buf, namespace);
82    BufMut::put_slice(&mut buf, msg);
83    buf.into()
84}
85
86/// Compute the modulo of bytes interpreted as a big-endian integer.
87///
88/// This function is used to select a random entry from an array when the bytes are a random seed.
89///
90/// # Panics
91///
92/// Panics if `n` is zero.
93pub fn modulo(bytes: &[u8], n: u64) -> u64 {
94    assert_ne!(n, 0, "modulus must be non-zero");
95
96    let n = n as u128;
97    let mut result = 0u128;
98    for &byte in bytes {
99        result = (result << 8) | (byte as u128);
100        result %= n;
101    }
102
103    // Result is either 0 or modulo `n`, so we can safely cast to u64
104    result as u64
105}
106
107/// A macro to create a `NonZeroUsize` from a value, panicking if the value is zero.
108#[macro_export]
109macro_rules! NZUsize {
110    ($val:expr) => {
111        // This will panic at runtime if $val is zero.
112        // For literals, the compiler *might* optimize, but the check is still conceptually there.
113        std::num::NonZeroUsize::new($val).expect("value must be non-zero")
114    };
115}
116
117/// A macro to create a `NonZeroU8` from a value, panicking if the value is zero.
118#[macro_export]
119macro_rules! NZU8 {
120    ($val:expr) => {
121        std::num::NonZeroU8::new($val).expect("value must be non-zero")
122    };
123}
124
125/// A macro to create a `NonZeroU16` from a value, panicking if the value is zero.
126#[macro_export]
127macro_rules! NZU16 {
128    ($val:expr) => {
129        std::num::NonZeroU16::new($val).expect("value must be non-zero")
130    };
131}
132
133/// A macro to create a `NonZeroU32` from a value, panicking if the value is zero.
134#[macro_export]
135macro_rules! NZU32 {
136    ($val:expr) => {
137        // This will panic at runtime if $val is zero.
138        // For literals, the compiler *might* optimize, but the check is still conceptually there.
139        std::num::NonZeroU32::new($val).expect("value must be non-zero")
140    };
141}
142
143/// A macro to create a `NonZeroU64` from a value, panicking if the value is zero.
144#[macro_export]
145macro_rules! NZU64 {
146    ($val:expr) => {
147        // This will panic at runtime if $val is zero.
148        // For literals, the compiler *might* optimize, but the check is still conceptually there.
149        std::num::NonZeroU64::new($val).expect("value must be non-zero")
150    };
151}
152
153/// A wrapper around `Duration` that guarantees the duration is non-zero.
154#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
155pub struct NonZeroDuration(Duration);
156
157impl NonZeroDuration {
158    /// Creates a `NonZeroDuration` if the given duration is non-zero.
159    pub fn new(duration: Duration) -> Option<Self> {
160        if duration == Duration::ZERO {
161            None
162        } else {
163            Some(Self(duration))
164        }
165    }
166
167    /// Creates a `NonZeroDuration` from the given duration, panicking if it's zero.
168    pub fn new_panic(duration: Duration) -> Self {
169        Self::new(duration).expect("duration must be non-zero")
170    }
171
172    /// Returns the wrapped `Duration`.
173    pub fn get(self) -> Duration {
174        self.0
175    }
176}
177
178impl From<NonZeroDuration> for Duration {
179    fn from(nz_duration: NonZeroDuration) -> Self {
180        nz_duration.0
181    }
182}
183
184/// A macro to create a `NonZeroDuration` from a duration, panicking if the duration is zero.
185#[macro_export]
186macro_rules! NZDuration {
187    ($val:expr) => {
188        // This will panic at runtime if $val is zero.
189        $crate::NonZeroDuration::new_panic($val)
190    };
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use num_bigint::BigUint;
197    use rand::{rngs::StdRng, Rng, SeedableRng};
198
199    #[test]
200    fn test_hex() {
201        // Test case 0: empty bytes
202        let b = &[];
203        let h = hex(b);
204        assert_eq!(h, "");
205        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
206
207        // Test case 1: single byte
208        let b = &[0x01];
209        let h = hex(b);
210        assert_eq!(h, "01");
211        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
212
213        // Test case 2: multiple bytes
214        let b = &[0x01, 0x02, 0x03];
215        let h = hex(b);
216        assert_eq!(h, "010203");
217        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
218
219        // Test case 3: odd number of bytes
220        let h = "0102030";
221        assert!(from_hex(h).is_none());
222
223        // Test case 4: invalid hexadecimal character
224        let h = "01g3";
225        assert!(from_hex(h).is_none());
226    }
227
228    #[test]
229    fn test_from_hex_formatted() {
230        // Test case 0: empty bytes
231        let b = &[];
232        let h = hex(b);
233        assert_eq!(h, "");
234        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
235
236        // Test case 1: single byte
237        let b = &[0x01];
238        let h = hex(b);
239        assert_eq!(h, "01");
240        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
241
242        // Test case 2: multiple bytes
243        let b = &[0x01, 0x02, 0x03];
244        let h = hex(b);
245        assert_eq!(h, "010203");
246        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
247
248        // Test case 3: odd number of bytes
249        let h = "0102030";
250        assert!(from_hex_formatted(h).is_none());
251
252        // Test case 4: invalid hexadecimal character
253        let h = "01g3";
254        assert!(from_hex_formatted(h).is_none());
255
256        // Test case 5: whitespace
257        let h = "01 02 03";
258        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
259
260        // Test case 6: 0x prefix
261        let h = "0x010203";
262        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
263
264        // Test case 7: 0x prefix + different whitespace chars
265        let h = "    \n\n0x\r\n01
266                            02\t03\n";
267        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
268    }
269
270    #[test]
271    fn test_from_hex_utf8_char_boundaries() {
272        const MISALIGNMENT_CASE: &str = "쀘\n";
273
274        // Ensure that `from_hex` can handle misaligned UTF-8 character boundaries.
275        let b = from_hex(MISALIGNMENT_CASE);
276        assert!(b.is_none());
277    }
278
279    #[test]
280    fn test_max_faults_zero() {
281        assert_eq!(max_faults(0), 0);
282    }
283
284    #[test]
285    #[should_panic]
286    fn test_quorum_zero() {
287        quorum(0);
288    }
289
290    #[test]
291    fn test_quorum_and_max_faults() {
292        // n, expected_f, expected_q
293        let test_cases = [
294            (1, 0, 1),
295            (2, 0, 2),
296            (3, 0, 3),
297            (4, 1, 3),
298            (5, 1, 4),
299            (6, 1, 5),
300            (7, 2, 5),
301            (8, 2, 6),
302            (9, 2, 7),
303            (10, 3, 7),
304            (11, 3, 8),
305            (12, 3, 9),
306            (13, 4, 9),
307            (14, 4, 10),
308            (15, 4, 11),
309            (16, 5, 11),
310            (17, 5, 12),
311            (18, 5, 13),
312            (19, 6, 13),
313            (20, 6, 14),
314            (21, 6, 15),
315        ];
316
317        for (n, ef, eq) in test_cases {
318            assert_eq!(max_faults(n), ef);
319            assert_eq!(quorum(n), eq);
320            assert_eq!(n, ef + eq);
321        }
322    }
323
324    #[test]
325    fn test_union() {
326        // Test case 0: empty slices
327        assert_eq!(union(&[], &[]), []);
328
329        // Test case 1: empty and non-empty slices
330        assert_eq!(union(&[], &[0x01, 0x02, 0x03]), [0x01, 0x02, 0x03]);
331
332        // Test case 2: non-empty and non-empty slices
333        assert_eq!(
334            union(&[0x01, 0x02, 0x03], &[0x04, 0x05, 0x06]),
335            [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]
336        );
337    }
338
339    #[test]
340    fn test_union_unique() {
341        let namespace = b"namespace";
342        let msg = b"message";
343
344        let length_encoding = vec![0b0000_1001];
345        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
346        expected.extend_from_slice(&length_encoding);
347        expected.extend_from_slice(namespace);
348        expected.extend_from_slice(msg);
349
350        let result = union_unique(namespace, msg);
351        assert_eq!(result, expected);
352        assert_eq!(result.len(), result.capacity());
353    }
354
355    #[test]
356    fn test_union_unique_zero_length() {
357        let namespace = b"";
358        let msg = b"message";
359
360        let length_encoding = vec![0];
361        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
362        expected.extend_from_slice(&length_encoding);
363        expected.extend_from_slice(msg);
364
365        let result = union_unique(namespace, msg);
366        assert_eq!(result, expected);
367        assert_eq!(result.len(), result.capacity());
368    }
369
370    #[test]
371    fn test_union_unique_long_length() {
372        // Use a namespace of over length 127.
373        let namespace = &b"n".repeat(256);
374        let msg = b"message";
375
376        let length_encoding = vec![0b1000_0000, 0b0000_0010];
377        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
378        expected.extend_from_slice(&length_encoding);
379        expected.extend_from_slice(namespace);
380        expected.extend_from_slice(msg);
381
382        let result = union_unique(namespace, msg);
383        assert_eq!(result, expected);
384        assert_eq!(result.len(), result.capacity());
385    }
386
387    #[test]
388    fn test_modulo() {
389        // Test case 0: empty bytes
390        assert_eq!(modulo(&[], 1), 0);
391
392        // Test case 1: single byte
393        assert_eq!(modulo(&[0x01], 1), 0);
394
395        // Test case 2: multiple bytes
396        assert_eq!(modulo(&[0x01, 0x02, 0x03], 10), 1);
397
398        // Test case 3: check equivalence with BigUint
399        for i in 0..100 {
400            let mut rng = StdRng::seed_from_u64(i);
401            let bytes: [u8; 32] = rng.gen();
402
403            // 1-byte modulus
404            let n = 11u64;
405            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
406            let utils_modulo = modulo(&bytes, n);
407            assert_eq!(big_modulo, BigUint::from(utils_modulo));
408
409            // 2-byte modulus
410            let n = 11_111u64;
411            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
412            let utils_modulo = modulo(&bytes, n);
413            assert_eq!(big_modulo, BigUint::from(utils_modulo));
414
415            // 8-byte modulus
416            let n = 0xDFFFFFFFFFFFFFFD;
417            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
418            let utils_modulo = modulo(&bytes, n);
419            assert_eq!(big_modulo, BigUint::from(utils_modulo));
420        }
421    }
422
423    #[test]
424    #[should_panic]
425    fn test_modulo_zero_panics() {
426        modulo(&[0x01, 0x02, 0x03], 0);
427    }
428
429    #[test]
430    fn test_non_zero_macros() {
431        // Test case 0: zero value
432        assert!(std::panic::catch_unwind(|| NZUsize!(0)).is_err());
433        assert!(std::panic::catch_unwind(|| NZU32!(0)).is_err());
434        assert!(std::panic::catch_unwind(|| NZU64!(0)).is_err());
435        assert!(std::panic::catch_unwind(|| NZDuration!(Duration::ZERO)).is_err());
436
437        // Test case 1: non-zero value
438        assert_eq!(NZUsize!(1).get(), 1);
439        assert_eq!(NZU32!(2).get(), 2);
440        assert_eq!(NZU64!(3).get(), 3);
441        assert_eq!(
442            NZDuration!(Duration::from_secs(1)).get(),
443            Duration::from_secs(1)
444        );
445    }
446
447    #[test]
448    fn test_non_zero_duration() {
449        // Test case 0: zero duration
450        assert!(NonZeroDuration::new(Duration::ZERO).is_none());
451
452        // Test case 1: non-zero duration
453        let duration = Duration::from_millis(100);
454        let nz_duration = NonZeroDuration::new(duration).unwrap();
455        assert_eq!(nz_duration.get(), duration);
456        assert_eq!(Duration::from(nz_duration), duration);
457
458        // Test case 2: panic on zero
459        assert!(std::panic::catch_unwind(|| NonZeroDuration::new_panic(Duration::ZERO)).is_err());
460
461        // Test case 3: ordering
462        let d1 = NonZeroDuration::new(Duration::from_millis(100)).unwrap();
463        let d2 = NonZeroDuration::new(Duration::from_millis(200)).unwrap();
464        assert!(d1 < d2);
465    }
466}