use crate::crypto::{
AEAD_TAG_LEN, create_aes256gcmsiv_cipher, create_xchacha20poly1305_cipher, generate_salt,
};
use crate::file::default_out_path;
use crate::format::{DiskHeader, StreamInfo};
use crate::kdf::derive_key_argon2id;
use crate::types::{AeadAlg, EncFileError, EncryptOptions};
use aead::Aead;
use chacha20poly1305::aead::generic_array::{GenericArray, typenum::U19};
use chacha20poly1305::aead::stream::{DecryptorBE32, EncryptorBE32};
use getrandom::fill as getrandom;
use secrecy::SecretString;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use tempfile::NamedTempFile;
use zeroize::{Zeroize, Zeroizing};
const FLAG_FINAL: u8 = 1;
pub fn validate_chunk_size_for_streaming(chunk_size: usize) -> Result<(), EncFileError> {
if chunk_size == 0 {
return Err(EncFileError::Invalid("streaming chunk_size must be > 0"));
}
let max_frame_size = (u32::MAX as usize).saturating_sub(AEAD_TAG_LEN);
if chunk_size > max_frame_size {
return Err(EncFileError::Invalid(
"chunk_size too large for 32-bit frame",
));
}
Ok(())
}
fn calculate_optimal_chunk_size(
user_chunk_size: usize,
file_size_hint: Option<u64>,
) -> Result<usize, EncFileError> {
if user_chunk_size != 0 {
validate_chunk_size_for_streaming(user_chunk_size)?;
return Ok(user_chunk_size);
}
let optimal_size = if let Some(file_size) = file_size_hint {
match file_size {
0..=1_048_576 => crate::types::MIN_CHUNK_SIZE, 1_048_577..=104_857_600 => crate::types::DEFAULT_CHUNK_SIZE, _ => {
(file_size / 1000).clamp(
crate::types::DEFAULT_CHUNK_SIZE as u64,
crate::types::MAX_CHUNK_SIZE as u64,
) as usize
}
}
} else {
crate::types::DEFAULT_CHUNK_SIZE
};
validate_chunk_size_for_streaming(optimal_size)?;
Ok(optimal_size)
}
fn validate_header_chunk_size(chunk_size: usize) -> Result<(), EncFileError> {
validate_chunk_size_for_streaming(chunk_size)
}
fn write_frame<W: Write>(mut w: W, ct: &[u8], is_final: bool) -> Result<(), EncFileError> {
let flags = if is_final { FLAG_FINAL } else { 0 };
w.write_all(&[flags])?;
w.write_all(&(ct.len() as u32).to_be_bytes())?;
w.write_all(ct)?;
Ok(())
}
pub fn encrypt_file_streaming(
input: &Path,
output: Option<&Path>,
password: SecretString,
mut opts: EncryptOptions,
) -> Result<PathBuf, EncFileError> {
let file_metadata = std::fs::metadata(input).ok();
let file_size_hint = file_metadata.map(|m| m.len());
if let Some(file_size) = file_size_hint {
crate::crypto::validate_file_size(file_size)?;
}
let eff_chunk_size = calculate_optimal_chunk_size(opts.chunk_size, file_size_hint)?;
if !opts.stream {
opts.stream = true;
}
let out_path = default_out_path(input, output, "enc");
if out_path.exists() && !opts.force {
return Err(EncFileError::Invalid(
"output exists; use --force to overwrite",
));
}
let salt = generate_salt()?;
let key = derive_key_argon2id(&password, opts.kdf_params, &salt)?;
let mut file_id = vec![0u8; 16];
getrandom(&mut file_id).map_err(|_| EncFileError::Crypto)?;
let stream_info = match opts.alg {
AeadAlg::XChaCha20Poly1305 => {
let mut prefix = vec![0u8; 19];
getrandom(&mut prefix).map_err(|_| EncFileError::Crypto)?;
StreamInfo {
chunk_size: eff_chunk_size as u32,
nonce_prefix: prefix,
file_id: Some(file_id),
}
}
AeadAlg::Aes256GcmSiv => {
let mut prefix = vec![0u8; 8];
getrandom(&mut prefix).map_err(|_| EncFileError::Crypto)?;
StreamInfo {
chunk_size: eff_chunk_size as u32,
nonce_prefix: prefix,
file_id: Some(file_id),
}
}
};
let header = DiskHeader::new_stream(opts.alg, opts.kdf, opts.kdf_params, salt, stream_info);
let mut header_bytes = Vec::new();
ciborium::ser::into_writer(&header, &mut header_bytes)?;
let tmp = NamedTempFile::new_in(
out_path
.parent()
.ok_or(EncFileError::Invalid("output path has no parent"))?,
)?;
let mut writer = BufWriter::with_capacity(64 * 1024, tmp);
writer.write_all(&(header_bytes.len() as u32).to_le_bytes())?;
writer.write_all(&header_bytes)?;
let file = File::open(input)?;
let buffer_size = (eff_chunk_size / 4).clamp(64 * 1024, 512 * 1024); let mut reader = BufReader::with_capacity(buffer_size, file);
let mut buf = vec![0u8; eff_chunk_size];
match opts.alg {
AeadAlg::XChaCha20Poly1305 => {
let cipher = create_xchacha20poly1305_cipher(&key)?;
let stream_info = match &header.stream {
Some(s) => s,
None => return Err(EncFileError::Invalid("missing stream info")),
};
let nonce_prefix = GenericArray::<u8, U19>::from_slice(&stream_info.nonce_prefix);
let mut enc = EncryptorBE32::from_aead(cipher, nonce_prefix);
loop {
let n = reader.read(&mut buf)?;
if n == 0 {
break;
}
let pt = &buf[..n];
let ct = enc.encrypt_next(pt).map_err(|_| EncFileError::Crypto)?;
write_frame(&mut writer, &ct, false)?;
buf[..n].zeroize();
}
let ct_final = enc
.encrypt_last(&[] as &[u8])
.map_err(|_| EncFileError::Crypto)?;
write_frame(&mut writer, &ct_final, true)?;
buf.zeroize();
}
AeadAlg::Aes256GcmSiv => {
let cipher = create_aes256gcmsiv_cipher(&key)?;
let stream = match header.stream.as_ref() {
Some(s) => s,
None => return Err(EncFileError::Invalid("missing stream info")),
};
let prefix = &stream.nonce_prefix;
let mut counter = 0u32;
let mut nonce_bytes = Vec::with_capacity(12);
loop {
let n = reader.read(&mut buf)?;
let is_final = n == 0 || n < eff_chunk_size;
if counter == u32::MAX && !is_final {
return Err(EncFileError::Invalid("too many frames for 32-bit counter"));
}
nonce_bytes.clear();
nonce_bytes.extend_from_slice(prefix);
nonce_bytes.extend_from_slice(&counter.to_be_bytes());
counter = counter.wrapping_add(1);
let pt = &buf[..n];
let ct = cipher
.encrypt(GenericArray::from_slice(&nonce_bytes), pt)
.map_err(|_| EncFileError::Crypto)?;
write_frame(&mut writer, &ct, is_final)?;
if n > 0 {
buf[..n].zeroize();
}
nonce_bytes.zeroize();
if is_final {
break;
}
}
buf.zeroize();
nonce_bytes.zeroize();
}
}
writer.flush()?;
let tmp = writer
.into_inner()
.map_err(|e| EncFileError::Io(e.into_error()))?;
tmp.as_file().sync_all()?;
tmp.persist(&out_path)
.map_err(|e| EncFileError::Io(e.error))?;
let mut key_z = key;
key_z.zeroize();
Ok(out_path)
}
fn parse_frame_from_slice(body: &[u8]) -> Result<(u8, usize, &[u8]), EncFileError> {
if body.len() < 5 {
return Err(EncFileError::Malformed);
}
let flags = body[0];
let ct_len = u32::from_be_bytes(body[1..5].try_into().unwrap()) as usize;
let remaining = &body[5..];
if remaining.len() < ct_len {
return Err(EncFileError::Malformed);
}
if ct_len > crate::types::MAX_CHUNK_SIZE + crate::crypto::AEAD_TAG_LEN {
return Err(EncFileError::Invalid("frame size exceeds maximum allowed"));
}
Ok((flags, ct_len, remaining))
}
fn parse_frame_from_reader<R: Read>(reader: &mut R) -> Result<(u8, usize), EncFileError> {
let mut frame_header = [0u8; 5];
reader.read_exact(&mut frame_header).map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
EncFileError::Malformed
} else {
EncFileError::Io(e)
}
})?;
let flags = frame_header[0];
let ct_len = u32::from_be_bytes(frame_header[1..5].try_into().unwrap()) as usize;
if ct_len > crate::types::MAX_CHUNK_SIZE + crate::crypto::AEAD_TAG_LEN {
return Err(EncFileError::Invalid("frame size exceeds maximum allowed"));
}
Ok((flags, ct_len))
}
pub fn decrypt_stream_into_vec(
alg: AeadAlg,
key: &[u8; 32],
stream: &StreamInfo,
mut body: &[u8],
) -> Result<Vec<u8>, EncFileError> {
validate_header_chunk_size(stream.chunk_size as usize)?;
let mut out = Vec::new();
match alg {
AeadAlg::XChaCha20Poly1305 => {
let cipher = create_xchacha20poly1305_cipher(key)?;
if stream.nonce_prefix.len() != 19 {
return Err(EncFileError::Malformed);
}
let nonce_prefix = GenericArray::<u8, U19>::from_slice(&stream.nonce_prefix);
let mut dec = DecryptorBE32::from_aead(cipher, nonce_prefix);
loop {
let (flags, ct_len, remaining_body) = parse_frame_from_slice(body)?;
let ct = &remaining_body[..ct_len];
body = &remaining_body[ct_len..];
let is_final = (flags & FLAG_FINAL) != 0;
if is_final {
let mut pt = dec.decrypt_last(ct).map_err(|_| EncFileError::Crypto)?;
out.extend_from_slice(&pt);
pt.zeroize();
break;
} else {
let mut pt = dec.decrypt_next(ct).map_err(|_| EncFileError::Crypto)?;
out.extend_from_slice(&pt);
pt.zeroize();
}
}
}
AeadAlg::Aes256GcmSiv => {
let cipher = create_aes256gcmsiv_cipher(key)?;
let prefix = &stream.nonce_prefix;
if prefix.len() != 8 {
return Err(EncFileError::Malformed);
}
let mut counter = 0u32;
let mut nonce_bytes = Vec::with_capacity(12);
loop {
let (flags, ct_len, remaining_body) = parse_frame_from_slice(body)?;
let ct = &remaining_body[..ct_len];
body = &remaining_body[ct_len..];
let is_final = (flags & FLAG_FINAL) != 0;
nonce_bytes.clear();
nonce_bytes.extend_from_slice(prefix);
nonce_bytes.extend_from_slice(&counter.to_be_bytes());
counter = counter.wrapping_add(1);
let mut pt = cipher
.decrypt(GenericArray::from_slice(&nonce_bytes), ct)
.map_err(|_| EncFileError::Crypto)?;
out.extend_from_slice(&pt);
pt.zeroize();
nonce_bytes.zeroize();
if is_final {
break;
}
}
nonce_bytes.zeroize();
}
}
Ok(out)
}
pub fn decrypt_stream_to_writer<R: Read, W: Write>(
reader: &mut R,
writer: &mut W,
aead_alg: AeadAlg,
key: &[u8; 32],
stream_info: &StreamInfo,
) -> Result<(), EncFileError> {
validate_header_chunk_size(stream_info.chunk_size as usize)?;
let expected_chunk_size = stream_info.chunk_size as usize;
let buffer_size = (expected_chunk_size / 4).clamp(64 * 1024, 512 * 1024);
let mut buf_reader = BufReader::with_capacity(buffer_size, reader);
let mut buf_writer = BufWriter::with_capacity(buffer_size, writer);
match aead_alg {
AeadAlg::XChaCha20Poly1305 => {
let cipher = create_xchacha20poly1305_cipher(key)?;
if stream_info.nonce_prefix.len() != 19 {
return Err(EncFileError::Malformed);
}
let nonce_prefix = GenericArray::<u8, U19>::from_slice(&stream_info.nonce_prefix);
let mut dec = DecryptorBE32::from_aead(cipher, nonce_prefix);
let mut ct_buf = Vec::new();
loop {
let (flags, ct_len) = parse_frame_from_reader(&mut buf_reader)?;
if ct_buf.len() < ct_len {
ct_buf.resize(ct_len, 0);
}
buf_reader.read_exact(&mut ct_buf[..ct_len]).map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
EncFileError::Malformed
} else {
EncFileError::Io(e)
}
})?;
let is_final = (flags & FLAG_FINAL) != 0;
if is_final {
let pt = Zeroizing::new(
dec.decrypt_last(&ct_buf[..ct_len])
.map_err(|_| EncFileError::Crypto)?,
);
buf_writer.write_all(&pt)?;
break;
} else {
let pt = Zeroizing::new(
dec.decrypt_next(&ct_buf[..ct_len])
.map_err(|_| EncFileError::Crypto)?,
);
buf_writer.write_all(&pt)?;
}
}
ct_buf.zeroize();
}
AeadAlg::Aes256GcmSiv => {
let cipher = create_aes256gcmsiv_cipher(key)?;
let prefix = &stream_info.nonce_prefix;
if prefix.len() != 8 {
return Err(EncFileError::Malformed);
}
let mut counter = 0u32;
let mut ct_buf = Vec::new();
let mut nonce_bytes = Vec::with_capacity(12);
loop {
let (flags, ct_len) = parse_frame_from_reader(&mut buf_reader)?;
if ct_buf.len() < ct_len {
ct_buf.resize(ct_len, 0);
}
buf_reader.read_exact(&mut ct_buf[..ct_len]).map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
EncFileError::Malformed
} else {
EncFileError::Io(e)
}
})?;
let is_final = (flags & FLAG_FINAL) != 0;
nonce_bytes.clear();
nonce_bytes.extend_from_slice(prefix);
nonce_bytes.extend_from_slice(&counter.to_be_bytes());
let pt = Zeroizing::new(
cipher
.decrypt(GenericArray::from_slice(&nonce_bytes), &ct_buf[..ct_len])
.map_err(|_| EncFileError::Crypto)?,
);
buf_writer.write_all(&pt)?;
nonce_bytes.zeroize();
counter = counter.wrapping_add(1);
if is_final {
break;
}
}
ct_buf.zeroize();
nonce_bytes.zeroize();
}
}
buf_writer.flush()?;
Ok(())
}