use std::io::{Error, ErrorKind, Read, Write};
use std::sync::{Arc, Mutex};
use serde_tuple::{Deserialize_tuple, Serialize_tuple};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::capsule::common::{CapsuleError, KEY_SIZE, NONCE_BLOCK_SIZE, NONCE_SIZE};
const CHUNK_SIZE: usize = 256 * 1024;
const AES_256_GCM_TAG_LEN: usize = 16;
#[derive(Serialize_tuple, Deserialize_tuple)]
pub struct AD {
pub chunk_num: u64,
pub final_chunk: bool,
}
impl AD {
pub fn new(chunk_num: u64, final_chunk: bool) -> Self {
Self {
chunk_num,
final_chunk,
}
}
pub fn unmarshal(serialized: &[u8]) -> Result<Self, CapsuleError> {
let result = ciborium::from_reader(std::io::Cursor::new(serialized.to_vec()))
.map_err(|e| CapsuleError::CBORDecodeFailed(format!("decoding AD for chunk: {}", e)))?;
Ok(result)
}
pub fn marshal(self) -> Result<Vec<u8>, CapsuleError> {
let mut result: Vec<u8> = Vec::new();
ciborium::ser::into_writer(&self, &mut result)
.map_err(|e| CapsuleError::CBOREncodeFailed(format!("encoding AD for chunk: {}", e)))?;
Ok(result)
}
}
fn increment_nonce(nonce: &mut [u8; NONCE_SIZE]) -> bool {
let mut carry = true;
for i in (NONCE_BLOCK_SIZE..NONCE_SIZE).rev() {
if carry {
let (new_val, overflow) = nonce[i].overflowing_add(1);
nonce[i] = new_val;
carry = overflow;
} else {
break;
}
}
carry
}
pub struct EncryptingAEADReader<R: Read> {
input: R,
cipher: ring::aead::LessSafeKey,
nonce_block: [u8; NONCE_SIZE],
chunk_num: u64,
next_byte: [u8; 1],
buffer: Vec<u8>,
buffer_len: usize,
buffer_offset: usize,
eof: bool,
}
impl<R: Read> EncryptingAEADReader<R> {
pub fn new(
nonce_block: [u8; NONCE_SIZE],
key: &[u8; KEY_SIZE],
input: R,
) -> Result<Self, CapsuleError> {
let mut result = Self {
input,
cipher: ring::aead::LessSafeKey::new(
ring::aead::UnboundKey::new(&ring::aead::AES_256_GCM, key).map_err(|e| {
CapsuleError::Generic(format!("creating AES 256 GCM key: {}", e))
})?,
),
nonce_block,
chunk_num: 0,
next_byte: [0u8; 1],
buffer: vec![0; 4 + NONCE_SIZE + CHUNK_SIZE + AES_256_GCM_TAG_LEN + 4],
buffer_len: 0,
buffer_offset: 0,
eof: false,
};
result
.input
.read_exact(&mut result.next_byte)
.map_err(|e| CapsuleError::Generic(format!("reading input stream: {}", e)))?;
Ok(result)
}
}
impl<R: Read> Read for EncryptingAEADReader<R> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
let mut bytes_read: usize = 0;
if self.buffer_len > 0 {
let to_copy = std::cmp::min(buf.len(), self.buffer_len);
buf[..to_copy]
.copy_from_slice(&self.buffer[self.buffer_offset..self.buffer_offset + to_copy]);
self.buffer_len -= to_copy;
self.buffer_offset += to_copy;
bytes_read += to_copy;
if self.buffer_len > 0 {
return Ok(bytes_read);
}
}
if self.eof {
return Ok(bytes_read);
}
self.buffer[4 + NONCE_SIZE] = self.next_byte[0];
let n = self
.input
.read(&mut self.buffer[4 + NONCE_SIZE + 1..4 + NONCE_SIZE + 1 + CHUNK_SIZE - 1])?
+ 1;
let mut final_chunk: bool = false;
match self.input.read(&mut self.next_byte) {
Ok(0) => final_chunk = true,
Ok(_) => {}
Err(e) => return Err(e),
}
let tag = self
.cipher
.seal_in_place_separate_tag(
ring::aead::Nonce::assume_unique_for_key(self.nonce_block),
ring::aead::Aad::from(
AD {
final_chunk,
chunk_num: self.chunk_num,
}
.marshal()
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("marshaling additional data: {}", e),
)
})?
.as_slice(),
),
&mut self.buffer[4 + NONCE_SIZE..4 + NONCE_SIZE + n],
)
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("failed to seal in place: {}", e),
)
})?;
self.buffer[4 + NONCE_SIZE + n..4 + NONCE_SIZE + n + AES_256_GCM_TAG_LEN]
.copy_from_slice(tag.as_ref());
self.buffer[..4].copy_from_slice(&((n + AES_256_GCM_TAG_LEN) as u32).to_le_bytes());
self.buffer[4..4 + NONCE_SIZE].copy_from_slice(&self.nonce_block);
self.buffer_offset = 0;
self.buffer_len = 4 + NONCE_SIZE + n + AES_256_GCM_TAG_LEN;
if final_chunk {
self.buffer[4 + NONCE_SIZE + n + AES_256_GCM_TAG_LEN
..4 + NONCE_SIZE + n + AES_256_GCM_TAG_LEN + 4]
.copy_from_slice(&0_u32.to_le_bytes());
self.buffer_len += 4;
self.eof = true;
}
self.chunk_num += 1;
if increment_nonce(&mut self.nonce_block) {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"nonce block exhausted".to_string(),
));
}
let to_copy = std::cmp::min(buf.len() - bytes_read, self.buffer_len);
buf[bytes_read..bytes_read + to_copy].copy_from_slice(&self.buffer[..to_copy]);
self.buffer_len -= to_copy;
self.buffer_offset += to_copy;
bytes_read += to_copy;
Ok(bytes_read)
}
}
pub struct EncryptingAEADWriter<W: Write> {
output: Arc<Mutex<W>>,
cipher: ring::aead::LessSafeKey,
nonce_block: [u8; NONCE_SIZE],
chunk_num: u64,
buffer: Vec<u8>,
buffer_idx: usize,
}
impl<W: Write> EncryptingAEADWriter<W> {
pub fn new(
nonce_block: [u8; NONCE_SIZE],
key: &[u8; KEY_SIZE],
output: Arc<Mutex<W>>,
) -> Result<Self, CapsuleError> {
let result = Self {
output,
cipher: ring::aead::LessSafeKey::new(
ring::aead::UnboundKey::new(&ring::aead::AES_256_GCM, key).map_err(|e| {
CapsuleError::Generic(format!("creating AES 256 GCM key: {}", e))
})?,
),
nonce_block,
chunk_num: 0,
buffer: vec![0; 4 + NONCE_SIZE + CHUNK_SIZE + AES_256_GCM_TAG_LEN + 4],
buffer_idx: 4 + NONCE_SIZE,
};
Ok(result)
}
fn flush_buffer(&mut self, final_chunk: bool) -> std::io::Result<()> {
let payload_length = self.buffer_idx + AES_256_GCM_TAG_LEN - 4 - NONCE_SIZE;
let mut chunk_length = self.buffer_idx + AES_256_GCM_TAG_LEN;
let binding = AD {
final_chunk,
chunk_num: self.chunk_num,
}
.marshal()
.map_err(|e| {
Error::new(
ErrorKind::Other,
format!("marshaling additional data: {}", e),
)
})?;
let ad_bytes = binding.as_slice();
let tag = self
.cipher
.seal_in_place_separate_tag(
ring::aead::Nonce::assume_unique_for_key(self.nonce_block),
ring::aead::Aad::from(ad_bytes),
&mut self.buffer[4 + NONCE_SIZE..self.buffer_idx],
)
.map_err(|e| Error::new(ErrorKind::Other, format!("failed to seal in place: {}", e)))?;
self.buffer[self.buffer_idx..self.buffer_idx + AES_256_GCM_TAG_LEN]
.copy_from_slice(tag.as_ref());
self.buffer[..4].copy_from_slice(&(payload_length as u32).to_le_bytes());
self.buffer[4..4 + NONCE_SIZE].copy_from_slice(&self.nonce_block);
if final_chunk {
self.buffer
[self.buffer_idx + AES_256_GCM_TAG_LEN..self.buffer_idx + AES_256_GCM_TAG_LEN + 4]
.copy_from_slice(&0_u32.to_le_bytes());
chunk_length += 4;
}
self.chunk_num += 1;
if increment_nonce(&mut self.nonce_block) {
return Err(Error::new(
ErrorKind::Other,
"nonce block exhausted".to_string(),
));
}
self.buffer_idx = 4 + NONCE_SIZE;
let mut writer = self.output.lock().unwrap();
writer.write_all(&self.buffer[..chunk_length])?;
writer.flush()
}
}
impl<W: Write> Write for EncryptingAEADWriter<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut bytes_read: usize = 0;
let buffer_size = self.buffer.len() - AES_256_GCM_TAG_LEN - 4;
loop {
if self.buffer_idx == CHUNK_SIZE {
self.flush_buffer(false)?;
}
let to_copy = std::cmp::min(buf.len() - bytes_read, buffer_size - self.buffer_idx);
if to_copy == 0 {
break; }
self.buffer[self.buffer_idx..self.buffer_idx + to_copy]
.copy_from_slice(&buf[bytes_read..bytes_read + to_copy]);
bytes_read += to_copy;
self.buffer_idx += to_copy;
}
Ok(bytes_read)
}
fn flush(&mut self) -> std::io::Result<()> {
self.flush_buffer(true)
}
}
pub struct DecryptingAEAD<R: Read> {
input: Arc<Mutex<R>>,
cipher: ring::aead::LessSafeKey,
len_buffer: [u8; 4],
nonce_buffer: [u8; NONCE_SIZE],
chunk_buffer: Vec<u8>,
chunk_buffer_len: usize,
chunk_buffer_offset: usize,
chunk_len: u32,
chunk_num: u64,
eof: bool,
}
impl<R: Read> DecryptingAEAD<R> {
pub fn new(key: &[u8; KEY_SIZE], input: Arc<Mutex<R>>) -> Result<Self, CapsuleError> {
let mut result = Self {
input,
cipher: ring::aead::LessSafeKey::new(
ring::aead::UnboundKey::new(&ring::aead::AES_256_GCM, key).map_err(|e| {
CapsuleError::Generic(format!("creating AES 256 GCM key: {}", e))
})?,
),
len_buffer: [0u8; 4],
nonce_buffer: [0u8; NONCE_SIZE],
chunk_buffer: vec![0; CHUNK_SIZE + AES_256_GCM_TAG_LEN],
chunk_buffer_len: 0,
chunk_buffer_offset: 0,
chunk_len: 0,
chunk_num: 0,
eof: false,
};
result
.input
.lock()
.unwrap()
.read_exact(&mut result.len_buffer)
.map_err(|e| CapsuleError::DecryptionFailure(format!("reading input stream: {}", e)))?;
Ok(result)
}
}
impl<R: Read> Read for DecryptingAEAD<R> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, std::io::Error> {
let mut bytes_read: usize = 0;
if self.chunk_buffer_len > 0 {
let to_copy = std::cmp::min(buf.len(), self.chunk_buffer_len);
buf[..to_copy].copy_from_slice(
&self.chunk_buffer[self.chunk_buffer_offset..self.chunk_buffer_offset + to_copy],
);
self.chunk_buffer_len -= to_copy;
self.chunk_buffer_offset += to_copy;
bytes_read += to_copy;
if self.chunk_buffer_len > 0 {
return Ok(bytes_read);
}
}
if self.eof {
return Ok(bytes_read);
}
self.chunk_len = u32::from_le_bytes(self.len_buffer);
if self.chunk_len == 0 {
return Ok(bytes_read);
} else if self.chunk_len > ((CHUNK_SIZE + AES_256_GCM_TAG_LEN) as u32) {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!(
"chunk length {} exceeds maximum chunk size {} (this is probably a bug)",
self.chunk_len,
CHUNK_SIZE + AES_256_GCM_TAG_LEN
),
));
}
self.input
.lock()
.unwrap()
.read_exact(&mut self.nonce_buffer)
.map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("reading nonce: {}", e))
})?;
self.input
.lock()
.unwrap()
.read_exact(&mut self.chunk_buffer[..(self.chunk_len as usize)])
.map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("reading chunk: {}", e))
})?;
self.input
.lock()
.unwrap()
.read_exact(&mut self.len_buffer)
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("reading chunk length: {}", e),
)
})?;
let final_chunk = u32::from_le_bytes(self.len_buffer) == 0;
self.cipher
.open_in_place(
ring::aead::Nonce::assume_unique_for_key(self.nonce_buffer),
ring::aead::Aad::from(
AD {
final_chunk,
chunk_num: self.chunk_num,
}
.marshal()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{}", e)))?
.as_slice(),
),
&mut self.chunk_buffer[..(self.chunk_len as usize)],
)
.map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("decrypting chunk {}: {}", self.chunk_num, e),
)
})?;
let to_copy = std::cmp::min(
buf.len() - bytes_read,
self.chunk_len as usize - AES_256_GCM_TAG_LEN,
);
buf[bytes_read..bytes_read + to_copy].copy_from_slice(&self.chunk_buffer[..to_copy]);
self.chunk_buffer_offset = to_copy;
self.chunk_buffer_len = self.chunk_len as usize - AES_256_GCM_TAG_LEN - to_copy;
bytes_read += to_copy;
self.chunk_num += 1;
Ok(bytes_read)
}
}
pub fn streaming_decrypt_aes_256_gcm<R, W>(
key: &[u8; KEY_SIZE],
mut input: R,
mut output: W,
) -> Result<(), CapsuleError>
where
R: Read + Unpin,
W: Write + Unpin,
{
let key = ring::aead::UnboundKey::new(&ring::aead::AES_256_GCM, key)
.map_err(|e| CapsuleError::Generic(format!("creating AES 256 GCM key: {}", e)))?;
let cipher = ring::aead::LessSafeKey::new(key);
let mut len_buffer = [0; std::mem::size_of::<u32>()]; let mut nonce_buffer = [0; NONCE_SIZE];
let mut chunk_buffer = [0; CHUNK_SIZE + AES_256_GCM_TAG_LEN];
let mut chunk_len: u32;
let mut chunk_num: u64 = 0;
let mut final_chunk: bool;
input
.read_exact(&mut len_buffer)
.map_err(|e| CapsuleError::DecryptionFailure(format!("reading input stream: {}", e)))?;
loop {
chunk_len = u32::from_le_bytes(len_buffer);
if chunk_len == 0 {
break;
} else if chunk_len > ((CHUNK_SIZE + AES_256_GCM_TAG_LEN) as u32) {
return Err(CapsuleError::DecryptionFailure(format!(
"chunk length {} exceeds maximum chunk size {} (this is probably a bug)",
chunk_len,
CHUNK_SIZE + AES_256_GCM_TAG_LEN
)));
}
input
.read_exact(&mut nonce_buffer)
.map_err(|e| CapsuleError::DecryptionFailure(format!("reading nonce: {}", e)))?;
input
.read_exact(&mut chunk_buffer[..(chunk_len as usize)])
.map_err(|e| CapsuleError::DecryptionFailure(format!("reading chunk: {}", e)))?;
input
.read_exact(&mut len_buffer)
.map_err(|e| CapsuleError::DecryptionFailure(format!("reading chunk length: {}", e)))?;
final_chunk = u32::from_le_bytes(len_buffer) == 0;
cipher
.open_in_place(
ring::aead::Nonce::assume_unique_for_key(nonce_buffer),
ring::aead::Aad::from(
AD {
final_chunk,
chunk_num,
}
.marshal()?
.as_slice(),
),
&mut chunk_buffer[..(chunk_len as usize)],
)
.map_err(|e| {
CapsuleError::DecryptionFailure(format!("decrypting chunk {}: {}", chunk_num, e))
})?;
let (decrypted_data, _) =
chunk_buffer.split_at_mut(chunk_len as usize - AES_256_GCM_TAG_LEN);
output
.write_all(decrypted_data)
.map_err(|e| CapsuleError::DecryptionFailure(format!("writing output: {}", e)))?;
chunk_num += 1;
}
Ok(())
}
pub fn streaming_encrypt_aes_256_gcm<R, W>(
key: &[u8; KEY_SIZE],
nonce_block: &mut [u8; NONCE_SIZE],
mut input: R,
mut output: W,
) -> Result<(), CapsuleError>
where
R: Read + Unpin,
W: Write + Unpin,
{
let key = ring::aead::UnboundKey::new(&ring::aead::AES_256_GCM, key)
.map_err(|e| CapsuleError::Generic(format!("creating AES 256 GCM key: {}", e)))?;
let cipher = ring::aead::LessSafeKey::new(key);
let mut chunk_num: u64 = 0;
let mut final_chunk: bool = false;
let mut next_byte = [0u8; 1];
let mut buffer = [0u8; CHUNK_SIZE + AES_256_GCM_TAG_LEN];
match input.read(&mut next_byte) {
Ok(0) => {
return Err(CapsuleError::EncryptionFailure(
"empty plaintext".to_string(),
))
}
Ok(_) => {}
Err(e) => {
return Err(CapsuleError::EncryptionFailure(format!(
"reading input stream: {}",
e
)))
}
}
loop {
buffer[0] = next_byte[0];
let n = input
.read(&mut buffer[1..CHUNK_SIZE])
.map_err(|e| CapsuleError::EncryptionFailure(format!("reading input stream: {}", e)))?
+ 1;
match input.read(&mut next_byte) {
Ok(0) => final_chunk = true,
Ok(_) => {}
Err(e) => {
return Err(CapsuleError::EncryptionFailure(format!(
"reading input stream: {}",
e
)))
}
}
let tag = cipher
.seal_in_place_separate_tag(
ring::aead::Nonce::assume_unique_for_key(*nonce_block),
ring::aead::Aad::from(
AD {
final_chunk,
chunk_num,
}
.marshal()?
.as_slice(),
),
&mut buffer[..n],
)
.map_err(|e| CapsuleError::Generic(format!("failed to seal in place: {}", e)))?;
buffer[n..n + AES_256_GCM_TAG_LEN].copy_from_slice(tag.as_ref());
output
.write_all(&((n + AES_256_GCM_TAG_LEN) as u32).to_le_bytes())
.map_err(|e| {
CapsuleError::EncryptionFailure(format!(
"writing chunk length to output stream: {}",
e
))
})?;
output.write_all(nonce_block).map_err(|e| {
CapsuleError::EncryptionFailure(format!("writing nonce to output stream: {}", e))
})?;
output
.write_all(&buffer[..n + AES_256_GCM_TAG_LEN])
.map_err(|e| {
CapsuleError::EncryptionFailure(format!(
"writing encrypted data to output stream: {}",
e
))
})?;
if final_chunk {
output.write_all(&0_u32.to_le_bytes()).map_err(|e| {
CapsuleError::EncryptionFailure(format!(
"writing sentinel chunk length to output stream: {}",
e
))
})?;
break;
}
chunk_num += 1;
if increment_nonce(nonce_block) {
return Err(CapsuleError::EncryptionFailure(
"nonce block exhausted".to_string(),
));
}
}
Ok(())
}
pub async fn async_streaming_decrypt_aes_256_gcm<R, W>(
key: &[u8; KEY_SIZE],
mut input: R,
mut output: W,
) -> Result<(), CapsuleError>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let key = ring::aead::UnboundKey::new(&ring::aead::AES_256_GCM, key)
.map_err(|e| CapsuleError::Generic(format!("creating AES 256 GCM key: {}", e)))?;
let cipher = ring::aead::LessSafeKey::new(key);
let mut len_buffer = [0; std::mem::size_of::<u32>()]; let mut nonce_buffer = [0; NONCE_SIZE];
let mut chunk_buffer = [0; CHUNK_SIZE + AES_256_GCM_TAG_LEN];
let mut chunk_len: u32;
let mut chunk_num: u64 = 0;
let mut final_chunk: bool;
input
.read_exact(&mut len_buffer)
.await
.map_err(|e| CapsuleError::DecryptionFailure(format!("reading input stream: {}", e)))?;
loop {
chunk_len = u32::from_le_bytes(len_buffer);
if chunk_len == 0 {
break;
} else if chunk_len > ((CHUNK_SIZE + AES_256_GCM_TAG_LEN) as u32) {
return Err(CapsuleError::DecryptionFailure(format!(
"chunk length {} exceeds maximum chunk size {} (this is probably a bug)",
chunk_len,
CHUNK_SIZE + AES_256_GCM_TAG_LEN
)));
}
input
.read_exact(&mut nonce_buffer)
.await
.map_err(|e| CapsuleError::DecryptionFailure(format!("reading nonce: {}", e)))?;
input
.read_exact(&mut chunk_buffer[..(chunk_len as usize)])
.await
.map_err(|e| CapsuleError::DecryptionFailure(format!("reading chunk: {}", e)))?;
input
.read_exact(&mut len_buffer)
.await
.map_err(|e| CapsuleError::DecryptionFailure(format!("reading chunk length: {}", e)))?;
final_chunk = u32::from_le_bytes(len_buffer) == 0;
cipher
.open_in_place(
ring::aead::Nonce::assume_unique_for_key(nonce_buffer),
ring::aead::Aad::from(
AD {
final_chunk,
chunk_num,
}
.marshal()?
.as_slice(),
),
&mut chunk_buffer[..(chunk_len as usize)],
)
.map_err(|e| {
CapsuleError::DecryptionFailure(format!("decrypting chunk {}: {}", chunk_num, e))
})?;
let (decrypted_data, _) =
chunk_buffer.split_at_mut(chunk_len as usize - AES_256_GCM_TAG_LEN);
output
.write_all(decrypted_data)
.await
.map_err(|e| CapsuleError::DecryptionFailure(format!("writing output: {}", e)))?;
chunk_num += 1;
}
Ok(())
}
pub async fn async_streaming_encrypt_aes_256_gcm<R, W>(
key: &[u8; KEY_SIZE],
nonce_block: &mut [u8; NONCE_SIZE],
mut input: R,
mut output: W,
) -> Result<(), CapsuleError>
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let key = ring::aead::UnboundKey::new(&ring::aead::AES_256_GCM, key)
.map_err(|e| CapsuleError::Generic(format!("creating AES 256 GCM key: {}", e)))?;
let cipher = ring::aead::LessSafeKey::new(key);
let mut chunk_num: u64 = 0;
let mut final_chunk: bool = false;
let mut next_byte = [0u8; 1];
let mut buffer = [0u8; CHUNK_SIZE + AES_256_GCM_TAG_LEN];
match input.read(&mut next_byte).await {
Ok(0) => {
return Err(CapsuleError::EncryptionFailure(
"empty plaintext".to_string(),
))
}
Ok(_) => {}
Err(e) => {
return Err(CapsuleError::EncryptionFailure(format!(
"reading input stream: {}",
e
)))
}
}
loop {
buffer[0] = next_byte[0];
let n =
input.read(&mut buffer[1..CHUNK_SIZE]).await.map_err(|e| {
CapsuleError::EncryptionFailure(format!("reading input stream: {}", e))
})? + 1;
match input.read(&mut next_byte).await {
Ok(0) => final_chunk = true,
Ok(_) => {}
Err(e) => {
return Err(CapsuleError::EncryptionFailure(format!(
"reading input stream: {}",
e
)))
}
}
let tag = cipher
.seal_in_place_separate_tag(
ring::aead::Nonce::assume_unique_for_key(*nonce_block),
ring::aead::Aad::from(
AD {
final_chunk,
chunk_num,
}
.marshal()?
.as_slice(),
),
&mut buffer[..n],
)
.map_err(|e| CapsuleError::Generic(format!("failed to seal in place: {}", e)))?;
buffer[n..n + AES_256_GCM_TAG_LEN].copy_from_slice(tag.as_ref());
output
.write_all(&((n + AES_256_GCM_TAG_LEN) as u32).to_le_bytes())
.await
.map_err(|e| {
CapsuleError::EncryptionFailure(format!(
"writing chunk length to output stream: {}",
e
))
})?;
output.write_all(nonce_block).await.map_err(|e| {
CapsuleError::EncryptionFailure(format!("writing nonce to output stream: {}", e))
})?;
output
.write_all(&buffer[..n + AES_256_GCM_TAG_LEN])
.await
.map_err(|e| {
CapsuleError::EncryptionFailure(format!(
"writing encrypted data to output stream: {}",
e
))
})?;
if final_chunk {
output.write_all(&0_u32.to_le_bytes()).await.map_err(|e| {
CapsuleError::EncryptionFailure(format!(
"writing sentinel chunk length to output stream: {}",
e
))
})?;
break;
}
chunk_num += 1;
if increment_nonce(nonce_block) {
return Err(CapsuleError::EncryptionFailure(
"nonce block exhausted".to_string(),
));
}
}
Ok(())
}
#[cfg(test)]
pub mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt_readers() {
use rand::Rng;
let mut rng = rand::thread_rng();
let a = rng.gen::<u8>() % 11; let b = rng.gen::<usize>() % CHUNK_SIZE; let data_size: usize = (a as usize) * CHUNK_SIZE + b;
let data: Vec<u8> = (0..data_size).map(|_| rng.gen()).collect();
let nonce_block = [0u8; NONCE_SIZE];
let mut data_reader = std::io::Cursor::new(&data);
let mut encrypt =
EncryptingAEADReader::new(nonce_block, &[0u8; KEY_SIZE], &mut data_reader).unwrap();
let mut encrypted_data: Vec<u8> = Vec::new();
let _ = encrypt.read_to_end(&mut encrypted_data).unwrap();
let encrypted_reader = Arc::new(Mutex::new(std::io::Cursor::new(&encrypted_data)));
let mut decrypt = DecryptingAEAD::new(&[0u8; KEY_SIZE], encrypted_reader).unwrap();
let mut decrypted_data: Vec<u8> = Vec::new();
let _ = decrypt.read_to_end(&mut decrypted_data).unwrap();
}
#[test]
fn test_encrypt_writer_decrypt_reader() {
use rand::Rng;
let mut rng = rand::thread_rng();
let a = rng.gen::<u8>() % 11; let b = rng.gen::<usize>() % CHUNK_SIZE; let data_size: usize = (a as usize) * CHUNK_SIZE + b;
let mut data: Vec<u8> = (0..data_size).map(|_| rng.gen()).collect();
let mut encrypted_data: Vec<u8> = Vec::new();
let nonce_block = [0u8; NONCE_SIZE];
let mut encrypt = EncryptingAEADWriter::new(
nonce_block,
&[0u8; KEY_SIZE],
Arc::new(Mutex::new(&mut encrypted_data)),
)
.unwrap();
let _ = encrypt
.write(&mut data)
.expect("failed write data from input");
encrypt.flush().expect("failed to flush writer");
let encrypted_reader = Arc::new(Mutex::new(std::io::Cursor::new(&encrypted_data)));
let mut decrypt = DecryptingAEAD::new(&[0u8; KEY_SIZE], encrypted_reader).unwrap();
let mut decrypted_data: Vec<u8> = Vec::new();
let _ = decrypt.read_to_end(&mut decrypted_data).unwrap();
}
#[test]
fn streaming_encrypt_and_decrypt() {
use rand::Rng;
let mut rng = rand::thread_rng();
let a = rng.gen::<u8>() % 11; let b = rng.gen::<usize>() % CHUNK_SIZE; let data_size: usize = (a as usize) * CHUNK_SIZE + b;
let data: Vec<u8> = (0..data_size).map(|_| rng.gen()).collect();
let mut encrypt_output = Vec::<u8>::new();
let mut decrypt_output = Vec::<u8>::new();
if let Err(e) = streaming_encrypt_aes_256_gcm(
&[0u8; KEY_SIZE],
&mut [0u8; NONCE_SIZE],
&mut std::io::Cursor::new(&data),
&mut encrypt_output,
) {
assert!(false, "error encrypting: {}", e)
}
if let Err(e) = streaming_decrypt_aes_256_gcm(
&[0u8; KEY_SIZE],
&mut std::io::Cursor::new(encrypt_output),
&mut decrypt_output,
) {
assert!(false, "error decrypting: {}", e)
}
assert!(
data == decrypt_output,
"decrypted data doesn't match for size {}",
data_size
)
}
}