chie_shared/utils/
security.rs

1//! Security, cryptography, and byte manipulation utility functions.
2
3/// Generate a random nonce for challenge-response.
4#[must_use]
5pub fn generate_nonce() -> [u8; 32] {
6    let mut nonce = [0u8; 32];
7    getrandom::fill(&mut nonce).expect("Failed to generate random nonce");
8    nonce
9}
10
11/// Constant-time comparison of two byte slices to prevent timing attacks.
12/// Returns true if slices are equal, false otherwise.
13/// This is critical for comparing signatures, hashes, and nonces securely.
14///
15/// # Examples
16///
17/// ```
18/// use chie_shared::constant_time_eq;
19///
20/// // Comparing equal slices
21/// let signature1 = b"valid_signature_data";
22/// let signature2 = b"valid_signature_data";
23/// assert!(constant_time_eq(signature1, signature2));
24///
25/// // Comparing different slices
26/// let sig_a = b"signature_a";
27/// let sig_b = b"signature_b";
28/// assert!(!constant_time_eq(sig_a, sig_b));
29///
30/// // Different lengths always return false
31/// assert!(!constant_time_eq(b"short", b"longer_text"));
32/// ```
33#[inline]
34#[must_use]
35pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
36    if a.len() != b.len() {
37        return false;
38    }
39
40    let mut result = 0u8;
41    for (x, y) in a.iter().zip(b.iter()) {
42        result |= x ^ y;
43    }
44
45    result == 0
46}
47
48/// Constant-time comparison of two 32-byte arrays (common for hashes/keys).
49/// Optimized version for fixed-size arrays.
50#[inline]
51#[must_use]
52pub fn constant_time_eq_32(a: &[u8; 32], b: &[u8; 32]) -> bool {
53    let mut result = 0u8;
54    for i in 0..32 {
55        result |= a[i] ^ b[i];
56    }
57    result == 0
58}
59
60/// XOR two byte slices into a new vector.
61/// Returns a vector with the XOR of each byte pair.
62/// If slices have different lengths, uses the shorter length.
63///
64/// # Examples
65///
66/// ```
67/// use chie_shared::xor_bytes;
68///
69/// // Basic XOR operation
70/// let a = &[0xFF, 0x00, 0xAA];
71/// let b = &[0xFF, 0xFF, 0x55];
72/// assert_eq!(xor_bytes(a, b), vec![0x00, 0xFF, 0xFF]);
73///
74/// // Different lengths - uses shorter
75/// let short = &[0xFF, 0x00];
76/// let long = &[0xFF, 0xFF, 0xFF];
77/// assert_eq!(xor_bytes(short, long), vec![0x00, 0xFF]);
78/// ```
79#[must_use]
80pub fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
81    let len = a.len().min(b.len());
82    let mut result = Vec::with_capacity(len);
83    for i in 0..len {
84        result.push(a[i] ^ b[i]);
85    }
86    result
87}
88
89/// Rotate bytes left by n positions.
90/// Wraps around: \[1,2,3,4\] rotated left by 1 = \[2,3,4,1\]
91#[must_use]
92pub fn rotate_bytes_left(bytes: &[u8], n: usize) -> Vec<u8> {
93    if bytes.is_empty() {
94        return Vec::new();
95    }
96
97    let n = n % bytes.len();
98    let mut result = Vec::with_capacity(bytes.len());
99    result.extend_from_slice(&bytes[n..]);
100    result.extend_from_slice(&bytes[..n]);
101    result
102}
103
104/// Rotate bytes right by n positions.
105/// Wraps around: \[1,2,3,4\] rotated right by 1 = \[4,1,2,3\]
106#[must_use]
107pub fn rotate_bytes_right(bytes: &[u8], n: usize) -> Vec<u8> {
108    if bytes.is_empty() {
109        return Vec::new();
110    }
111
112    let len = bytes.len();
113    let n = n % len;
114    rotate_bytes_left(bytes, len - n)
115}
116
117/// Check if all bytes in a slice are zero.
118#[inline]
119#[must_use]
120pub fn is_all_zeros(bytes: &[u8]) -> bool {
121    bytes.iter().all(|&b| b == 0)
122}
123
124/// Count the number of set bits (1s) in a byte slice.
125#[must_use]
126pub fn count_set_bits(bytes: &[u8]) -> usize {
127    bytes.iter().map(|&b| b.count_ones() as usize).sum()
128}
129
130/// Encode bytes as hexadecimal string.
131///
132/// # Examples
133///
134/// ```
135/// use chie_shared::encode_hex;
136///
137/// // Encode bytes to hex
138/// let data = &[0xde, 0xad, 0xbe, 0xef];
139/// assert_eq!(encode_hex(data), "deadbeef");
140///
141/// // Encode small numbers
142/// let numbers = &[0, 15, 255];
143/// assert_eq!(encode_hex(numbers), "000fff");
144/// ```
145pub fn encode_hex(data: &[u8]) -> String {
146    data.iter()
147        .map(|b| format!("{:02x}", b))
148        .collect::<String>()
149}
150
151/// Decode hexadecimal string to bytes.
152///
153/// # Examples
154///
155/// ```
156/// use chie_shared::decode_hex;
157///
158/// // Decode hex string
159/// let bytes = decode_hex("deadbeef").unwrap();
160/// assert_eq!(bytes, vec![0xde, 0xad, 0xbe, 0xef]);
161///
162/// // Roundtrip encoding/decoding
163/// let original = vec![1, 2, 3, 255];
164/// let hex = chie_shared::encode_hex(&original);
165/// let decoded = decode_hex(&hex).unwrap();
166/// assert_eq!(original, decoded);
167///
168/// // Error on odd length
169/// assert!(decode_hex("abc").is_err());
170///
171/// // Error on invalid hex
172/// assert!(decode_hex("xyz").is_err());
173/// ```
174pub fn decode_hex(hex: &str) -> Result<Vec<u8>, String> {
175    if hex.len() % 2 != 0 {
176        return Err("Hex string must have even length".to_string());
177    }
178
179    (0..hex.len())
180        .step_by(2)
181        .map(|i| {
182            u8::from_str_radix(&hex[i..i + 2], 16)
183                .map_err(|e| format!("Invalid hex character: {}", e))
184        })
185        .collect()
186}
187
188/// Apply random jitter to a value for backoff/retry timing.
189///
190/// # Arguments
191/// * `value` - Base value to apply jitter to
192/// * `factor` - Jitter factor as a fraction (e.g., 0.25 for ±25%)
193///
194/// # Returns
195/// Value with random jitter applied in the range `[value * (1 - factor), value * (1 + factor)]`
196#[must_use]
197pub fn random_jitter(value: u64, factor: f64) -> u64 {
198    if value == 0 || factor <= 0.0 {
199        return value;
200    }
201
202    let mut random_bytes = [0u8; 8];
203    getrandom::fill(&mut random_bytes).expect("Failed to generate random bytes");
204    let random_u64 = u64::from_le_bytes(random_bytes);
205
206    // Calculate jitter range
207    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
208    let jitter_range = ((value as f64) * factor) as u64;
209
210    if jitter_range == 0 {
211        return value;
212    }
213
214    // Random value in range [-jitter_range, +jitter_range]
215    let jitter = (random_u64 % (jitter_range * 2)).saturating_sub(jitter_range);
216
217    value.saturating_add(jitter)
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_generate_nonce() {
226        let nonce1 = generate_nonce();
227        let nonce2 = generate_nonce();
228        assert_ne!(nonce1, nonce2);
229        assert_eq!(nonce1.len(), 32);
230    }
231
232    #[test]
233    fn test_constant_time_eq() {
234        let a = b"hello world";
235        let b = b"hello world";
236        let c = b"hello earth";
237        let d = b"different";
238
239        assert!(constant_time_eq(a, b));
240        assert!(!constant_time_eq(a, c));
241        assert!(!constant_time_eq(a, d));
242        assert!(!constant_time_eq(a, b"short"));
243        assert!(constant_time_eq(&[], &[]));
244    }
245
246    #[test]
247    fn test_constant_time_eq_32() {
248        let a = [1u8; 32];
249        let b = [1u8; 32];
250        let mut c = [1u8; 32];
251        c[31] = 2;
252
253        assert!(constant_time_eq_32(&a, &b));
254        assert!(!constant_time_eq_32(&a, &c));
255        assert!(constant_time_eq_32(&[0u8; 32], &[0u8; 32]));
256    }
257
258    #[test]
259    fn test_xor_bytes() {
260        assert_eq!(xor_bytes(&[0xFF, 0x00], &[0xFF, 0xFF]), vec![0x00, 0xFF]);
261        assert_eq!(xor_bytes(&[0xAA, 0x55], &[0x55, 0xAA]), vec![0xFF, 0xFF]);
262        assert_eq!(xor_bytes(&[1, 2, 3], &[3, 2, 1]), vec![2, 0, 2]);
263        assert_eq!(xor_bytes(&[], &[]), Vec::<u8>::new());
264        // Different lengths - use shorter
265        assert_eq!(xor_bytes(&[1, 2, 3], &[1, 1]), vec![0, 3]);
266    }
267
268    #[test]
269    fn test_rotate_bytes_left() {
270        assert_eq!(rotate_bytes_left(&[1, 2, 3, 4], 1), vec![2, 3, 4, 1]);
271        assert_eq!(rotate_bytes_left(&[1, 2, 3, 4], 2), vec![3, 4, 1, 2]);
272        assert_eq!(rotate_bytes_left(&[1, 2, 3, 4], 0), vec![1, 2, 3, 4]);
273        assert_eq!(rotate_bytes_left(&[1, 2, 3, 4], 4), vec![1, 2, 3, 4]); // Full rotation
274        assert_eq!(rotate_bytes_left(&[1, 2, 3, 4], 5), vec![2, 3, 4, 1]); // Wraps around
275        assert_eq!(rotate_bytes_left(&[], 5), Vec::<u8>::new()); // Empty
276    }
277
278    #[test]
279    fn test_rotate_bytes_right() {
280        assert_eq!(rotate_bytes_right(&[1, 2, 3, 4], 1), vec![4, 1, 2, 3]);
281        assert_eq!(rotate_bytes_right(&[1, 2, 3, 4], 2), vec![3, 4, 1, 2]);
282        assert_eq!(rotate_bytes_right(&[1, 2, 3, 4], 0), vec![1, 2, 3, 4]);
283        assert_eq!(rotate_bytes_right(&[1, 2, 3, 4], 4), vec![1, 2, 3, 4]); // Full rotation
284        assert_eq!(rotate_bytes_right(&[1, 2, 3, 4], 5), vec![4, 1, 2, 3]); // Wraps around
285        assert_eq!(rotate_bytes_right(&[], 5), Vec::<u8>::new()); // Empty
286    }
287
288    #[test]
289    fn test_is_all_zeros() {
290        assert!(is_all_zeros(&[0, 0, 0, 0]));
291        assert!(is_all_zeros(&[0]));
292        assert!(is_all_zeros(&[]));
293        assert!(!is_all_zeros(&[0, 0, 1, 0]));
294        assert!(!is_all_zeros(&[1, 2, 3]));
295    }
296
297    #[test]
298    fn test_count_set_bits() {
299        assert_eq!(count_set_bits(&[0b1111_1111]), 8);
300        assert_eq!(count_set_bits(&[0b0000_0001]), 1);
301        assert_eq!(count_set_bits(&[0b1010_1010]), 4);
302        assert_eq!(count_set_bits(&[0b1111_1111, 0b0000_0000]), 8);
303        assert_eq!(count_set_bits(&[0b1010_1010, 0b0101_0101]), 8);
304        assert_eq!(count_set_bits(&[]), 0);
305        assert_eq!(count_set_bits(&[0, 0, 0]), 0);
306    }
307
308    #[test]
309    fn test_encode_hex() {
310        assert_eq!(encode_hex(&[0, 15, 255]), "000fff");
311        assert_eq!(encode_hex(&[0xde, 0xad, 0xbe, 0xef]), "deadbeef");
312        assert_eq!(encode_hex(&[]), "");
313    }
314
315    #[test]
316    fn test_decode_hex() {
317        assert_eq!(decode_hex("000fff").unwrap(), vec![0, 15, 255]);
318        assert_eq!(
319            decode_hex("deadbeef").unwrap(),
320            vec![0xde, 0xad, 0xbe, 0xef]
321        );
322        assert_eq!(decode_hex("").unwrap(), Vec::<u8>::new());
323
324        // Error cases
325        assert!(decode_hex("abc").is_err()); // Odd length
326        assert!(decode_hex("xyz").is_err()); // Invalid hex
327    }
328
329    #[test]
330    fn test_hex_roundtrip() {
331        let data = vec![1, 2, 3, 4, 5, 255, 0, 128];
332        let hex = encode_hex(&data);
333        let decoded = decode_hex(&hex).unwrap();
334        assert_eq!(data, decoded);
335    }
336
337    #[test]
338    fn test_random_jitter() {
339        // Test zero value
340        assert_eq!(random_jitter(0, 0.25), 0);
341
342        // Test zero factor
343        assert_eq!(random_jitter(1000, 0.0), 1000);
344
345        // Test normal jitter
346        let value = 1000;
347        let factor = 0.25;
348        for _ in 0..100 {
349            let jittered = random_jitter(value, factor);
350            // Should be in range [750, 1250] (±25%)
351            assert!((750..=1250).contains(&jittered));
352        }
353
354        // Test that jitter produces different values (generate many to ensure variability)
355        let mut values = Vec::new();
356        for _ in 0..10 {
357            values.push(random_jitter(10_000, 0.25));
358        }
359        // With 10 random values in a range of 5000 (7500-12500), at least some should be different
360        let all_same = values.windows(2).all(|w| w[0] == w[1]);
361        assert!(!all_same, "Random jitter should produce varying values");
362    }
363}