Skip to main content

commonware_utils/
lib.rs

1//! Leverage common functionality across multiple primitives.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7#![cfg_attr(not(any(feature = "std", test)), no_std)]
8
9commonware_macros::stability_scope!(ALPHA, cfg(feature = "std") {
10    pub mod rng;
11    pub use rng::{test_rng, test_rng_seeded, BytesRng};
12});
13commonware_macros::stability_scope!(BETA {
14    #[cfg(not(feature = "std"))]
15    extern crate alloc;
16
17    #[cfg(not(feature = "std"))]
18    use alloc::{boxed::Box, string::String, vec::Vec};
19    use bytes::{BufMut, BytesMut};
20    use core::{fmt::Write as FmtWrite, time::Duration};
21    pub mod faults;
22    pub use faults::{Faults, N3f1, N5f1};
23
24    pub mod sequence;
25    pub use sequence::{Array, Span};
26
27    pub mod hostname;
28    pub use hostname::Hostname;
29
30    pub mod bitmap;
31    pub mod ordered;
32
33    use bytes::Buf;
34    use commonware_codec::{varint::UInt, EncodeSize, Error as CodecError, Read, ReadExt, Write};
35
36    /// Represents a participant/validator index within a consensus committee.
37    ///
38    /// Participant indices are used to identify validators in attestations,
39    /// votes, and certificates. The index corresponds to the position of the
40    /// validator's public key in the ordered participant set.
41    #[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
42    #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
43    pub struct Participant(u32);
44
45    impl Participant {
46        /// Creates a new participant from a u32 index.
47        pub const fn new(index: u32) -> Self {
48            Self(index)
49        }
50
51        /// Creates a new participant from a usize index.
52        ///
53        /// # Panics
54        ///
55        /// Panics if `index` exceeds `u32::MAX`.
56        pub fn from_usize(index: usize) -> Self {
57            Self(u32::try_from(index).expect("participant index exceeds u32::MAX"))
58        }
59
60        /// Returns the underlying u32 index.
61        pub const fn get(self) -> u32 {
62            self.0
63        }
64    }
65
66    impl From<Participant> for usize {
67        fn from(p: Participant) -> Self {
68            p.0 as Self
69        }
70    }
71
72    impl core::fmt::Display for Participant {
73        fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
74            write!(f, "{}", self.0)
75        }
76    }
77
78    impl Read for Participant {
79        type Cfg = ();
80
81        fn read_cfg(buf: &mut impl Buf, _cfg: &Self::Cfg) -> Result<Self, CodecError> {
82            let value: u32 = UInt::read(buf)?.into();
83            Ok(Self(value))
84        }
85    }
86
87    impl Write for Participant {
88        fn write(&self, buf: &mut impl bytes::BufMut) {
89            UInt(self.0).write(buf);
90        }
91    }
92
93    impl EncodeSize for Participant {
94        fn encode_size(&self) -> usize {
95            UInt(self.0).encode_size()
96        }
97    }
98
99    /// A type that can be constructed from an iterator, possibly failing.
100    pub trait TryFromIterator<T>: Sized {
101        /// The error type returned when construction fails.
102        type Error;
103
104        /// Attempts to construct `Self` from an iterator.
105        fn try_from_iter<I: IntoIterator<Item = T>>(iter: I) -> Result<Self, Self::Error>;
106    }
107
108    /// Extension trait for iterators that provides fallible collection.
109    pub trait TryCollect: Iterator + Sized {
110        /// Attempts to collect elements into a collection that may fail.
111        fn try_collect<C: TryFromIterator<Self::Item>>(self) -> Result<C, C::Error> {
112            C::try_from_iter(self)
113        }
114    }
115
116    impl<I: Iterator> TryCollect for I {}
117
118    /// Alias for boxed errors that are `Send` and `Sync`.
119    pub type BoxedError = Box<dyn core::error::Error + Send + Sync>;
120
121    /// Converts bytes to a hexadecimal string.
122    pub fn hex(bytes: &[u8]) -> String {
123        let mut hex = String::new();
124        for byte in bytes.iter() {
125            write!(hex, "{byte:02x}").expect("writing to string should never fail");
126        }
127        hex
128    }
129
130    /// Converts a hexadecimal string to bytes.
131    pub fn from_hex(hex: &str) -> Option<Vec<u8>> {
132        let bytes = hex.as_bytes();
133        if !bytes.len().is_multiple_of(2) {
134            return None;
135        }
136
137        bytes
138            .chunks_exact(2)
139            .map(|chunk| {
140                let hi = decode_hex_digit(chunk[0])?;
141                let lo = decode_hex_digit(chunk[1])?;
142                Some((hi << 4) | lo)
143            })
144            .collect()
145    }
146
147    /// Converts a hexadecimal string to bytes, stripping whitespace and/or a `0x` prefix. Commonly used
148    /// in testing to encode external test vectors without modification.
149    pub fn from_hex_formatted(hex: &str) -> Option<Vec<u8>> {
150        let hex = hex.replace(['\t', '\n', '\r', ' '], "");
151        let res = hex.strip_prefix("0x").unwrap_or(&hex);
152        from_hex(res)
153    }
154
155    /// Computes the union of two byte slices.
156    pub fn union(a: &[u8], b: &[u8]) -> Vec<u8> {
157        let mut union = Vec::with_capacity(a.len() + b.len());
158        union.extend_from_slice(a);
159        union.extend_from_slice(b);
160        union
161    }
162
163    /// Concatenate a namespace and a message, prepended by a varint encoding of the namespace length.
164    ///
165    /// This produces a unique byte sequence (i.e. no collisions) for each `(namespace, msg)` pair.
166    pub fn union_unique(namespace: &[u8], msg: &[u8]) -> Vec<u8> {
167        use commonware_codec::EncodeSize;
168        let len_prefix = namespace.len();
169        let mut buf =
170            BytesMut::with_capacity(len_prefix.encode_size() + namespace.len() + msg.len());
171        len_prefix.write(&mut buf);
172        BufMut::put_slice(&mut buf, namespace);
173        BufMut::put_slice(&mut buf, msg);
174        buf.into()
175    }
176
177    /// Compute the modulo of bytes interpreted as a big-endian integer.
178    ///
179    /// This function is used to select a random entry from an array when the bytes are a random seed.
180    ///
181    /// # Panics
182    ///
183    /// Panics if `n` is zero.
184    pub fn modulo(bytes: &[u8], n: u64) -> u64 {
185        assert_ne!(n, 0, "modulus must be non-zero");
186
187        let n = n as u128;
188        let mut result = 0u128;
189        for &byte in bytes {
190            result = (result << 8) | (byte as u128);
191            result %= n;
192        }
193
194        // Result is either 0 or modulo `n`, so we can safely cast to u64
195        result as u64
196    }
197
198    /// A wrapper around `Duration` that guarantees the duration is non-zero.
199    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
200    pub struct NonZeroDuration(Duration);
201
202    impl NonZeroDuration {
203        /// Creates a `NonZeroDuration` if the given duration is non-zero.
204        pub fn new(duration: Duration) -> Option<Self> {
205            if duration == Duration::ZERO {
206                None
207            } else {
208                Some(Self(duration))
209            }
210        }
211
212        /// Creates a `NonZeroDuration` from the given duration, panicking if it's zero.
213        pub fn new_panic(duration: Duration) -> Self {
214            Self::new(duration).expect("duration must be non-zero")
215        }
216
217        /// Returns the wrapped `Duration`.
218        pub const fn get(self) -> Duration {
219            self.0
220        }
221    }
222
223    impl From<NonZeroDuration> for Duration {
224        fn from(nz_duration: NonZeroDuration) -> Self {
225            nz_duration.0
226        }
227    }
228});
229commonware_macros::stability_scope!(BETA, cfg(feature = "std") {
230    pub mod acknowledgement;
231    pub use acknowledgement::Acknowledgement;
232
233    pub mod net;
234    pub use net::IpAddrExt;
235
236    pub mod time;
237    pub use time::{DurationExt, SystemTimeExt};
238
239    pub mod rational;
240    pub use rational::BigRationalExt;
241
242    mod priority_set;
243    pub use priority_set::PrioritySet;
244
245    pub mod channel;
246    pub mod concurrency;
247    pub mod futures;
248});
249#[cfg(not(any(
250    commonware_stability_GAMMA,
251    commonware_stability_DELTA,
252    commonware_stability_EPSILON,
253    commonware_stability_RESERVED
254)))] // BETA
255pub mod hex_literal;
256#[cfg(not(any(
257    commonware_stability_GAMMA,
258    commonware_stability_DELTA,
259    commonware_stability_EPSILON,
260    commonware_stability_RESERVED
261)))] // BETA
262pub mod vec;
263
264#[commonware_macros::stability(BETA)]
265#[inline]
266const fn decode_hex_digit(byte: u8) -> Option<u8> {
267    match byte {
268        b'0'..=b'9' => Some(byte - b'0'),
269        b'a'..=b'f' => Some(byte - b'a' + 10),
270        b'A'..=b'F' => Some(byte - b'A' + 10),
271        _ => None,
272    }
273}
274
275/// A macro to create a `NonZeroUsize` from a value, panicking if the value is zero.
276/// For literal values, validation occurs at compile time. For expressions, validation
277/// occurs at runtime.
278#[macro_export]
279macro_rules! NZUsize {
280    ($val:literal) => {
281        const { ::core::num::NonZeroUsize::new($val).expect("value must be non-zero") }
282    };
283    ($val:expr) => {
284        // This will panic at runtime if $val is zero.
285        ::core::num::NonZeroUsize::new($val).expect("value must be non-zero")
286    };
287}
288
289/// A macro to create a `NonZeroU8` from a value, panicking if the value is zero.
290/// For literal values, validation occurs at compile time. For expressions, validation
291/// occurs at runtime.
292#[macro_export]
293macro_rules! NZU8 {
294    ($val:literal) => {
295        const { ::core::num::NonZeroU8::new($val).expect("value must be non-zero") }
296    };
297    ($val:expr) => {
298        // This will panic at runtime if $val is zero.
299        ::core::num::NonZeroU8::new($val).expect("value must be non-zero")
300    };
301}
302
303/// A macro to create a `NonZeroU16` from a value, panicking if the value is zero.
304/// For literal values, validation occurs at compile time. For expressions, validation
305/// occurs at runtime.
306#[macro_export]
307macro_rules! NZU16 {
308    ($val:literal) => {
309        const { ::core::num::NonZeroU16::new($val).expect("value must be non-zero") }
310    };
311    ($val:expr) => {
312        // This will panic at runtime if $val is zero.
313        ::core::num::NonZeroU16::new($val).expect("value must be non-zero")
314    };
315}
316
317/// A macro to create a `NonZeroU32` from a value, panicking if the value is zero.
318/// For literal values, validation occurs at compile time. For expressions, validation
319/// occurs at runtime.
320#[macro_export]
321macro_rules! NZU32 {
322    ($val:literal) => {
323        const { ::core::num::NonZeroU32::new($val).expect("value must be non-zero") }
324    };
325    ($val:expr) => {
326        // This will panic at runtime if $val is zero.
327        ::core::num::NonZeroU32::new($val).expect("value must be non-zero")
328    };
329}
330
331/// A macro to create a `NonZeroU64` from a value, panicking if the value is zero.
332/// For literal values, validation occurs at compile time. For expressions, validation
333/// occurs at runtime.
334#[macro_export]
335macro_rules! NZU64 {
336    ($val:literal) => {
337        const { ::core::num::NonZeroU64::new($val).expect("value must be non-zero") }
338    };
339    ($val:expr) => {
340        // This will panic at runtime if $val is zero.
341        ::core::num::NonZeroU64::new($val).expect("value must be non-zero")
342    };
343}
344
345/// A macro to create a `NonZeroDuration` from a duration, panicking if the duration is zero.
346#[macro_export]
347macro_rules! NZDuration {
348    ($val:expr) => {
349        // This will panic at runtime if $val is zero.
350        $crate::NonZeroDuration::new_panic($val)
351    };
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use num_bigint::BigUint;
358    use rand::{rngs::StdRng, Rng, SeedableRng};
359
360    #[test]
361    fn test_hex() {
362        // Test case 0: empty bytes
363        let b = &[];
364        let h = hex(b);
365        assert_eq!(h, "");
366        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
367
368        // Test case 1: single byte
369        let b = &hex!("0x01");
370        let h = hex(b);
371        assert_eq!(h, "01");
372        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
373
374        // Test case 2: multiple bytes
375        let b = &hex!("0x010203");
376        let h = hex(b);
377        assert_eq!(h, "010203");
378        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
379
380        // Test case 3: odd number of bytes
381        let h = "0102030";
382        assert!(from_hex(h).is_none());
383
384        // Test case 4: invalid hexadecimal character
385        let h = "01g3";
386        assert!(from_hex(h).is_none());
387
388        // Test case 5: invalid `+` in string
389        let h = "+123";
390        assert!(from_hex(h).is_none());
391
392        // Test case 6: empty string
393        assert_eq!(from_hex(""), Some(vec![]));
394    }
395
396    #[test]
397    fn test_from_hex_formatted() {
398        // Test case 0: empty bytes
399        let b = &[];
400        let h = hex(b);
401        assert_eq!(h, "");
402        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
403
404        // Test case 1: single byte
405        let b = &hex!("0x01");
406        let h = hex(b);
407        assert_eq!(h, "01");
408        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
409
410        // Test case 2: multiple bytes
411        let b = &hex!("0x010203");
412        let h = hex(b);
413        assert_eq!(h, "010203");
414        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
415
416        // Test case 3: odd number of bytes
417        let h = "0102030";
418        assert!(from_hex_formatted(h).is_none());
419
420        // Test case 4: invalid hexadecimal character
421        let h = "01g3";
422        assert!(from_hex_formatted(h).is_none());
423
424        // Test case 5: whitespace
425        let h = "01 02 03";
426        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
427
428        // Test case 6: 0x prefix
429        let h = "0x010203";
430        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
431
432        // Test case 7: 0x prefix + different whitespace chars
433        let h = "    \n\n0x\r\n01
434                            02\t03\n";
435        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
436    }
437
438    #[test]
439    fn test_from_hex_utf8_char_boundaries() {
440        const MISALIGNMENT_CASE: &str = "쀘\n";
441
442        // Ensure that `from_hex` can handle misaligned UTF-8 character boundaries.
443        let b = from_hex(MISALIGNMENT_CASE);
444        assert!(b.is_none());
445    }
446
447    #[test]
448    fn test_union() {
449        // Test case 0: empty slices
450        assert_eq!(union(&[], &[]), Vec::<u8>::new());
451
452        // Test case 1: empty and non-empty slices
453        assert_eq!(union(&[], &hex!("0x010203")), hex!("0x010203"));
454
455        // Test case 2: non-empty and non-empty slices
456        assert_eq!(
457            union(&hex!("0x010203"), &hex!("0x040506")),
458            hex!("0x010203040506")
459        );
460    }
461
462    #[test]
463    fn test_union_unique() {
464        let namespace = b"namespace";
465        let msg = b"message";
466
467        let length_encoding = vec![0b0000_1001];
468        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
469        expected.extend_from_slice(&length_encoding);
470        expected.extend_from_slice(namespace);
471        expected.extend_from_slice(msg);
472
473        let result = union_unique(namespace, msg);
474        assert_eq!(result, expected);
475        assert_eq!(result.len(), result.capacity());
476    }
477
478    #[test]
479    fn test_union_unique_zero_length() {
480        let namespace = b"";
481        let msg = b"message";
482
483        let length_encoding = vec![0];
484        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
485        expected.extend_from_slice(&length_encoding);
486        expected.extend_from_slice(msg);
487
488        let result = union_unique(namespace, msg);
489        assert_eq!(result, expected);
490        assert_eq!(result.len(), result.capacity());
491    }
492
493    #[test]
494    fn test_union_unique_long_length() {
495        // Use a namespace of over length 127.
496        let namespace = &b"n".repeat(256);
497        let msg = b"message";
498
499        let length_encoding = vec![0b1000_0000, 0b0000_0010];
500        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
501        expected.extend_from_slice(&length_encoding);
502        expected.extend_from_slice(namespace);
503        expected.extend_from_slice(msg);
504
505        let result = union_unique(namespace, msg);
506        assert_eq!(result, expected);
507        assert_eq!(result.len(), result.capacity());
508    }
509
510    #[test]
511    fn test_modulo() {
512        // Test case 0: empty bytes
513        assert_eq!(modulo(&[], 1), 0);
514
515        // Test case 1: single byte
516        assert_eq!(modulo(&hex!("0x01"), 1), 0);
517
518        // Test case 2: multiple bytes
519        assert_eq!(modulo(&hex!("0x010203"), 10), 1);
520
521        // Test case 3: check equivalence with BigUint
522        for i in 0..100 {
523            let mut rng = StdRng::seed_from_u64(i);
524            let bytes: [u8; 32] = rng.gen();
525
526            // 1-byte modulus
527            let n = 11u64;
528            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
529            let utils_modulo = modulo(&bytes, n);
530            assert_eq!(big_modulo, BigUint::from(utils_modulo));
531
532            // 2-byte modulus
533            let n = 11_111u64;
534            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
535            let utils_modulo = modulo(&bytes, n);
536            assert_eq!(big_modulo, BigUint::from(utils_modulo));
537
538            // 8-byte modulus
539            let n = 0xDFFFFFFFFFFFFFFD;
540            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
541            let utils_modulo = modulo(&bytes, n);
542            assert_eq!(big_modulo, BigUint::from(utils_modulo));
543        }
544    }
545
546    #[test]
547    #[should_panic]
548    fn test_modulo_zero_panics() {
549        modulo(&hex!("0x010203"), 0);
550    }
551
552    #[test]
553    fn test_non_zero_macros_compile_time() {
554        // Literal values are validated at compile time.
555        // NZU32!(0) would be a compile error.
556        assert_eq!(NZUsize!(1).get(), 1);
557        assert_eq!(NZU8!(2).get(), 2);
558        assert_eq!(NZU16!(3).get(), 3);
559        assert_eq!(NZU32!(4).get(), 4);
560        assert_eq!(NZU64!(5).get(), 5);
561
562        // Literals can be used in const contexts
563        const _: core::num::NonZeroUsize = NZUsize!(1);
564        const _: core::num::NonZeroU8 = NZU8!(2);
565        const _: core::num::NonZeroU16 = NZU16!(3);
566        const _: core::num::NonZeroU32 = NZU32!(4);
567        const _: core::num::NonZeroU64 = NZU64!(5);
568    }
569
570    #[test]
571    fn test_non_zero_macros_runtime() {
572        // Runtime variables are validated at runtime
573        let one_usize: usize = 1;
574        let two_u8: u8 = 2;
575        let three_u16: u16 = 3;
576        let four_u32: u32 = 4;
577        let five_u64: u64 = 5;
578
579        assert_eq!(NZUsize!(one_usize).get(), 1);
580        assert_eq!(NZU8!(two_u8).get(), 2);
581        assert_eq!(NZU16!(three_u16).get(), 3);
582        assert_eq!(NZU32!(four_u32).get(), 4);
583        assert_eq!(NZU64!(five_u64).get(), 5);
584
585        // Zero runtime values panic
586        let zero_usize: usize = 0;
587        let zero_u8: u8 = 0;
588        let zero_u16: u16 = 0;
589        let zero_u32: u32 = 0;
590        let zero_u64: u64 = 0;
591
592        assert!(std::panic::catch_unwind(|| NZUsize!(zero_usize)).is_err());
593        assert!(std::panic::catch_unwind(|| NZU8!(zero_u8)).is_err());
594        assert!(std::panic::catch_unwind(|| NZU16!(zero_u16)).is_err());
595        assert!(std::panic::catch_unwind(|| NZU32!(zero_u32)).is_err());
596        assert!(std::panic::catch_unwind(|| NZU64!(zero_u64)).is_err());
597
598        // NZDuration is runtime-only since Duration has no literal syntax
599        assert!(std::panic::catch_unwind(|| NZDuration!(Duration::ZERO)).is_err());
600        assert_eq!(
601            NZDuration!(Duration::from_secs(1)).get(),
602            Duration::from_secs(1)
603        );
604    }
605
606    #[test]
607    fn test_non_zero_duration() {
608        // Test case 0: zero duration
609        assert!(NonZeroDuration::new(Duration::ZERO).is_none());
610
611        // Test case 1: non-zero duration
612        let duration = Duration::from_millis(100);
613        let nz_duration = NonZeroDuration::new(duration).unwrap();
614        assert_eq!(nz_duration.get(), duration);
615        assert_eq!(Duration::from(nz_duration), duration);
616
617        // Test case 2: panic on zero
618        assert!(std::panic::catch_unwind(|| NonZeroDuration::new_panic(Duration::ZERO)).is_err());
619
620        // Test case 3: ordering
621        let d1 = NonZeroDuration::new(Duration::from_millis(100)).unwrap();
622        let d2 = NonZeroDuration::new(Duration::from_millis(200)).unwrap();
623        assert!(d1 < d2);
624    }
625
626    #[test]
627    fn test_participant_constructors() {
628        assert_eq!(Participant::new(0).get(), 0);
629        assert_eq!(Participant::new(42).get(), 42);
630        assert_eq!(Participant::from_usize(0).get(), 0);
631        assert_eq!(Participant::from_usize(42).get(), 42);
632        assert_eq!(Participant::from_usize(u32::MAX as usize).get(), u32::MAX);
633    }
634
635    #[test]
636    #[should_panic(expected = "participant index exceeds u32::MAX")]
637    fn test_participant_from_usize_overflow() {
638        Participant::from_usize((u32::MAX as usize) + 1);
639    }
640
641    #[test]
642    fn test_participant_display() {
643        assert_eq!(format!("{}", Participant::new(0)), "0");
644        assert_eq!(format!("{}", Participant::new(42)), "42");
645        assert_eq!(format!("{}", Participant::new(1000)), "1000");
646    }
647
648    #[test]
649    fn test_participant_ordering() {
650        assert!(Participant::new(0) < Participant::new(1));
651        assert!(Participant::new(5) < Participant::new(10));
652        assert!(Participant::new(10) > Participant::new(5));
653        assert_eq!(Participant::new(42), Participant::new(42));
654    }
655
656    #[test]
657    fn test_participant_encode_decode() {
658        use commonware_codec::{DecodeExt, Encode};
659
660        let cases = vec![0u32, 1, 127, 128, 255, 256, u32::MAX];
661        for value in cases {
662            let participant = Participant::new(value);
663            let encoded = participant.encode();
664            assert_eq!(encoded.len(), participant.encode_size());
665            let decoded = Participant::decode(encoded).unwrap();
666            assert_eq!(participant, decoded);
667        }
668    }
669
670    #[cfg(feature = "arbitrary")]
671    mod conformance {
672        use super::*;
673        use commonware_codec::conformance::CodecConformance;
674
675        commonware_conformance::conformance_tests! {
676            CodecConformance<Participant>,
677        }
678    }
679}