commonware_utils/
lib.rs

1//! Leverage common functionality across multiple primitives.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7#![cfg_attr(not(feature = "std"), no_std)]
8
9#[cfg(not(feature = "std"))]
10extern crate alloc;
11
12#[cfg(not(feature = "std"))]
13use alloc::{string::String, vec::Vec};
14use bytes::{BufMut, BytesMut};
15use commonware_codec::{EncodeSize, Write};
16use core::{fmt::Write as FmtWrite, time::Duration};
17
18pub mod sequence;
19pub use sequence::{Array, Span};
20mod bitvec;
21pub use bitvec::{BitIterator, BitVec};
22#[cfg(feature = "std")]
23pub mod channels;
24#[cfg(feature = "std")]
25mod time;
26#[cfg(feature = "std")]
27pub use time::{parse_duration, SystemTimeExt};
28#[cfg(feature = "std")]
29mod priority_set;
30#[cfg(feature = "std")]
31pub use priority_set::PrioritySet;
32#[cfg(feature = "std")]
33pub mod futures;
34mod stable_buf;
35pub use stable_buf::StableBuf;
36
37/// Converts bytes to a hexadecimal string.
38pub fn hex(bytes: &[u8]) -> String {
39    let mut hex = String::new();
40    for byte in bytes.iter() {
41        write!(hex, "{byte:02x}").expect("writing to string should never fail");
42    }
43    hex
44}
45
46/// Converts a hexadecimal string to bytes.
47pub fn from_hex(hex: &str) -> Option<Vec<u8>> {
48    if !hex.len().is_multiple_of(2) {
49        return None;
50    }
51
52    (0..hex.len())
53        .step_by(2)
54        .map(|i| u8::from_str_radix(hex.get(i..i + 2)?, 16).ok())
55        .collect()
56}
57
58/// Converts a hexadecimal string to bytes, stripping whitespace and/or a `0x` prefix. Commonly used
59/// in testing to encode external test vectors without modification.
60pub fn from_hex_formatted(hex: &str) -> Option<Vec<u8>> {
61    let hex = hex.replace(['\t', '\n', '\r', ' '], "");
62    let res = hex.strip_prefix("0x").unwrap_or(&hex);
63    from_hex(res)
64}
65
66/// Compute the maximum number of `f` (faults) that can be tolerated for a given set of `n`
67/// participants. This is the maximum integer `f` such that `n >= 3*f + 1`. `f` may be zero.
68pub fn max_faults(n: u32) -> u32 {
69    n.saturating_sub(1) / 3
70}
71
72/// Compute the quorum size for a given set of `n` participants. This is the minimum integer `q`
73/// such that `3*q >= 2*n + 1`. It is also equal to `n - f`, where `f` is the maximum number of
74/// faults.
75///
76/// # Panics
77///
78/// Panics if `n` is zero.
79pub fn quorum(n: u32) -> u32 {
80    assert!(n > 0, "n must not be zero");
81    n - max_faults(n)
82}
83
84/// Compute the quorum size for a given slice.
85///
86/// # Panics
87///
88/// Panics if the slice length is greater than [u32::MAX].
89pub fn quorum_from_slice<T>(slice: &[T]) -> u32 {
90    let n: u32 = slice
91        .len()
92        .try_into()
93        .expect("slice length must be less than u32::MAX");
94    quorum(n)
95}
96
97/// Computes the union of two byte slices.
98pub fn union(a: &[u8], b: &[u8]) -> Vec<u8> {
99    let mut union = Vec::with_capacity(a.len() + b.len());
100    union.extend_from_slice(a);
101    union.extend_from_slice(b);
102    union
103}
104
105/// Concatenate a namespace and a message, prepended by a varint encoding of the namespace length.
106///
107/// This produces a unique byte sequence (i.e. no collisions) for each `(namespace, msg)` pair.
108pub fn union_unique(namespace: &[u8], msg: &[u8]) -> Vec<u8> {
109    let len_prefix = namespace.len();
110    let mut buf = BytesMut::with_capacity(len_prefix.encode_size() + namespace.len() + msg.len());
111    len_prefix.write(&mut buf);
112    BufMut::put_slice(&mut buf, namespace);
113    BufMut::put_slice(&mut buf, msg);
114    buf.into()
115}
116
117/// Compute the modulo of bytes interpreted as a big-endian integer.
118///
119/// This function is used to select a random entry from an array when the bytes are a random seed.
120///
121/// # Panics
122///
123/// Panics if `n` is zero.
124pub fn modulo(bytes: &[u8], n: u64) -> u64 {
125    assert_ne!(n, 0, "modulus must be non-zero");
126
127    let n = n as u128;
128    let mut result = 0u128;
129    for &byte in bytes {
130        result = (result << 8) | (byte as u128);
131        result %= n;
132    }
133
134    // Result is either 0 or modulo `n`, so we can safely cast to u64
135    result as u64
136}
137
138/// A macro to create a `NonZeroUsize` from a value, panicking if the value is zero.
139#[macro_export]
140macro_rules! NZUsize {
141    ($val:expr) => {
142        // This will panic at runtime if $val is zero.
143        // For literals, the compiler *might* optimize, but the check is still conceptually there.
144        core::num::NonZeroUsize::new($val).expect("value must be non-zero")
145    };
146}
147
148/// A macro to create a `NonZeroU8` from a value, panicking if the value is zero.
149#[macro_export]
150macro_rules! NZU8 {
151    ($val:expr) => {
152        core::num::NonZeroU8::new($val).expect("value must be non-zero")
153    };
154}
155
156/// A macro to create a `NonZeroU16` from a value, panicking if the value is zero.
157#[macro_export]
158macro_rules! NZU16 {
159    ($val:expr) => {
160        core::num::NonZeroU16::new($val).expect("value must be non-zero")
161    };
162}
163
164/// A macro to create a `NonZeroU32` from a value, panicking if the value is zero.
165#[macro_export]
166macro_rules! NZU32 {
167    ($val:expr) => {
168        // This will panic at runtime if $val is zero.
169        // For literals, the compiler *might* optimize, but the check is still conceptually there.
170        core::num::NonZeroU32::new($val).expect("value must be non-zero")
171    };
172}
173
174/// A macro to create a `NonZeroU64` from a value, panicking if the value is zero.
175#[macro_export]
176macro_rules! NZU64 {
177    ($val:expr) => {
178        // This will panic at runtime if $val is zero.
179        // For literals, the compiler *might* optimize, but the check is still conceptually there.
180        core::num::NonZeroU64::new($val).expect("value must be non-zero")
181    };
182}
183
184/// A wrapper around `Duration` that guarantees the duration is non-zero.
185#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
186pub struct NonZeroDuration(Duration);
187
188impl NonZeroDuration {
189    /// Creates a `NonZeroDuration` if the given duration is non-zero.
190    pub fn new(duration: Duration) -> Option<Self> {
191        if duration == Duration::ZERO {
192            None
193        } else {
194            Some(Self(duration))
195        }
196    }
197
198    /// Creates a `NonZeroDuration` from the given duration, panicking if it's zero.
199    pub fn new_panic(duration: Duration) -> Self {
200        Self::new(duration).expect("duration must be non-zero")
201    }
202
203    /// Returns the wrapped `Duration`.
204    pub fn get(self) -> Duration {
205        self.0
206    }
207}
208
209impl From<NonZeroDuration> for Duration {
210    fn from(nz_duration: NonZeroDuration) -> Self {
211        nz_duration.0
212    }
213}
214
215/// A macro to create a `NonZeroDuration` from a duration, panicking if the duration is zero.
216#[macro_export]
217macro_rules! NZDuration {
218    ($val:expr) => {
219        // This will panic at runtime if $val is zero.
220        $crate::NonZeroDuration::new_panic($val)
221    };
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use num_bigint::BigUint;
228    use rand::{rngs::StdRng, Rng, SeedableRng};
229
230    #[test]
231    fn test_hex() {
232        // Test case 0: empty bytes
233        let b = &[];
234        let h = hex(b);
235        assert_eq!(h, "");
236        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
237
238        // Test case 1: single byte
239        let b = &[0x01];
240        let h = hex(b);
241        assert_eq!(h, "01");
242        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
243
244        // Test case 2: multiple bytes
245        let b = &[0x01, 0x02, 0x03];
246        let h = hex(b);
247        assert_eq!(h, "010203");
248        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
249
250        // Test case 3: odd number of bytes
251        let h = "0102030";
252        assert!(from_hex(h).is_none());
253
254        // Test case 4: invalid hexadecimal character
255        let h = "01g3";
256        assert!(from_hex(h).is_none());
257    }
258
259    #[test]
260    fn test_from_hex_formatted() {
261        // Test case 0: empty bytes
262        let b = &[];
263        let h = hex(b);
264        assert_eq!(h, "");
265        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
266
267        // Test case 1: single byte
268        let b = &[0x01];
269        let h = hex(b);
270        assert_eq!(h, "01");
271        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
272
273        // Test case 2: multiple bytes
274        let b = &[0x01, 0x02, 0x03];
275        let h = hex(b);
276        assert_eq!(h, "010203");
277        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
278
279        // Test case 3: odd number of bytes
280        let h = "0102030";
281        assert!(from_hex_formatted(h).is_none());
282
283        // Test case 4: invalid hexadecimal character
284        let h = "01g3";
285        assert!(from_hex_formatted(h).is_none());
286
287        // Test case 5: whitespace
288        let h = "01 02 03";
289        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
290
291        // Test case 6: 0x prefix
292        let h = "0x010203";
293        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
294
295        // Test case 7: 0x prefix + different whitespace chars
296        let h = "    \n\n0x\r\n01
297                            02\t03\n";
298        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
299    }
300
301    #[test]
302    fn test_from_hex_utf8_char_boundaries() {
303        const MISALIGNMENT_CASE: &str = "쀘\n";
304
305        // Ensure that `from_hex` can handle misaligned UTF-8 character boundaries.
306        let b = from_hex(MISALIGNMENT_CASE);
307        assert!(b.is_none());
308    }
309
310    #[test]
311    fn test_max_faults_zero() {
312        assert_eq!(max_faults(0), 0);
313    }
314
315    #[test]
316    #[should_panic]
317    fn test_quorum_zero() {
318        quorum(0);
319    }
320
321    #[test]
322    fn test_quorum_and_max_faults() {
323        // n, expected_f, expected_q
324        let test_cases = [
325            (1, 0, 1),
326            (2, 0, 2),
327            (3, 0, 3),
328            (4, 1, 3),
329            (5, 1, 4),
330            (6, 1, 5),
331            (7, 2, 5),
332            (8, 2, 6),
333            (9, 2, 7),
334            (10, 3, 7),
335            (11, 3, 8),
336            (12, 3, 9),
337            (13, 4, 9),
338            (14, 4, 10),
339            (15, 4, 11),
340            (16, 5, 11),
341            (17, 5, 12),
342            (18, 5, 13),
343            (19, 6, 13),
344            (20, 6, 14),
345            (21, 6, 15),
346        ];
347
348        for (n, ef, eq) in test_cases {
349            assert_eq!(max_faults(n), ef);
350            assert_eq!(quorum(n), eq);
351            assert_eq!(n, ef + eq);
352        }
353    }
354
355    #[test]
356    fn test_union() {
357        // Test case 0: empty slices
358        assert_eq!(union(&[], &[]), []);
359
360        // Test case 1: empty and non-empty slices
361        assert_eq!(union(&[], &[0x01, 0x02, 0x03]), [0x01, 0x02, 0x03]);
362
363        // Test case 2: non-empty and non-empty slices
364        assert_eq!(
365            union(&[0x01, 0x02, 0x03], &[0x04, 0x05, 0x06]),
366            [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]
367        );
368    }
369
370    #[test]
371    fn test_union_unique() {
372        let namespace = b"namespace";
373        let msg = b"message";
374
375        let length_encoding = vec![0b0000_1001];
376        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
377        expected.extend_from_slice(&length_encoding);
378        expected.extend_from_slice(namespace);
379        expected.extend_from_slice(msg);
380
381        let result = union_unique(namespace, msg);
382        assert_eq!(result, expected);
383        assert_eq!(result.len(), result.capacity());
384    }
385
386    #[test]
387    fn test_union_unique_zero_length() {
388        let namespace = b"";
389        let msg = b"message";
390
391        let length_encoding = vec![0];
392        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
393        expected.extend_from_slice(&length_encoding);
394        expected.extend_from_slice(msg);
395
396        let result = union_unique(namespace, msg);
397        assert_eq!(result, expected);
398        assert_eq!(result.len(), result.capacity());
399    }
400
401    #[test]
402    fn test_union_unique_long_length() {
403        // Use a namespace of over length 127.
404        let namespace = &b"n".repeat(256);
405        let msg = b"message";
406
407        let length_encoding = vec![0b1000_0000, 0b0000_0010];
408        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
409        expected.extend_from_slice(&length_encoding);
410        expected.extend_from_slice(namespace);
411        expected.extend_from_slice(msg);
412
413        let result = union_unique(namespace, msg);
414        assert_eq!(result, expected);
415        assert_eq!(result.len(), result.capacity());
416    }
417
418    #[test]
419    fn test_modulo() {
420        // Test case 0: empty bytes
421        assert_eq!(modulo(&[], 1), 0);
422
423        // Test case 1: single byte
424        assert_eq!(modulo(&[0x01], 1), 0);
425
426        // Test case 2: multiple bytes
427        assert_eq!(modulo(&[0x01, 0x02, 0x03], 10), 1);
428
429        // Test case 3: check equivalence with BigUint
430        for i in 0..100 {
431            let mut rng = StdRng::seed_from_u64(i);
432            let bytes: [u8; 32] = rng.gen();
433
434            // 1-byte modulus
435            let n = 11u64;
436            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
437            let utils_modulo = modulo(&bytes, n);
438            assert_eq!(big_modulo, BigUint::from(utils_modulo));
439
440            // 2-byte modulus
441            let n = 11_111u64;
442            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
443            let utils_modulo = modulo(&bytes, n);
444            assert_eq!(big_modulo, BigUint::from(utils_modulo));
445
446            // 8-byte modulus
447            let n = 0xDFFFFFFFFFFFFFFD;
448            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
449            let utils_modulo = modulo(&bytes, n);
450            assert_eq!(big_modulo, BigUint::from(utils_modulo));
451        }
452    }
453
454    #[test]
455    #[should_panic]
456    fn test_modulo_zero_panics() {
457        modulo(&[0x01, 0x02, 0x03], 0);
458    }
459
460    #[test]
461    fn test_non_zero_macros() {
462        // Test case 0: zero value
463        assert!(std::panic::catch_unwind(|| NZUsize!(0)).is_err());
464        assert!(std::panic::catch_unwind(|| NZU32!(0)).is_err());
465        assert!(std::panic::catch_unwind(|| NZU64!(0)).is_err());
466        assert!(std::panic::catch_unwind(|| NZDuration!(Duration::ZERO)).is_err());
467
468        // Test case 1: non-zero value
469        assert_eq!(NZUsize!(1).get(), 1);
470        assert_eq!(NZU32!(2).get(), 2);
471        assert_eq!(NZU64!(3).get(), 3);
472        assert_eq!(
473            NZDuration!(Duration::from_secs(1)).get(),
474            Duration::from_secs(1)
475        );
476    }
477
478    #[test]
479    fn test_non_zero_duration() {
480        // Test case 0: zero duration
481        assert!(NonZeroDuration::new(Duration::ZERO).is_none());
482
483        // Test case 1: non-zero duration
484        let duration = Duration::from_millis(100);
485        let nz_duration = NonZeroDuration::new(duration).unwrap();
486        assert_eq!(nz_duration.get(), duration);
487        assert_eq!(Duration::from(nz_duration), duration);
488
489        // Test case 2: panic on zero
490        assert!(std::panic::catch_unwind(|| NonZeroDuration::new_panic(Duration::ZERO)).is_err());
491
492        // Test case 3: ordering
493        let d1 = NonZeroDuration::new(Duration::from_millis(100)).unwrap();
494        let d2 = NonZeroDuration::new(Duration::from_millis(200)).unwrap();
495        assert!(d1 < d2);
496    }
497}