use assert_fs::prelude::*;
use ciborium::Value;
use secrecy::SecretString;
use std::fs;
use enc_file::{AeadAlg, EncFileError, EncryptOptions, decrypt_file, encrypt_file_streaming};
fn make_stream_ct(
alg: AeadAlg,
) -> (
assert_fs::TempDir,
std::path::PathBuf,
SecretString,
Vec<u8>,
) {
let td = assert_fs::TempDir::new().unwrap();
let pw = SecretString::new("pw".to_string().into());
let input = td.child("in.bin");
input.write_binary(b"some streaming input").unwrap();
let ct_path = td.child("ct.enc");
let opts = EncryptOptions {
alg,
stream: true,
chunk_size: 65536,
..Default::default()
};
let out = encrypt_file_streaming(input.path(), Some(ct_path.path()), pw.clone(), opts).unwrap();
let bytes = fs::read(&out).unwrap();
(td, out, pw, bytes)
}
fn update_or_insert_map_key(map: &mut Vec<(Value, Value)>, key: Value, new_value: Value) {
for (k, v) in map.iter_mut() {
if *k == key {
*v = new_value;
return;
}
}
map.push((key, new_value));
}
fn find_map_value_mut<'a>(map: &'a mut Vec<(Value, Value)>, key: &Value) -> Option<&'a mut Value> {
for (k, v) in map.iter_mut() {
if k == key {
return Some(v);
}
}
None
}
fn tamper_chunk_size(file_bytes: Vec<u8>, new_chunk: u32) -> Vec<u8> {
assert!(file_bytes.len() >= 4);
let mut len_le = [0u8; 4];
len_le.copy_from_slice(&file_bytes[..4]);
let header_len = u32::from_le_bytes(len_le) as usize;
let start = 4;
let end = 4 + header_len;
let mut header_val: Value = ciborium::de::from_reader(&file_bytes[start..end]).unwrap();
if let Value::Map(ref mut top) = header_val {
let stream_key = Value::Text("stream".to_string());
let cs_key = Value::Text("chunk_size".to_string());
if let Some(stream_value) = find_map_value_mut(top, &stream_key) {
if let Value::Map(stream_map) = stream_value {
update_or_insert_map_key(stream_map, cs_key, Value::Integer(new_chunk.into()));
}
} else {
let stream_map = vec![(cs_key, Value::Integer(new_chunk.into()))];
update_or_insert_map_key(top, stream_key, Value::Map(stream_map));
}
} else {
panic!("header is not a CBOR map");
}
let mut new_header = Vec::new();
ciborium::ser::into_writer(&header_val, &mut new_header).unwrap();
let mut rebuilt = Vec::with_capacity(4 + new_header.len() + (file_bytes.len() - end));
rebuilt.extend_from_slice(&(new_header.len() as u32).to_le_bytes());
rebuilt.extend_from_slice(&new_header);
rebuilt.extend_from_slice(&file_bytes[end..]); rebuilt
}
fn msg_contains_any(msg: &str, needles: &[&str]) -> bool {
let msg_l = msg.to_lowercase();
needles.iter().any(|n| msg_l.contains(n))
}
#[test]
fn dec_rejects_zero_chunk_size_in_header_for_both_algs() {
for alg in [AeadAlg::XChaCha20Poly1305, AeadAlg::Aes256GcmSiv] {
let (td, _ct_path, pw, bytes) = make_stream_ct(alg);
let tampered = tamper_chunk_size(bytes, 0);
let bad = td.child("bad.enc");
bad.write_binary(&tampered).unwrap();
let out = td.child("out.bin");
let res = decrypt_file(bad.path(), Some(out.path()), pw.clone());
match res {
Err(EncFileError::Invalid(msg)) => {
let ok = msg_contains_any(
msg,
&[
"chunk_size", "chunk size", "must be > 0", "cannot be zero", "zero", ],
);
assert!(ok, "unexpected Invalid message: {msg}");
}
other => panic!("expected Invalid for zero chunk_size, got: {:?}", other),
}
td.close().ok();
}
}
#[test]
fn dec_rejects_too_large_chunk_size_in_header_for_both_algs() {
const TAG: u32 = 16;
let too_big = (u32::MAX - TAG) + 1;
for alg in [AeadAlg::XChaCha20Poly1305, AeadAlg::Aes256GcmSiv] {
let (td, _ct_path, pw, bytes) = make_stream_ct(alg);
let tampered = tamper_chunk_size(bytes, too_big);
let bad = td.child("bad2.enc");
bad.write_binary(&tampered).unwrap();
let out = td.child("out2.bin");
let res = decrypt_file(bad.path(), Some(out.path()), pw.clone());
match res {
Err(EncFileError::Invalid(msg)) => {
let ok = msg_contains_any(
msg,
&[
"chunk_size", "chunk size", "32-bit", "too large for 32-bit", "too large for frame", "frame format", ],
);
assert!(ok, "unexpected Invalid message: {msg}");
}
other => panic!(
"expected Invalid for oversized chunk_size, got: {:?}",
other
),
}
td.close().ok();
}
}