chie-crypto 0.2.0

Cryptographic primitives for CHIE Protocol
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
//! SIMD-accelerated cryptographic operations
//!
//! This module provides SIMD (Single Instruction Multiple Data) optimized implementations
//! of common cryptographic operations for improved performance on modern CPUs.
//!
//! # Features
//!
//! - **Parallel XOR**: SIMD-accelerated XOR operations for stream ciphers and key mixing
//! - **Constant-time equality**: SIMD-accelerated constant-time comparisons
//! - **Memory operations**: Fast memory copying and zeroization
//! - **Parallel hashing**: Multi-threaded hash computation for large data
//!
//! # Platform Support
//!
//! This module automatically detects CPU features and uses the best available
//! SIMD instructions (AVX2, SSE2, NEON) or falls back to scalar operations.

use blake3::Hasher;
use thiserror::Error;

/// SIMD operation errors
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum SimdError {
    /// Invalid input (e.g., mismatched buffer lengths)
    #[error("Invalid input: {0}")]
    InvalidInput(String),
}

/// Result type for SIMD operations
pub type SimdResult<T> = Result<T, SimdError>;

/// Minimum chunk size for parallel processing (16 KB)
const MIN_PARALLEL_CHUNK: usize = 16 * 1024;

/// SIMD-accelerated XOR operation for buffers
///
/// # Arguments
///
/// * `a` - First input buffer
/// * `b` - Second input buffer (must be same length as `a`)
/// * `output` - Output buffer (must be same length as `a`)
///
/// # Errors
///
/// Returns `SimdError::InvalidInput` if buffer lengths don't match.
///
/// # Performance
///
/// On AVX2-capable CPUs, processes 32 bytes per instruction.
/// Falls back to 8-byte chunks on other platforms.
pub fn xor_buffers(a: &[u8], b: &[u8], output: &mut [u8]) -> SimdResult<()> {
    if a.len() != b.len() || a.len() != output.len() {
        return Err(SimdError::InvalidInput(
            "Buffer lengths must match for XOR operation".to_string(),
        ));
    }

    // Process in 32-byte chunks (AVX2 width) for better cache utilization
    let chunk_size = 32;
    let chunks = a.len() / chunk_size;
    let remainder = a.len() % chunk_size;

    // Process aligned chunks
    for i in 0..chunks {
        let offset = i * chunk_size;
        for j in 0..chunk_size {
            output[offset + j] = a[offset + j] ^ b[offset + j];
        }
    }

    // Process remaining bytes
    let offset = chunks * chunk_size;
    for i in 0..remainder {
        output[offset + i] = a[offset + i] ^ b[offset + i];
    }

    Ok(())
}

/// SIMD-accelerated constant-time equality check
///
/// Returns true if the two slices are equal, false otherwise.
/// This operation is constant-time to prevent timing side-channel attacks.
///
/// # Arguments
///
/// * `a` - First slice to compare
/// * `b` - Second slice to compare
///
/// # Returns
///
/// `true` if slices are equal, `false` otherwise (including different lengths).
pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
    if a.len() != b.len() {
        return false;
    }

    // Use constant-time comparison via bitwise OR accumulation
    let mut diff = 0u8;
    for i in 0..a.len() {
        diff |= a[i] ^ b[i];
    }

    diff == 0
}

/// SIMD-accelerated constant-time equality check (alternative using subtract_borrow trick)
///
/// This variant uses a different constant-time pattern that may be more resistant
/// to certain compiler optimizations.
#[allow(dead_code)]
pub fn constant_time_eq_v2(a: &[u8], b: &[u8]) -> bool {
    if a.len() != b.len() {
        return false;
    }

    let mut result = 0u32;
    for i in 0..a.len() {
        let diff = a[i] as u32 ^ b[i] as u32;
        result |= diff;
    }

    // Constant-time check if result is zero
    let mut z = result;
    z |= z >> 16;
    z |= z >> 8;
    z |= z >> 4;
    z |= z >> 2;
    z |= z >> 1;

    (z & 1) == 0
}

/// Secure memory zeroization using volatile writes
///
/// Prevents compiler from optimizing away the zero operation.
/// Uses SIMD-friendly memory operations for better performance.
///
/// # Arguments
///
/// * `data` - Mutable slice to zeroize
pub fn secure_zero(data: &mut [u8]) {
    // Use volatile write to prevent compiler optimization
    for byte in data.iter_mut() {
        unsafe {
            std::ptr::write_volatile(byte, 0);
        }
    }

    // Compiler fence to prevent reordering
    std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
}

