commonware_utils/
lib.rs

1//! Leverage common functionality across multiple primitives.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7#![cfg_attr(not(feature = "std"), no_std)]
8
9#[cfg(not(feature = "std"))]
10extern crate alloc;
11
12#[cfg(not(feature = "std"))]
13use alloc::{boxed::Box, string::String, vec::Vec};
14use bytes::{Buf, BufMut, BytesMut};
15use commonware_codec::{varint::UInt, EncodeSize, Error as CodecError, Read, ReadExt, Write};
16use core::{
17    fmt::{Debug, Write as FmtWrite},
18    time::Duration,
19};
20
21pub mod faults;
22pub use faults::{Faults, N3f1, N5f1};
23pub mod sequence;
24pub use sequence::{Array, Span};
25#[cfg(feature = "std")]
26pub mod acknowledgement;
27#[cfg(feature = "std")]
28pub use acknowledgement::Acknowledgement;
29pub mod bitmap;
30#[cfg(feature = "std")]
31pub mod channels;
32pub mod hex_literal;
33pub mod hostname;
34pub use hostname::Hostname;
35#[cfg(feature = "std")]
36pub mod net;
37pub mod ordered;
38pub mod vec;
39
40/// Represents a participant/validator index within a consensus committee.
41///
42/// Participant indices are used to identify validators in attestations,
43/// votes, and certificates. The index corresponds to the position of the
44/// validator's public key in the ordered participant set.
45#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
46#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
47pub struct Participant(u32);
48
49impl Participant {
50    /// Creates a new participant from a u32 index.
51    pub const fn new(index: u32) -> Self {
52        Self(index)
53    }
54
55    /// Creates a new participant from a usize index.
56    ///
57    /// # Panics
58    ///
59    /// Panics if `index` exceeds `u32::MAX`.
60    pub fn from_usize(index: usize) -> Self {
61        Self(u32::try_from(index).expect("participant index exceeds u32::MAX"))
62    }
63
64    /// Returns the underlying u32 index.
65    pub const fn get(self) -> u32 {
66        self.0
67    }
68}
69
70impl From<Participant> for usize {
71    fn from(p: Participant) -> Self {
72        p.0 as Self
73    }
74}
75
76impl core::fmt::Display for Participant {
77    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
78        write!(f, "{}", self.0)
79    }
80}
81
82impl Read for Participant {
83    type Cfg = ();
84
85    fn read_cfg(buf: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
86        let value: u32 = UInt::read(buf)?.into();
87        Ok(Self(value))
88    }
89}
90
91impl Write for Participant {
92    fn write(&self, buf: &mut impl BufMut) {
93        UInt(self.0).write(buf);
94    }
95}
96
97impl EncodeSize for Participant {
98    fn encode_size(&self) -> usize {
99        UInt(self.0).encode_size()
100    }
101}
102
103/// A type that can be constructed from an iterator, possibly failing.
104pub trait TryFromIterator<T>: Sized {
105    /// The error type returned when construction fails.
106    type Error;
107
108    /// Attempts to construct `Self` from an iterator.
109    fn try_from_iter<I: IntoIterator<Item = T>>(iter: I) -> Result<Self, Self::Error>;
110}
111
112/// Extension trait for iterators that provides fallible collection.
113pub trait TryCollect: Iterator + Sized {
114    /// Attempts to collect elements into a collection that may fail.
115    fn try_collect<C: TryFromIterator<Self::Item>>(self) -> Result<C, C::Error> {
116        C::try_from_iter(self)
117    }
118}
119
120impl<I: Iterator> TryCollect for I {}
121#[cfg(feature = "std")]
122pub use net::IpAddrExt;
123#[cfg(feature = "std")]
124pub mod time;
125#[cfg(feature = "std")]
126pub use time::{DurationExt, SystemTimeExt};
127#[cfg(feature = "std")]
128pub mod rational;
129#[cfg(feature = "std")]
130pub use rational::BigRationalExt;
131#[cfg(feature = "std")]
132mod priority_set;
133#[cfg(feature = "std")]
134pub use priority_set::PrioritySet;
135#[cfg(feature = "std")]
136pub mod futures;
137mod stable_buf;
138pub use stable_buf::StableBuf;
139#[cfg(feature = "std")]
140pub mod concurrency;
141
142/// Returns a seeded RNG for deterministic testing.
143///
144/// Uses seed 0 by default to ensure reproducible test results.
145#[cfg(feature = "std")]
146pub fn test_rng() -> rand::rngs::StdRng {
147    rand::SeedableRng::seed_from_u64(0)
148}
149
150/// Returns a seeded RNG with a custom seed for deterministic testing.
151///
152/// Use this when you need multiple independent RNG streams in the same test,
153/// or when a helper function needs its own RNG that won't collide with the caller's.
154#[cfg(feature = "std")]
155pub fn test_rng_seeded(seed: u64) -> rand::rngs::StdRng {
156    rand::SeedableRng::seed_from_u64(seed)
157}
158
159/// Alias for boxed errors that are `Send` and `Sync`.
160pub type BoxedError = Box<dyn core::error::Error + Send + Sync>;
161
162/// Converts bytes to a hexadecimal string.
163pub fn hex(bytes: &[u8]) -> String {
164    let mut hex = String::new();
165    for byte in bytes.iter() {
166        write!(hex, "{byte:02x}").expect("writing to string should never fail");
167    }
168    hex
169}
170
171/// Converts a hexadecimal string to bytes.
172pub fn from_hex(hex: &str) -> Option<Vec<u8>> {
173    let bytes = hex.as_bytes();
174    if !bytes.len().is_multiple_of(2) {
175        return None;
176    }
177
178    bytes
179        .chunks_exact(2)
180        .map(|chunk| {
181            let hi = decode_hex_digit(chunk[0])?;
182            let lo = decode_hex_digit(chunk[1])?;
183            Some((hi << 4) | lo)
184        })
185        .collect()
186}
187
188#[inline]
189const fn decode_hex_digit(byte: u8) -> Option<u8> {
190    match byte {
191        b'0'..=b'9' => Some(byte - b'0'),
192        b'a'..=b'f' => Some(byte - b'a' + 10),
193        b'A'..=b'F' => Some(byte - b'A' + 10),
194        _ => None,
195    }
196}
197
198/// Converts a hexadecimal string to bytes, stripping whitespace and/or a `0x` prefix. Commonly used
199/// in testing to encode external test vectors without modification.
200pub fn from_hex_formatted(hex: &str) -> Option<Vec<u8>> {
201    let hex = hex.replace(['\t', '\n', '\r', ' '], "");
202    let res = hex.strip_prefix("0x").unwrap_or(&hex);
203    from_hex(res)
204}
205
206/// Computes the union of two byte slices.
207pub fn union(a: &[u8], b: &[u8]) -> Vec<u8> {
208    let mut union = Vec::with_capacity(a.len() + b.len());
209    union.extend_from_slice(a);
210    union.extend_from_slice(b);
211    union
212}
213
214/// Concatenate a namespace and a message, prepended by a varint encoding of the namespace length.
215///
216/// This produces a unique byte sequence (i.e. no collisions) for each `(namespace, msg)` pair.
217pub fn union_unique(namespace: &[u8], msg: &[u8]) -> Vec<u8> {
218    let len_prefix = namespace.len();
219    let mut buf = BytesMut::with_capacity(len_prefix.encode_size() + namespace.len() + msg.len());
220    len_prefix.write(&mut buf);
221    BufMut::put_slice(&mut buf, namespace);
222    BufMut::put_slice(&mut buf, msg);
223    buf.into()
224}
225
226/// Compute the modulo of bytes interpreted as a big-endian integer.
227///
228/// This function is used to select a random entry from an array when the bytes are a random seed.
229///
230/// # Panics
231///
232/// Panics if `n` is zero.
233pub fn modulo(bytes: &[u8], n: u64) -> u64 {
234    assert_ne!(n, 0, "modulus must be non-zero");
235
236    let n = n as u128;
237    let mut result = 0u128;
238    for &byte in bytes {
239        result = (result << 8) | (byte as u128);
240        result %= n;
241    }
242
243    // Result is either 0 or modulo `n`, so we can safely cast to u64
244    result as u64
245}
246
247/// A macro to create a `NonZeroUsize` from a value, panicking if the value is zero.
248/// For literal values, validation occurs at compile time. For expressions, validation
249/// occurs at runtime.
250#[macro_export]
251macro_rules! NZUsize {
252    ($val:literal) => {
253        const { ::core::num::NonZeroUsize::new($val).expect("value must be non-zero") }
254    };
255    ($val:expr) => {
256        // This will panic at runtime if $val is zero.
257        ::core::num::NonZeroUsize::new($val).expect("value must be non-zero")
258    };
259}
260
261/// A macro to create a `NonZeroU8` from a value, panicking if the value is zero.
262/// For literal values, validation occurs at compile time. For expressions, validation
263/// occurs at runtime.
264#[macro_export]
265macro_rules! NZU8 {
266    ($val:literal) => {
267        const { ::core::num::NonZeroU8::new($val).expect("value must be non-zero") }
268    };
269    ($val:expr) => {
270        // This will panic at runtime if $val is zero.
271        ::core::num::NonZeroU8::new($val).expect("value must be non-zero")
272    };
273}
274
275/// A macro to create a `NonZeroU16` from a value, panicking if the value is zero.
276/// For literal values, validation occurs at compile time. For expressions, validation
277/// occurs at runtime.
278#[macro_export]
279macro_rules! NZU16 {
280    ($val:literal) => {
281        const { ::core::num::NonZeroU16::new($val).expect("value must be non-zero") }
282    };
283    ($val:expr) => {
284        // This will panic at runtime if $val is zero.
285        ::core::num::NonZeroU16::new($val).expect("value must be non-zero")
286    };
287}
288
289/// A macro to create a `NonZeroU32` from a value, panicking if the value is zero.
290/// For literal values, validation occurs at compile time. For expressions, validation
291/// occurs at runtime.
292#[macro_export]
293macro_rules! NZU32 {
294    ($val:literal) => {
295        const { ::core::num::NonZeroU32::new($val).expect("value must be non-zero") }
296    };
297    ($val:expr) => {
298        // This will panic at runtime if $val is zero.
299        ::core::num::NonZeroU32::new($val).expect("value must be non-zero")
300    };
301}
302
303/// A macro to create a `NonZeroU64` from a value, panicking if the value is zero.
304/// For literal values, validation occurs at compile time. For expressions, validation
305/// occurs at runtime.
306#[macro_export]
307macro_rules! NZU64 {
308    ($val:literal) => {
309        const { ::core::num::NonZeroU64::new($val).expect("value must be non-zero") }
310    };
311    ($val:expr) => {
312        // This will panic at runtime if $val is zero.
313        ::core::num::NonZeroU64::new($val).expect("value must be non-zero")
314    };
315}
316
317/// A wrapper around `Duration` that guarantees the duration is non-zero.
318#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
319pub struct NonZeroDuration(Duration);
320
321impl NonZeroDuration {
322    /// Creates a `NonZeroDuration` if the given duration is non-zero.
323    pub fn new(duration: Duration) -> Option<Self> {
324        if duration == Duration::ZERO {
325            None
326        } else {
327            Some(Self(duration))
328        }
329    }
330
331    /// Creates a `NonZeroDuration` from the given duration, panicking if it's zero.
332    pub fn new_panic(duration: Duration) -> Self {
333        Self::new(duration).expect("duration must be non-zero")
334    }
335
336    /// Returns the wrapped `Duration`.
337    pub const fn get(self) -> Duration {
338        self.0
339    }
340}
341
342impl From<NonZeroDuration> for Duration {
343    fn from(nz_duration: NonZeroDuration) -> Self {
344        nz_duration.0
345    }
346}
347
348/// A macro to create a `NonZeroDuration` from a duration, panicking if the duration is zero.
349#[macro_export]
350macro_rules! NZDuration {
351    ($val:expr) => {
352        // This will panic at runtime if $val is zero.
353        $crate::NonZeroDuration::new_panic($val)
354    };
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use num_bigint::BigUint;
361    use rand::{rngs::StdRng, Rng, SeedableRng};
362
363    #[test]
364    fn test_hex() {
365        // Test case 0: empty bytes
366        let b = &[];
367        let h = hex(b);
368        assert_eq!(h, "");
369        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
370
371        // Test case 1: single byte
372        let b = &hex!("0x01");
373        let h = hex(b);
374        assert_eq!(h, "01");
375        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
376
377        // Test case 2: multiple bytes
378        let b = &hex!("0x010203");
379        let h = hex(b);
380        assert_eq!(h, "010203");
381        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
382
383        // Test case 3: odd number of bytes
384        let h = "0102030";
385        assert!(from_hex(h).is_none());
386
387        // Test case 4: invalid hexadecimal character
388        let h = "01g3";
389        assert!(from_hex(h).is_none());
390
391        // Test case 5: invalid `+` in string
392        let h = "+123";
393        assert!(from_hex(h).is_none());
394
395        // Test case 6: empty string
396        assert_eq!(from_hex(""), Some(vec![]));
397    }
398
399    #[test]
400    fn test_from_hex_formatted() {
401        // Test case 0: empty bytes
402        let b = &[];
403        let h = hex(b);
404        assert_eq!(h, "");
405        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
406
407        // Test case 1: single byte
408        let b = &hex!("0x01");
409        let h = hex(b);
410        assert_eq!(h, "01");
411        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
412
413        // Test case 2: multiple bytes
414        let b = &hex!("0x010203");
415        let h = hex(b);
416        assert_eq!(h, "010203");
417        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
418
419        // Test case 3: odd number of bytes
420        let h = "0102030";
421        assert!(from_hex_formatted(h).is_none());
422
423        // Test case 4: invalid hexadecimal character
424        let h = "01g3";
425        assert!(from_hex_formatted(h).is_none());
426
427        // Test case 5: whitespace
428        let h = "01 02 03";
429        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
430
431        // Test case 6: 0x prefix
432        let h = "0x010203";
433        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
434
435        // Test case 7: 0x prefix + different whitespace chars
436        let h = "    \n\n0x\r\n01
437                            02\t03\n";
438        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
439    }
440
441    #[test]
442    fn test_from_hex_utf8_char_boundaries() {
443        const MISALIGNMENT_CASE: &str = "쀘\n";
444
445        // Ensure that `from_hex` can handle misaligned UTF-8 character boundaries.
446        let b = from_hex(MISALIGNMENT_CASE);
447        assert!(b.is_none());
448    }
449
450    #[test]
451    fn test_union() {
452        // Test case 0: empty slices
453        assert_eq!(union(&[], &[]), []);
454
455        // Test case 1: empty and non-empty slices
456        assert_eq!(union(&[], &hex!("0x010203")), hex!("0x010203"));
457
458        // Test case 2: non-empty and non-empty slices
459        assert_eq!(
460            union(&hex!("0x010203"), &hex!("0x040506")),
461            hex!("0x010203040506")
462        );
463    }
464
465    #[test]
466    fn test_union_unique() {
467        let namespace = b"namespace";
468        let msg = b"message";
469
470        let length_encoding = vec![0b0000_1001];
471        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
472        expected.extend_from_slice(&length_encoding);
473        expected.extend_from_slice(namespace);
474        expected.extend_from_slice(msg);
475
476        let result = union_unique(namespace, msg);
477        assert_eq!(result, expected);
478        assert_eq!(result.len(), result.capacity());
479    }
480
481    #[test]
482    fn test_union_unique_zero_length() {
483        let namespace = b"";
484        let msg = b"message";
485
486        let length_encoding = vec![0];
487        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
488        expected.extend_from_slice(&length_encoding);
489        expected.extend_from_slice(msg);
490
491        let result = union_unique(namespace, msg);
492        assert_eq!(result, expected);
493        assert_eq!(result.len(), result.capacity());
494    }
495
496    #[test]
497    fn test_union_unique_long_length() {
498        // Use a namespace of over length 127.
499        let namespace = &b"n".repeat(256);
500        let msg = b"message";
501
502        let length_encoding = vec![0b1000_0000, 0b0000_0010];
503        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
504        expected.extend_from_slice(&length_encoding);
505        expected.extend_from_slice(namespace);
506        expected.extend_from_slice(msg);
507
508        let result = union_unique(namespace, msg);
509        assert_eq!(result, expected);
510        assert_eq!(result.len(), result.capacity());
511    }
512
513    #[test]
514    fn test_modulo() {
515        // Test case 0: empty bytes
516        assert_eq!(modulo(&[], 1), 0);
517
518        // Test case 1: single byte
519        assert_eq!(modulo(&hex!("0x01"), 1), 0);
520
521        // Test case 2: multiple bytes
522        assert_eq!(modulo(&hex!("0x010203"), 10), 1);
523
524        // Test case 3: check equivalence with BigUint
525        for i in 0..100 {
526            let mut rng = StdRng::seed_from_u64(i);
527            let bytes: [u8; 32] = rng.gen();
528
529            // 1-byte modulus
530            let n = 11u64;
531            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
532            let utils_modulo = modulo(&bytes, n);
533            assert_eq!(big_modulo, BigUint::from(utils_modulo));
534
535            // 2-byte modulus
536            let n = 11_111u64;
537            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
538            let utils_modulo = modulo(&bytes, n);
539            assert_eq!(big_modulo, BigUint::from(utils_modulo));
540
541            // 8-byte modulus
542            let n = 0xDFFFFFFFFFFFFFFD;
543            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
544            let utils_modulo = modulo(&bytes, n);
545            assert_eq!(big_modulo, BigUint::from(utils_modulo));
546        }
547    }
548
549    #[test]
550    #[should_panic]
551    fn test_modulo_zero_panics() {
552        modulo(&hex!("0x010203"), 0);
553    }
554
555    #[test]
556    fn test_non_zero_macros_compile_time() {
557        // Literal values are validated at compile time.
558        // NZU32!(0) would be a compile error.
559        assert_eq!(NZUsize!(1).get(), 1);
560        assert_eq!(NZU8!(2).get(), 2);
561        assert_eq!(NZU16!(3).get(), 3);
562        assert_eq!(NZU32!(4).get(), 4);
563        assert_eq!(NZU64!(5).get(), 5);
564
565        // Literals can be used in const contexts
566        const _: core::num::NonZeroUsize = NZUsize!(1);
567        const _: core::num::NonZeroU8 = NZU8!(2);
568        const _: core::num::NonZeroU16 = NZU16!(3);
569        const _: core::num::NonZeroU32 = NZU32!(4);
570        const _: core::num::NonZeroU64 = NZU64!(5);
571    }
572
573    #[test]
574    fn test_non_zero_macros_runtime() {
575        // Runtime variables are validated at runtime
576        let one_usize: usize = 1;
577        let two_u8: u8 = 2;
578        let three_u16: u16 = 3;
579        let four_u32: u32 = 4;
580        let five_u64: u64 = 5;
581
582        assert_eq!(NZUsize!(one_usize).get(), 1);
583        assert_eq!(NZU8!(two_u8).get(), 2);
584        assert_eq!(NZU16!(three_u16).get(), 3);
585        assert_eq!(NZU32!(four_u32).get(), 4);
586        assert_eq!(NZU64!(five_u64).get(), 5);
587
588        // Zero runtime values panic
589        let zero_usize: usize = 0;
590        let zero_u8: u8 = 0;
591        let zero_u16: u16 = 0;
592        let zero_u32: u32 = 0;
593        let zero_u64: u64 = 0;
594
595        assert!(std::panic::catch_unwind(|| NZUsize!(zero_usize)).is_err());
596        assert!(std::panic::catch_unwind(|| NZU8!(zero_u8)).is_err());
597        assert!(std::panic::catch_unwind(|| NZU16!(zero_u16)).is_err());
598        assert!(std::panic::catch_unwind(|| NZU32!(zero_u32)).is_err());
599        assert!(std::panic::catch_unwind(|| NZU64!(zero_u64)).is_err());
600
601        // NZDuration is runtime-only since Duration has no literal syntax
602        assert!(std::panic::catch_unwind(|| NZDuration!(Duration::ZERO)).is_err());
603        assert_eq!(
604            NZDuration!(Duration::from_secs(1)).get(),
605            Duration::from_secs(1)
606        );
607    }
608
609    #[test]
610    fn test_non_zero_duration() {
611        // Test case 0: zero duration
612        assert!(NonZeroDuration::new(Duration::ZERO).is_none());
613
614        // Test case 1: non-zero duration
615        let duration = Duration::from_millis(100);
616        let nz_duration = NonZeroDuration::new(duration).unwrap();
617        assert_eq!(nz_duration.get(), duration);
618        assert_eq!(Duration::from(nz_duration), duration);
619
620        // Test case 2: panic on zero
621        assert!(std::panic::catch_unwind(|| NonZeroDuration::new_panic(Duration::ZERO)).is_err());
622
623        // Test case 3: ordering
624        let d1 = NonZeroDuration::new(Duration::from_millis(100)).unwrap();
625        let d2 = NonZeroDuration::new(Duration::from_millis(200)).unwrap();
626        assert!(d1 < d2);
627    }
628
629    #[test]
630    fn test_participant_constructors() {
631        assert_eq!(Participant::new(0).get(), 0);
632        assert_eq!(Participant::new(42).get(), 42);
633        assert_eq!(Participant::from_usize(0).get(), 0);
634        assert_eq!(Participant::from_usize(42).get(), 42);
635        assert_eq!(Participant::from_usize(u32::MAX as usize).get(), u32::MAX);
636    }
637
638    #[test]
639    #[should_panic(expected = "participant index exceeds u32::MAX")]
640    fn test_participant_from_usize_overflow() {
641        Participant::from_usize((u32::MAX as usize) + 1);
642    }
643
644    #[test]
645    fn test_participant_display() {
646        assert_eq!(format!("{}", Participant::new(0)), "0");
647        assert_eq!(format!("{}", Participant::new(42)), "42");
648        assert_eq!(format!("{}", Participant::new(1000)), "1000");
649    }
650
651    #[test]
652    fn test_participant_ordering() {
653        assert!(Participant::new(0) < Participant::new(1));
654        assert!(Participant::new(5) < Participant::new(10));
655        assert!(Participant::new(10) > Participant::new(5));
656        assert_eq!(Participant::new(42), Participant::new(42));
657    }
658
659    #[test]
660    fn test_participant_encode_decode() {
661        use commonware_codec::{DecodeExt, Encode};
662
663        let cases = vec![0u32, 1, 127, 128, 255, 256, u32::MAX];
664        for value in cases {
665            let participant = Participant::new(value);
666            let encoded = participant.encode();
667            assert_eq!(encoded.len(), participant.encode_size());
668            let decoded = Participant::decode(encoded).unwrap();
669            assert_eq!(participant, decoded);
670        }
671    }
672
673    #[cfg(feature = "arbitrary")]
674    mod conformance {
675        use super::*;
676        use commonware_codec::conformance::CodecConformance;
677
678        commonware_conformance::conformance_tests! {
679            CodecConformance<Participant>,
680        }
681    }
682}