commonware_utils/
lib.rs

1//! Leverage common functionality across multiple primitives.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7#![cfg_attr(not(feature = "std"), no_std)]
8
9#[cfg(not(feature = "std"))]
10extern crate alloc;
11
12#[cfg(not(feature = "std"))]
13use alloc::{string::String, vec::Vec};
14use bytes::{BufMut, BytesMut};
15use commonware_codec::{EncodeSize, Write};
16use core::{
17    fmt::{Debug, Write as FmtWrite},
18    time::Duration,
19};
20
21pub mod sequence;
22pub use sequence::{Array, Span};
23pub mod bitmap;
24#[cfg(feature = "std")]
25pub mod channels;
26pub mod hex_literal;
27#[cfg(feature = "std")]
28pub mod net;
29pub mod set;
30#[cfg(feature = "std")]
31pub use net::IpAddrExt;
32#[cfg(feature = "std")]
33pub mod time;
34#[cfg(feature = "std")]
35pub use time::{DurationExt, SystemTimeExt};
36#[cfg(feature = "std")]
37pub mod rational;
38#[cfg(feature = "std")]
39pub use rational::BigRationalExt;
40#[cfg(feature = "std")]
41mod priority_set;
42#[cfg(feature = "std")]
43pub use priority_set::PrioritySet;
44#[cfg(feature = "std")]
45pub mod futures;
46mod stable_buf;
47pub use stable_buf::StableBuf;
48#[cfg(feature = "std")]
49pub mod concurrency;
50
51/// Converts bytes to a hexadecimal string.
52pub fn hex(bytes: &[u8]) -> String {
53    let mut hex = String::new();
54    for byte in bytes.iter() {
55        write!(hex, "{byte:02x}").expect("writing to string should never fail");
56    }
57    hex
58}
59
60/// Converts a hexadecimal string to bytes.
61pub fn from_hex(hex: &str) -> Option<Vec<u8>> {
62    let bytes = hex.as_bytes();
63    if !bytes.len().is_multiple_of(2) {
64        return None;
65    }
66
67    bytes
68        .chunks_exact(2)
69        .map(|chunk| {
70            let hi = decode_hex_digit(chunk[0])?;
71            let lo = decode_hex_digit(chunk[1])?;
72            Some((hi << 4) | lo)
73        })
74        .collect()
75}
76
77#[inline]
78fn decode_hex_digit(byte: u8) -> Option<u8> {
79    match byte {
80        b'0'..=b'9' => Some(byte - b'0'),
81        b'a'..=b'f' => Some(byte - b'a' + 10),
82        b'A'..=b'F' => Some(byte - b'A' + 10),
83        _ => None,
84    }
85}
86
87/// Converts a hexadecimal string to bytes, stripping whitespace and/or a `0x` prefix. Commonly used
88/// in testing to encode external test vectors without modification.
89pub fn from_hex_formatted(hex: &str) -> Option<Vec<u8>> {
90    let hex = hex.replace(['\t', '\n', '\r', ' '], "");
91    let res = hex.strip_prefix("0x").unwrap_or(&hex);
92    from_hex(res)
93}
94
95/// Compute the maximum number of `f` (faults) that can be tolerated for a given set of `n`
96/// participants. This is the maximum integer `f` such that `n >= 3*f + 1`. `f` may be zero.
97pub fn max_faults(n: u32) -> u32 {
98    n.saturating_sub(1) / 3
99}
100
101/// Compute the quorum size for a given set of `n` participants. This is the minimum integer `q`
102/// such that `3*q >= 2*n + 1`. It is also equal to `n - f`, where `f` is the maximum number of
103/// faults.
104///
105/// # Panics
106///
107/// Panics if `n` is zero.
108pub fn quorum(n: u32) -> u32 {
109    assert!(n > 0, "n must not be zero");
110    n - max_faults(n)
111}
112
113/// Compute the quorum size for a given slice.
114///
115/// # Panics
116///
117/// Panics if the slice length is greater than [u32::MAX].
118pub fn quorum_from_slice<T>(slice: &[T]) -> u32 {
119    let n: u32 = slice
120        .len()
121        .try_into()
122        .expect("slice length must be less than u32::MAX");
123    quorum(n)
124}
125
126/// Computes the union of two byte slices.
127pub fn union(a: &[u8], b: &[u8]) -> Vec<u8> {
128    let mut union = Vec::with_capacity(a.len() + b.len());
129    union.extend_from_slice(a);
130    union.extend_from_slice(b);
131    union
132}
133
134/// Concatenate a namespace and a message, prepended by a varint encoding of the namespace length.
135///
136/// This produces a unique byte sequence (i.e. no collisions) for each `(namespace, msg)` pair.
137pub fn union_unique(namespace: &[u8], msg: &[u8]) -> Vec<u8> {
138    let len_prefix = namespace.len();
139    let mut buf = BytesMut::with_capacity(len_prefix.encode_size() + namespace.len() + msg.len());
140    len_prefix.write(&mut buf);
141    BufMut::put_slice(&mut buf, namespace);
142    BufMut::put_slice(&mut buf, msg);
143    buf.into()
144}
145
146/// Compute the modulo of bytes interpreted as a big-endian integer.
147///
148/// This function is used to select a random entry from an array when the bytes are a random seed.
149///
150/// # Panics
151///
152/// Panics if `n` is zero.
153pub fn modulo(bytes: &[u8], n: u64) -> u64 {
154    assert_ne!(n, 0, "modulus must be non-zero");
155
156    let n = n as u128;
157    let mut result = 0u128;
158    for &byte in bytes {
159        result = (result << 8) | (byte as u128);
160        result %= n;
161    }
162
163    // Result is either 0 or modulo `n`, so we can safely cast to u64
164    result as u64
165}
166
167/// A macro to create a `NonZeroUsize` from a value, panicking if the value is zero.
168#[macro_export]
169macro_rules! NZUsize {
170    ($val:expr) => {
171        // This will panic at runtime if $val is zero.
172        // For literals, the compiler *might* optimize, but the check is still conceptually there.
173        core::num::NonZeroUsize::new($val).expect("value must be non-zero")
174    };
175}
176
177/// A macro to create a `NonZeroU8` from a value, panicking if the value is zero.
178#[macro_export]
179macro_rules! NZU8 {
180    ($val:expr) => {
181        core::num::NonZeroU8::new($val).expect("value must be non-zero")
182    };
183}
184
185/// A macro to create a `NonZeroU16` from a value, panicking if the value is zero.
186#[macro_export]
187macro_rules! NZU16 {
188    ($val:expr) => {
189        core::num::NonZeroU16::new($val).expect("value must be non-zero")
190    };
191}
192
193/// A macro to create a `NonZeroU32` from a value, panicking if the value is zero.
194#[macro_export]
195macro_rules! NZU32 {
196    ($val:expr) => {
197        // This will panic at runtime if $val is zero.
198        // For literals, the compiler *might* optimize, but the check is still conceptually there.
199        core::num::NonZeroU32::new($val).expect("value must be non-zero")
200    };
201}
202
203/// A macro to create a `NonZeroU64` from a value, panicking if the value is zero.
204#[macro_export]
205macro_rules! NZU64 {
206    ($val:expr) => {
207        // This will panic at runtime if $val is zero.
208        // For literals, the compiler *might* optimize, but the check is still conceptually there.
209        core::num::NonZeroU64::new($val).expect("value must be non-zero")
210    };
211}
212
213/// A wrapper around `Duration` that guarantees the duration is non-zero.
214#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
215pub struct NonZeroDuration(Duration);
216
217impl NonZeroDuration {
218    /// Creates a `NonZeroDuration` if the given duration is non-zero.
219    pub fn new(duration: Duration) -> Option<Self> {
220        if duration == Duration::ZERO {
221            None
222        } else {
223            Some(Self(duration))
224        }
225    }
226
227    /// Creates a `NonZeroDuration` from the given duration, panicking if it's zero.
228    pub fn new_panic(duration: Duration) -> Self {
229        Self::new(duration).expect("duration must be non-zero")
230    }
231
232    /// Returns the wrapped `Duration`.
233    pub fn get(self) -> Duration {
234        self.0
235    }
236}
237
238impl From<NonZeroDuration> for Duration {
239    fn from(nz_duration: NonZeroDuration) -> Self {
240        nz_duration.0
241    }
242}
243
244/// A macro to create a `NonZeroDuration` from a duration, panicking if the duration is zero.
245#[macro_export]
246macro_rules! NZDuration {
247    ($val:expr) => {
248        // This will panic at runtime if $val is zero.
249        $crate::NonZeroDuration::new_panic($val)
250    };
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use num_bigint::BigUint;
257    use rand::{rngs::StdRng, Rng, SeedableRng};
258
259    #[test]
260    fn test_hex() {
261        // Test case 0: empty bytes
262        let b = &[];
263        let h = hex(b);
264        assert_eq!(h, "");
265        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
266
267        // Test case 1: single byte
268        let b = &hex!("0x01");
269        let h = hex(b);
270        assert_eq!(h, "01");
271        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
272
273        // Test case 2: multiple bytes
274        let b = &hex!("0x010203");
275        let h = hex(b);
276        assert_eq!(h, "010203");
277        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
278
279        // Test case 3: odd number of bytes
280        let h = "0102030";
281        assert!(from_hex(h).is_none());
282
283        // Test case 4: invalid hexadecimal character
284        let h = "01g3";
285        assert!(from_hex(h).is_none());
286
287        // Test case 5: invalid `+` in string
288        let h = "+123";
289        assert!(from_hex(h).is_none());
290
291        // Test case 6: empty string
292        assert_eq!(from_hex(""), Some(vec![]));
293    }
294
295    #[test]
296    fn test_from_hex_formatted() {
297        // Test case 0: empty bytes
298        let b = &[];
299        let h = hex(b);
300        assert_eq!(h, "");
301        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
302
303        // Test case 1: single byte
304        let b = &hex!("0x01");
305        let h = hex(b);
306        assert_eq!(h, "01");
307        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
308
309        // Test case 2: multiple bytes
310        let b = &hex!("0x010203");
311        let h = hex(b);
312        assert_eq!(h, "010203");
313        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
314
315        // Test case 3: odd number of bytes
316        let h = "0102030";
317        assert!(from_hex_formatted(h).is_none());
318
319        // Test case 4: invalid hexadecimal character
320        let h = "01g3";
321        assert!(from_hex_formatted(h).is_none());
322
323        // Test case 5: whitespace
324        let h = "01 02 03";
325        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
326
327        // Test case 6: 0x prefix
328        let h = "0x010203";
329        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
330
331        // Test case 7: 0x prefix + different whitespace chars
332        let h = "    \n\n0x\r\n01
333                            02\t03\n";
334        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
335    }
336
337    #[test]
338    fn test_from_hex_utf8_char_boundaries() {
339        const MISALIGNMENT_CASE: &str = "쀘\n";
340
341        // Ensure that `from_hex` can handle misaligned UTF-8 character boundaries.
342        let b = from_hex(MISALIGNMENT_CASE);
343        assert!(b.is_none());
344    }
345
346    #[test]
347    fn test_max_faults_zero() {
348        assert_eq!(max_faults(0), 0);
349    }
350
351    #[test]
352    #[should_panic]
353    fn test_quorum_zero() {
354        quorum(0);
355    }
356
357    #[test]
358    fn test_quorum_and_max_faults() {
359        // n, expected_f, expected_q
360        let test_cases = [
361            (1, 0, 1),
362            (2, 0, 2),
363            (3, 0, 3),
364            (4, 1, 3),
365            (5, 1, 4),
366            (6, 1, 5),
367            (7, 2, 5),
368            (8, 2, 6),
369            (9, 2, 7),
370            (10, 3, 7),
371            (11, 3, 8),
372            (12, 3, 9),
373            (13, 4, 9),
374            (14, 4, 10),
375            (15, 4, 11),
376            (16, 5, 11),
377            (17, 5, 12),
378            (18, 5, 13),
379            (19, 6, 13),
380            (20, 6, 14),
381            (21, 6, 15),
382        ];
383
384        for (n, ef, eq) in test_cases {
385            assert_eq!(max_faults(n), ef);
386            assert_eq!(quorum(n), eq);
387            assert_eq!(n, ef + eq);
388        }
389    }
390
391    #[test]
392    fn test_union() {
393        // Test case 0: empty slices
394        assert_eq!(union(&[], &[]), []);
395
396        // Test case 1: empty and non-empty slices
397        assert_eq!(union(&[], &hex!("0x010203")), hex!("0x010203"));
398
399        // Test case 2: non-empty and non-empty slices
400        assert_eq!(
401            union(&hex!("0x010203"), &hex!("0x040506")),
402            hex!("0x010203040506")
403        );
404    }
405
406    #[test]
407    fn test_union_unique() {
408        let namespace = b"namespace";
409        let msg = b"message";
410
411        let length_encoding = vec![0b0000_1001];
412        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
413        expected.extend_from_slice(&length_encoding);
414        expected.extend_from_slice(namespace);
415        expected.extend_from_slice(msg);
416
417        let result = union_unique(namespace, msg);
418        assert_eq!(result, expected);
419        assert_eq!(result.len(), result.capacity());
420    }
421
422    #[test]
423    fn test_union_unique_zero_length() {
424        let namespace = b"";
425        let msg = b"message";
426
427        let length_encoding = vec![0];
428        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
429        expected.extend_from_slice(&length_encoding);
430        expected.extend_from_slice(msg);
431
432        let result = union_unique(namespace, msg);
433        assert_eq!(result, expected);
434        assert_eq!(result.len(), result.capacity());
435    }
436
437    #[test]
438    fn test_union_unique_long_length() {
439        // Use a namespace of over length 127.
440        let namespace = &b"n".repeat(256);
441        let msg = b"message";
442
443        let length_encoding = vec![0b1000_0000, 0b0000_0010];
444        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
445        expected.extend_from_slice(&length_encoding);
446        expected.extend_from_slice(namespace);
447        expected.extend_from_slice(msg);
448
449        let result = union_unique(namespace, msg);
450        assert_eq!(result, expected);
451        assert_eq!(result.len(), result.capacity());
452    }
453
454    #[test]
455    fn test_modulo() {
456        // Test case 0: empty bytes
457        assert_eq!(modulo(&[], 1), 0);
458
459        // Test case 1: single byte
460        assert_eq!(modulo(&hex!("0x01"), 1), 0);
461
462        // Test case 2: multiple bytes
463        assert_eq!(modulo(&hex!("0x010203"), 10), 1);
464
465        // Test case 3: check equivalence with BigUint
466        for i in 0..100 {
467            let mut rng = StdRng::seed_from_u64(i);
468            let bytes: [u8; 32] = rng.gen();
469
470            // 1-byte modulus
471            let n = 11u64;
472            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
473            let utils_modulo = modulo(&bytes, n);
474            assert_eq!(big_modulo, BigUint::from(utils_modulo));
475
476            // 2-byte modulus
477            let n = 11_111u64;
478            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
479            let utils_modulo = modulo(&bytes, n);
480            assert_eq!(big_modulo, BigUint::from(utils_modulo));
481
482            // 8-byte modulus
483            let n = 0xDFFFFFFFFFFFFFFD;
484            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
485            let utils_modulo = modulo(&bytes, n);
486            assert_eq!(big_modulo, BigUint::from(utils_modulo));
487        }
488    }
489
490    #[test]
491    #[should_panic]
492    fn test_modulo_zero_panics() {
493        modulo(&hex!("0x010203"), 0);
494    }
495
496    #[test]
497    fn test_non_zero_macros() {
498        // Test case 0: zero value
499        assert!(std::panic::catch_unwind(|| NZUsize!(0)).is_err());
500        assert!(std::panic::catch_unwind(|| NZU32!(0)).is_err());
501        assert!(std::panic::catch_unwind(|| NZU64!(0)).is_err());
502        assert!(std::panic::catch_unwind(|| NZDuration!(Duration::ZERO)).is_err());
503
504        // Test case 1: non-zero value
505        assert_eq!(NZUsize!(1).get(), 1);
506        assert_eq!(NZU32!(2).get(), 2);
507        assert_eq!(NZU64!(3).get(), 3);
508        assert_eq!(
509            NZDuration!(Duration::from_secs(1)).get(),
510            Duration::from_secs(1)
511        );
512    }
513
514    #[test]
515    fn test_non_zero_duration() {
516        // Test case 0: zero duration
517        assert!(NonZeroDuration::new(Duration::ZERO).is_none());
518
519        // Test case 1: non-zero duration
520        let duration = Duration::from_millis(100);
521        let nz_duration = NonZeroDuration::new(duration).unwrap();
522        assert_eq!(nz_duration.get(), duration);
523        assert_eq!(Duration::from(nz_duration), duration);
524
525        // Test case 2: panic on zero
526        assert!(std::panic::catch_unwind(|| NonZeroDuration::new_panic(Duration::ZERO)).is_err());
527
528        // Test case 3: ordering
529        let d1 = NonZeroDuration::new(Duration::from_millis(100)).unwrap();
530        let d2 = NonZeroDuration::new(Duration::from_millis(200)).unwrap();
531        assert!(d1 < d2);
532    }
533}