aead_io/
reader.rs

1use crate::buffer::{CappedBuffer, ResizeBuffer};
2use crate::error::{Error, InvalidCapacity};
3use crate::rw::Read;
4use aead::generic_array::ArrayLength;
5use aead::stream::{Decryptor, NewStream, Nonce, NonceSize, StreamPrimitive};
6use aead::{AeadInPlace, Key, KeyInit};
7use core::ops::Sub;
8
9pub enum MaybeUninitDecryptor<A, S>
10where
11    A: AeadInPlace + KeyInit,
12    S: StreamPrimitive<A> + NewStream<A>,
13    A::NonceSize: Sub<S::NonceOverhead>,
14    NonceSize<A, S>: ArrayLength<u8>,
15{
16    Uninit(A),
17    Decryptor(Decryptor<A, S>),
18    Empty,
19}
20
21impl<A, S> MaybeUninitDecryptor<A, S>
22where
23    A: AeadInPlace + KeyInit,
24    S: StreamPrimitive<A> + NewStream<A>,
25    A::NonceSize: Sub<S::NonceOverhead>,
26    NonceSize<A, S>: ArrayLength<u8>,
27{
28    fn uninit(aead: A) -> Self {
29        Self::Uninit(aead)
30    }
31    fn init(&mut self, nonce: &Nonce<A, S>) -> Result<(), aead::Error> {
32        match core::mem::replace(self, Self::Empty) {
33            Self::Uninit(aead) => *self = Self::Decryptor(Decryptor::from_aead(aead, nonce)),
34            Self::Decryptor(decryptor) => *self = Self::Decryptor(decryptor),
35            Self::Empty => return Err(aead::Error),
36        }
37        Ok(())
38    }
39    fn is_uninit(&self) -> bool {
40        matches!(self, Self::Uninit(_))
41    }
42    fn as_mut(&mut self) -> Option<&mut Decryptor<A, S>> {
43        match self {
44            Self::Decryptor(decryptor) => Some(decryptor),
45            _ => None,
46        }
47    }
48    fn take(&mut self) -> Option<Decryptor<A, S>> {
49        match core::mem::replace(self, Self::Empty) {
50            Self::Decryptor(decryptor) => Some(decryptor),
51            Self::Uninit(_) => None,
52            Self::Empty => None,
53        }
54    }
55}
56
57/// A wrapper around a [`Read`](Read) object and a [`StreamPrimitive`](`StreamPrimitive`)
58/// providing a [`Read`](Read) interface which automatically decrypts the underlying stream when
59/// reading
60pub struct DecryptBufReader<A, B, R, S>
61where
62    A: AeadInPlace + KeyInit,
63    S: StreamPrimitive<A> + NewStream<A>,
64    A::NonceSize: Sub<S::NonceOverhead>,
65    NonceSize<A, S>: ArrayLength<u8>,
66{
67    decryptor: MaybeUninitDecryptor<A, S>,
68    buffer: B,
69    reader: R,
70    bytes_to_read: usize,
71    read_offset: usize,
72    capacity: usize,
73}
74
75impl<A, B, R, S> DecryptBufReader<A, B, R, S>
76where
77    A: AeadInPlace + KeyInit,
78    B: ResizeBuffer + CappedBuffer,
79    S: StreamPrimitive<A> + NewStream<A>,
80    A::NonceSize: Sub<S::NonceOverhead>,
81    NonceSize<A, S>: ArrayLength<u8>,
82{
83    /// Constructs a new Reader using an AEAD key, buffer and reader
84    pub fn new(key: &Key<A>, mut buffer: B, reader: R) -> Result<Self, InvalidCapacity> {
85        buffer.truncate(0);
86        let capacity = buffer.capacity().min(u32::MAX as usize);
87        if capacity < 1 {
88            Err(InvalidCapacity)
89        } else {
90            Ok(Self {
91                decryptor: MaybeUninitDecryptor::uninit(A::new(key)),
92                reader,
93                buffer,
94                bytes_to_read: 0,
95                read_offset: 0,
96                capacity,
97            })
98        }
99    }
100
101    /// Constructs a new Reader using an AEAD primitive, buffer and reader
102    pub fn from_aead(aead: A, mut buffer: B, reader: R) -> Result<Self, InvalidCapacity> {
103        buffer.truncate(0);
104        let capacity = buffer.capacity().min(u32::MAX as usize);
105        if capacity < 1 {
106            Err(InvalidCapacity)
107        } else {
108            Ok(Self {
109                decryptor: MaybeUninitDecryptor::uninit(aead),
110                reader,
111                buffer,
112                bytes_to_read: 0,
113                read_offset: 0,
114                capacity,
115            })
116        }
117    }
118
119    /// Gets a reference to the inner reader
120    pub fn inner(&self) -> &R {
121        &self.reader
122    }
123
124    /// Consumes the Reader and returns the inner reader
125    pub fn into_inner(self) -> R {
126        self.reader
127    }
128}
129
130impl<A, B, R, S> DecryptBufReader<A, B, R, S>
131where
132    A: AeadInPlace + KeyInit,
133    B: ResizeBuffer + CappedBuffer,
134    R: Read,
135    S: StreamPrimitive<A> + NewStream<A>,
136    A::NonceSize: Sub<S::NonceOverhead>,
137    NonceSize<A, S>: ArrayLength<u8>,
138{
139    fn read_chunk_size(&mut self) -> Result<(), Error<R::Error>> {
140        let mut bytes_to_read = [0u8; 4];
141        let mut offset = 0;
142        while offset < 4 {
143            let read = self.reader.read(&mut bytes_to_read[offset..])?;
144            if read == 0 {
145                if offset == 0 {
146                    self.bytes_to_read = 0;
147                    return Ok(());
148                } else {
149                    return Err(Error::Aead);
150                }
151            }
152            offset += read;
153        }
154        let bytes_to_read = u32::from_be_bytes(bytes_to_read) as usize;
155        if bytes_to_read > self.capacity {
156            Err(Error::Aead)
157        } else {
158            self.bytes_to_read = bytes_to_read;
159            Ok(())
160        }
161    }
162
163    fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error<R::Error>> {
164        if self.decryptor.is_uninit() {
165            let mut nonce = Nonce::<A, S>::default();
166            self.reader.read_exact(&mut nonce)?;
167            self.decryptor.init(&nonce).map_err(|_| Error::Aead)?;
168            self.read_chunk_size()?;
169        }
170
171        while self.buffer.is_empty() {
172            if self.bytes_to_read == 0 {
173                return Ok(0);
174            }
175            self.buffer
176                .resize_zeroed(self.bytes_to_read)
177                .map_err(|_| Error::Aead)?;
178            self.reader.read_exact(self.buffer.as_mut())?;
179            self.read_chunk_size()?;
180
181            if self.bytes_to_read == 0 {
182                self.decryptor
183                    .take()
184                    .ok_or(Error::Aead)?
185                    .decrypt_last_in_place(&[], &mut self.buffer)
186                    .map_err(|_| Error::Aead)?;
187            } else {
188                self.decryptor
189                    .as_mut()
190                    .ok_or(Error::Aead)?
191                    .decrypt_next_in_place(&[], &mut self.buffer)
192                    .map_err(|_| Error::Aead)?;
193            }
194        }
195
196        let bytes_to_copy = (self.buffer.len() - self.read_offset).min(buf.len());
197        buf[..bytes_to_copy].copy_from_slice(
198            &self.buffer.as_ref()[self.read_offset..self.read_offset + bytes_to_copy],
199        );
200        self.buffer.as_mut()[self.read_offset..self.read_offset + bytes_to_copy].fill(0);
201
202        if self.buffer.len() == self.read_offset + bytes_to_copy {
203            self.read_offset = 0;
204            self.buffer.truncate(0);
205        } else {
206            self.read_offset += bytes_to_copy;
207        }
208
209        Ok(bytes_to_copy)
210    }
211}
212
213#[cfg(feature = "std")]
214impl<A, B, R, S> std::io::Read for DecryptBufReader<A, B, R, S>
215where
216    A: AeadInPlace + KeyInit,
217    B: ResizeBuffer + CappedBuffer,
218    R: Read,
219    R::Error: Into<std::io::Error>,
220    S: StreamPrimitive<A> + NewStream<A>,
221    A::NonceSize: Sub<S::NonceOverhead>,
222    NonceSize<A, S>: ArrayLength<u8>,
223{
224    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
225        Ok(self.read(buf)?)
226    }
227}
228
229#[cfg(not(feature = "std"))]
230impl<A, B, R, S> Read for DecryptBufReader<A, B, R, S>
231where
232    A: AeadInPlace + KeyInit,
233    B: ResizeBuffer + CappedBuffer,
234    R: Read,
235    S: StreamPrimitive<A> + NewStream<A>,
236    A::NonceSize: Sub<S::NonceOverhead>,
237    NonceSize<A, S>: ArrayLength<u8>,
238{
239    type Error = Error<R::Error>;
240    fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
241        Ok(self.read(buf)?)
242    }
243    fn read_exact(&mut self, mut buf: &mut [u8]) -> Result<(), Self::Error> {
244        while !buf.is_empty() {
245            match self.read(buf) {
246                Ok(0) => break,
247                Ok(n) => {
248                    let tmp = buf;
249                    buf = &mut tmp[n..];
250                }
251                Err(e) => return Err(e),
252            }
253        }
254        if !buf.is_empty() {
255            Err(Error::Aead)
256        } else {
257            Ok(())
258        }
259    }
260}