cipher/stream/
wrapper.rs

1use crate::StreamCipherCounter;
2
3use super::{
4    OverflowError, SeekNum, StreamCipher, StreamCipherCore, StreamCipherSeek, StreamCipherSeekCore,
5    errors::StreamCipherError,
6};
7use block_buffer::ReadBuffer;
8use core::fmt;
9use crypto_common::{
10    Iv, IvSizeUser, Key, KeyInit, KeyIvInit, KeySizeUser, array::Array, typenum::Unsigned,
11};
12use inout::InOutBuf;
13#[cfg(feature = "zeroize")]
14use zeroize::ZeroizeOnDrop;
15
16/// Buffering wrapper around a [`StreamCipherCore`] implementation.
17///
18/// It handles data buffering and implements the slice-based traits.
19pub struct StreamCipherCoreWrapper<T: StreamCipherCore> {
20    core: T,
21    buffer: ReadBuffer<T::BlockSize>,
22}
23
24impl<T: StreamCipherCore + Clone> Clone for StreamCipherCoreWrapper<T> {
25    #[inline]
26    fn clone(&self) -> Self {
27        Self {
28            core: self.core.clone(),
29            buffer: self.buffer.clone(),
30        }
31    }
32}
33
34impl<T: StreamCipherCore + fmt::Debug> fmt::Debug for StreamCipherCoreWrapper<T> {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        f.debug_struct("StreamCipherCoreWrapper")
37            .finish_non_exhaustive()
38    }
39}
40
41impl<T: StreamCipherCore> StreamCipherCoreWrapper<T> {
42    /// Initialize from a [`StreamCipherCore`] instance.
43    pub fn from_core(core: T) -> Self {
44        Self {
45            core,
46            buffer: Default::default(),
47        }
48    }
49
50    /// Get reference to the wrapped [`StreamCipherCore`] instance.
51    pub fn get_core(&self) -> &T {
52        &self.core
53    }
54}
55
56impl<T: StreamCipherCore> StreamCipher for StreamCipherCoreWrapper<T> {
57    #[inline]
58    fn check_remaining(&self, data_len: usize) -> Result<(), StreamCipherError> {
59        let Some(rem_blocks) = self.core.remaining_blocks() else {
60            return Ok(());
61        };
62        let Some(data_len) = data_len.checked_sub(self.buffer.remaining()) else {
63            return Ok(());
64        };
65        let req_blocks = data_len.div_ceil(T::BlockSize::USIZE);
66        if req_blocks > rem_blocks {
67            Err(StreamCipherError)
68        } else {
69            Ok(())
70        }
71    }
72
73    #[inline]
74    fn unchecked_apply_keystream_inout(&mut self, data: InOutBuf<'_, '_, u8>) {
75        let head_ks = self.buffer.read_cached(data.len());
76
77        let (mut head, data) = data.split_at(head_ks.len());
78        let (blocks, mut tail) = data.into_chunks();
79
80        head.xor_in2out(head_ks);
81        self.core.apply_keystream_blocks_inout(blocks);
82
83        self.buffer.write_block(
84            tail.len(),
85            |b| self.core.write_keystream_block(b),
86            |tail_ks| {
87                tail.xor_in2out(tail_ks);
88            },
89        );
90    }
91
92    #[inline]
93    fn unchecked_write_keystream(&mut self, data: &mut [u8]) {
94        let head_ks = self.buffer.read_cached(data.len());
95
96        let (head, data) = data.split_at_mut(head_ks.len());
97        let (blocks, tail) = Array::slice_as_chunks_mut(data);
98
99        head.copy_from_slice(head_ks);
100        self.core.write_keystream_blocks(blocks);
101
102        self.buffer.write_block(
103            tail.len(),
104            |b| self.core.write_keystream_block(b),
105            |tail_ks| tail.copy_from_slice(tail_ks),
106        );
107    }
108}
109
110impl<T: StreamCipherSeekCore> StreamCipherSeek for StreamCipherCoreWrapper<T> {
111    fn try_current_pos<SN: SeekNum>(&self) -> Result<SN, OverflowError> {
112        let pos = u8::try_from(self.buffer.get_pos())
113            .expect("buffer position is always smaller than 256");
114        SN::from_block_byte(self.core.get_block_pos(), pos, T::BlockSize::U8)
115    }
116
117    fn try_seek<SN: SeekNum>(&mut self, new_pos: SN) -> Result<(), StreamCipherError> {
118        let (block_pos, byte_pos) = new_pos.into_block_byte::<T::Counter>(T::BlockSize::U8)?;
119        if byte_pos != 0 && block_pos.is_max() {
120            return Err(StreamCipherError);
121        }
122        // For correct implementations of `SeekNum` the compiler should be able to
123        // eliminate this assert
124        assert!(byte_pos < T::BlockSize::U8);
125
126        self.core.set_block_pos(block_pos);
127
128        self.buffer.reset();
129
130        self.buffer.write_block(
131            usize::from(byte_pos),
132            |b| self.core.write_keystream_block(b),
133            |_| {},
134        );
135        Ok(())
136    }
137}
138
139// Note: ideally we would only implement the InitInner trait and everything
140// else would be handled by blanket impls, but, unfortunately, it will
141// not work properly without mutually exclusive traits, see:
142// https://github.com/rust-lang/rfcs/issues/1053
143
144impl<T: KeySizeUser + StreamCipherCore> KeySizeUser for StreamCipherCoreWrapper<T> {
145    type KeySize = T::KeySize;
146}
147
148impl<T: IvSizeUser + StreamCipherCore> IvSizeUser for StreamCipherCoreWrapper<T> {
149    type IvSize = T::IvSize;
150}
151
152impl<T: KeyIvInit + StreamCipherCore> KeyIvInit for StreamCipherCoreWrapper<T> {
153    #[inline]
154    fn new(key: &Key<Self>, iv: &Iv<Self>) -> Self {
155        Self {
156            core: T::new(key, iv),
157            buffer: Default::default(),
158        }
159    }
160}
161
162impl<T: KeyInit + StreamCipherCore> KeyInit for StreamCipherCoreWrapper<T> {
163    #[inline]
164    fn new(key: &Key<Self>) -> Self {
165        Self {
166            core: T::new(key),
167            buffer: Default::default(),
168        }
169    }
170}
171
172#[cfg(feature = "zeroize")]
173impl<T: StreamCipherCore + ZeroizeOnDrop> ZeroizeOnDrop for StreamCipherCoreWrapper<T> {}
174
175// Assert that `ReadBuffer` implements `ZeroizeOnDrop`
176#[cfg(feature = "zeroize")]
177const _: () = {
178    #[allow(dead_code)]
179    fn check_buffer<BS: crate::array::ArraySize>(v: &ReadBuffer<BS>) {
180        let _ = v as &dyn crate::zeroize::ZeroizeOnDrop;
181    }
182};