1use crate::buffer::{CappedBuffer, ResizeBuffer};
2use crate::error::{Error, InvalidCapacity};
3use crate::rw::Read;
4use aead::generic_array::ArrayLength;
5use aead::stream::{Decryptor, NewStream, Nonce, NonceSize, StreamPrimitive};
6use aead::{AeadInPlace, Key, KeyInit};
7use core::ops::Sub;
8
9pub enum MaybeUninitDecryptor<A, S>
10where
11 A: AeadInPlace + KeyInit,
12 S: StreamPrimitive<A> + NewStream<A>,
13 A::NonceSize: Sub<S::NonceOverhead>,
14 NonceSize<A, S>: ArrayLength<u8>,
15{
16 Uninit(A),
17 Decryptor(Decryptor<A, S>),
18 Empty,
19}
20
21impl<A, S> MaybeUninitDecryptor<A, S>
22where
23 A: AeadInPlace + KeyInit,
24 S: StreamPrimitive<A> + NewStream<A>,
25 A::NonceSize: Sub<S::NonceOverhead>,
26 NonceSize<A, S>: ArrayLength<u8>,
27{
28 fn uninit(aead: A) -> Self {
29 Self::Uninit(aead)
30 }
31 fn init(&mut self, nonce: &Nonce<A, S>) -> Result<(), aead::Error> {
32 match core::mem::replace(self, Self::Empty) {
33 Self::Uninit(aead) => *self = Self::Decryptor(Decryptor::from_aead(aead, nonce)),
34 Self::Decryptor(decryptor) => *self = Self::Decryptor(decryptor),
35 Self::Empty => return Err(aead::Error),
36 }
37 Ok(())
38 }
39 fn is_uninit(&self) -> bool {
40 matches!(self, Self::Uninit(_))
41 }
42 fn as_mut(&mut self) -> Option<&mut Decryptor<A, S>> {
43 match self {
44 Self::Decryptor(decryptor) => Some(decryptor),
45 _ => None,
46 }
47 }
48 fn take(&mut self) -> Option<Decryptor<A, S>> {
49 match core::mem::replace(self, Self::Empty) {
50 Self::Decryptor(decryptor) => Some(decryptor),
51 Self::Uninit(_) => None,
52 Self::Empty => None,
53 }
54 }
55}
56
57pub struct DecryptBufReader<A, B, R, S>
61where
62 A: AeadInPlace + KeyInit,
63 S: StreamPrimitive<A> + NewStream<A>,
64 A::NonceSize: Sub<S::NonceOverhead>,
65 NonceSize<A, S>: ArrayLength<u8>,
66{
67 decryptor: MaybeUninitDecryptor<A, S>,
68 buffer: B,
69 reader: R,
70 bytes_to_read: usize,
71 read_offset: usize,
72 capacity: usize,
73}
74
75impl<A, B, R, S> DecryptBufReader<A, B, R, S>
76where
77 A: AeadInPlace + KeyInit,
78 B: ResizeBuffer + CappedBuffer,
79 S: StreamPrimitive<A> + NewStream<A>,
80 A::NonceSize: Sub<S::NonceOverhead>,
81 NonceSize<A, S>: ArrayLength<u8>,
82{
83 pub fn new(key: &Key<A>, mut buffer: B, reader: R) -> Result<Self, InvalidCapacity> {
85 buffer.truncate(0);
86 let capacity = buffer.capacity().min(u32::MAX as usize);
87 if capacity < 1 {
88 Err(InvalidCapacity)
89 } else {
90 Ok(Self {
91 decryptor: MaybeUninitDecryptor::uninit(A::new(key)),
92 reader,
93 buffer,
94 bytes_to_read: 0,
95 read_offset: 0,
96 capacity,
97 })
98 }
99 }
100
101 pub fn from_aead(aead: A, mut buffer: B, reader: R) -> Result<Self, InvalidCapacity> {
103 buffer.truncate(0);
104 let capacity = buffer.capacity().min(u32::MAX as usize);
105 if capacity < 1 {
106 Err(InvalidCapacity)
107 } else {
108 Ok(Self {
109 decryptor: MaybeUninitDecryptor::uninit(aead),
110 reader,
111 buffer,
112 bytes_to_read: 0,
113 read_offset: 0,
114 capacity,
115 })
116 }
117 }
118
119 pub fn inner(&self) -> &R {
121 &self.reader
122 }
123
124 pub fn into_inner(self) -> R {
126 self.reader
127 }
128}
129
130impl<A, B, R, S> DecryptBufReader<A, B, R, S>
131where
132 A: AeadInPlace + KeyInit,
133 B: ResizeBuffer + CappedBuffer,
134 R: Read,
135 S: StreamPrimitive<A> + NewStream<A>,
136 A::NonceSize: Sub<S::NonceOverhead>,
137 NonceSize<A, S>: ArrayLength<u8>,
138{
139 fn read_chunk_size(&mut self) -> Result<(), Error<R::Error>> {
140 let mut bytes_to_read = [0u8; 4];
141 let mut offset = 0;
142 while offset < 4 {
143 let read = self.reader.read(&mut bytes_to_read[offset..])?;
144 if read == 0 {
145 if offset == 0 {
146 self.bytes_to_read = 0;
147 return Ok(());
148 } else {
149 return Err(Error::Aead);
150 }
151 }
152 offset += read;
153 }
154 let bytes_to_read = u32::from_be_bytes(bytes_to_read) as usize;
155 if bytes_to_read > self.capacity {
156 Err(Error::Aead)
157 } else {
158 self.bytes_to_read = bytes_to_read;
159 Ok(())
160 }
161 }
162
163 fn read(&mut self, buf: &mut [u8]) -> Result<usize, Error<R::Error>> {
164 if self.decryptor.is_uninit() {
165 let mut nonce = Nonce::<A, S>::default();
166 self.reader.read_exact(&mut nonce)?;
167 self.decryptor.init(&nonce).map_err(|_| Error::Aead)?;
168 self.read_chunk_size()?;
169 }
170
171 while self.buffer.is_empty() {
172 if self.bytes_to_read == 0 {
173 return Ok(0);
174 }
175 self.buffer
176 .resize_zeroed(self.bytes_to_read)
177 .map_err(|_| Error::Aead)?;
178 self.reader.read_exact(self.buffer.as_mut())?;
179 self.read_chunk_size()?;
180
181 if self.bytes_to_read == 0 {
182 self.decryptor
183 .take()
184 .ok_or(Error::Aead)?
185 .decrypt_last_in_place(&[], &mut self.buffer)
186 .map_err(|_| Error::Aead)?;
187 } else {
188 self.decryptor
189 .as_mut()
190 .ok_or(Error::Aead)?
191 .decrypt_next_in_place(&[], &mut self.buffer)
192 .map_err(|_| Error::Aead)?;
193 }
194 }
195
196 let bytes_to_copy = (self.buffer.len() - self.read_offset).min(buf.len());
197 buf[..bytes_to_copy].copy_from_slice(
198 &self.buffer.as_ref()[self.read_offset..self.read_offset + bytes_to_copy],
199 );
200 self.buffer.as_mut()[self.read_offset..self.read_offset + bytes_to_copy].fill(0);
201
202 if self.buffer.len() == self.read_offset + bytes_to_copy {
203 self.read_offset = 0;
204 self.buffer.truncate(0);
205 } else {
206 self.read_offset += bytes_to_copy;
207 }
208
209 Ok(bytes_to_copy)
210 }
211}
212
213#[cfg(feature = "std")]
214impl<A, B, R, S> std::io::Read for DecryptBufReader<A, B, R, S>
215where
216 A: AeadInPlace + KeyInit,
217 B: ResizeBuffer + CappedBuffer,
218 R: Read,
219 R::Error: Into<std::io::Error>,
220 S: StreamPrimitive<A> + NewStream<A>,
221 A::NonceSize: Sub<S::NonceOverhead>,
222 NonceSize<A, S>: ArrayLength<u8>,
223{
224 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
225 Ok(self.read(buf)?)
226 }
227}
228
229#[cfg(not(feature = "std"))]
230impl<A, B, R, S> Read for DecryptBufReader<A, B, R, S>
231where
232 A: AeadInPlace + KeyInit,
233 B: ResizeBuffer + CappedBuffer,
234 R: Read,
235 S: StreamPrimitive<A> + NewStream<A>,
236 A::NonceSize: Sub<S::NonceOverhead>,
237 NonceSize<A, S>: ArrayLength<u8>,
238{
239 type Error = Error<R::Error>;
240 fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
241 Ok(self.read(buf)?)
242 }
243 fn read_exact(&mut self, mut buf: &mut [u8]) -> Result<(), Self::Error> {
244 while !buf.is_empty() {
245 match self.read(buf) {
246 Ok(0) => break,
247 Ok(n) => {
248 let tmp = buf;
249 buf = &mut tmp[n..];
250 }
251 Err(e) => return Err(e),
252 }
253 }
254 if !buf.is_empty() {
255 Err(Error::Aead)
256 } else {
257 Ok(())
258 }
259 }
260}