commonware_utils/
lib.rs

1//! Leverage common functionality across multiple primitives.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7#![cfg_attr(not(feature = "std"), no_std)]
8
9#[cfg(not(feature = "std"))]
10extern crate alloc;
11
12#[cfg(not(feature = "std"))]
13use alloc::{string::String, vec::Vec};
14use bytes::{BufMut, BytesMut};
15use commonware_codec::{EncodeSize, Write};
16use core::{
17    fmt::{Debug, Write as FmtWrite},
18    time::Duration,
19};
20
21pub mod sequence;
22pub use sequence::{Array, Span};
23mod bitvec;
24pub use bitvec::{BitIterator, BitVec};
25#[cfg(feature = "std")]
26pub mod channels;
27#[cfg(feature = "std")]
28mod time;
29#[cfg(feature = "std")]
30pub use time::{parse_duration, SystemTimeExt};
31#[cfg(feature = "std")]
32mod priority_set;
33#[cfg(feature = "std")]
34pub use priority_set::PrioritySet;
35#[cfg(feature = "std")]
36pub mod futures;
37mod stable_buf;
38pub use stable_buf::StableBuf;
39
40/// Converts bytes to a hexadecimal string.
41pub fn hex(bytes: &[u8]) -> String {
42    let mut hex = String::new();
43    for byte in bytes.iter() {
44        write!(hex, "{byte:02x}").expect("writing to string should never fail");
45    }
46    hex
47}
48
49/// Converts a hexadecimal string to bytes.
50pub fn from_hex(hex: &str) -> Option<Vec<u8>> {
51    let bytes = hex.as_bytes();
52    if bytes.len() % 2 != 0 {
53        return None;
54    }
55
56    bytes
57        .chunks_exact(2)
58        .map(|chunk| {
59            let hi = decode_hex_digit(chunk[0])?;
60            let lo = decode_hex_digit(chunk[1])?;
61            Some((hi << 4) | lo)
62        })
63        .collect()
64}
65
66#[inline]
67fn decode_hex_digit(byte: u8) -> Option<u8> {
68    match byte {
69        b'0'..=b'9' => Some(byte - b'0'),
70        b'a'..=b'f' => Some(byte - b'a' + 10),
71        b'A'..=b'F' => Some(byte - b'A' + 10),
72        _ => None,
73    }
74}
75
76/// Converts a hexadecimal string to bytes, stripping whitespace and/or a `0x` prefix. Commonly used
77/// in testing to encode external test vectors without modification.
78pub fn from_hex_formatted(hex: &str) -> Option<Vec<u8>> {
79    let hex = hex.replace(['\t', '\n', '\r', ' '], "");
80    let res = hex.strip_prefix("0x").unwrap_or(&hex);
81    from_hex(res)
82}
83
84/// Compute the maximum number of `f` (faults) that can be tolerated for a given set of `n`
85/// participants. This is the maximum integer `f` such that `n >= 3*f + 1`. `f` may be zero.
86pub fn max_faults(n: u32) -> u32 {
87    n.saturating_sub(1) / 3
88}
89
90/// Compute the quorum size for a given set of `n` participants. This is the minimum integer `q`
91/// such that `3*q >= 2*n + 1`. It is also equal to `n - f`, where `f` is the maximum number of
92/// faults.
93///
94/// # Panics
95///
96/// Panics if `n` is zero.
97pub fn quorum(n: u32) -> u32 {
98    assert!(n > 0, "n must not be zero");
99    n - max_faults(n)
100}
101
102/// Compute the quorum size for a given slice.
103///
104/// # Panics
105///
106/// Panics if the slice length is greater than [u32::MAX].
107pub fn quorum_from_slice<T>(slice: &[T]) -> u32 {
108    let n: u32 = slice
109        .len()
110        .try_into()
111        .expect("slice length must be less than u32::MAX");
112    quorum(n)
113}
114
115/// Computes the union of two byte slices.
116pub fn union(a: &[u8], b: &[u8]) -> Vec<u8> {
117    let mut union = Vec::with_capacity(a.len() + b.len());
118    union.extend_from_slice(a);
119    union.extend_from_slice(b);
120    union
121}
122
123/// Concatenate a namespace and a message, prepended by a varint encoding of the namespace length.
124///
125/// This produces a unique byte sequence (i.e. no collisions) for each `(namespace, msg)` pair.
126pub fn union_unique(namespace: &[u8], msg: &[u8]) -> Vec<u8> {
127    let len_prefix = namespace.len();
128    let mut buf = BytesMut::with_capacity(len_prefix.encode_size() + namespace.len() + msg.len());
129    len_prefix.write(&mut buf);
130    BufMut::put_slice(&mut buf, namespace);
131    BufMut::put_slice(&mut buf, msg);
132    buf.into()
133}
134
135/// Compute the modulo of bytes interpreted as a big-endian integer.
136///
137/// This function is used to select a random entry from an array when the bytes are a random seed.
138///
139/// # Panics
140///
141/// Panics if `n` is zero.
142pub fn modulo(bytes: &[u8], n: u64) -> u64 {
143    assert_ne!(n, 0, "modulus must be non-zero");
144
145    let n = n as u128;
146    let mut result = 0u128;
147    for &byte in bytes {
148        result = (result << 8) | (byte as u128);
149        result %= n;
150    }
151
152    // Result is either 0 or modulo `n`, so we can safely cast to u64
153    result as u64
154}
155
156/// A macro to create a `NonZeroUsize` from a value, panicking if the value is zero.
157#[macro_export]
158macro_rules! NZUsize {
159    ($val:expr) => {
160        // This will panic at runtime if $val is zero.
161        // For literals, the compiler *might* optimize, but the check is still conceptually there.
162        core::num::NonZeroUsize::new($val).expect("value must be non-zero")
163    };
164}
165
166/// A macro to create a `NonZeroU8` from a value, panicking if the value is zero.
167#[macro_export]
168macro_rules! NZU8 {
169    ($val:expr) => {
170        core::num::NonZeroU8::new($val).expect("value must be non-zero")
171    };
172}
173
174/// A macro to create a `NonZeroU16` from a value, panicking if the value is zero.
175#[macro_export]
176macro_rules! NZU16 {
177    ($val:expr) => {
178        core::num::NonZeroU16::new($val).expect("value must be non-zero")
179    };
180}
181
182/// A macro to create a `NonZeroU32` from a value, panicking if the value is zero.
183#[macro_export]
184macro_rules! NZU32 {
185    ($val:expr) => {
186        // This will panic at runtime if $val is zero.
187        // For literals, the compiler *might* optimize, but the check is still conceptually there.
188        core::num::NonZeroU32::new($val).expect("value must be non-zero")
189    };
190}
191
192/// A macro to create a `NonZeroU64` from a value, panicking if the value is zero.
193#[macro_export]
194macro_rules! NZU64 {
195    ($val:expr) => {
196        // This will panic at runtime if $val is zero.
197        // For literals, the compiler *might* optimize, but the check is still conceptually there.
198        core::num::NonZeroU64::new($val).expect("value must be non-zero")
199    };
200}
201
202/// A wrapper around `Duration` that guarantees the duration is non-zero.
203#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
204pub struct NonZeroDuration(Duration);
205
206impl NonZeroDuration {
207    /// Creates a `NonZeroDuration` if the given duration is non-zero.
208    pub fn new(duration: Duration) -> Option<Self> {
209        if duration == Duration::ZERO {
210            None
211        } else {
212            Some(Self(duration))
213        }
214    }
215
216    /// Creates a `NonZeroDuration` from the given duration, panicking if it's zero.
217    pub fn new_panic(duration: Duration) -> Self {
218        Self::new(duration).expect("duration must be non-zero")
219    }
220
221    /// Returns the wrapped `Duration`.
222    pub fn get(self) -> Duration {
223        self.0
224    }
225}
226
227impl From<NonZeroDuration> for Duration {
228    fn from(nz_duration: NonZeroDuration) -> Self {
229        nz_duration.0
230    }
231}
232
233/// A macro to create a `NonZeroDuration` from a duration, panicking if the duration is zero.
234#[macro_export]
235macro_rules! NZDuration {
236    ($val:expr) => {
237        // This will panic at runtime if $val is zero.
238        $crate::NonZeroDuration::new_panic($val)
239    };
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245    use num_bigint::BigUint;
246    use rand::{rngs::StdRng, Rng, SeedableRng};
247
248    #[test]
249    fn test_hex() {
250        // Test case 0: empty bytes
251        let b = &[];
252        let h = hex(b);
253        assert_eq!(h, "");
254        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
255
256        // Test case 1: single byte
257        let b = &[0x01];
258        let h = hex(b);
259        assert_eq!(h, "01");
260        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
261
262        // Test case 2: multiple bytes
263        let b = &[0x01, 0x02, 0x03];
264        let h = hex(b);
265        assert_eq!(h, "010203");
266        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
267
268        // Test case 3: odd number of bytes
269        let h = "0102030";
270        assert!(from_hex(h).is_none());
271
272        // Test case 4: invalid hexadecimal character
273        let h = "01g3";
274        assert!(from_hex(h).is_none());
275
276        // Test case 5: invalid `+` in string
277        let h = "+123";
278        assert!(from_hex(h).is_none());
279
280        // Test case 6: empty string
281        assert_eq!(from_hex(""), Some(vec![]));
282    }
283
284    #[test]
285    fn test_from_hex_formatted() {
286        // Test case 0: empty bytes
287        let b = &[];
288        let h = hex(b);
289        assert_eq!(h, "");
290        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
291
292        // Test case 1: single byte
293        let b = &[0x01];
294        let h = hex(b);
295        assert_eq!(h, "01");
296        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
297
298        // Test case 2: multiple bytes
299        let b = &[0x01, 0x02, 0x03];
300        let h = hex(b);
301        assert_eq!(h, "010203");
302        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
303
304        // Test case 3: odd number of bytes
305        let h = "0102030";
306        assert!(from_hex_formatted(h).is_none());
307
308        // Test case 4: invalid hexadecimal character
309        let h = "01g3";
310        assert!(from_hex_formatted(h).is_none());
311
312        // Test case 5: whitespace
313        let h = "01 02 03";
314        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
315
316        // Test case 6: 0x prefix
317        let h = "0x010203";
318        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
319
320        // Test case 7: 0x prefix + different whitespace chars
321        let h = "    \n\n0x\r\n01
322                            02\t03\n";
323        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
324    }
325
326    #[test]
327    fn test_from_hex_utf8_char_boundaries() {
328        const MISALIGNMENT_CASE: &str = "쀘\n";
329
330        // Ensure that `from_hex` can handle misaligned UTF-8 character boundaries.
331        let b = from_hex(MISALIGNMENT_CASE);
332        assert!(b.is_none());
333    }
334
335    #[test]
336    fn test_max_faults_zero() {
337        assert_eq!(max_faults(0), 0);
338    }
339
340    #[test]
341    #[should_panic]
342    fn test_quorum_zero() {
343        quorum(0);
344    }
345
346    #[test]
347    fn test_quorum_and_max_faults() {
348        // n, expected_f, expected_q
349        let test_cases = [
350            (1, 0, 1),
351            (2, 0, 2),
352            (3, 0, 3),
353            (4, 1, 3),
354            (5, 1, 4),
355            (6, 1, 5),
356            (7, 2, 5),
357            (8, 2, 6),
358            (9, 2, 7),
359            (10, 3, 7),
360            (11, 3, 8),
361            (12, 3, 9),
362            (13, 4, 9),
363            (14, 4, 10),
364            (15, 4, 11),
365            (16, 5, 11),
366            (17, 5, 12),
367            (18, 5, 13),
368            (19, 6, 13),
369            (20, 6, 14),
370            (21, 6, 15),
371        ];
372
373        for (n, ef, eq) in test_cases {
374            assert_eq!(max_faults(n), ef);
375            assert_eq!(quorum(n), eq);
376            assert_eq!(n, ef + eq);
377        }
378    }
379
380    #[test]
381    fn test_union() {
382        // Test case 0: empty slices
383        assert_eq!(union(&[], &[]), []);
384
385        // Test case 1: empty and non-empty slices
386        assert_eq!(union(&[], &[0x01, 0x02, 0x03]), [0x01, 0x02, 0x03]);
387
388        // Test case 2: non-empty and non-empty slices
389        assert_eq!(
390            union(&[0x01, 0x02, 0x03], &[0x04, 0x05, 0x06]),
391            [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]
392        );
393    }
394
395    #[test]
396    fn test_union_unique() {
397        let namespace = b"namespace";
398        let msg = b"message";
399
400        let length_encoding = vec![0b0000_1001];
401        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
402        expected.extend_from_slice(&length_encoding);
403        expected.extend_from_slice(namespace);
404        expected.extend_from_slice(msg);
405
406        let result = union_unique(namespace, msg);
407        assert_eq!(result, expected);
408        assert_eq!(result.len(), result.capacity());
409    }
410
411    #[test]
412    fn test_union_unique_zero_length() {
413        let namespace = b"";
414        let msg = b"message";
415
416        let length_encoding = vec![0];
417        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
418        expected.extend_from_slice(&length_encoding);
419        expected.extend_from_slice(msg);
420
421        let result = union_unique(namespace, msg);
422        assert_eq!(result, expected);
423        assert_eq!(result.len(), result.capacity());
424    }
425
426    #[test]
427    fn test_union_unique_long_length() {
428        // Use a namespace of over length 127.
429        let namespace = &b"n".repeat(256);
430        let msg = b"message";
431
432        let length_encoding = vec![0b1000_0000, 0b0000_0010];
433        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
434        expected.extend_from_slice(&length_encoding);
435        expected.extend_from_slice(namespace);
436        expected.extend_from_slice(msg);
437
438        let result = union_unique(namespace, msg);
439        assert_eq!(result, expected);
440        assert_eq!(result.len(), result.capacity());
441    }
442
443    #[test]
444    fn test_modulo() {
445        // Test case 0: empty bytes
446        assert_eq!(modulo(&[], 1), 0);
447
448        // Test case 1: single byte
449        assert_eq!(modulo(&[0x01], 1), 0);
450
451        // Test case 2: multiple bytes
452        assert_eq!(modulo(&[0x01, 0x02, 0x03], 10), 1);
453
454        // Test case 3: check equivalence with BigUint
455        for i in 0..100 {
456            let mut rng = StdRng::seed_from_u64(i);
457            let bytes: [u8; 32] = rng.gen();
458
459            // 1-byte modulus
460            let n = 11u64;
461            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
462            let utils_modulo = modulo(&bytes, n);
463            assert_eq!(big_modulo, BigUint::from(utils_modulo));
464
465            // 2-byte modulus
466            let n = 11_111u64;
467            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
468            let utils_modulo = modulo(&bytes, n);
469            assert_eq!(big_modulo, BigUint::from(utils_modulo));
470
471            // 8-byte modulus
472            let n = 0xDFFFFFFFFFFFFFFD;
473            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
474            let utils_modulo = modulo(&bytes, n);
475            assert_eq!(big_modulo, BigUint::from(utils_modulo));
476        }
477    }
478
479    #[test]
480    #[should_panic]
481    fn test_modulo_zero_panics() {
482        modulo(&[0x01, 0x02, 0x03], 0);
483    }
484
485    #[test]
486    fn test_non_zero_macros() {
487        // Test case 0: zero value
488        assert!(std::panic::catch_unwind(|| NZUsize!(0)).is_err());
489        assert!(std::panic::catch_unwind(|| NZU32!(0)).is_err());
490        assert!(std::panic::catch_unwind(|| NZU64!(0)).is_err());
491        assert!(std::panic::catch_unwind(|| NZDuration!(Duration::ZERO)).is_err());
492
493        // Test case 1: non-zero value
494        assert_eq!(NZUsize!(1).get(), 1);
495        assert_eq!(NZU32!(2).get(), 2);
496        assert_eq!(NZU64!(3).get(), 3);
497        assert_eq!(
498            NZDuration!(Duration::from_secs(1)).get(),
499            Duration::from_secs(1)
500        );
501    }
502
503    #[test]
504    fn test_non_zero_duration() {
505        // Test case 0: zero duration
506        assert!(NonZeroDuration::new(Duration::ZERO).is_none());
507
508        // Test case 1: non-zero duration
509        let duration = Duration::from_millis(100);
510        let nz_duration = NonZeroDuration::new(duration).unwrap();
511        assert_eq!(nz_duration.get(), duration);
512        assert_eq!(Duration::from(nz_duration), duration);
513
514        // Test case 2: panic on zero
515        assert!(std::panic::catch_unwind(|| NonZeroDuration::new_panic(Duration::ZERO)).is_err());
516
517        // Test case 3: ordering
518        let d1 = NonZeroDuration::new(Duration::from_millis(100)).unwrap();
519        let d2 = NonZeroDuration::new(Duration::from_millis(200)).unwrap();
520        assert!(d1 < d2);
521    }
522}