aead_io/
writer.rs

1use crate::buffer::CappedBuffer;
2use crate::error::{Error, IntoInnerError, InvalidCapacity};
3use crate::rw::Write;
4use aead::generic_array::typenum::Unsigned;
5use aead::generic_array::ArrayLength;
6use aead::stream::{Encryptor, NewStream, Nonce, NonceSize, StreamPrimitive};
7use aead::{AeadCore, AeadInPlace, Key, KeyInit};
8use core::ops::Sub;
9use core::{mem, ptr};
10
11#[derive(Clone, Copy)]
12enum State {
13    Init,
14    Writing,
15    Finished,
16}
17
18/// A wrapper around a [`Write`](Write) object and a [`StreamPrimitive`](`StreamPrimitive`)
19/// providing a [`Write`](Write) interface which automatically encrypts the underlying stream when
20/// writing
21pub struct EncryptBufWriter<A, B, W, S>
22where
23    A: AeadInPlace,
24    B: CappedBuffer,
25    W: Write,
26    S: StreamPrimitive<A>,
27    A::NonceSize: Sub<S::NonceOverhead>,
28    NonceSize<A, S>: ArrayLength<u8>,
29{
30    encryptor: Option<Encryptor<A, S>>,
31    nonce: Nonce<A, S>,
32    buffer: B,
33    writer: W,
34    capacity: usize,
35    state: State,
36}
37
38impl<A, B, W, S> EncryptBufWriter<A, B, W, S>
39where
40    A: AeadInPlace,
41    B: CappedBuffer,
42    W: Write,
43    S: StreamPrimitive<A>,
44    A::NonceSize: Sub<S::NonceOverhead>,
45    NonceSize<A, S>: ArrayLength<u8>,
46{
47    /// Constructs a new Writer using an AEAD key, buffer and reader
48    pub fn new(
49        key: &Key<A>,
50        nonce: &Nonce<A, S>,
51        mut buffer: B,
52        writer: W,
53    ) -> Result<Self, InvalidCapacity>
54    where
55        A: KeyInit,
56        S: NewStream<A>,
57    {
58        buffer.truncate(0);
59        let capacity = Self::capacity_for_buffer(&buffer)?;
60        Ok(Self {
61            encryptor: Some(Encryptor::new(key, nonce)),
62            nonce: nonce.clone(),
63            writer,
64            buffer,
65            capacity,
66            state: State::Init,
67        })
68    }
69
70    /// Constructs a new Writer using an AEAD primitive, buffer and reader
71    pub fn from_aead(
72        aead: A,
73        nonce: &Nonce<A, S>,
74        mut buffer: B,
75        writer: W,
76    ) -> Result<Self, InvalidCapacity>
77    where
78        A: KeyInit,
79        S: NewStream<A>,
80    {
81        buffer.truncate(0);
82        let capacity = Self::capacity_for_buffer(&buffer)?;
83        Ok(Self {
84            encryptor: Some(Encryptor::from_aead(aead, nonce)),
85            nonce: nonce.clone(),
86            writer,
87            buffer,
88            capacity,
89            state: State::Init,
90        })
91    }
92
93    fn capacity_for_buffer(buffer: &B) -> Result<usize, InvalidCapacity> {
94        let capacity = buffer
95            .capacity()
96            .min(u32::MAX as usize)
97            .checked_sub(<<A as AeadCore>::TagSize as Unsigned>::to_usize())
98            .ok_or(InvalidCapacity)?;
99        if capacity < 1 {
100            Err(InvalidCapacity)
101        } else {
102            Ok(capacity)
103        }
104    }
105
106    /// Gets a reference to the inner writer
107    pub fn inner(&self) -> &W {
108        &self.writer
109    }
110
111    /// Consumes the Writer and returns the inner writer
112    pub fn into_inner(mut self) -> Result<W, IntoInnerError<Self, W::Error>> {
113        match self.flush_buffer(true) {
114            Ok(()) => {
115                let inner = unsafe { ptr::read(&self.writer) };
116                mem::forget(self);
117                Ok(inner)
118            }
119            Err(err) => Err(IntoInnerError::new(self, err)),
120        }
121    }
122
123    fn capacity_remaining(&self) -> usize {
124        self.capacity - self.buffer.len()
125    }
126
127    fn flush_buffer(&mut self, last: bool) -> Result<(), Error<W::Error>> {
128        if matches!(self.state, State::Finished) {
129            return Ok(());
130        }
131
132        if last {
133            self.encryptor
134                .take()
135                .ok_or(Error::Aead)?
136                .encrypt_last_in_place(&[], &mut self.buffer)
137                .map_err(|_| Error::Aead)?;
138        } else {
139            self.encryptor
140                .as_mut()
141                .ok_or(Error::Aead)?
142                .encrypt_next_in_place(&[], &mut self.buffer)
143                .map_err(|_| Error::Aead)?;
144        }
145
146        if matches!(self.state, State::Init) {
147            self.writer.write_all(self.nonce.as_slice())?;
148            self.state = State::Writing;
149        }
150
151        self.writer
152            .write_all(&(self.buffer.len() as u32).to_be_bytes())?;
153        self.writer.write_all(self.buffer.as_ref())?;
154        if last {
155            self.state = State::Finished;
156        }
157
158        self.buffer.truncate(0);
159        Ok(())
160    }
161
162    fn write(&mut self, buf: &[u8]) -> Result<usize, Error<W::Error>> {
163        if matches!(self.state, State::Finished) {
164            return Err(Error::Aead);
165        }
166        if buf.len() > self.capacity_remaining() {
167            self.flush_buffer(false)?;
168        }
169        let bytes_to_write = buf.len().min(self.capacity_remaining());
170        self.buffer
171            .extend_from_slice(&buf[..bytes_to_write])
172            .map_err(|_| Error::Aead)?;
173        Ok(bytes_to_write)
174    }
175
176    fn flush(&mut self) -> Result<(), Error<W::Error>> {
177        self.flush_buffer(true)?;
178        self.writer.flush()?;
179        Ok(())
180    }
181}
182
183impl<A, B, W, S> Drop for EncryptBufWriter<A, B, W, S>
184where
185    A: AeadInPlace,
186    B: CappedBuffer,
187    W: Write,
188    S: StreamPrimitive<A>,
189    A::NonceSize: Sub<S::NonceOverhead>,
190    NonceSize<A, S>: ArrayLength<u8>,
191{
192    fn drop(&mut self) {
193        let _ = self.flush_buffer(true);
194    }
195}
196
197#[cfg(feature = "std")]
198impl<A, B, W, S> std::io::Write for EncryptBufWriter<A, B, W, S>
199where
200    A: AeadInPlace,
201    B: CappedBuffer,
202    W: Write,
203    W::Error: Into<std::io::Error>,
204    S: StreamPrimitive<A>,
205    A::NonceSize: Sub<S::NonceOverhead>,
206    NonceSize<A, S>: ArrayLength<u8>,
207{
208    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
209        Ok(self.write(buf)?)
210    }
211    fn flush(&mut self) -> std::io::Result<()> {
212        Ok(self.flush()?)
213    }
214}
215
216#[cfg(not(feature = "std"))]
217impl<A, B, W, S> Write for EncryptBufWriter<A, B, W, S>
218where
219    A: AeadInPlace,
220    B: CappedBuffer,
221    W: Write,
222    S: StreamPrimitive<A>,
223    A::NonceSize: Sub<S::NonceOverhead>,
224    NonceSize<A, S>: ArrayLength<u8>,
225{
226    type Error = Error<W::Error>;
227    fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
228        Ok(self.write(buf)?)
229    }
230    fn flush(&mut self) -> Result<(), Self::Error> {
231        Ok(self.flush()?)
232    }
233    fn write_all(&mut self, mut buf: &[u8]) -> Result<(), Self::Error> {
234        while !buf.is_empty() {
235            match self.write(buf) {
236                Ok(0) => return Err(Error::Aead),
237                Ok(n) => buf = &buf[n..],
238                Err(e) => return Err(e),
239            }
240        }
241        Ok(())
242    }
243}