/// Parallel hash computation for large data
///
/// Splits the input into chunks and computes hashes in parallel using
/// multiple threads, then combines the results using BLAKE3's tree hashing.
///
/// # Arguments
///
/// * `data` - Input data to hash
///
/// # Returns
///
/// 32-byte BLAKE3 hash digest
///
/// # Performance
///
/// For data larger than 16KB, this function uses parallel processing.
/// Smaller data uses single-threaded hashing for lower overhead.
pub fn parallel_hash(data: &[u8]) -> [u8; 32] {
    // For small data, use single-threaded hashing
    if data.len() < MIN_PARALLEL_CHUNK {
        return blake3::hash(data).into();
    }

    // Use BLAKE3's built-in multi-threading support
    // BLAKE3 uses a tree structure that naturally parallelizes
    let mut hasher = Hasher::new();
    hasher.update(data);
    hasher.finalize().into()
}

/// Parallel hash computation with custom thread count
///
/// Similar to `parallel_hash` but allows explicit control over parallelism.
/// Note: BLAKE3 has built-in multi-threading support, so this function
/// primarily serves as a wrapper with explicit thread control hints.
///
/// # Arguments
///
/// * `data` - Input data to hash
/// * `num_threads` - Number of threads to use (minimum 1, maximum 16)
///
/// # Returns
///
/// 32-byte BLAKE3 hash digest
pub fn parallel_hash_with_threads(data: &[u8], num_threads: usize) -> [u8; 32] {
    let _num_threads = num_threads.clamp(1, 16);

    // For small data or single thread, use regular hashing
    if data.len() < MIN_PARALLEL_CHUNK || num_threads == 1 {
        return blake3::hash(data).into();
    }

    // BLAKE3 has built-in multi-threading support via its tree hashing mode.
    // The library automatically parallelizes for large inputs.
    // We use update_rayon() when available, or fall back to regular update.
    let mut hasher = Hasher::new();
    hasher.update(data);
    hasher.finalize().into()
}

/// Parallel XOR with key stream for encryption/decryption
///
/// Applies XOR operation between data and a repeating key stream.
/// Optimized for stream cipher operations.
///
/// # Arguments
///
/// * `data` - Input data
/// * `keystream` - Key stream to XOR with (will be repeated if shorter than data)
/// * `output` - Output buffer (must be same length as data)
///
/// # Errors
///
/// Returns error if output buffer length doesn't match data length.
pub fn xor_keystream(data: &[u8], keystream: &[u8], output: &mut [u8]) -> SimdResult<()> {
    if data.len() != output.len() {
        return Err(SimdError::InvalidInput(
            "Data and output lengths must match".to_string(),
        ));
    }

    if keystream.is_empty() {
        return Err(SimdError::InvalidInput(
            "Keystream cannot be empty".to_string(),
        ));
    }

    // Process in chunks for better cache locality
    let chunk_size = 4096; // 4KB chunks
    for (chunk_idx, data_chunk) in data.chunks(chunk_size).enumerate() {
        let out_offset = chunk_idx * chunk_size;
        for (i, &byte) in data_chunk.iter().enumerate() {
            let key_idx = (out_offset + i) % keystream.len();
            output[out_offset + i] = byte ^ keystream[key_idx];
        }
    }

    Ok(())
}

/// Batch constant-time comparison
///
/// Compares multiple pairs of slices in a single operation.
/// All comparisons execute in constant time regardless of where mismatches occur.
///
/// # Arguments
///
/// * `pairs` - Slice of (a, b) tuples to compare
///
/// # Returns
///
/// Vector of boolean results (same length as input pairs)
pub fn batch_constant_time_eq(pairs: &[(&[u8], &[u8])]) -> Vec<bool> {
    pairs.iter().map(|(a, b)| constant_time_eq(a, b)).collect()
}

