cipher/stream/
wrapper.rs

1use super::{
2    Block, OverflowError, SeekNum, StreamCipher, StreamCipherCore, StreamCipherSeek,
3    StreamCipherSeekCore, errors::StreamCipherError,
4};
5use core::fmt;
6use crypto_common::{Iv, IvSizeUser, Key, KeyInit, KeyIvInit, KeySizeUser, typenum::Unsigned};
7use inout::InOutBuf;
8#[cfg(feature = "zeroize")]
9use zeroize::{Zeroize, ZeroizeOnDrop};
10
11/// Buffering wrapper around a [`StreamCipherCore`] implementation.
12///
13/// It handles data buffering and implements the slice-based traits.
14pub struct StreamCipherCoreWrapper<T: StreamCipherCore> {
15    core: T,
16    // First byte is used as position
17    buffer: Block<T>,
18}
19
20impl<T: StreamCipherCore + Default> Default for StreamCipherCoreWrapper<T> {
21    #[inline]
22    fn default() -> Self {
23        Self::from_core(T::default())
24    }
25}
26
27impl<T: StreamCipherCore + Clone> Clone for StreamCipherCoreWrapper<T> {
28    #[inline]
29    fn clone(&self) -> Self {
30        Self {
31            core: self.core.clone(),
32            buffer: self.buffer.clone(),
33        }
34    }
35}
36
37impl<T: StreamCipherCore + fmt::Debug> fmt::Debug for StreamCipherCoreWrapper<T> {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        let pos = self.get_pos().into();
40        let buf_data = &self.buffer[pos..];
41        f.debug_struct("StreamCipherCoreWrapper")
42            .field("core", &self.core)
43            .field("buffer_data", &buf_data)
44            .finish()
45    }
46}
47
48impl<T: StreamCipherCore> StreamCipherCoreWrapper<T> {
49    /// Return reference to the core type.
50    pub fn get_core(&self) -> &T {
51        &self.core
52    }
53
54    /// Return reference to the core type.
55    pub fn from_core(core: T) -> Self {
56        let mut buffer: Block<T> = Default::default();
57        buffer[0] = T::BlockSize::U8;
58        Self { core, buffer }
59    }
60
61    /// Return current cursor position.
62    #[inline]
63    fn get_pos(&self) -> u8 {
64        let pos = self.buffer[0];
65        if pos == 0 || pos > T::BlockSize::U8 {
66            debug_assert!(false);
67            // SAFETY: `pos` never breaks the invariant
68            unsafe {
69                core::hint::unreachable_unchecked();
70            }
71        }
72        pos
73    }
74
75    /// Set buffer position without checking that it's smaller
76    /// than buffer size.
77    ///
78    /// # Safety
79    /// `pos` MUST be bigger than zero and smaller or equal to `T::BlockSize::USIZE`.
80    #[inline]
81    unsafe fn set_pos_unchecked(&mut self, pos: usize) {
82        debug_assert!(pos != 0 && pos <= T::BlockSize::USIZE);
83        // Block size is always smaller than 256 because of the `BlockSizes` bound,
84        // so if the safety condition is satisfied, the `as` cast does not truncate
85        // any non-zero bits.
86        self.buffer[0] = pos as u8;
87    }
88
89    /// Return number of remaining bytes in the internal buffer.
90    #[inline]
91    fn remaining(&self) -> u8 {
92        // This never underflows because of the safety invariant
93        T::BlockSize::U8 - self.get_pos()
94    }
95
96    fn check_remaining(&self, data_len: usize) -> Result<(), StreamCipherError> {
97        let rem_blocks = match self.core.remaining_blocks() {
98            Some(v) => v,
99            None => return Ok(()),
100        };
101
102        let buf_rem = usize::from(self.remaining());
103        let data_len = match data_len.checked_sub(buf_rem) {
104            Some(0) | None => return Ok(()),
105            Some(res) => res,
106        };
107
108        let bs = T::BlockSize::USIZE;
109        let blocks = data_len.div_ceil(bs);
110        if blocks > rem_blocks {
111            Err(StreamCipherError)
112        } else {
113            Ok(())
114        }
115    }
116}
117
118impl<T: StreamCipherCore> StreamCipher for StreamCipherCoreWrapper<T> {
119    #[inline]
120    fn try_apply_keystream_inout(
121        &mut self,
122        mut data: InOutBuf<'_, '_, u8>,
123    ) -> Result<(), StreamCipherError> {
124        self.check_remaining(data.len())?;
125
126        let pos = usize::from(self.get_pos());
127        let rem = usize::from(self.remaining());
128        let data_len = data.len();
129
130        if rem != 0 {
131            if data_len <= rem {
132                data.xor_in2out(&self.buffer[pos..][..data_len]);
133                // SAFETY: we have checked that `data_len` is less or equal to length
134                // of remaining keystream data, thus `pos + data_len` can not be bigger
135                // than block size. Since `pos` is never zero, `pos + data_len` can not
136                // be zero. Thus `pos + data_len` satisfies the safety invariant required
137                // by `set_pos_unchecked`.
138                unsafe {
139                    self.set_pos_unchecked(pos + data_len);
140                }
141                return Ok(());
142            }
143            let (mut left, right) = data.split_at(rem);
144            data = right;
145            left.xor_in2out(&self.buffer[pos..]);
146        }
147
148        let (blocks, mut tail) = data.into_chunks();
149        self.core.apply_keystream_blocks_inout(blocks);
150
151        let new_pos = if tail.is_empty() {
152            T::BlockSize::USIZE
153        } else {
154            // Note that we temporarily write a pseudo-random byte into
155            // the first byte of `self.buffer`. It may break the safety invariant,
156            // but after XORing keystream block with `tail`, we immediately
157            // overwrite the first byte with a correct value.
158            self.core.write_keystream_block(&mut self.buffer);
159            tail.xor_in2out(&self.buffer[..tail.len()]);
160            tail.len()
161        };
162
163        // SAFETY: `into_chunks` always returns tail with size
164        // less than block size. If `tail.len()` is zero, we replace
165        // it with block size. Thus the invariant required by
166        // `set_pos_unchecked` is satisfied.
167        unsafe {
168            self.set_pos_unchecked(new_pos);
169        }
170
171        Ok(())
172    }
173}
174
175impl<T: StreamCipherSeekCore> StreamCipherSeek for StreamCipherCoreWrapper<T> {
176    fn try_current_pos<SN: SeekNum>(&self) -> Result<SN, OverflowError> {
177        let pos = self.get_pos();
178        SN::from_block_byte(self.core.get_block_pos(), pos, T::BlockSize::U8)
179    }
180
181    fn try_seek<SN: SeekNum>(&mut self, new_pos: SN) -> Result<(), StreamCipherError> {
182        let (block_pos, byte_pos) = new_pos.into_block_byte(T::BlockSize::U8)?;
183        // For correct implementations of `SeekNum` compiler should be able to
184        // eliminate this assert
185        assert!(byte_pos < T::BlockSize::U8);
186
187        self.core.set_block_pos(block_pos);
188        let new_pos = if byte_pos != 0 {
189            // See comment in `try_apply_keystream_inout` for use of `write_keystream_block`
190            self.core.write_keystream_block(&mut self.buffer);
191            byte_pos.into()
192        } else {
193            T::BlockSize::USIZE
194        };
195        // SAFETY: we assert that `byte_pos` is always smaller than block size.
196        // If `byte_pos` is zero, we replace it with block size. Thus the invariant
197        // required by `set_pos_unchecked` is satisfied.
198        unsafe {
199            self.set_pos_unchecked(new_pos);
200        }
201        Ok(())
202    }
203}
204
205// Note: ideally we would only implement the InitInner trait and everything
206// else would be handled by blanket impls, but, unfortunately, it will
207// not work properly without mutually exclusive traits, see:
208// https://github.com/rust-lang/rfcs/issues/1053
209
210impl<T: KeySizeUser + StreamCipherCore> KeySizeUser for StreamCipherCoreWrapper<T> {
211    type KeySize = T::KeySize;
212}
213
214impl<T: IvSizeUser + StreamCipherCore> IvSizeUser for StreamCipherCoreWrapper<T> {
215    type IvSize = T::IvSize;
216}
217
218impl<T: KeyIvInit + StreamCipherCore> KeyIvInit for StreamCipherCoreWrapper<T> {
219    #[inline]
220    fn new(key: &Key<Self>, iv: &Iv<Self>) -> Self {
221        let mut buffer = Block::<T>::default();
222        buffer[0] = T::BlockSize::U8;
223        Self {
224            core: T::new(key, iv),
225            buffer,
226        }
227    }
228}
229
230impl<T: KeyInit + StreamCipherCore> KeyInit for StreamCipherCoreWrapper<T> {
231    #[inline]
232    fn new(key: &Key<Self>) -> Self {
233        let mut buffer = Block::<T>::default();
234        buffer[0] = T::BlockSize::U8;
235        Self {
236            core: T::new(key),
237            buffer,
238        }
239    }
240}
241
242#[cfg(feature = "zeroize")]
243impl<T: StreamCipherCore> Drop for StreamCipherCoreWrapper<T> {
244    fn drop(&mut self) {
245        // If present, `core` will be zeroized by its own `Drop`.
246        self.buffer.zeroize();
247    }
248}
249
250#[cfg(feature = "zeroize")]
251impl<T: StreamCipherCore + ZeroizeOnDrop> ZeroizeOnDrop for StreamCipherCoreWrapper<T> {}