use std::cmp::min;
use std::io;
use std::io::prelude::*;
use std::io::SeekFrom;
#[cfg(test)]
mod test;
pub const KEY_LEN: usize = 32;
pub const CHUNK_LEN: usize = 16384; pub const NONCE_LEN: usize = 24;
pub const TAG_LEN: usize = 32;
type Key = [u8; KEY_LEN];
type Nonce = [u8; NONCE_LEN];
#[derive(Debug)]
pub struct Error {
msg: &'static str,
}
impl Error {
fn truncated() -> Self {
Self {
msg: "ciphertext has been truncated",
}
}
fn corrupt() -> Self {
Self {
msg: "ciphertext is corrupt",
}
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.msg)
}
}
impl From<Error> for io::Error {
fn from(e: Error) -> io::Error {
io::Error::new(io::ErrorKind::InvalidData, e.msg)
}
}
impl std::error::Error for Error {}
#[repr(u8)]
#[derive(Clone, Copy, Debug)]
enum FinalFlag {
NotFinal = 0,
Final = 1,
}
fn chunk_keys(
long_term_key: &Key,
nonce: &Nonce,
chunk_index: u64,
final_flag: FinalFlag,
) -> (Key, Key) {
let mut input = [0; NONCE_LEN + 8 + 1];
input[..NONCE_LEN].copy_from_slice(nonce);
input[NONCE_LEN..][..8].copy_from_slice(&chunk_index.to_le_bytes());
input[NONCE_LEN + 8] = final_flag as u8;
let mut output = [0; 2 * KEY_LEN];
blake3::Hasher::new_keyed(long_term_key)
.update(&input)
.finalize_xof()
.fill(&mut output);
(
output[..KEY_LEN].try_into().unwrap(),
output[KEY_LEN..].try_into().unwrap(),
)
}
fn xor_stream(mut stream_reader: blake3::OutputReader, input: &[u8], output: &mut [u8]) {
assert_eq!(input.len(), output.len());
let mut position = 0;
const STREAM_ARRAY_LEN: usize = 16 * 64;
let mut stream_array = [0; STREAM_ARRAY_LEN];
while position < input.len() {
let take = min(STREAM_ARRAY_LEN, input.len() - position);
let stream_slice = &mut stream_array[..take];
stream_reader.fill(stream_slice);
for _ in 0..take {
output[position] = input[position] ^ stream_slice[position % STREAM_ARRAY_LEN];
position += 1;
}
}
}
fn encrypt_chunk(
long_term_key: &Key,
nonce: &Nonce,
chunk_index: u64,
final_flag: FinalFlag,
plaintext: &[u8],
ciphertext: &mut [u8],
) {
debug_assert!(plaintext.len() <= CHUNK_LEN);
debug_assert_eq!(plaintext.len() + TAG_LEN, ciphertext.len());
match final_flag {
FinalFlag::NotFinal => debug_assert_eq!(plaintext.len(), CHUNK_LEN),
FinalFlag::Final => debug_assert!(plaintext.len() < CHUNK_LEN),
}
let (auth_key, stream_key) = chunk_keys(long_term_key, nonce, chunk_index, final_flag);
let tag = blake3::keyed_hash(&auth_key, plaintext);
let stream_reader = blake3::Hasher::new_keyed(&stream_key)
.update(tag.as_bytes())
.finalize_xof();
xor_stream(stream_reader, plaintext, &mut ciphertext[..plaintext.len()]);
ciphertext[plaintext.len()..].copy_from_slice(tag.as_bytes());
}
fn decrypt_chunk(
long_term_key: &Key,
nonce: &Nonce,
chunk_index: u64,
final_flag: FinalFlag,
ciphertext: &[u8],
plaintext: &mut [u8],
) -> Result<(), Error> {
if ciphertext.len() < TAG_LEN {
return Err(Error::truncated());
}
debug_assert!(plaintext.len() <= CHUNK_LEN);
debug_assert_eq!(plaintext.len() + TAG_LEN, ciphertext.len());
match final_flag {
FinalFlag::NotFinal => debug_assert_eq!(plaintext.len(), CHUNK_LEN),
FinalFlag::Final => debug_assert!(plaintext.len() < CHUNK_LEN),
}
let (auth_key, stream_key) = chunk_keys(long_term_key, nonce, chunk_index, final_flag);
let tag_bytes: &[u8; TAG_LEN] = ciphertext[ciphertext.len() - TAG_LEN..].try_into().unwrap();
let stream_reader = blake3::Hasher::new_keyed(&stream_key)
.update(tag_bytes)
.finalize_xof();
xor_stream(stream_reader, &ciphertext[..plaintext.len()], plaintext);
let computed_tag: blake3::Hash = blake3::keyed_hash(&auth_key, plaintext);
if &computed_tag != tag_bytes {
plaintext.fill(0);
return Err(Error::corrupt());
}
Ok(())
}
pub fn ciphertext_len(plaintext_len: u64) -> Option<u64> {
let num_chunks = (plaintext_len / CHUNK_LEN as u64) + 1;
plaintext_len
.checked_add(NONCE_LEN as u64)?
.checked_add(num_chunks * TAG_LEN as u64)
}
pub fn plaintext_len(ciphertext_len: u64) -> Option<u64> {
let chunks_len = ciphertext_len.checked_sub(NONCE_LEN as u64)?;
let whole_chunks = chunks_len / (CHUNK_LEN + TAG_LEN) as u64;
let last_chunk = chunks_len % (CHUNK_LEN + TAG_LEN) as u64;
Some((whole_chunks * CHUNK_LEN as u64) + last_chunk.checked_sub(TAG_LEN as u64)?)
}
pub fn generate_key() -> Key {
rand::random()
}
fn generate_nonce() -> Nonce {
rand::random()
}
pub fn encrypt(key: &Key, plaintext: &[u8]) -> Vec<u8> {
let ciphertext_len: usize = ciphertext_len(plaintext.len() as u64)
.expect("length overflows a u64")
.try_into()
.expect("length overflows a usize");
let mut ciphertext = vec![0; ciphertext_len];
encrypt_to_slice(key, plaintext, &mut ciphertext);
ciphertext
}
pub mod testing {
use super::*;
pub fn encrypt_with_nonce(key: &Key, nonce: &Nonce, plaintext: &[u8]) -> Vec<u8> {
let ciphertext_len: usize = ciphertext_len(plaintext.len() as u64)
.expect("length overflows a u64")
.try_into()
.expect("length overflows a usize");
let mut ciphertext = vec![0; ciphertext_len];
encrypt_to_slice_with_nonce(key, nonce, plaintext, &mut ciphertext);
ciphertext
}
}
pub fn encrypt_to_slice(key: &Key, plaintext: &[u8], ciphertext: &mut [u8]) {
let nonce = generate_nonce();
encrypt_to_slice_with_nonce(key, &nonce, plaintext, ciphertext);
}
fn encrypt_to_slice_with_nonce(key: &Key, nonce: &Nonce, plaintext: &[u8], ciphertext: &mut [u8]) {
ciphertext[..NONCE_LEN].copy_from_slice(nonce);
let mut chunk_index = 0;
let mut plaintext_chunks = plaintext.chunks_exact(CHUNK_LEN);
let mut ciphertext_chunks = ciphertext[NONCE_LEN..].chunks_exact_mut(CHUNK_LEN + TAG_LEN);
for (plaintext_chunk, ciphertext_chunk) in plaintext_chunks.by_ref().zip(&mut ciphertext_chunks)
{
encrypt_chunk(
key,
&nonce,
chunk_index,
FinalFlag::NotFinal,
plaintext_chunk,
ciphertext_chunk,
);
chunk_index += 1;
}
encrypt_chunk(
key,
&nonce,
chunk_index,
FinalFlag::Final,
plaintext_chunks.remainder(),
ciphertext_chunks.into_remainder(),
);
}
pub fn decrypt(key: &Key, ciphertext: &[u8]) -> Result<Vec<u8>, Error> {
let plaintext_len = if let Some(len) = plaintext_len(ciphertext.len() as u64) {
len as usize
} else {
return Err(Error::truncated());
};
let mut plaintext = vec![0; plaintext_len];
decrypt_to_slice(key, ciphertext, &mut plaintext)?;
Ok(plaintext)
}
pub fn decrypt_to_slice(key: &Key, ciphertext: &[u8], plaintext: &mut [u8]) -> Result<(), Error> {
let nonce: &Nonce = &ciphertext[..NONCE_LEN].try_into().unwrap();
let mut chunk_index = 0;
let mut ciphertext_chunks = ciphertext[NONCE_LEN..].chunks_exact(CHUNK_LEN + TAG_LEN);
let mut plaintext_chunks = plaintext.chunks_exact_mut(CHUNK_LEN);
for (ciphertext_chunk, plaintext_chunk) in ciphertext_chunks.by_ref().zip(&mut plaintext_chunks)
{
let chunk_result = decrypt_chunk(
key,
&nonce,
chunk_index,
FinalFlag::NotFinal,
ciphertext_chunk,
plaintext_chunk,
);
if let Err(e) = chunk_result {
plaintext.fill(0);
return Err(e);
}
chunk_index += 1;
}
let chunk_result = decrypt_chunk(
key,
&nonce,
chunk_index,
FinalFlag::Final,
ciphertext_chunks.remainder(),
plaintext_chunks.into_remainder(),
);
if let Err(e) = chunk_result {
plaintext.fill(0);
return Err(e);
}
Ok(())
}
pub struct EncryptWriter<W: Write> {
inner_writer: W,
long_term_key: Key,
nonce: Nonce,
chunk_index: u64,
plaintext_buf: [u8; CHUNK_LEN],
plaintext_buf_len: usize,
did_error: bool, }
impl<W: Write> EncryptWriter<W> {
pub fn new(key: &Key, inner_writer: W) -> Self {
Self {
inner_writer,
long_term_key: *key,
nonce: generate_nonce(),
chunk_index: 0,
plaintext_buf: [0; CHUNK_LEN],
plaintext_buf_len: 0,
did_error: false,
}
}
fn bail_if_errored_before(&self) -> io::Result<()> {
if self.did_error {
Err(io::Error::new(
io::ErrorKind::Other,
"already encountered an error",
))
} else {
Ok(())
}
}
fn encrypt_and_write_buf(&mut self, final_flag: FinalFlag) -> io::Result<usize> {
debug_assert!(!self.did_error);
self.did_error = true;
if self.chunk_index == 0 {
self.inner_writer.write_all(&self.nonce)?;
}
let mut ciphertext_array = [0; CHUNK_LEN + TAG_LEN];
let ciphertext_slice = &mut ciphertext_array[..self.plaintext_buf_len + TAG_LEN];
encrypt_chunk(
&self.long_term_key,
&self.nonce,
self.chunk_index,
final_flag,
&self.plaintext_buf[..self.plaintext_buf_len],
ciphertext_slice,
);
assert!(self.chunk_index < u64::MAX, "chunk index overflow");
self.chunk_index += 1;
self.plaintext_buf_len = 0;
self.inner_writer.write_all(ciphertext_slice)?;
self.did_error = false;
Ok(self.plaintext_buf_len)
}
pub fn finalize(&mut self) -> io::Result<()> {
self.bail_if_errored_before()?;
self.encrypt_and_write_buf(FinalFlag::Final)?;
Ok(())
}
pub fn into_inner(self) -> W {
self.inner_writer
}
}
impl<W: Write> Write for EncryptWriter<W> {
fn write(&mut self, plaintext: &[u8]) -> io::Result<usize> {
self.bail_if_errored_before()?;
let want = CHUNK_LEN - self.plaintext_buf_len;
let take = min(want, plaintext.len());
self.plaintext_buf[self.plaintext_buf_len..][..take].copy_from_slice(&plaintext[..take]);
self.plaintext_buf_len += take;
if self.plaintext_buf_len == CHUNK_LEN {
self.encrypt_and_write_buf(FinalFlag::NotFinal)?;
}
Ok(take)
}
fn flush(&mut self) -> io::Result<()> {
self.inner_writer.flush()
}
}
impl<W: Write> std::fmt::Debug for EncryptWriter<W> {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("EncryptWriter").finish()
}
}
#[derive(Clone)]
pub struct DecryptReader<R: Read> {
inner_reader: R,
long_term_key: Key,
nonce: Option<Nonce>,
plaintext_buf: [u8; CHUNK_LEN],
plaintext_buf_pos: u16, plaintext_buf_len: u16, plaintext_buf_end_offset: u64, at_eof: bool,
authenticated_plaintext_length: Option<u64>,
}
impl<R: Read> DecryptReader<R> {
pub fn new(key: &Key, inner_reader: R) -> Self {
Self {
inner_reader,
long_term_key: *key,
nonce: None,
plaintext_buf: [0; CHUNK_LEN],
plaintext_buf_pos: 0,
plaintext_buf_len: 0,
plaintext_buf_end_offset: 0,
at_eof: false,
authenticated_plaintext_length: None,
}
}
pub fn into_inner(self) -> R {
self.inner_reader
}
fn get_nonce(&mut self) -> io::Result<Nonce> {
match self.nonce {
Some(nonce) => Ok(nonce),
None => {
let mut nonce = [0; NONCE_LEN];
self.inner_reader.read_exact(&mut nonce)?;
self.nonce = Some(nonce);
Ok(nonce)
}
}
}
pub fn position(&self) -> u64 {
debug_assert!(self.plaintext_buf_pos <= self.plaintext_buf_len);
debug_assert!(self.plaintext_buf_len as u64 <= self.plaintext_buf_end_offset);
self.plaintext_buf_end_offset - self.plaintext_buf_len as u64
+ self.plaintext_buf_pos as u64
}
}
fn read_exact_or_eof<'buf>(
reader: &mut impl Read,
buf: &'buf mut [u8],
) -> io::Result<&'buf mut [u8]> {
let mut total_read = 0;
let mut remaining_buf = &mut buf[..];
while !remaining_buf.is_empty() {
match reader.read(&mut remaining_buf) {
Ok(n) => {
total_read += n;
if n == 0 {
break;
}
remaining_buf = &mut remaining_buf[n..];
}
Err(e) => {
if e.kind() == io::ErrorKind::Interrupted {
continue;
}
return Err(e);
}
}
}
Ok(&mut buf[..total_read])
}
impl<R: Read> DecryptReader<R> {
fn read_and_decrypt_next_chunk(&mut self, next_chunk_start_offset: u64) -> io::Result<()> {
debug_assert_eq!(next_chunk_start_offset % CHUNK_LEN as u64, 0);
let nonce = self.get_nonce()?;
self.plaintext_buf_pos = 0;
self.plaintext_buf_len = 0;
let mut ciphertext_array = [0; CHUNK_LEN + TAG_LEN];
let chunk_ciphertext = read_exact_or_eof(&mut self.inner_reader, &mut ciphertext_array)?;
if chunk_ciphertext.len() < TAG_LEN {
return Err(Error::truncated().into());
}
let next_chunk_index = next_chunk_start_offset / CHUNK_LEN as u64;
let final_flag = if chunk_ciphertext.len() == CHUNK_LEN + TAG_LEN {
FinalFlag::NotFinal
} else {
FinalFlag::Final
};
let chunk_plaintext = &mut self.plaintext_buf[..chunk_ciphertext.len() - TAG_LEN];
decrypt_chunk(
&self.long_term_key,
&nonce,
next_chunk_index,
final_flag,
chunk_ciphertext,
chunk_plaintext,
)?;
self.plaintext_buf_end_offset = next_chunk_start_offset
.checked_add(chunk_plaintext.len() as u64)
.expect("position overflow");
self.plaintext_buf_len = chunk_plaintext.len() as u16;
self.at_eof = matches!(final_flag, FinalFlag::Final);
if self.at_eof {
self.authenticated_plaintext_length = Some(self.plaintext_buf_end_offset);
}
Ok(())
}
}
impl<R: Read> Read for DecryptReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if !self.at_eof && self.plaintext_buf_pos == self.plaintext_buf_len {
self.read_and_decrypt_next_chunk(self.plaintext_buf_end_offset)?;
}
let available = self.plaintext_buf_len - self.plaintext_buf_pos;
let take = min(buf.len(), available as usize);
buf[..take].copy_from_slice(&self.plaintext_buf[self.plaintext_buf_pos as usize..][..take]);
self.plaintext_buf_pos += take as u16;
Ok(take)
}
}
impl<R: Read + Seek> DecryptReader<R> {
fn get_authenticated_plaintext_length(&mut self) -> io::Result<u64> {
if let Some(len) = self.authenticated_plaintext_length {
return Ok(len);
}
self.get_nonce()?;
let apparent_ciphertext_length = self.inner_reader.seek(SeekFrom::End(0))?;
let apparent_plaintext_length = plaintext_len(apparent_ciphertext_length)
.ok_or_else(|| io::Error::from(Error::truncated()))?;
let apparent_last_chunk_ciphertext_length =
(apparent_ciphertext_length - NONCE_LEN as u64) % (CHUNK_LEN + TAG_LEN) as u64;
let apparent_last_chunk_plaintext_length = apparent_plaintext_length % CHUNK_LEN as u64;
let apparent_last_chunk_ciphertext_start =
apparent_ciphertext_length - apparent_last_chunk_ciphertext_length;
let apparent_last_chunk_plaintext_start =
apparent_plaintext_length - apparent_last_chunk_plaintext_length;
self.inner_reader
.seek(SeekFrom::Start(apparent_last_chunk_ciphertext_start))?;
self.read_and_decrypt_next_chunk(apparent_last_chunk_plaintext_start)?;
if let Some(len) = self.authenticated_plaintext_length {
Ok(len)
} else {
Err(Error::truncated().into())
}
}
}
impl<R: Read + Seek> Seek for DecryptReader<R> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
self.get_nonce()?;
let starting_position = self.position();
let plaintext_len = self.get_authenticated_plaintext_length()?;
let mut target = match pos {
SeekFrom::Start(n) => n,
SeekFrom::Current(n) => (starting_position as i128 + n as i128)
.try_into()
.expect("seek target overflow"),
SeekFrom::End(n) => (plaintext_len as i128 + n as i128)
.try_into()
.expect("seek target overflow"),
};
if target > plaintext_len {
target = plaintext_len;
}
if target <= self.plaintext_buf_end_offset {
let remaining = self.plaintext_buf_end_offset - target;
if remaining <= self.plaintext_buf_len as u64 {
self.plaintext_buf_pos = self.plaintext_buf_len - remaining as u16;
debug_assert_eq!(target, self.position());
return Ok(target);
}
}
let target_chunk_index = target / CHUNK_LEN as u64;
let target_position_within_chunk = (target % CHUNK_LEN as u64) as u16;
let target_chunk_start = target - target_position_within_chunk as u64;
let target_ciphertext_chunk_start = ((CHUNK_LEN + TAG_LEN) as u64)
.checked_mul(target_chunk_index)
.and_then(|s| s.checked_add(NONCE_LEN as u64))
.expect("ciphertext target overflow");
self.inner_reader
.seek(SeekFrom::Start(target_ciphertext_chunk_start))?;
self.read_and_decrypt_next_chunk(target_chunk_start)?;
if self.plaintext_buf_len < target_position_within_chunk {
return Err(Error::truncated().into());
}
self.plaintext_buf_pos = target_position_within_chunk;
debug_assert_eq!(target, self.position());
Ok(target)
}
fn stream_position(&mut self) -> io::Result<u64> {
Ok(self.position())
}
}
impl<R: Read> std::fmt::Debug for DecryptReader<R> {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.debug_struct("DecryptReader").finish()
}
}