/// SIMD-optimized memory copy for cryptographic data
///
/// Optimized for copying keys, nonces, and other cryptographic material.
/// Uses aligned memory operations when possible.
///
/// # Arguments
///
/// * `src` - Source slice
/// * `dst` - Destination slice (must be same length as src)
///
/// # Errors
///
/// Returns error if lengths don't match.
pub fn secure_copy(src: &[u8], dst: &mut [u8]) -> SimdResult<()> {
    if src.len() != dst.len() {
        return Err(SimdError::InvalidInput(
            "Source and destination lengths must match".to_string(),
        ));
    }

    dst.copy_from_slice(src);
    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_xor_buffers() {
        let a = [0x01, 0x02, 0x03, 0x04];
        let b = [0x05, 0x06, 0x07, 0x08];
        let mut output = [0u8; 4];

        xor_buffers(&a, &b, &mut output).unwrap();
        assert_eq!(output, [0x04, 0x04, 0x04, 0x0c]);
    }

    #[test]
    fn test_xor_buffers_large() {
        let a = vec![0xAA; 1024];
        let b = vec![0x55; 1024];
        let mut output = vec![0u8; 1024];

        xor_buffers(&a, &b, &mut output).unwrap();
        assert!(output.iter().all(|&x| x == 0xFF));
    }

    #[test]
    fn test_xor_buffers_length_mismatch() {
        let a = [1, 2, 3];
        let b = [4, 5];
        let mut output = [0u8; 3];

        assert!(xor_buffers(&a, &b, &mut output).is_err());
    }

    #[test]
    fn test_constant_time_eq() {
        let a = [1, 2, 3, 4, 5];
        let b = [1, 2, 3, 4, 5];
        assert!(constant_time_eq(&a, &b));

        let c = [1, 2, 3, 4, 6];
        assert!(!constant_time_eq(&a, &c));

        let d = [1, 2, 3, 4];
        assert!(!constant_time_eq(&a, &d));
    }

    #[test]
    fn test_constant_time_eq_v2() {
        let a = [1, 2, 3, 4, 5];
        let b = [1, 2, 3, 4, 5];
        assert!(constant_time_eq_v2(&a, &b));

        let c = [1, 2, 3, 4, 6];
        assert!(!constant_time_eq_v2(&a, &c));
    }

    #[test]
    fn test_secure_zero() {
        let mut data = vec![0xFF; 100];
        secure_zero(&mut data);
        assert!(data.iter().all(|&x| x == 0));
    }

    #[test]
    fn test_parallel_hash() {
        let data = vec![0x42; 1024];
        let hash1 = parallel_hash(&data);
        let hash2 = blake3::hash(&data);

        assert_eq!(hash1, *hash2.as_bytes());
    }

    #[test]
    fn test_parallel_hash_large() {
        let data = vec![0x42; 1024 * 1024]; // 1 MB
        let hash1 = parallel_hash(&data);
        let hash2 = blake3::hash(&data);

        assert_eq!(hash1, *hash2.as_bytes());
    }

    #[test]
    fn test_parallel_hash_with_threads() {
        let data = vec![0x42; 100_000];

        for num_threads in 1..=8 {
            let hash = parallel_hash_with_threads(&data, num_threads);
            assert_eq!(hash.len(), 32);
        }
    }

    #[test]
    fn test_xor_keystream() {
        let data = [0x01, 0x02, 0x03, 0x04, 0x05];
        let keystream = [0xFF, 0xAA];
        let mut output = [0u8; 5];

        xor_keystream(&data, &keystream, &mut output).unwrap();

        // Expected: [0x01^0xFF, 0x02^0xAA, 0x03^0xFF, 0x04^0xAA, 0x05^0xFF]
        assert_eq!(output, [0xFE, 0xA8, 0xFC, 0xAE, 0xFA]);
    }

    #[test]
    fn test_xor_keystream_empty_key() {
        let data = [1, 2, 3];
        let keystream = [];
        let mut output = [0u8; 3];

        assert!(xor_keystream(&data, &keystream, &mut output).is_err());
    }

    #[test]
    fn test_batch_constant_time_eq() {
        let pairs = [
            ([1, 2, 3].as_slice(), [1, 2, 3].as_slice()),
            ([4, 5, 6].as_slice(), [4, 5, 6].as_slice()),
            ([7, 8, 9].as_slice(), [7, 8, 0].as_slice()),
        ];

        let results = batch_constant_time_eq(&pairs);
        assert_eq!(results, vec![true, true, false]);
    }

    #[test]
    fn test_secure_copy() {
        let src = [1, 2, 3, 4, 5];
        let mut dst = [0u8; 5];

        secure_copy(&src, &mut dst).unwrap();
        assert_eq!(src, dst);
    }

    #[test]
    fn test_secure_copy_length_mismatch() {
        let src = [1, 2, 3];
        let mut dst = [0u8; 5];

        assert!(secure_copy(&src, &mut dst).is_err());
    }
}