Skip to main content

commonware_utils/
lib.rs

1//! Leverage common functionality across multiple primitives.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7#![cfg_attr(not(any(feature = "std", test)), no_std)]
8
9commonware_macros::stability_scope!(ALPHA, cfg(feature = "std") {
10    pub mod rng;
11    pub use rng::{test_rng, test_rng_seeded, FuzzRng};
12
13    pub mod thread_local;
14    pub use thread_local::Cached;
15});
16commonware_macros::stability_scope!(BETA {
17    #[cfg(not(feature = "std"))]
18    extern crate alloc;
19
20    #[cfg(not(feature = "std"))]
21    use alloc::{boxed::Box, string::String, vec::Vec};
22    use bytes::{BufMut, BytesMut};
23    use core::{fmt::Write as FmtWrite, time::Duration};
24    pub mod faults;
25    pub use faults::{Faults, N3f1, N5f1};
26
27    pub mod sequence;
28    pub use sequence::{Array, Span};
29
30    pub mod hostname;
31    pub use hostname::Hostname;
32
33    pub mod bitmap;
34    pub mod ordered;
35
36    use bytes::Buf;
37    use commonware_codec::{varint::UInt, EncodeSize, Error as CodecError, Read, ReadExt, Write};
38
39    /// Represents a participant/validator index within a consensus committee.
40    ///
41    /// Participant indices are used to identify validators in attestations,
42    /// votes, and certificates. The index corresponds to the position of the
43    /// validator's public key in the ordered participant set.
44    #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
45    #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
46    pub struct Participant(u32);
47
48    impl Participant {
49        /// Creates a new participant from a u32 index.
50        pub const fn new(index: u32) -> Self {
51            Self(index)
52        }
53
54        /// Creates a new participant from a usize index.
55        ///
56        /// # Panics
57        ///
58        /// Panics if `index` exceeds `u32::MAX`.
59        pub fn from_usize(index: usize) -> Self {
60            Self(u32::try_from(index).expect("participant index exceeds u32::MAX"))
61        }
62
63        /// Returns the underlying u32 index.
64        pub const fn get(self) -> u32 {
65            self.0
66        }
67    }
68
69    impl From<Participant> for usize {
70        fn from(p: Participant) -> Self {
71            p.0 as Self
72        }
73    }
74
75    impl core::fmt::Display for Participant {
76        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
77            write!(f, "{}", self.0)
78        }
79    }
80
81    impl Read for Participant {
82        type Cfg = ();
83
84        fn read_cfg(buf: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
85            let value: u32 = UInt::read(buf)?.into();
86            Ok(Self(value))
87        }
88    }
89
90    impl Write for Participant {
91        fn write(&self, buf: &mut impl bytes::BufMut) {
92            UInt(self.0).write(buf);
93        }
94    }
95
96    impl EncodeSize for Participant {
97        fn encode_size(&self) -> usize {
98            UInt(self.0).encode_size()
99        }
100    }
101
102    /// A type that can be constructed from an iterator, possibly failing.
103    pub trait TryFromIterator<T>: Sized {
104        /// The error type returned when construction fails.
105        type Error;
106
107        /// Attempts to construct `Self` from an iterator.
108        fn try_from_iter<I: IntoIterator<Item = T>>(iter: I) -> Result<Self, Self::Error>;
109    }
110
111    /// Extension trait for iterators that provides fallible collection.
112    pub trait TryCollect: Iterator + Sized {
113        /// Attempts to collect elements into a collection that may fail.
114        fn try_collect<C: TryFromIterator<Self::Item>>(self) -> Result<C, C::Error> {
115            C::try_from_iter(self)
116        }
117    }
118
119    impl<I: Iterator> TryCollect for I {}
120
121    /// Alias for boxed errors that are `Send` and `Sync`.
122    pub type BoxedError = Box<dyn core::error::Error + Send + Sync>;
123
124    /// Converts bytes to a hexadecimal string.
125    pub fn hex(bytes: &[u8]) -> String {
126        let mut hex = String::with_capacity(bytes.len() * 2);
127        for byte in bytes.iter() {
128            write!(hex, "{byte:02x}").expect("writing to string should never fail");
129        }
130        hex
131    }
132
133    /// Converts a hexadecimal string to bytes.
134    pub fn from_hex(hex: &str) -> Option<Vec<u8>> {
135        let bytes = hex.as_bytes();
136        if !bytes.len().is_multiple_of(2) {
137            return None;
138        }
139
140        bytes
141            .chunks_exact(2)
142            .map(|chunk| {
143                let hi = decode_hex_digit(chunk[0])?;
144                let lo = decode_hex_digit(chunk[1])?;
145                Some((hi << 4) | lo)
146            })
147            .collect()
148    }
149
150    /// Converts a hexadecimal string to bytes, stripping whitespace and/or a `0x` prefix. Commonly used
151    /// in testing to encode external test vectors without modification.
152    pub fn from_hex_formatted(hex: &str) -> Option<Vec<u8>> {
153        let hex = hex.replace(['\t', '\n', '\r', ' '], "");
154        let res = hex.strip_prefix("0x").unwrap_or(&hex);
155        from_hex(res)
156    }
157
158    /// Computes the union of two byte slices.
159    pub fn union(a: &[u8], b: &[u8]) -> Vec<u8> {
160        let mut union = Vec::with_capacity(a.len() + b.len());
161        union.extend_from_slice(a);
162        union.extend_from_slice(b);
163        union
164    }
165
166    /// Concatenate a namespace and a message, prepended by a varint encoding of the namespace length.
167    ///
168    /// This produces a unique byte sequence (i.e. no collisions) for each `(namespace, msg)` pair.
169    pub fn union_unique(namespace: &[u8], msg: &[u8]) -> Vec<u8> {
170        use commonware_codec::EncodeSize;
171        let len_prefix = namespace.len();
172        let mut buf =
173            BytesMut::with_capacity(len_prefix.encode_size() + namespace.len() + msg.len());
174        len_prefix.write(&mut buf);
175        BufMut::put_slice(&mut buf, namespace);
176        BufMut::put_slice(&mut buf, msg);
177        buf.into()
178    }
179
180    /// Compute the modulo of bytes interpreted as a big-endian integer.
181    ///
182    /// This function is used to select a random entry from an array when the bytes are a random seed.
183    ///
184    /// # Panics
185    ///
186    /// Panics if `n` is zero.
187    pub fn modulo(bytes: &[u8], n: u64) -> u64 {
188        assert_ne!(n, 0, "modulus must be non-zero");
189
190        let n = n as u128;
191        let mut result = 0u128;
192        for &byte in bytes {
193            result = (result << 8) | (byte as u128);
194            result %= n;
195        }
196
197        // Result is either 0 or modulo `n`, so we can safely cast to u64
198        result as u64
199    }
200
201    /// A wrapper around `Duration` that guarantees the duration is non-zero.
202    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
203    pub struct NonZeroDuration(Duration);
204
205    impl NonZeroDuration {
206        /// Creates a `NonZeroDuration` if the given duration is non-zero.
207        pub fn new(duration: Duration) -> Option<Self> {
208            if duration == Duration::ZERO {
209                None
210            } else {
211                Some(Self(duration))
212            }
213        }
214
215        /// Creates a `NonZeroDuration` from the given duration, panicking if it's zero.
216        pub fn new_panic(duration: Duration) -> Self {
217            Self::new(duration).expect("duration must be non-zero")
218        }
219
220        /// Returns the wrapped `Duration`.
221        pub const fn get(self) -> Duration {
222            self.0
223        }
224    }
225
226    impl From<NonZeroDuration> for Duration {
227        fn from(nz_duration: NonZeroDuration) -> Self {
228            nz_duration.0
229        }
230    }
231});
232commonware_macros::stability_scope!(BETA, cfg(feature = "std") {
233    pub mod acknowledgement;
234    pub use acknowledgement::Acknowledgement;
235
236    pub mod net;
237    pub use net::IpAddrExt;
238
239    pub mod time;
240    pub use time::{DurationExt, SystemTimeExt};
241
242    pub mod rational;
243    pub use rational::BigRationalExt;
244
245    mod priority_set;
246    pub use priority_set::PrioritySet;
247
248    pub mod channel;
249    pub mod concurrency;
250    pub mod futures;
251    pub mod sync;
252});
253#[cfg(not(any(
254    commonware_stability_GAMMA,
255    commonware_stability_DELTA,
256    commonware_stability_EPSILON,
257    commonware_stability_RESERVED
258)))] // BETA
259pub mod hex_literal;
260#[cfg(not(any(
261    commonware_stability_GAMMA,
262    commonware_stability_DELTA,
263    commonware_stability_EPSILON,
264    commonware_stability_RESERVED
265)))] // BETA
266pub mod vec;
267
268#[commonware_macros::stability(BETA)]
269#[inline]
270const fn decode_hex_digit(byte: u8) -> Option<u8> {
271    match byte {
272        b'0'..=b'9' => Some(byte - b'0'),
273        b'a'..=b'f' => Some(byte - b'a' + 10),
274        b'A'..=b'F' => Some(byte - b'A' + 10),
275        _ => None,
276    }
277}
278
279/// A macro to create a `NonZeroUsize` from a value, panicking if the value is zero.
280/// For literal values, validation occurs at compile time. For expressions, validation
281/// occurs at runtime.
282#[macro_export]
283macro_rules! NZUsize {
284    ($val:literal) => {
285        const { ::core::num::NonZeroUsize::new($val).expect("value must be non-zero") }
286    };
287    ($val:expr) => {
288        // This will panic at runtime if $val is zero.
289        ::core::num::NonZeroUsize::new($val).expect("value must be non-zero")
290    };
291}
292
293/// A macro to create a `NonZeroU8` from a value, panicking if the value is zero.
294/// For literal values, validation occurs at compile time. For expressions, validation
295/// occurs at runtime.
296#[macro_export]
297macro_rules! NZU8 {
298    ($val:literal) => {
299        const { ::core::num::NonZeroU8::new($val).expect("value must be non-zero") }
300    };
301    ($val:expr) => {
302        // This will panic at runtime if $val is zero.
303        ::core::num::NonZeroU8::new($val).expect("value must be non-zero")
304    };
305}
306
307/// A macro to create a `NonZeroU16` from a value, panicking if the value is zero.
308/// For literal values, validation occurs at compile time. For expressions, validation
309/// occurs at runtime.
310#[macro_export]
311macro_rules! NZU16 {
312    ($val:literal) => {
313        const { ::core::num::NonZeroU16::new($val).expect("value must be non-zero") }
314    };
315    ($val:expr) => {
316        // This will panic at runtime if $val is zero.
317        ::core::num::NonZeroU16::new($val).expect("value must be non-zero")
318    };
319}
320
321/// A macro to create a `NonZeroU32` from a value, panicking if the value is zero.
322/// For literal values, validation occurs at compile time. For expressions, validation
323/// occurs at runtime.
324#[macro_export]
325macro_rules! NZU32 {
326    ($val:literal) => {
327        const { ::core::num::NonZeroU32::new($val).expect("value must be non-zero") }
328    };
329    ($val:expr) => {
330        // This will panic at runtime if $val is zero.
331        ::core::num::NonZeroU32::new($val).expect("value must be non-zero")
332    };
333}
334
335/// A macro to create a `NonZeroU64` from a value, panicking if the value is zero.
336/// For literal values, validation occurs at compile time. For expressions, validation
337/// occurs at runtime.
338#[macro_export]
339macro_rules! NZU64 {
340    ($val:literal) => {
341        const { ::core::num::NonZeroU64::new($val).expect("value must be non-zero") }
342    };
343    ($val:expr) => {
344        // This will panic at runtime if $val is zero.
345        ::core::num::NonZeroU64::new($val).expect("value must be non-zero")
346    };
347}
348
349/// A macro to create a `NonZeroDuration` from a duration, panicking if the duration is zero.
350#[macro_export]
351macro_rules! NZDuration {
352    ($val:expr) => {
353        // This will panic at runtime if $val is zero.
354        $crate::NonZeroDuration::new_panic($val)
355    };
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use num_bigint::BigUint;
362    use rand::{rngs::StdRng, Rng, SeedableRng};
363
364    #[test]
365    fn test_hex() {
366        // Test case 0: empty bytes
367        let b = &[];
368        let h = hex(b);
369        assert_eq!(h, "");
370        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
371
372        // Test case 1: single byte
373        let b = &hex!("0x01");
374        let h = hex(b);
375        assert_eq!(h, "01");
376        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
377
378        // Test case 2: multiple bytes
379        let b = &hex!("0x010203");
380        let h = hex(b);
381        assert_eq!(h, "010203");
382        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
383
384        // Test case 3: odd number of bytes
385        let h = "0102030";
386        assert!(from_hex(h).is_none());
387
388        // Test case 4: invalid hexadecimal character
389        let h = "01g3";
390        assert!(from_hex(h).is_none());
391
392        // Test case 5: invalid `+` in string
393        let h = "+123";
394        assert!(from_hex(h).is_none());
395
396        // Test case 6: empty string
397        assert_eq!(from_hex(""), Some(vec![]));
398    }
399
400    #[test]
401    fn test_from_hex_formatted() {
402        // Test case 0: empty bytes
403        let b = &[];
404        let h = hex(b);
405        assert_eq!(h, "");
406        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
407
408        // Test case 1: single byte
409        let b = &hex!("0x01");
410        let h = hex(b);
411        assert_eq!(h, "01");
412        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
413
414        // Test case 2: multiple bytes
415        let b = &hex!("0x010203");
416        let h = hex(b);
417        assert_eq!(h, "010203");
418        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
419
420        // Test case 3: odd number of bytes
421        let h = "0102030";
422        assert!(from_hex_formatted(h).is_none());
423
424        // Test case 4: invalid hexadecimal character
425        let h = "01g3";
426        assert!(from_hex_formatted(h).is_none());
427
428        // Test case 5: whitespace
429        let h = "01 02 03";
430        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
431
432        // Test case 6: 0x prefix
433        let h = "0x010203";
434        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
435
436        // Test case 7: 0x prefix + different whitespace chars
437        let h = "    \n\n0x\r\n01
438                            02\t03\n";
439        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
440    }
441
442    #[test]
443    fn test_from_hex_utf8_char_boundaries() {
444        const MISALIGNMENT_CASE: &str = "쀘\n";
445
446        // Ensure that `from_hex` can handle misaligned UTF-8 character boundaries.
447        let b = from_hex(MISALIGNMENT_CASE);
448        assert!(b.is_none());
449    }
450
451    #[test]
452    fn test_union() {
453        // Test case 0: empty slices
454        assert_eq!(union(&[], &[]), Vec::<u8>::new());
455
456        // Test case 1: empty and non-empty slices
457        assert_eq!(union(&[], &hex!("0x010203")), hex!("0x010203"));
458
459        // Test case 2: non-empty and non-empty slices
460        assert_eq!(
461            union(&hex!("0x010203"), &hex!("0x040506")),
462            hex!("0x010203040506")
463        );
464    }
465
466    #[test]
467    fn test_union_unique() {
468        let namespace = b"namespace";
469        let msg = b"message";
470
471        let length_encoding = vec![0b0000_1001];
472        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
473        expected.extend_from_slice(&length_encoding);
474        expected.extend_from_slice(namespace);
475        expected.extend_from_slice(msg);
476
477        let result = union_unique(namespace, msg);
478        assert_eq!(result, expected);
479        assert_eq!(result.len(), result.capacity());
480    }
481
482    #[test]
483    fn test_union_unique_zero_length() {
484        let namespace = b"";
485        let msg = b"message";
486
487        let length_encoding = vec![0];
488        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
489        expected.extend_from_slice(&length_encoding);
490        expected.extend_from_slice(msg);
491
492        let result = union_unique(namespace, msg);
493        assert_eq!(result, expected);
494        assert_eq!(result.len(), result.capacity());
495    }
496
497    #[test]
498    fn test_union_unique_long_length() {
499        // Use a namespace of over length 127.
500        let namespace = &b"n".repeat(256);
501        let msg = b"message";
502
503        let length_encoding = vec![0b1000_0000, 0b0000_0010];
504        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
505        expected.extend_from_slice(&length_encoding);
506        expected.extend_from_slice(namespace);
507        expected.extend_from_slice(msg);
508
509        let result = union_unique(namespace, msg);
510        assert_eq!(result, expected);
511        assert_eq!(result.len(), result.capacity());
512    }
513
514    #[test]
515    fn test_modulo() {
516        // Test case 0: empty bytes
517        assert_eq!(modulo(&[], 1), 0);
518
519        // Test case 1: single byte
520        assert_eq!(modulo(&hex!("0x01"), 1), 0);
521
522        // Test case 2: multiple bytes
523        assert_eq!(modulo(&hex!("0x010203"), 10), 1);
524
525        // Test case 3: check equivalence with BigUint
526        for i in 0..100 {
527            let mut rng = StdRng::seed_from_u64(i);
528            let bytes: [u8; 32] = rng.gen();
529
530            // 1-byte modulus
531            let n = 11u64;
532            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
533            let utils_modulo = modulo(&bytes, n);
534            assert_eq!(big_modulo, BigUint::from(utils_modulo));
535
536            // 2-byte modulus
537            let n = 11_111u64;
538            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
539            let utils_modulo = modulo(&bytes, n);
540            assert_eq!(big_modulo, BigUint::from(utils_modulo));
541
542            // 8-byte modulus
543            let n = 0xDFFFFFFFFFFFFFFD;
544            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
545            let utils_modulo = modulo(&bytes, n);
546            assert_eq!(big_modulo, BigUint::from(utils_modulo));
547        }
548    }
549
550    #[test]
551    #[should_panic]
552    fn test_modulo_zero_panics() {
553        modulo(&hex!("0x010203"), 0);
554    }
555
556    #[test]
557    fn test_non_zero_macros_compile_time() {
558        // Literal values are validated at compile time.
559        // NZU32!(0) would be a compile error.
560        assert_eq!(NZUsize!(1).get(), 1);
561        assert_eq!(NZU8!(2).get(), 2);
562        assert_eq!(NZU16!(3).get(), 3);
563        assert_eq!(NZU32!(4).get(), 4);
564        assert_eq!(NZU64!(5).get(), 5);
565
566        // Literals can be used in const contexts
567        const _: core::num::NonZeroUsize = NZUsize!(1);
568        const _: core::num::NonZeroU8 = NZU8!(2);
569        const _: core::num::NonZeroU16 = NZU16!(3);
570        const _: core::num::NonZeroU32 = NZU32!(4);
571        const _: core::num::NonZeroU64 = NZU64!(5);
572    }
573
574    #[test]
575    fn test_non_zero_macros_runtime() {
576        // Runtime variables are validated at runtime
577        let one_usize: usize = 1;
578        let two_u8: u8 = 2;
579        let three_u16: u16 = 3;
580        let four_u32: u32 = 4;
581        let five_u64: u64 = 5;
582
583        assert_eq!(NZUsize!(one_usize).get(), 1);
584        assert_eq!(NZU8!(two_u8).get(), 2);
585        assert_eq!(NZU16!(three_u16).get(), 3);
586        assert_eq!(NZU32!(four_u32).get(), 4);
587        assert_eq!(NZU64!(five_u64).get(), 5);
588
589        // Zero runtime values panic
590        let zero_usize: usize = 0;
591        let zero_u8: u8 = 0;
592        let zero_u16: u16 = 0;
593        let zero_u32: u32 = 0;
594        let zero_u64: u64 = 0;
595
596        assert!(std::panic::catch_unwind(|| NZUsize!(zero_usize)).is_err());
597        assert!(std::panic::catch_unwind(|| NZU8!(zero_u8)).is_err());
598        assert!(std::panic::catch_unwind(|| NZU16!(zero_u16)).is_err());
599        assert!(std::panic::catch_unwind(|| NZU32!(zero_u32)).is_err());
600        assert!(std::panic::catch_unwind(|| NZU64!(zero_u64)).is_err());
601
602        // NZDuration is runtime-only since Duration has no literal syntax
603        assert!(std::panic::catch_unwind(|| NZDuration!(Duration::ZERO)).is_err());
604        assert_eq!(
605            NZDuration!(Duration::from_secs(1)).get(),
606            Duration::from_secs(1)
607        );
608    }
609
610    #[test]
611    fn test_non_zero_duration() {
612        // Test case 0: zero duration
613        assert!(NonZeroDuration::new(Duration::ZERO).is_none());
614
615        // Test case 1: non-zero duration
616        let duration = Duration::from_millis(100);
617        let nz_duration = NonZeroDuration::new(duration).unwrap();
618        assert_eq!(nz_duration.get(), duration);
619        assert_eq!(Duration::from(nz_duration), duration);
620
621        // Test case 2: panic on zero
622        assert!(std::panic::catch_unwind(|| NonZeroDuration::new_panic(Duration::ZERO)).is_err());
623
624        // Test case 3: ordering
625        let d1 = NonZeroDuration::new(Duration::from_millis(100)).unwrap();
626        let d2 = NonZeroDuration::new(Duration::from_millis(200)).unwrap();
627        assert!(d1 < d2);
628    }
629
630    #[test]
631    fn test_participant_constructors() {
632        assert_eq!(Participant::new(0).get(), 0);
633        assert_eq!(Participant::new(42).get(), 42);
634        assert_eq!(Participant::from_usize(0).get(), 0);
635        assert_eq!(Participant::from_usize(42).get(), 42);
636        assert_eq!(Participant::from_usize(u32::MAX as usize).get(), u32::MAX);
637    }
638
639    #[test]
640    #[should_panic(expected = "participant index exceeds u32::MAX")]
641    fn test_participant_from_usize_overflow() {
642        Participant::from_usize((u32::MAX as usize) + 1);
643    }
644
645    #[test]
646    fn test_participant_display() {
647        assert_eq!(format!("{}", Participant::new(0)), "0");
648        assert_eq!(format!("{}", Participant::new(42)), "42");
649        assert_eq!(format!("{}", Participant::new(1000)), "1000");
650    }
651
652    #[test]
653    fn test_participant_ordering() {
654        assert!(Participant::new(0) < Participant::new(1));
655        assert!(Participant::new(5) < Participant::new(10));
656        assert!(Participant::new(10) > Participant::new(5));
657        assert_eq!(Participant::new(42), Participant::new(42));
658    }
659
660    #[test]
661    fn test_participant_encode_decode() {
662        use commonware_codec::{DecodeExt, Encode};
663
664        let cases = vec![0u32, 1, 127, 128, 255, 256, u32::MAX];
665        for value in cases {
666            let participant = Participant::new(value);
667            let encoded = participant.encode();
668            assert_eq!(encoded.len(), participant.encode_size());
669            let decoded = Participant::decode(encoded).unwrap();
670            assert_eq!(participant, decoded);
671        }
672    }
673
674    #[cfg(feature = "arbitrary")]
675    mod conformance {
676        use super::*;
677        use commonware_codec::conformance::CodecConformance;
678
679        commonware_conformance::conformance_tests! {
680            CodecConformance<Participant>,
681        }
682    }
683}