dcrypt_algorithms/block/modes/ctr/mod.rs
1//! Counter (CTR) mode with proper error propagation and secure memory handling
2//!
3//! Counter mode turns a block cipher into a stream cipher by encrypting
4//! successive values of a counter and XORing the result with the plaintext.
5//!
6//! This implementation follows NIST SP 800-38A recommendations for CTR mode,
7//! using a flexible nonce-counter format with secure memory handling.
8
9#[cfg(not(feature = "std"))]
10use alloc::vec::Vec;
11use byteorder::{BigEndian, ByteOrder};
12use zeroize::{Zeroize, ZeroizeOnDrop, Zeroizing};
13
14use super::super::BlockCipher;
15use crate::error::{validate, Result};
16use crate::types::nonce::AesCtrCompatible;
17use crate::types::Nonce;
18
19// Import security types for memory safety
20use dcrypt_common::security::barrier;
21
22/// Counter position within the counter block
23#[derive(Debug, Clone, Copy, PartialEq)]
24pub enum CounterPosition {
25    /// Counter is placed at the beginning of the block (bytes 0 to counter_size-1)
26    /// This is common in some implementations, especially with 8-byte counters
27    Prefix,
28
29    /// Counter is placed at the end of the block (last counter_size bytes)
30    /// This is the most common arrangement for AES-CTR
31    Postfix,
32
33    /// Counter is placed at a specific offset within the block
34    /// Allows for custom layouts
35    Custom(usize),
36}
37
38/// Counter mode implementation with secure memory handling
39#[derive(Clone, Zeroize, ZeroizeOnDrop)]
40pub struct Ctr<B: BlockCipher + Zeroize> {
41    cipher: B,
42    counter_block: Zeroizing<Vec<u8>>,
43    counter_position: usize,
44    counter_size: usize,
45    keystream: Zeroizing<Vec<u8>>,
46    keystream_pos: usize,
47}
48
49impl<B: BlockCipher + Zeroize> Ctr<B> {
50    /// Creates a new CTR mode instance with the default configuration
51    ///
52    /// * `cipher` - The block cipher to use
53    /// * `nonce` - The nonce (must be compatible with CTR mode)
54    ///
55    /// This creates a standard CTR mode with the counter in the last 4 bytes
56    /// and the nonce filling the beginning of the counter block.
57    pub fn new<const N: usize>(cipher: B, nonce: &Nonce<N>) -> Result<Self>
58    where
59        Nonce<N>: AesCtrCompatible,
60    {
61        // Standard CTR mode with 4-byte counter at the end
62        Self::with_counter_params(cipher, nonce, CounterPosition::Postfix, 4)
63    }
64
65    /// Creates a new CTR mode instance with custom counter parameters
66    ///
67    /// * `cipher` - The block cipher to use
68    /// * `nonce` - The nonce (must be compatible with CTR mode)
69    /// * `counter_pos` - Position of the counter within the counter block
70    /// * `counter_size` - Size of the counter in bytes (1-8)
71    ///
72    /// This allows for flexible counter block layouts to match different standards
73    /// and implementations.
74    pub fn with_counter_params<const N: usize>(
75        cipher: B,
76        nonce: &Nonce<N>,
77        counter_pos: CounterPosition,
78        counter_size: usize,
79    ) -> Result<Self>
80    where
81        Nonce<N>: AesCtrCompatible,
82    {
83        let block_size = B::block_size();
84
85        // Validate counter size (1-8 bytes for u64 counter)
86        validate::parameter(
87            counter_size > 0 && counter_size <= 8,
88            "counter_size",
89            "Counter size must be between 1 and 8 bytes",
90        )?;
91
92        // Determine the counter position
93        let position = match counter_pos {
94            CounterPosition::Prefix => 0,
95            CounterPosition::Postfix => block_size - counter_size,
96            CounterPosition::Custom(offset) => {
97                validate::parameter(
98                    offset + counter_size <= block_size,
99                    "counter_position",
100                    "Counter with specified size doesn't fit at offset in block",
101                )?;
102                offset
103            }
104        };
105
106        // Create and initialize the counter block with Zeroizing
107        let mut counter_block = Zeroizing::new(vec![0u8; block_size]);
108
109        // Handle nonce according to its size
110        let max_nonce_size = block_size - counter_size;
111
112        // If nonce is too large, truncate it
113        let effective_nonce = if N > max_nonce_size {
114            &nonce.as_ref()[0..max_nonce_size]
115        } else {
116            nonce.as_ref()
117        };
118
119        // Fill in the nonce
120        if position == 0 {
121            // Counter is at the beginning, place nonce after it
122            counter_block[counter_size..counter_size + effective_nonce.len()]
123                .copy_from_slice(effective_nonce);
124        } else {
125            // Counter is elsewhere, place nonce at the beginning by default
126            counter_block[0..effective_nonce.len()].copy_from_slice(effective_nonce);
127        }
128
129        Ok(Self {
130            cipher,
131            counter_block,
132            counter_position: position,
133            counter_size,
134            keystream: Zeroizing::new(Vec::new()),
135            keystream_pos: 0,
136        })
137    }
138
139    /// Generate keystream for CTR mode with secure memory handling
140    fn generate_keystream(&mut self) -> Result<()> {
141        let block_size = B::block_size();
142
143        // Create a new zeroizing keystream buffer
144        self.keystream = Zeroizing::new(vec![0u8; block_size]);
145
146        // Use memory barrier to prevent optimization
147        barrier::compiler_fence_seq_cst();
148
149        // Copy current counter block to keystream
150        self.keystream.copy_from_slice(&self.counter_block);
151
152        // Encrypt the counter value
153        self.cipher.encrypt_block(&mut self.keystream)?;
154
155        // Increment the counter based on its size
156        self.increment_counter();
157
158        self.keystream_pos = 0;
159
160        // Use memory barrier after operation
161        barrier::compiler_fence_seq_cst();
162
163        Ok(())
164    }
165
166    /// Increment the counter in the counter block
167    fn increment_counter(&mut self) {
168        match self.counter_size {
169            8 => {
170                let mut counter = [0u8; 8];
171                counter.copy_from_slice(
172                    &self.counter_block[self.counter_position..self.counter_position + 8],
173                );
174                let value = BigEndian::read_u64(&counter);
175                BigEndian::write_u64(&mut counter, value.wrapping_add(1));
176                self.counter_block[self.counter_position..self.counter_position + 8]
177                    .copy_from_slice(&counter);
178
179                // Zeroize the temporary counter array
180                counter.zeroize();
181            }
182            4 => {
183                let mut counter = [0u8; 4];
184                counter.copy_from_slice(
185                    &self.counter_block[self.counter_position..self.counter_position + 4],
186                );
187                let value = BigEndian::read_u32(&counter);
188                BigEndian::write_u32(&mut counter, value.wrapping_add(1));
189                self.counter_block[self.counter_position..self.counter_position + 4]
190                    .copy_from_slice(&counter);
191
192                // Zeroize the temporary counter array
193                counter.zeroize();
194            }
195            // For other counter sizes, we'll read/write the appropriate number of bytes
196            size => {
197                let mut value: u64 = 0;
198
199                // Read counter value (big-endian)
200                for i in 0..size {
201                    value = (value << 8) | (self.counter_block[self.counter_position + i] as u64);
202                }
203
204                // Increment counter
205                value = value.wrapping_add(1);
206
207                // Write counter value back (big-endian)
208                for i in 0..size {
209                    self.counter_block[self.counter_position + size - 1 - i] = (value & 0xff) as u8;
210                    value >>= 8;
211                }
212            }
213        }
214    }
215
216    /// Encrypts a message using CTR mode
217    pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>> {
218        let mut ciphertext = Vec::with_capacity(plaintext.len());
219
220        // Use memory barrier before sensitive operations
221        barrier::compiler_fence_seq_cst();
222
223        for &byte in plaintext {
224            if self.keystream_pos >= self.keystream.len() {
225                self.generate_keystream()?;
226            }
227
228            ciphertext.push(byte ^ self.keystream[self.keystream_pos]);
229            self.keystream_pos += 1;
230        }
231
232        // Use memory barrier after sensitive operations
233        barrier::compiler_fence_seq_cst();
234
235        Ok(ciphertext)
236    }
237
238    /// Decrypts a message using CTR mode
239    /// In CTR mode, encryption and decryption are the same operation
240    pub fn decrypt(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>> {
241        self.encrypt(ciphertext)
242    }
243
244    /// Process data in place (encrypt or decrypt)
245    pub fn process(&mut self, data: &mut [u8]) -> Result<()> {
246        // Use memory barrier before sensitive operations
247        barrier::compiler_fence_seq_cst();
248
249        for byte in data.iter_mut() {
250            // Generate new keystream block if needed
251            if self.keystream_pos >= self.keystream.len() {
252                self.generate_keystream()?;
253            }
254
255            // XOR data with keystream
256            *byte ^= self.keystream[self.keystream_pos];
257            self.keystream_pos += 1;
258        }
259
260        // Use memory barrier after sensitive operations
261        barrier::compiler_fence_seq_cst();
262
263        Ok(())
264    }
265
266    /// Generate keystream directly into an output buffer
267    pub fn keystream(&mut self, output: &mut [u8]) -> Result<()> {
268        // Zero the output buffer
269        for byte in output.iter_mut() {
270            *byte = 0;
271        }
272
273        // Force generation from a block boundary (ignore any leftover position)
274        self.keystream_pos = self.keystream.len();
275
276        // Then run the encryption pass to copy the keystream
277        self.process(output)
278    }
279
280    /// Seek to a specific block position
281    ///
282    /// `block_offset` is the number of full blocks that have been consumed;
283    /// after seeking, the next generated block will be at `block_offset + 1`.
284    pub fn seek(&mut self, block_offset: u32) {
285        // Calculate the counter value based on the offset
286        let mut counter_value = [0u8; 8];
287        BigEndian::write_u32(&mut counter_value[4..], block_offset.wrapping_add(1));
288
289        // Update counter in the counter block
290        for i in 0..self.counter_size {
291            let idx = self.counter_position + self.counter_size - 1 - i;
292            self.counter_block[idx] = counter_value[7 - i];
293        }
294
295        // Force regeneration on next use
296        self.keystream_pos = self.keystream.len();
297
298        // Clear any old keystream with Zeroizing
299        self.keystream = Zeroizing::new(Vec::new());
300
301        // Zeroize the temporary counter value
302        counter_value.zeroize();
303    }
304
305    /// Set the counter value directly
306    ///
307    /// This allows for manual control of the counter, which can be useful for
308    /// seeking to specific positions in the stream.
309    ///
310    /// # Arguments
311    /// * `counter` - The new counter value
312    pub fn set_counter(&mut self, counter: u32) {
313        // Update counter in the counter block
314        let counter_pos = self.counter_position;
315
316        // Write the counter value in big-endian format
317        // This handles various counter sizes (1-8 bytes)
318        let counter_bytes = counter.to_be_bytes();
319        let start_idx = 4 - self.counter_size;
320
321        for i in 0..self.counter_size {
322            if start_idx + i < 4 {
323                // Only copy if within counter_bytes bounds
324                self.counter_block[counter_pos + i] = counter_bytes[start_idx + i];
325            }
326        }
327
328        // Force regeneration of keystream on next use
329        self.keystream_pos = self.keystream.len();
330    }
331
332    /// Reset to initial state with the same key and nonce
333    ///
334    /// This resets the counter to 0 and clears any buffered keystream.
335    ///
336    /// # Arguments
337    /// * `nonce` - Optional new nonce to use (if not provided, keeps the current nonce)
338    /// * `counter` - Optional initial counter value (defaults to 0)
339    pub fn reset<const N: usize>(&mut self, nonce: Option<&Nonce<N>>, counter: u32) -> Result<()>
340    where
341        Nonce<N>: AesCtrCompatible,
342    {
343        // Use memory barrier before sensitive operations
344        barrier::compiler_fence_seq_cst();
345
346        // Update nonce if provided
347        if let Some(new_nonce) = nonce {
348            let block_size = B::block_size();
349            let max_nonce_size = block_size - self.counter_size;
350
351            // If nonce is too large, truncate it
352            let effective_nonce = if N > max_nonce_size {
353                &new_nonce.as_ref()[0..max_nonce_size]
354            } else {
355                new_nonce.as_ref()
356            };
357
358            // Clear the counter block
359            for b in &mut *self.counter_block {
360                *b = 0;
361            }
362
363            // Fill in the nonce
364            let counter_pos = match self.counter_position {
365                0 => self.counter_size, // Counter is at beginning, nonce follows
366                _ => 0,                 // Otherwise nonce is at beginning
367            };
368
369            // Copy the new nonce
370            self.counter_block[counter_pos..counter_pos + effective_nonce.len()]
371                .copy_from_slice(effective_nonce);
372        }
373
374        // Set the counter value
375        self.set_counter(counter);
376
377        // Clear keystream
378        self.keystream = Zeroizing::new(Vec::new());
379        self.keystream_pos = 0;
380
381        // Use memory barrier after sensitive operations
382        barrier::compiler_fence_seq_cst();
383
384        Ok(())
385    }
386}
387
388#[cfg(test)]
389mod tests;