use crate::sizes;
pub const DEFAULT_CHUNK: usize = 512 * 1024;
type Nonce = aead::stream::Nonce<
aes_gcm_siv::Aes256GcmSiv,
aead::stream::StreamLE31<aes_gcm_siv::Aes256GcmSiv>,
>;
pub struct Encrypter<Out>
where
Out: std::io::Write,
{
stream: Option<aead::stream::EncryptorLE31<aes_gcm_siv::Aes256GcmSiv>>,
buffer: Vec<u8>,
output: Out,
plain_capacity: usize,
}
impl<Out> Encrypter<Out>
where
Out: std::io::Write,
{
pub fn new<'k, Key>(key: Key, output: Out) -> std::io::Result<Self>
where
Key: Into<&'k aes_gcm_siv::Key<aes_gcm_siv::Aes256GcmSiv>>,
{
Self::with_chunk(key, output, DEFAULT_CHUNK)
}
#[cfg(feature = "argon")]
pub fn new_with_password<Password>(password: Password, output: Out) -> std::io::Result<Self>
where
Password: AsRef<[u8]>,
{
Self::with_chunk_and_password(password, output, DEFAULT_CHUNK)
}
pub fn with_chunk<'k, Key>(key: Key, mut output: Out, chunk: usize) -> std::io::Result<Self>
where
Key: Into<&'k aes_gcm_siv::Key<aes_gcm_siv::Aes256GcmSiv>>,
{
use aes_gcm_siv::aead::KeyInit;
if chunk < 32 {
return Err(std::io::ErrorKind::InvalidInput.into());
}
let key = key.into();
let nonce = make_nonce();
output.write_all(&nonce)?;
let stream = Some(aead::stream::EncryptorLE31::from_aead(
aes_gcm_siv::Aes256GcmSiv::new(key),
&nonce,
));
let buffer = Vec::with_capacity(chunk);
let plain_capacity = buffer.capacity() - sizes::TAG_LEN;
Ok(Self {
stream,
buffer,
output,
plain_capacity,
})
}
#[cfg(feature = "argon")]
pub fn with_chunk_and_password<Password>(
password: Password,
mut output: Out,
chunk: usize,
) -> std::io::Result<Self>
where
Password: AsRef<[u8]>,
{
use aes_gcm_siv::aead::KeyInit;
if chunk < 32 {
return Err(std::io::ErrorKind::InvalidInput.into());
}
let (key, salt) =
crate::argon::derive_key(password).ok_or(std::io::ErrorKind::InvalidInput)?;
output.write_all(&salt)?;
let nonce = make_nonce();
output.write_all(&nonce)?;
let stream = Some(aead::stream::EncryptorLE31::from_aead(
aes_gcm_siv::Aes256GcmSiv::new(&key),
&nonce,
));
let buffer = Vec::with_capacity(chunk);
let plain_capacity = buffer.capacity() - sizes::TAG_LEN;
Ok(Self {
stream,
buffer,
output,
plain_capacity,
})
}
pub fn finish(mut self) -> std::io::Result<()> {
self.finish_inner()
}
fn flush_block(&mut self) -> std::io::Result<()> {
unsafe {
self.stream
.as_mut()
.unwrap_unchecked()
.encrypt_next_in_place(b"", &mut self.buffer)
.map_err(|err| std::io::Error::other(err.to_string()))?;
}
self.output.write_all(&self.buffer)?;
self.buffer.clear();
Ok(())
}
fn finish_inner(&mut self) -> std::io::Result<()> {
let mut stream = unsafe { self.stream.take().unwrap_unchecked() };
if self.buffer.len() == self.plain_capacity {
stream
.encrypt_next_in_place(b"", &mut self.buffer)
.map_err(|err| std::io::Error::other(err.to_string()))?;
self.output.write_all(&self.buffer)?;
self.buffer.clear();
}
stream
.encrypt_last_in_place(b"", &mut self.buffer)
.map_err(|err| std::io::Error::other(err.to_string()))?;
self.output.write_all(&self.buffer)?;
self.output.flush()
}
unsafe fn fill_buf(&mut self, buf: &[u8]) {
let len = self.buffer.len();
unsafe { self.buffer.set_len(len + buf.len()) };
self.buffer[len..].copy_from_slice(buf);
}
}
impl<Out> std::io::Write for Encrypter<Out>
where
Out: std::io::Write,
{
fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
let mut sent = 0;
let mut rem_cap = self.plain_capacity.saturating_sub(self.buffer.len());
while buf.len() > rem_cap {
unsafe { self.fill_buf(&buf[..rem_cap]) };
self.flush_block()?;
buf = &buf[rem_cap..];
sent += rem_cap;
rem_cap = self.plain_capacity;
}
unsafe { self.fill_buf(buf) };
Ok(sent + buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
if self.buffer.len() == self.plain_capacity {
self.flush_block()?;
}
self.output.flush()
}
}
impl<Out> Drop for Encrypter<Out>
where
Out: std::io::Write,
{
fn drop(&mut self) {
if self.stream.is_some() {
drop(self.finish_inner());
}
}
}
pub struct Decrypter<In>
where
In: std::io::Read,
{
stream: Option<aead::stream::DecryptorLE31<aes_gcm_siv::Aes256GcmSiv>>,
buffer: Vec<u8>,
cursor: usize,
input: In,
}
impl<In> Decrypter<In>
where
In: std::io::Read,
{
pub fn new<'k, Key>(key: Key, input: In) -> std::io::Result<Self>
where
Key: Into<&'k aes_gcm_siv::Key<aes_gcm_siv::Aes256GcmSiv>>,
{
Self::with_chunk(key, input, DEFAULT_CHUNK)
}
#[cfg(feature = "argon")]
pub fn new_with_password<Password>(password: Password, input: In) -> std::io::Result<Self>
where
Password: AsRef<[u8]>,
{
Self::with_chunk_with_password(password, input, DEFAULT_CHUNK)
}
pub fn with_chunk<'k, Key>(key: Key, mut input: In, chunk: usize) -> std::io::Result<Self>
where
Key: Into<&'k aes_gcm_siv::Key<aes_gcm_siv::Aes256GcmSiv>>,
{
use aead::KeyInit;
if chunk < 32 {
return Err(std::io::ErrorKind::InvalidInput.into());
}
let key = key.into();
let mut nonce = Nonce::default();
input.read_exact(&mut nonce)?;
let stream = Some(aead::stream::DecryptorLE31::from_aead(
aes_gcm_siv::Aes256GcmSiv::new(key),
nonce.as_slice().into(),
));
let buffer = Vec::with_capacity(chunk);
Ok(Self {
stream,
buffer,
cursor: 0,
input,
})
}
#[cfg(feature = "argon")]
pub fn with_chunk_with_password<Password>(
password: Password,
mut input: In,
chunk: usize,
) -> std::io::Result<Self>
where
Password: AsRef<[u8]>,
{
use aead::KeyInit;
if chunk < 32 {
return Err(std::io::ErrorKind::InvalidInput.into());
}
let mut salt = crate::argon::Salt::default();
input.read_exact(&mut salt)?;
let key = crate::argon::derive_with_salt(password, &salt)
.ok_or(std::io::ErrorKind::InvalidData)?;
let mut nonce = Nonce::default();
input.read_exact(&mut nonce)?;
let stream = Some(aead::stream::DecryptorLE31::from_aead(
aes_gcm_siv::Aes256GcmSiv::new(&key),
nonce.as_slice().into(),
));
let buffer = Vec::with_capacity(chunk);
Ok(Self {
stream,
buffer,
cursor: 0,
input,
})
}
fn fill_buf(&mut self) -> std::io::Result<()> {
unsafe { self.buffer.set_len(self.buffer.capacity()) };
let mut read = 0;
while read < self.buffer.capacity() {
read += {
let bytes = self.input.read(&mut self.buffer[read..])?;
if bytes == 0 {
break;
}
bytes
};
}
unsafe { self.buffer.set_len(read) };
self.cursor = 0;
Ok(())
}
unsafe fn decrypt(&mut self) -> std::io::Result<()> {
unsafe {
if self.buffer.len() < self.buffer.capacity() {
self.stream
.take()
.unwrap_unchecked()
.decrypt_last_in_place(b"", &mut self.buffer)
.map_err(|err| std::io::Error::other(err.to_string()))?;
} else {
self.stream
.as_mut()
.unwrap_unchecked()
.decrypt_next_in_place(b"", &mut self.buffer)
.map_err(|err| std::io::Error::other(err.to_string()))?;
}
}
Ok(())
}
}
impl<In> std::io::Read for Decrypter<In>
where
In: std::io::Read,
{
fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
let mut read = 0;
while !buf.is_empty() {
let buf_size = self.buffer.len() - self.cursor;
if buf_size > 0 {
let len = buf_size.min(buf.len());
buf[..len].copy_from_slice(&self.buffer[self.cursor..self.cursor + len]);
self.cursor += len;
buf = &mut buf[len..];
read += len;
continue;
}
if self.stream.is_none() {
break;
}
self.fill_buf()?;
unsafe { self.decrypt() }?;
}
Ok(read)
}
}
fn make_nonce() -> Nonce {
let mut nonce = Nonce::default();
aes_gcm_siv::aead::rand_core::RngCore::fill_bytes(&mut aes_gcm_siv::aead::OsRng, &mut nonce);
nonce
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Write};
fn make_key() -> aes_gcm_siv::Key<aes_gcm_siv::Aes256GcmSiv> {
<aes_gcm_siv::Aes256GcmSiv as aes_gcm_siv::KeyInit>::generate_key(aes_gcm_siv::aead::OsRng)
}
#[test]
fn round_trip() {
let key = make_key();
let input = (u8::MIN..=u8::MAX)
.flat_map(|_| u8::MIN..u8::MAX)
.collect::<Vec<_>>();
let mut transient = Vec::with_capacity(input.len());
let mut output = Vec::with_capacity(input.len());
let mut encrypter = Encrypter::new(&key, &mut transient).unwrap();
encrypter.write_all(input.as_slice()).unwrap();
encrypter.finish().unwrap();
assert_ne!(input, transient);
let mut decrypter = Decrypter::new(&key, transient.as_slice()).unwrap();
decrypter.read_to_end(&mut output).unwrap();
assert_eq!(input, output);
}
#[test]
#[cfg(feature = "argon")]
fn round_trip_with_password() {
let password = "super secret password";
let input = (u8::MIN..=u8::MAX)
.flat_map(|_| u8::MIN..u8::MAX)
.collect::<Vec<_>>();
let mut transient = Vec::with_capacity(input.len());
let mut output = Vec::with_capacity(input.len());
let mut encrypter = Encrypter::new_with_password(password, &mut transient).unwrap();
encrypter.write_all(input.as_slice()).unwrap();
encrypter.finish().unwrap();
assert_ne!(input, transient);
let mut decrypter = Decrypter::new_with_password(password, transient.as_slice()).unwrap();
decrypter.read_to_end(&mut output).unwrap();
assert_eq!(input, output);
}
#[test]
fn round_trip_with_drop() {
let key = make_key();
let input = (u8::MIN..=u8::MAX)
.flat_map(|_| u8::MIN..u8::MAX)
.collect::<Vec<_>>();
let mut transient = Vec::with_capacity(input.len());
let mut output = Vec::with_capacity(input.len());
{
let mut encrypter = Encrypter::new(&key, &mut transient).unwrap();
encrypter.write_all(input.as_slice()).unwrap();
}
assert_ne!(input, transient);
let mut decrypter = Decrypter::new(&key, transient.as_slice()).unwrap();
decrypter.read_to_end(&mut output).unwrap();
assert_eq!(input, output);
}
#[test]
fn different_block_size() {
let key = make_key();
let input = (u8::MIN..=u8::MAX)
.flat_map(|_| u8::MIN..u8::MAX)
.collect::<Vec<_>>();
let mut transient = Vec::with_capacity(input.len());
let mut output = Vec::with_capacity(input.len());
let mut encrypter = Encrypter::with_chunk(&key, &mut transient, 256).unwrap();
encrypter.write_all(input.as_slice()).unwrap();
encrypter.finish().unwrap();
assert_ne!(input, transient);
let mut decrypter = Decrypter::with_chunk(&key, transient.as_slice(), 128).unwrap();
let err = decrypter.read_to_end(&mut output).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::Other);
assert_eq!(err.to_string(), "aead::Error");
}
#[test]
fn variable_chunk() {
let key = make_key();
for chunk in [64, 240, 256, 272, 512, 1024, 16 * 1024 * 1024] {
for chunk in chunk - 1..=chunk + 1 {
if let Err(err) = std::thread::spawn(move || {
let input = (u8::MIN..=u8::MAX)
.flat_map(|_| u8::MIN..u8::MAX)
.collect::<Vec<_>>();
let mut transient = Vec::with_capacity(input.len());
let mut output = Vec::with_capacity(input.len());
eprintln!("Chunk {chunk}");
let mut encrypter = Encrypter::with_chunk(&key, &mut transient, chunk).unwrap();
encrypter.write_all(input.as_slice()).unwrap();
encrypter.finish().unwrap();
assert_ne!(input, transient);
let mut decrypter =
Decrypter::with_chunk(&key, transient.as_slice(), chunk).unwrap();
decrypter.read_to_end(&mut output).unwrap();
assert_eq!(input, output);
})
.join()
{
panic!("Chunk {chunk}: {err:?}");
}
}
}
}
#[test]
fn minimum_chunk() {
let key = make_key();
for chunk in 0..64 {
let mut buffer = Vec::new();
let result = Encrypter::with_chunk(&key, &mut buffer, chunk);
if chunk < 32 {
assert!(result.is_err(), "Chunk {chunk}: Expected error");
} else if let Err(err) = result {
panic!("Chunk {chunk}: {err:?}");
}
}
}
}