commonware_utils/
lib.rs

1//! Leverage common functionality across multiple primitives.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7#![cfg_attr(not(feature = "std"), no_std)]
8
9#[cfg(not(feature = "std"))]
10extern crate alloc;
11
12#[cfg(not(feature = "std"))]
13use alloc::{boxed::Box, string::String, vec::Vec};
14use bytes::{BufMut, BytesMut};
15use commonware_codec::{EncodeSize, Write};
16use core::{
17    fmt::{Debug, Write as FmtWrite},
18    time::Duration,
19};
20
21pub mod sequence;
22pub use sequence::{Array, Span};
23#[cfg(feature = "std")]
24pub mod acknowledgement;
25#[cfg(feature = "std")]
26pub use acknowledgement::Acknowledgement;
27pub mod bitmap;
28#[cfg(feature = "std")]
29pub mod channels;
30pub mod hex_literal;
31pub mod hostname;
32pub use hostname::Hostname;
33#[cfg(feature = "std")]
34pub mod net;
35pub mod ordered;
36pub mod vec;
37
38/// A type that can be constructed from an iterator, possibly failing.
39pub trait TryFromIterator<T>: Sized {
40    /// The error type returned when construction fails.
41    type Error;
42
43    /// Attempts to construct `Self` from an iterator.
44    fn try_from_iter<I: IntoIterator<Item = T>>(iter: I) -> Result<Self, Self::Error>;
45}
46
47/// Extension trait for iterators that provides fallible collection.
48pub trait TryCollect: Iterator + Sized {
49    /// Attempts to collect elements into a collection that may fail.
50    fn try_collect<C: TryFromIterator<Self::Item>>(self) -> Result<C, C::Error> {
51        C::try_from_iter(self)
52    }
53}
54
55impl<I: Iterator> TryCollect for I {}
56#[cfg(feature = "std")]
57pub use net::IpAddrExt;
58#[cfg(feature = "std")]
59pub mod time;
60#[cfg(feature = "std")]
61pub use time::{DurationExt, SystemTimeExt};
62#[cfg(feature = "std")]
63pub mod rational;
64#[cfg(feature = "std")]
65pub use rational::BigRationalExt;
66#[cfg(feature = "std")]
67mod priority_set;
68#[cfg(feature = "std")]
69pub use priority_set::PrioritySet;
70#[cfg(feature = "std")]
71pub mod futures;
72mod stable_buf;
73pub use stable_buf::StableBuf;
74#[cfg(feature = "std")]
75pub mod concurrency;
76
77/// Alias for boxed errors that are `Send` and `Sync`.
78pub type BoxedError = Box<dyn core::error::Error + Send + Sync>;
79
80/// Converts bytes to a hexadecimal string.
81pub fn hex(bytes: &[u8]) -> String {
82    let mut hex = String::new();
83    for byte in bytes.iter() {
84        write!(hex, "{byte:02x}").expect("writing to string should never fail");
85    }
86    hex
87}
88
89/// Converts a hexadecimal string to bytes.
90pub fn from_hex(hex: &str) -> Option<Vec<u8>> {
91    let bytes = hex.as_bytes();
92    if !bytes.len().is_multiple_of(2) {
93        return None;
94    }
95
96    bytes
97        .chunks_exact(2)
98        .map(|chunk| {
99            let hi = decode_hex_digit(chunk[0])?;
100            let lo = decode_hex_digit(chunk[1])?;
101            Some((hi << 4) | lo)
102        })
103        .collect()
104}
105
106#[inline]
107const fn decode_hex_digit(byte: u8) -> Option<u8> {
108    match byte {
109        b'0'..=b'9' => Some(byte - b'0'),
110        b'a'..=b'f' => Some(byte - b'a' + 10),
111        b'A'..=b'F' => Some(byte - b'A' + 10),
112        _ => None,
113    }
114}
115
116/// Converts a hexadecimal string to bytes, stripping whitespace and/or a `0x` prefix. Commonly used
117/// in testing to encode external test vectors without modification.
118pub fn from_hex_formatted(hex: &str) -> Option<Vec<u8>> {
119    let hex = hex.replace(['\t', '\n', '\r', ' '], "");
120    let res = hex.strip_prefix("0x").unwrap_or(&hex);
121    from_hex(res)
122}
123
124/// Compute the maximum number of `f` (faults) that can be tolerated for a given set of `n`
125/// participants. This is the maximum integer `f` such that `n >= 3*f + 1`. `f` may be zero.
126pub const fn max_faults(n: u32) -> u32 {
127    n.saturating_sub(1) / 3
128}
129
130/// Compute the quorum size for a given set of `n` participants. This is the minimum integer `q`
131/// such that `3*q >= 2*n + 1`. It is also equal to `n - f`, where `f` is the maximum number of
132/// faults.
133///
134/// # Panics
135///
136/// Panics if `n` is zero.
137pub fn quorum(n: u32) -> u32 {
138    assert!(n > 0, "n must not be zero");
139    n - max_faults(n)
140}
141
142/// Compute the quorum size for a given slice.
143///
144/// # Panics
145///
146/// Panics if the slice length is greater than [u32::MAX].
147pub fn quorum_from_slice<T>(slice: &[T]) -> u32 {
148    let n: u32 = slice
149        .len()
150        .try_into()
151        .expect("slice length must be less than u32::MAX");
152    quorum(n)
153}
154
155/// Computes the union of two byte slices.
156pub fn union(a: &[u8], b: &[u8]) -> Vec<u8> {
157    let mut union = Vec::with_capacity(a.len() + b.len());
158    union.extend_from_slice(a);
159    union.extend_from_slice(b);
160    union
161}
162
163/// Concatenate a namespace and a message, prepended by a varint encoding of the namespace length.
164///
165/// This produces a unique byte sequence (i.e. no collisions) for each `(namespace, msg)` pair.
166pub fn union_unique(namespace: &[u8], msg: &[u8]) -> Vec<u8> {
167    let len_prefix = namespace.len();
168    let mut buf = BytesMut::with_capacity(len_prefix.encode_size() + namespace.len() + msg.len());
169    len_prefix.write(&mut buf);
170    BufMut::put_slice(&mut buf, namespace);
171    BufMut::put_slice(&mut buf, msg);
172    buf.into()
173}
174
175/// Compute the modulo of bytes interpreted as a big-endian integer.
176///
177/// This function is used to select a random entry from an array when the bytes are a random seed.
178///
179/// # Panics
180///
181/// Panics if `n` is zero.
182pub fn modulo(bytes: &[u8], n: u64) -> u64 {
183    assert_ne!(n, 0, "modulus must be non-zero");
184
185    let n = n as u128;
186    let mut result = 0u128;
187    for &byte in bytes {
188        result = (result << 8) | (byte as u128);
189        result %= n;
190    }
191
192    // Result is either 0 or modulo `n`, so we can safely cast to u64
193    result as u64
194}
195
196/// A macro to create a `NonZeroUsize` from a value, panicking if the value is zero.
197/// For literal values, validation occurs at compile time. For expressions, validation
198/// occurs at runtime.
199#[macro_export]
200macro_rules! NZUsize {
201    ($val:literal) => {
202        const { core::num::NonZeroUsize::new($val).expect("value must be non-zero") }
203    };
204    ($val:expr) => {
205        // This will panic at runtime if $val is zero.
206        core::num::NonZeroUsize::new($val).expect("value must be non-zero")
207    };
208}
209
210/// A macro to create a `NonZeroU8` from a value, panicking if the value is zero.
211/// For literal values, validation occurs at compile time. For expressions, validation
212/// occurs at runtime.
213#[macro_export]
214macro_rules! NZU8 {
215    ($val:literal) => {
216        const { core::num::NonZeroU8::new($val).expect("value must be non-zero") }
217    };
218    ($val:expr) => {
219        // This will panic at runtime if $val is zero.
220        core::num::NonZeroU8::new($val).expect("value must be non-zero")
221    };
222}
223
224/// A macro to create a `NonZeroU16` from a value, panicking if the value is zero.
225/// For literal values, validation occurs at compile time. For expressions, validation
226/// occurs at runtime.
227#[macro_export]
228macro_rules! NZU16 {
229    ($val:literal) => {
230        const { core::num::NonZeroU16::new($val).expect("value must be non-zero") }
231    };
232    ($val:expr) => {
233        // This will panic at runtime if $val is zero.
234        core::num::NonZeroU16::new($val).expect("value must be non-zero")
235    };
236}
237
238/// A macro to create a `NonZeroU32` from a value, panicking if the value is zero.
239/// For literal values, validation occurs at compile time. For expressions, validation
240/// occurs at runtime.
241#[macro_export]
242macro_rules! NZU32 {
243    ($val:literal) => {
244        const { core::num::NonZeroU32::new($val).expect("value must be non-zero") }
245    };
246    ($val:expr) => {
247        // This will panic at runtime if $val is zero.
248        core::num::NonZeroU32::new($val).expect("value must be non-zero")
249    };
250}
251
252/// A macro to create a `NonZeroU64` from a value, panicking if the value is zero.
253/// For literal values, validation occurs at compile time. For expressions, validation
254/// occurs at runtime.
255#[macro_export]
256macro_rules! NZU64 {
257    ($val:literal) => {
258        const { core::num::NonZeroU64::new($val).expect("value must be non-zero") }
259    };
260    ($val:expr) => {
261        // This will panic at runtime if $val is zero.
262        core::num::NonZeroU64::new($val).expect("value must be non-zero")
263    };
264}
265
266/// A wrapper around `Duration` that guarantees the duration is non-zero.
267#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
268pub struct NonZeroDuration(Duration);
269
270impl NonZeroDuration {
271    /// Creates a `NonZeroDuration` if the given duration is non-zero.
272    pub fn new(duration: Duration) -> Option<Self> {
273        if duration == Duration::ZERO {
274            None
275        } else {
276            Some(Self(duration))
277        }
278    }
279
280    /// Creates a `NonZeroDuration` from the given duration, panicking if it's zero.
281    pub fn new_panic(duration: Duration) -> Self {
282        Self::new(duration).expect("duration must be non-zero")
283    }
284
285    /// Returns the wrapped `Duration`.
286    pub const fn get(self) -> Duration {
287        self.0
288    }
289}
290
291impl From<NonZeroDuration> for Duration {
292    fn from(nz_duration: NonZeroDuration) -> Self {
293        nz_duration.0
294    }
295}
296
297/// A macro to create a `NonZeroDuration` from a duration, panicking if the duration is zero.
298#[macro_export]
299macro_rules! NZDuration {
300    ($val:expr) => {
301        // This will panic at runtime if $val is zero.
302        $crate::NonZeroDuration::new_panic($val)
303    };
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use num_bigint::BigUint;
310    use rand::{rngs::StdRng, Rng, SeedableRng};
311    use rstest::rstest;
312
313    #[test]
314    fn test_hex() {
315        // Test case 0: empty bytes
316        let b = &[];
317        let h = hex(b);
318        assert_eq!(h, "");
319        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
320
321        // Test case 1: single byte
322        let b = &hex!("0x01");
323        let h = hex(b);
324        assert_eq!(h, "01");
325        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
326
327        // Test case 2: multiple bytes
328        let b = &hex!("0x010203");
329        let h = hex(b);
330        assert_eq!(h, "010203");
331        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
332
333        // Test case 3: odd number of bytes
334        let h = "0102030";
335        assert!(from_hex(h).is_none());
336
337        // Test case 4: invalid hexadecimal character
338        let h = "01g3";
339        assert!(from_hex(h).is_none());
340
341        // Test case 5: invalid `+` in string
342        let h = "+123";
343        assert!(from_hex(h).is_none());
344
345        // Test case 6: empty string
346        assert_eq!(from_hex(""), Some(vec![]));
347    }
348
349    #[test]
350    fn test_from_hex_formatted() {
351        // Test case 0: empty bytes
352        let b = &[];
353        let h = hex(b);
354        assert_eq!(h, "");
355        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
356
357        // Test case 1: single byte
358        let b = &hex!("0x01");
359        let h = hex(b);
360        assert_eq!(h, "01");
361        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
362
363        // Test case 2: multiple bytes
364        let b = &hex!("0x010203");
365        let h = hex(b);
366        assert_eq!(h, "010203");
367        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
368
369        // Test case 3: odd number of bytes
370        let h = "0102030";
371        assert!(from_hex_formatted(h).is_none());
372
373        // Test case 4: invalid hexadecimal character
374        let h = "01g3";
375        assert!(from_hex_formatted(h).is_none());
376
377        // Test case 5: whitespace
378        let h = "01 02 03";
379        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
380
381        // Test case 6: 0x prefix
382        let h = "0x010203";
383        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
384
385        // Test case 7: 0x prefix + different whitespace chars
386        let h = "    \n\n0x\r\n01
387                            02\t03\n";
388        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
389    }
390
391    #[test]
392    fn test_from_hex_utf8_char_boundaries() {
393        const MISALIGNMENT_CASE: &str = "쀘\n";
394
395        // Ensure that `from_hex` can handle misaligned UTF-8 character boundaries.
396        let b = from_hex(MISALIGNMENT_CASE);
397        assert!(b.is_none());
398    }
399
400    #[test]
401    fn test_max_faults_zero() {
402        assert_eq!(max_faults(0), 0);
403    }
404
405    #[test]
406    #[should_panic]
407    fn test_quorum_zero() {
408        quorum(0);
409    }
410
411    #[rstest]
412    #[case(1, 0, 1)]
413    #[case(2, 0, 2)]
414    #[case(3, 0, 3)]
415    #[case(4, 1, 3)]
416    #[case(5, 1, 4)]
417    #[case(6, 1, 5)]
418    #[case(7, 2, 5)]
419    #[case(8, 2, 6)]
420    #[case(9, 2, 7)]
421    #[case(10, 3, 7)]
422    #[case(11, 3, 8)]
423    #[case(12, 3, 9)]
424    #[case(13, 4, 9)]
425    #[case(14, 4, 10)]
426    #[case(15, 4, 11)]
427    #[case(16, 5, 11)]
428    #[case(17, 5, 12)]
429    #[case(18, 5, 13)]
430    #[case(19, 6, 13)]
431    #[case(20, 6, 14)]
432    #[case(21, 6, 15)]
433    fn test_quorum_and_max_faults(
434        #[case] n: u32,
435        #[case] expected_f: u32,
436        #[case] expected_q: u32,
437    ) {
438        assert_eq!(max_faults(n), expected_f);
439        assert_eq!(quorum(n), expected_q);
440        assert_eq!(n, expected_f + expected_q);
441    }
442
443    #[test]
444    fn test_union() {
445        // Test case 0: empty slices
446        assert_eq!(union(&[], &[]), []);
447
448        // Test case 1: empty and non-empty slices
449        assert_eq!(union(&[], &hex!("0x010203")), hex!("0x010203"));
450
451        // Test case 2: non-empty and non-empty slices
452        assert_eq!(
453            union(&hex!("0x010203"), &hex!("0x040506")),
454            hex!("0x010203040506")
455        );
456    }
457
458    #[test]
459    fn test_union_unique() {
460        let namespace = b"namespace";
461        let msg = b"message";
462
463        let length_encoding = vec![0b0000_1001];
464        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
465        expected.extend_from_slice(&length_encoding);
466        expected.extend_from_slice(namespace);
467        expected.extend_from_slice(msg);
468
469        let result = union_unique(namespace, msg);
470        assert_eq!(result, expected);
471        assert_eq!(result.len(), result.capacity());
472    }
473
474    #[test]
475    fn test_union_unique_zero_length() {
476        let namespace = b"";
477        let msg = b"message";
478
479        let length_encoding = vec![0];
480        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
481        expected.extend_from_slice(&length_encoding);
482        expected.extend_from_slice(msg);
483
484        let result = union_unique(namespace, msg);
485        assert_eq!(result, expected);
486        assert_eq!(result.len(), result.capacity());
487    }
488
489    #[test]
490    fn test_union_unique_long_length() {
491        // Use a namespace of over length 127.
492        let namespace = &b"n".repeat(256);
493        let msg = b"message";
494
495        let length_encoding = vec![0b1000_0000, 0b0000_0010];
496        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
497        expected.extend_from_slice(&length_encoding);
498        expected.extend_from_slice(namespace);
499        expected.extend_from_slice(msg);
500
501        let result = union_unique(namespace, msg);
502        assert_eq!(result, expected);
503        assert_eq!(result.len(), result.capacity());
504    }
505
506    #[test]
507    fn test_modulo() {
508        // Test case 0: empty bytes
509        assert_eq!(modulo(&[], 1), 0);
510
511        // Test case 1: single byte
512        assert_eq!(modulo(&hex!("0x01"), 1), 0);
513
514        // Test case 2: multiple bytes
515        assert_eq!(modulo(&hex!("0x010203"), 10), 1);
516
517        // Test case 3: check equivalence with BigUint
518        for i in 0..100 {
519            let mut rng = StdRng::seed_from_u64(i);
520            let bytes: [u8; 32] = rng.gen();
521
522            // 1-byte modulus
523            let n = 11u64;
524            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
525            let utils_modulo = modulo(&bytes, n);
526            assert_eq!(big_modulo, BigUint::from(utils_modulo));
527
528            // 2-byte modulus
529            let n = 11_111u64;
530            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
531            let utils_modulo = modulo(&bytes, n);
532            assert_eq!(big_modulo, BigUint::from(utils_modulo));
533
534            // 8-byte modulus
535            let n = 0xDFFFFFFFFFFFFFFD;
536            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
537            let utils_modulo = modulo(&bytes, n);
538            assert_eq!(big_modulo, BigUint::from(utils_modulo));
539        }
540    }
541
542    #[test]
543    #[should_panic]
544    fn test_modulo_zero_panics() {
545        modulo(&hex!("0x010203"), 0);
546    }
547
548    #[test]
549    fn test_non_zero_macros_compile_time() {
550        // Literal values are validated at compile time.
551        // NZU32!(0) would be a compile error.
552        assert_eq!(NZUsize!(1).get(), 1);
553        assert_eq!(NZU8!(2).get(), 2);
554        assert_eq!(NZU16!(3).get(), 3);
555        assert_eq!(NZU32!(4).get(), 4);
556        assert_eq!(NZU64!(5).get(), 5);
557
558        // Literals can be used in const contexts
559        const _: core::num::NonZeroUsize = NZUsize!(1);
560        const _: core::num::NonZeroU8 = NZU8!(2);
561        const _: core::num::NonZeroU16 = NZU16!(3);
562        const _: core::num::NonZeroU32 = NZU32!(4);
563        const _: core::num::NonZeroU64 = NZU64!(5);
564    }
565
566    #[test]
567    fn test_non_zero_macros_runtime() {
568        // Runtime variables are validated at runtime
569        let one_usize: usize = 1;
570        let two_u8: u8 = 2;
571        let three_u16: u16 = 3;
572        let four_u32: u32 = 4;
573        let five_u64: u64 = 5;
574
575        assert_eq!(NZUsize!(one_usize).get(), 1);
576        assert_eq!(NZU8!(two_u8).get(), 2);
577        assert_eq!(NZU16!(three_u16).get(), 3);
578        assert_eq!(NZU32!(four_u32).get(), 4);
579        assert_eq!(NZU64!(five_u64).get(), 5);
580
581        // Zero runtime values panic
582        let zero_usize: usize = 0;
583        let zero_u8: u8 = 0;
584        let zero_u16: u16 = 0;
585        let zero_u32: u32 = 0;
586        let zero_u64: u64 = 0;
587
588        assert!(std::panic::catch_unwind(|| NZUsize!(zero_usize)).is_err());
589        assert!(std::panic::catch_unwind(|| NZU8!(zero_u8)).is_err());
590        assert!(std::panic::catch_unwind(|| NZU16!(zero_u16)).is_err());
591        assert!(std::panic::catch_unwind(|| NZU32!(zero_u32)).is_err());
592        assert!(std::panic::catch_unwind(|| NZU64!(zero_u64)).is_err());
593
594        // NZDuration is runtime-only since Duration has no literal syntax
595        assert!(std::panic::catch_unwind(|| NZDuration!(Duration::ZERO)).is_err());
596        assert_eq!(
597            NZDuration!(Duration::from_secs(1)).get(),
598            Duration::from_secs(1)
599        );
600    }
601
602    #[test]
603    fn test_non_zero_duration() {
604        // Test case 0: zero duration
605        assert!(NonZeroDuration::new(Duration::ZERO).is_none());
606
607        // Test case 1: non-zero duration
608        let duration = Duration::from_millis(100);
609        let nz_duration = NonZeroDuration::new(duration).unwrap();
610        assert_eq!(nz_duration.get(), duration);
611        assert_eq!(Duration::from(nz_duration), duration);
612
613        // Test case 2: panic on zero
614        assert!(std::panic::catch_unwind(|| NonZeroDuration::new_panic(Duration::ZERO)).is_err());
615
616        // Test case 3: ordering
617        let d1 = NonZeroDuration::new(Duration::from_millis(100)).unwrap();
618        let d2 = NonZeroDuration::new(Duration::from_millis(200)).unwrap();
619        assert!(d1 < d2);
620    }
621}