1use crate::buffer::CappedBuffer;
2use crate::error::{Error, IntoInnerError, InvalidCapacity};
3use crate::rw::Write;
4use aead::generic_array::typenum::Unsigned;
5use aead::generic_array::ArrayLength;
6use aead::stream::{Encryptor, NewStream, Nonce, NonceSize, StreamPrimitive};
7use aead::{AeadCore, AeadInPlace, Key, KeyInit};
8use core::ops::Sub;
9use core::{mem, ptr};
10
11#[derive(Clone, Copy)]
12enum State {
13 Init,
14 Writing,
15 Finished,
16}
17
18pub struct EncryptBufWriter<A, B, W, S>
22where
23 A: AeadInPlace,
24 B: CappedBuffer,
25 W: Write,
26 S: StreamPrimitive<A>,
27 A::NonceSize: Sub<S::NonceOverhead>,
28 NonceSize<A, S>: ArrayLength<u8>,
29{
30 encryptor: Option<Encryptor<A, S>>,
31 nonce: Nonce<A, S>,
32 buffer: B,
33 writer: W,
34 capacity: usize,
35 state: State,
36}
37
38impl<A, B, W, S> EncryptBufWriter<A, B, W, S>
39where
40 A: AeadInPlace,
41 B: CappedBuffer,
42 W: Write,
43 S: StreamPrimitive<A>,
44 A::NonceSize: Sub<S::NonceOverhead>,
45 NonceSize<A, S>: ArrayLength<u8>,
46{
47 pub fn new(
49 key: &Key<A>,
50 nonce: &Nonce<A, S>,
51 mut buffer: B,
52 writer: W,
53 ) -> Result<Self, InvalidCapacity>
54 where
55 A: KeyInit,
56 S: NewStream<A>,
57 {
58 buffer.truncate(0);
59 let capacity = Self::capacity_for_buffer(&buffer)?;
60 Ok(Self {
61 encryptor: Some(Encryptor::new(key, nonce)),
62 nonce: nonce.clone(),
63 writer,
64 buffer,
65 capacity,
66 state: State::Init,
67 })
68 }
69
70 pub fn from_aead(
72 aead: A,
73 nonce: &Nonce<A, S>,
74 mut buffer: B,
75 writer: W,
76 ) -> Result<Self, InvalidCapacity>
77 where
78 A: KeyInit,
79 S: NewStream<A>,
80 {
81 buffer.truncate(0);
82 let capacity = Self::capacity_for_buffer(&buffer)?;
83 Ok(Self {
84 encryptor: Some(Encryptor::from_aead(aead, nonce)),
85 nonce: nonce.clone(),
86 writer,
87 buffer,
88 capacity,
89 state: State::Init,
90 })
91 }
92
93 fn capacity_for_buffer(buffer: &B) -> Result<usize, InvalidCapacity> {
94 let capacity = buffer
95 .capacity()
96 .min(u32::MAX as usize)
97 .checked_sub(<<A as AeadCore>::TagSize as Unsigned>::to_usize())
98 .ok_or(InvalidCapacity)?;
99 if capacity < 1 {
100 Err(InvalidCapacity)
101 } else {
102 Ok(capacity)
103 }
104 }
105
106 pub fn inner(&self) -> &W {
108 &self.writer
109 }
110
111 pub fn into_inner(mut self) -> Result<W, IntoInnerError<Self, W::Error>> {
113 match self.flush_buffer(true) {
114 Ok(()) => {
115 let inner = unsafe { ptr::read(&self.writer) };
116 mem::forget(self);
117 Ok(inner)
118 }
119 Err(err) => Err(IntoInnerError::new(self, err)),
120 }
121 }
122
123 fn capacity_remaining(&self) -> usize {
124 self.capacity - self.buffer.len()
125 }
126
127 fn flush_buffer(&mut self, last: bool) -> Result<(), Error<W::Error>> {
128 if matches!(self.state, State::Finished) {
129 return Ok(());
130 }
131
132 if last {
133 self.encryptor
134 .take()
135 .ok_or(Error::Aead)?
136 .encrypt_last_in_place(&[], &mut self.buffer)
137 .map_err(|_| Error::Aead)?;
138 } else {
139 self.encryptor
140 .as_mut()
141 .ok_or(Error::Aead)?
142 .encrypt_next_in_place(&[], &mut self.buffer)
143 .map_err(|_| Error::Aead)?;
144 }
145
146 if matches!(self.state, State::Init) {
147 self.writer.write_all(self.nonce.as_slice())?;
148 self.state = State::Writing;
149 }
150
151 self.writer
152 .write_all(&(self.buffer.len() as u32).to_be_bytes())?;
153 self.writer.write_all(self.buffer.as_ref())?;
154 if last {
155 self.state = State::Finished;
156 }
157
158 self.buffer.truncate(0);
159 Ok(())
160 }
161
162 fn write(&mut self, buf: &[u8]) -> Result<usize, Error<W::Error>> {
163 if matches!(self.state, State::Finished) {
164 return Err(Error::Aead);
165 }
166 if buf.len() > self.capacity_remaining() {
167 self.flush_buffer(false)?;
168 }
169 let bytes_to_write = buf.len().min(self.capacity_remaining());
170 self.buffer
171 .extend_from_slice(&buf[..bytes_to_write])
172 .map_err(|_| Error::Aead)?;
173 Ok(bytes_to_write)
174 }
175
176 fn flush(&mut self) -> Result<(), Error<W::Error>> {
177 self.flush_buffer(true)?;
178 self.writer.flush()?;
179 Ok(())
180 }
181}
182
183impl<A, B, W, S> Drop for EncryptBufWriter<A, B, W, S>
184where
185 A: AeadInPlace,
186 B: CappedBuffer,
187 W: Write,
188 S: StreamPrimitive<A>,
189 A::NonceSize: Sub<S::NonceOverhead>,
190 NonceSize<A, S>: ArrayLength<u8>,
191{
192 fn drop(&mut self) {
193 let _ = self.flush_buffer(true);
194 }
195}
196
197#[cfg(feature = "std")]
198impl<A, B, W, S> std::io::Write for EncryptBufWriter<A, B, W, S>
199where
200 A: AeadInPlace,
201 B: CappedBuffer,
202 W: Write,
203 W::Error: Into<std::io::Error>,
204 S: StreamPrimitive<A>,
205 A::NonceSize: Sub<S::NonceOverhead>,
206 NonceSize<A, S>: ArrayLength<u8>,
207{
208 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
209 Ok(self.write(buf)?)
210 }
211 fn flush(&mut self) -> std::io::Result<()> {
212 Ok(self.flush()?)
213 }
214}
215
216#[cfg(not(feature = "std"))]
217impl<A, B, W, S> Write for EncryptBufWriter<A, B, W, S>
218where
219 A: AeadInPlace,
220 B: CappedBuffer,
221 W: Write,
222 S: StreamPrimitive<A>,
223 A::NonceSize: Sub<S::NonceOverhead>,
224 NonceSize<A, S>: ArrayLength<u8>,
225{
226 type Error = Error<W::Error>;
227 fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
228 Ok(self.write(buf)?)
229 }
230 fn flush(&mut self) -> Result<(), Self::Error> {
231 Ok(self.flush()?)
232 }
233 fn write_all(&mut self, mut buf: &[u8]) -> Result<(), Self::Error> {
234 while !buf.is_empty() {
235 match self.write(buf) {
236 Ok(0) => return Err(Error::Aead),
237 Ok(n) => buf = &buf[n..],
238 Err(e) => return Err(e),
239 }
240 }
241 Ok(())
242 }
243}