use std::io::Write;
use crate::{Aad, Aead, Rng, SystemRng};
use super::{Encryptor, Segment};
#[cfg(feature = "std")]
pub struct EncryptWriter<'write, W, A, N = SystemRng>
where
W: Write,
A: AsRef<[u8]>,
N: Rng,
{
encryptor: Encryptor<Vec<u8>, N>,
writer: &'write mut W,
aad: Aad<A>,
counter: usize,
}
impl<'write, W, A, N> EncryptWriter<'write, W, A, N>
where
W: Write,
A: AsRef<[u8]>,
N: Rng,
{
pub fn new<C>(rng: N, writer: &'write mut W, segment: Segment, aad: Aad<A>, cipher: C) -> Self
where
C: AsRef<Aead>,
{
let encryptor = Encryptor::new(
rng,
cipher,
Some(segment),
Vec::with_capacity(segment.into()),
);
Self {
encryptor,
writer,
aad,
counter: 0,
}
}
}
impl<'write, W, D> EncryptWriter<'write, W, D>
where
W: Write,
D: AsRef<[u8]>,
{
pub fn finalize(self) -> Result<usize, std::io::Error> {
let EncryptWriter {
encryptor,
writer,
mut counter,
aad,
} = self;
let ciphertext: Vec<u8> = encryptor.finalize(aad)?.flatten().collect();
writer.write_all(&ciphertext)?;
counter += ciphertext.len();
writer.flush()?;
Ok(counter)
}
}
impl<'write, W, D> Write for EncryptWriter<'write, W, D>
where
W: Write,
D: AsRef<[u8]>,
{
fn write(&mut self, buf: &[u8]) -> Result<usize, std::io::Error> {
self.encryptor.update(Aad(self.aad.as_ref()), buf)?;
if let Some(ciphertext) = self.encryptor.next() {
self.writer.write_all(&ciphertext)?;
self.counter += ciphertext.len();
}
Ok(buf.len())
}
fn flush(&mut self) -> Result<(), std::io::Error> {
self.writer.flush()
}
}