commonware_utils/
lib.rs

1//! Leverage common functionality across multiple primitives.
2
3use bytes::{BufMut, BytesMut};
4use commonware_codec::{EncodeSize, Write};
5
6pub mod array;
7pub use array::Array;
8mod bitvec;
9pub use bitvec::{BitIterator, BitVec};
10mod time;
11pub use time::SystemTimeExt;
12mod priority_set;
13pub use priority_set::PrioritySet;
14pub mod futures;
15mod stable_buf;
16pub use stable_buf::StableBuf;
17
18/// Converts bytes to a hexadecimal string.
19pub fn hex(bytes: &[u8]) -> String {
20    let mut hex = String::new();
21    for byte in bytes.iter() {
22        hex.push_str(&format!("{byte:02x}"));
23    }
24    hex
25}
26
27/// Converts a hexadecimal string to bytes.
28pub fn from_hex(hex: &str) -> Option<Vec<u8>> {
29    if hex.len() % 2 != 0 {
30        return None;
31    }
32
33    (0..hex.len())
34        .step_by(2)
35        .map(|i| u8::from_str_radix(hex.get(i..i + 2)?, 16).ok())
36        .collect()
37}
38
39/// Converts a hexadecimal string to bytes, stripping whitespace and/or a `0x` prefix. Commonly used
40/// in testing to encode external test vectors without modification.
41pub fn from_hex_formatted(hex: &str) -> Option<Vec<u8>> {
42    let hex = hex.replace(['\t', '\n', '\r', ' '], "");
43    let res = hex.strip_prefix("0x").unwrap_or(&hex);
44    from_hex(res)
45}
46
47/// Compute the maximum number of `f` (faults) that can be tolerated for a given set of `n`
48/// participants. This is the maximum integer `f` such that `n >= 3*f + 1`. `f` may be zero.
49pub fn max_faults(n: u32) -> u32 {
50    n.saturating_sub(1) / 3
51}
52
53/// Compute the quorum size for a given set of `n` participants. This is the minimum integer `q`
54/// such that `3*q >= 2*n + 1`. It is also equal to `n - f`, where `f` is the maximum number of
55/// faults.
56///
57/// # Panics
58///
59/// Panics if `n` is zero.
60pub fn quorum(n: u32) -> u32 {
61    assert!(n > 0, "n must not be zero");
62    n - max_faults(n)
63}
64
65/// Computes the union of two byte slices.
66pub fn union(a: &[u8], b: &[u8]) -> Vec<u8> {
67    let mut union = Vec::with_capacity(a.len() + b.len());
68    union.extend_from_slice(a);
69    union.extend_from_slice(b);
70    union
71}
72
73/// Concatenate a namespace and a message, prepended by a varint encoding of the namespace length.
74///
75/// This produces a unique byte sequence (i.e. no collisions) for each `(namespace, msg)` pair.
76pub fn union_unique(namespace: &[u8], msg: &[u8]) -> Vec<u8> {
77    let len_prefix = namespace.len();
78    let mut buf = BytesMut::with_capacity(len_prefix.encode_size() + namespace.len() + msg.len());
79    len_prefix.write(&mut buf);
80    BufMut::put_slice(&mut buf, namespace);
81    BufMut::put_slice(&mut buf, msg);
82    buf.into()
83}
84
85/// Compute the modulo of bytes interpreted as a big-endian integer.
86///
87/// This function is used to select a random entry from an array when the bytes are a random seed.
88///
89/// # Panics
90///
91/// Panics if `n` is zero.
92pub fn modulo(bytes: &[u8], n: u64) -> u64 {
93    assert_ne!(n, 0, "modulus must be non-zero");
94
95    let n = n as u128;
96    let mut result = 0u128;
97    for &byte in bytes {
98        result = (result << 8) | (byte as u128);
99        result %= n;
100    }
101
102    // Result is either 0 or modulo `n`, so we can safely cast to u64
103    result as u64
104}
105
106/// A macro to create a `NonZeroUsize` from a value, panicking if the value is zero.
107#[macro_export]
108macro_rules! NZUsize {
109    ($val:expr) => {
110        // This will panic at runtime if $val is zero.
111        // For literals, the compiler *might* optimize, but the check is still conceptually there.
112        std::num::NonZeroUsize::new($val).expect("value must be non-zero")
113    };
114}
115
116/// A macro to create a `NonZeroU8` from a value, panicking if the value is zero.
117#[macro_export]
118macro_rules! NZU8 {
119    ($val:expr) => {
120        std::num::NonZeroU8::new($val).expect("value must be non-zero")
121    };
122}
123
124/// A macro to create a `NonZeroU16` from a value, panicking if the value is zero.
125#[macro_export]
126macro_rules! NZU16 {
127    ($val:expr) => {
128        std::num::NonZeroU16::new($val).expect("value must be non-zero")
129    };
130}
131
132/// A macro to create a `NonZeroU32` from a value, panicking if the value is zero.
133#[macro_export]
134macro_rules! NZU32 {
135    ($val:expr) => {
136        // This will panic at runtime if $val is zero.
137        // For literals, the compiler *might* optimize, but the check is still conceptually there.
138        std::num::NonZeroU32::new($val).expect("value must be non-zero")
139    };
140}
141
142/// A macro to create a `NonZeroU64` from a value, panicking if the value is zero.
143#[macro_export]
144macro_rules! NZU64 {
145    ($val:expr) => {
146        // This will panic at runtime if $val is zero.
147        // For literals, the compiler *might* optimize, but the check is still conceptually there.
148        std::num::NonZeroU64::new($val).expect("value must be non-zero")
149    };
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use num_bigint::BigUint;
156    use rand::{rngs::StdRng, Rng, SeedableRng};
157
158    #[test]
159    fn test_hex() {
160        // Test case 0: empty bytes
161        let b = &[];
162        let h = hex(b);
163        assert_eq!(h, "");
164        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
165
166        // Test case 1: single byte
167        let b = &[0x01];
168        let h = hex(b);
169        assert_eq!(h, "01");
170        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
171
172        // Test case 2: multiple bytes
173        let b = &[0x01, 0x02, 0x03];
174        let h = hex(b);
175        assert_eq!(h, "010203");
176        assert_eq!(from_hex(&h).unwrap(), b.to_vec());
177
178        // Test case 3: odd number of bytes
179        let h = "0102030";
180        assert!(from_hex(h).is_none());
181
182        // Test case 4: invalid hexadecimal character
183        let h = "01g3";
184        assert!(from_hex(h).is_none());
185    }
186
187    #[test]
188    fn test_from_hex_formatted() {
189        // Test case 0: empty bytes
190        let b = &[];
191        let h = hex(b);
192        assert_eq!(h, "");
193        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
194
195        // Test case 1: single byte
196        let b = &[0x01];
197        let h = hex(b);
198        assert_eq!(h, "01");
199        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
200
201        // Test case 2: multiple bytes
202        let b = &[0x01, 0x02, 0x03];
203        let h = hex(b);
204        assert_eq!(h, "010203");
205        assert_eq!(from_hex_formatted(&h).unwrap(), b.to_vec());
206
207        // Test case 3: odd number of bytes
208        let h = "0102030";
209        assert!(from_hex_formatted(h).is_none());
210
211        // Test case 4: invalid hexadecimal character
212        let h = "01g3";
213        assert!(from_hex_formatted(h).is_none());
214
215        // Test case 5: whitespace
216        let h = "01 02 03";
217        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
218
219        // Test case 6: 0x prefix
220        let h = "0x010203";
221        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
222
223        // Test case 7: 0x prefix + different whitespace chars
224        let h = "    \n\n0x\r\n01
225                            02\t03\n";
226        assert_eq!(from_hex_formatted(h).unwrap(), b.to_vec());
227    }
228
229    #[test]
230    fn test_from_hex_utf8_char_boundaries() {
231        const MISALIGNMENT_CASE: &str = "쀘\n";
232
233        // Ensure that `from_hex` can handle misaligned UTF-8 character boundaries.
234        let b = from_hex(MISALIGNMENT_CASE);
235        assert!(b.is_none());
236    }
237
238    #[test]
239    fn test_max_faults_zero() {
240        assert_eq!(max_faults(0), 0);
241    }
242
243    #[test]
244    #[should_panic]
245    fn test_quorum_zero() {
246        quorum(0);
247    }
248
249    #[test]
250    fn test_quorum_and_max_faults() {
251        // n, expected_f, expected_q
252        let test_cases = [
253            (1, 0, 1),
254            (2, 0, 2),
255            (3, 0, 3),
256            (4, 1, 3),
257            (5, 1, 4),
258            (6, 1, 5),
259            (7, 2, 5),
260            (8, 2, 6),
261            (9, 2, 7),
262            (10, 3, 7),
263            (11, 3, 8),
264            (12, 3, 9),
265            (13, 4, 9),
266            (14, 4, 10),
267            (15, 4, 11),
268            (16, 5, 11),
269            (17, 5, 12),
270            (18, 5, 13),
271            (19, 6, 13),
272            (20, 6, 14),
273            (21, 6, 15),
274        ];
275
276        for (n, ef, eq) in test_cases {
277            assert_eq!(max_faults(n), ef);
278            assert_eq!(quorum(n), eq);
279            assert_eq!(n, ef + eq);
280        }
281    }
282
283    #[test]
284    fn test_union() {
285        // Test case 0: empty slices
286        assert_eq!(union(&[], &[]), []);
287
288        // Test case 1: empty and non-empty slices
289        assert_eq!(union(&[], &[0x01, 0x02, 0x03]), [0x01, 0x02, 0x03]);
290
291        // Test case 2: non-empty and non-empty slices
292        assert_eq!(
293            union(&[0x01, 0x02, 0x03], &[0x04, 0x05, 0x06]),
294            [0x01, 0x02, 0x03, 0x04, 0x05, 0x06]
295        );
296    }
297
298    #[test]
299    fn test_union_unique() {
300        let namespace = b"namespace";
301        let msg = b"message";
302
303        let length_encoding = vec![0b0000_1001];
304        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
305        expected.extend_from_slice(&length_encoding);
306        expected.extend_from_slice(namespace);
307        expected.extend_from_slice(msg);
308
309        let result = union_unique(namespace, msg);
310        assert_eq!(result, expected);
311        assert_eq!(result.len(), result.capacity());
312    }
313
314    #[test]
315    fn test_union_unique_zero_length() {
316        let namespace = b"";
317        let msg = b"message";
318
319        let length_encoding = vec![0];
320        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
321        expected.extend_from_slice(&length_encoding);
322        expected.extend_from_slice(msg);
323
324        let result = union_unique(namespace, msg);
325        assert_eq!(result, expected);
326        assert_eq!(result.len(), result.capacity());
327    }
328
329    #[test]
330    fn test_union_unique_long_length() {
331        // Use a namespace of over length 127.
332        let namespace = &b"n".repeat(256);
333        let msg = b"message";
334
335        let length_encoding = vec![0b1000_0000, 0b0000_0010];
336        let mut expected = Vec::with_capacity(length_encoding.len() + namespace.len() + msg.len());
337        expected.extend_from_slice(&length_encoding);
338        expected.extend_from_slice(namespace);
339        expected.extend_from_slice(msg);
340
341        let result = union_unique(namespace, msg);
342        assert_eq!(result, expected);
343        assert_eq!(result.len(), result.capacity());
344    }
345
346    #[test]
347    fn test_modulo() {
348        // Test case 0: empty bytes
349        assert_eq!(modulo(&[], 1), 0);
350
351        // Test case 1: single byte
352        assert_eq!(modulo(&[0x01], 1), 0);
353
354        // Test case 2: multiple bytes
355        assert_eq!(modulo(&[0x01, 0x02, 0x03], 10), 1);
356
357        // Test case 3: check equivalence with BigUint
358        for i in 0..100 {
359            let mut rng = StdRng::seed_from_u64(i);
360            let bytes: [u8; 32] = rng.gen();
361
362            // 1-byte modulus
363            let n = 11u64;
364            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
365            let utils_modulo = modulo(&bytes, n);
366            assert_eq!(big_modulo, BigUint::from(utils_modulo));
367
368            // 2-byte modulus
369            let n = 11_111u64;
370            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
371            let utils_modulo = modulo(&bytes, n);
372            assert_eq!(big_modulo, BigUint::from(utils_modulo));
373
374            // 8-byte modulus
375            let n = 0xDFFFFFFFFFFFFFFD;
376            let big_modulo = BigUint::from_bytes_be(&bytes) % n;
377            let utils_modulo = modulo(&bytes, n);
378            assert_eq!(big_modulo, BigUint::from(utils_modulo));
379        }
380    }
381
382    #[test]
383    #[should_panic]
384    fn test_modulo_zero_panics() {
385        modulo(&[0x01, 0x02, 0x03], 0);
386    }
387
388    #[test]
389    fn test_non_zero_macros() {
390        // Test case 0: zero value
391        assert!(std::panic::catch_unwind(|| NZUsize!(0)).is_err());
392        assert!(std::panic::catch_unwind(|| NZU32!(0)).is_err());
393        assert!(std::panic::catch_unwind(|| NZU64!(0)).is_err());
394
395        // Test case 1: non-zero value
396        assert_eq!(NZUsize!(1).get(), 1);
397        assert_eq!(NZU32!(2).get(), 2);
398        assert_eq!(NZU64!(3).get(), 3);
399    }
400}