use crate::entry::EntryName;
use crate::read_mla_entries_header;
pub use super::layers::traits::{InnerReaderTrait, InnerWriterTrait};
use super::{ArchiveEntryBlock, ArchiveEntryId, ArchiveReader, ArchiveWriter, Error};
use std::collections::HashMap;
use std::hash::BuildHasher;
use std::io::{self, Read, Seek, Write};
const DEFAULT_BUFFER_SIZE: usize = 128 * 1024;
pub fn mla_percent_escape(bytes: &[u8], bytes_to_preserve: &[u8]) -> Vec<u8> {
let mut s = Vec::with_capacity(bytes.len().checked_mul(3).unwrap());
for byte in bytes {
if bytes_to_preserve.contains(byte) {
s.push(*byte);
} else {
let low_nibble = nibble_to_hex_char(*byte & 0x0F);
let high_nibble = nibble_to_hex_char((*byte & 0xF0) >> 4);
s.push(b'%');
s.push(high_nibble);
s.push(low_nibble);
}
}
s
}
pub fn mla_percent_unescape(input: &[u8], bytes_to_allow: &[u8]) -> Option<Vec<u8>> {
let mut result = Vec::with_capacity(input.len());
let mut bytes = input.iter();
while let Some(b) = bytes.next() {
if bytes_to_allow.contains(b) {
result.push(*b);
} else if *b == b'%' {
let high_nibble = bytes.next().and_then(|c| hex_char_to_nibble(*c));
let low_nibble = bytes.next().and_then(|c| hex_char_to_nibble(*c));
match (high_nibble, low_nibble) {
(Some(high_nibble), Some(low_nibble)) => {
let decoded_byte = (high_nibble << 4) | low_nibble;
if bytes_to_allow.contains(&decoded_byte) {
return None;
}
result.push(decoded_byte);
}
_ => return None,
}
}
}
Some(result)
}
#[inline(always)]
#[allow(
clippy::arithmetic_side_effects,
reason = "Given valid values as input, cannot overflow"
)]
fn nibble_to_hex_char(nibble: u8) -> u8 {
if nibble <= 0x9 {
b'0' + nibble
} else {
b'a' + (nibble - 0xa)
}
}
#[inline(always)]
#[allow(
clippy::arithmetic_side_effects,
reason = "Given conditions on each branch, cannot overflow"
)]
fn hex_char_to_nibble(hex_char: u8) -> Option<u8> {
if hex_char.is_ascii_digit() {
Some(hex_char - b'0')
} else if (b'a'..=b'f').contains(&hex_char) {
Some(hex_char - b'a' + 0xa)
} else {
None
}
}
pub fn linear_extract<W1: InnerWriterTrait, R: InnerReaderTrait, S: BuildHasher>(
archive: &mut ArchiveReader<R>,
export: &mut HashMap<&EntryName, W1, S>,
) -> Result<(), Error> {
archive.src.rewind()?;
read_mla_entries_header(&mut archive.src)?;
let mut src = io::BufReader::with_capacity(DEFAULT_BUFFER_SIZE, &mut archive.src);
let mut id2name: HashMap<ArchiveEntryId, EntryName> = HashMap::new();
'read_block: loop {
match ArchiveEntryBlock::from(&mut src)? {
ArchiveEntryBlock::EntryStart { name, id, opts: _ } => {
if export.contains_key(&name) {
id2name.insert(id, name.clone());
}
}
ArchiveEntryBlock::EndOfEntry { id, .. } => {
id2name.remove(&id);
}
ArchiveEntryBlock::EntryContent { length, id, .. } => {
let copy_src = &mut (&mut src).take(length);
let mut extracted: bool = false;
if let Some(entry) = id2name.get(&id)
&& let Some(writer) = export.get_mut(entry)
{
io::copy(copy_src, writer)?;
extracted = true;
}
if !extracted {
io::copy(copy_src, &mut io::sink())?;
}
}
ArchiveEntryBlock::EndOfArchiveData => {
break 'read_block;
}
}
}
Ok(())
}
pub struct StreamWriter<'a, 'b, W: InnerWriterTrait> {
archive: &'b mut ArchiveWriter<'a, W>,
file_id: ArchiveEntryId,
}
impl<'a, 'b, W: InnerWriterTrait> StreamWriter<'a, 'b, W> {
pub fn new(archive: &'b mut ArchiveWriter<'a, W>, file_id: ArchiveEntryId) -> Self {
Self { archive, file_id }
}
}
impl<W: InnerWriterTrait> Write for StreamWriter<'_, '_, W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.archive
.append_entry_content(self.file_id, buf.len() as u64, buf)?;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
self.archive.flush()
}
}
pub mod shared_secret {
use std::io::{Read, Write};
use zeroize::Zeroize as _;
use crate::{
MLADeserialize, MLASerialize,
crypto::hybrid::{
HybridKemSharedSecret, HybridMultiRecipientEncapsulatedKey, MLADecryptionPrivateKey,
},
errors::{ConfigError, Error},
format::ArchiveHeader,
layers::{
encrypt::{ENCRYPTION_LAYER_MAGIC, read_encryption_header_after_magic},
raw::RawLayerTruncatedReader,
signature::{SIGNATURE_LAYER_MAGIC, SignatureLayerTruncatedReader},
traits::LayerTruncatedReader,
},
read_layer_magic,
};
pub struct MLADecryptionMetadata(pub(crate) HybridMultiRecipientEncapsulatedKey);
impl MLADecryptionMetadata {
pub fn from_archive<R: Read>(mut src: R) -> Result<Self, Error> {
let _ = ArchiveHeader::deserialize(&mut src)?;
let mut src: Box<dyn LayerTruncatedReader<R>> =
Box::new(RawLayerTruncatedReader::new(src));
let mut layer_magic = read_layer_magic(&mut src)?;
if layer_magic == SIGNATURE_LAYER_MAGIC {
src = Box::new(SignatureLayerTruncatedReader::new_skip_magic(src)?);
layer_magic = read_layer_magic(&mut src)?;
}
if &layer_magic == ENCRYPTION_LAYER_MAGIC {
let (read_encryption_metadata, _) = read_encryption_header_after_magic(&mut src)?;
Ok(MLADecryptionMetadata(
read_encryption_metadata.hybrid_multi_recipient_encapsulate_key,
))
} else {
Err(Error::EncryptionAskedButNotMarkedPresent)
}
}
pub fn serialize_metadata(&self, mut dest: impl Write) -> Result<u64, Error> {
self.0.serialize(&mut dest)
}
pub fn deserialize_metadata(mut src: impl Read) -> Result<Self, Error> {
Ok(MLADecryptionMetadata(MLADeserialize::deserialize(
&mut src,
)?))
}
pub fn decapsulate_shared_secret(
&self,
decryption_private_key: &MLADecryptionPrivateKey,
) -> Result<MLADecryptionSharedSecret, ConfigError> {
if let Ok(ss_hybrid) = decryption_private_key.decapsulate(&self.0) {
Ok(MLADecryptionSharedSecret(ss_hybrid))
} else {
Err(ConfigError::PrivateKeyNotFound)
}
}
}
#[derive(Clone)]
pub struct MLADecryptionSharedSecret(pub(crate) HybridKemSharedSecret);
impl MLADecryptionSharedSecret {
pub fn serialize_shared_secret(&self, mut dest: impl Write) -> Result<u64, Error> {
self.0.serialize(&mut dest)
}
pub fn deserialize_shared_secret(mut src: impl Read) -> Result<Self, Error> {
Ok(MLADecryptionSharedSecret(MLADeserialize::deserialize(
&mut src,
)?))
}
}
impl Drop for MLADecryptionSharedSecret {
fn drop(&mut self) {
self.0.zeroize();
}
}
}
#[cfg(test)]
mod tests {
use crypto::hybrid::generate_keypair_from_seed;
use rand::SeedableRng;
use rand::distributions::Standard;
use rand::prelude::Distribution;
use rand_chacha::ChaChaRng;
use super::*;
use crate::entry::ENTRY_NAME_RAW_CONTENT_ALLOWED_BYTES;
use crate::tests::build_archive;
use crate::*;
use std::io::Cursor;
const UNCOMPRESSED_DATA_SIZE: u32 = 4 * 1024 * 1024;
#[test]
fn full_linear_extract() {
let (mla, _sender_key, receiver_key, files) = build_archive(true, true, false, false);
let dest = Cursor::new(mla);
let config = ArchiveReaderConfig::without_signature_verification()
.with_encryption(&[receiver_key.0.get_decryption_private_key().clone()]);
let mut mla_read = ArchiveReader::from_config(dest, config).unwrap().0;
let file_list: Vec<EntryName> = mla_read
.list_entries()
.expect("reader.list_entries")
.cloned()
.collect();
let mut export: HashMap<&EntryName, Vec<u8>> =
file_list.iter().map(|fname| (fname, Vec::new())).collect();
linear_extract(&mut mla_read, &mut export).expect("Extract error");
for (entry, content) in &files {
assert_eq!(export.get(entry).unwrap(), content);
}
}
#[test]
fn one_linear_extract() {
let (mla, _sender_key, receiver_key, files) = build_archive(true, true, false, false);
let dest = Cursor::new(mla);
let config = ArchiveReaderConfig::without_signature_verification()
.with_encryption(&[receiver_key.0.get_decryption_private_key().clone()]);
let mut mla_read = ArchiveReader::from_config(dest, config).unwrap().0;
let mut export: HashMap<&EntryName, Vec<u8>> = HashMap::new();
export.insert(&files[0].0, Vec::new());
linear_extract(&mut mla_read, &mut export).expect("Extract error");
assert_eq!(export.get(&files[0].0).unwrap(), &files[0].1);
}
#[test]
fn linear_extract_big_file() {
let file_length = 4 * UNCOMPRESSED_DATA_SIZE as usize;
let file = Vec::new();
let mut rng = ChaChaRng::seed_from_u64(0);
let (private_key, public_key) = generate_keypair_from_seed([0; 32]);
let config = ArchiveWriterConfig::with_encryption_without_signature(&[public_key]).unwrap();
let mut mla = ArchiveWriter::from_config(file, config).expect("Writer init failed");
let entry = EntryName::from_arbitrary_bytes(b"my_file").unwrap();
let data: Vec<u8> = Standard.sample_iter(&mut rng).take(file_length).collect();
assert_eq!(data.len(), file_length);
mla.add_entry(entry.clone(), data.len() as u64, data.as_slice())
.unwrap();
let dest = mla.finalize().unwrap();
let dest = Cursor::new(dest);
let config =
ArchiveReaderConfig::without_signature_verification().with_encryption(&[private_key]);
let mut mla_read = ArchiveReader::from_config(dest, config).unwrap().0;
let mut export: HashMap<&EntryName, Vec<u8>> = HashMap::new();
export.insert(&entry, Vec::new());
linear_extract(&mut mla_read, &mut export).expect("Extract error");
assert_eq!(export.get(&entry).unwrap(), &data);
}
#[test]
fn stream_writer() {
let file = Vec::new();
let config = ArchiveWriterConfig::without_encryption_without_signature()
.unwrap()
.without_compression();
let mut mla = ArchiveWriter::from_config(file, config).expect("Writer init failed");
let fake_file = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let id = mla
.start_entry(EntryName::from_arbitrary_bytes(b"my_file").unwrap())
.unwrap();
let mut sw = StreamWriter::new(&mut mla, id);
sw.write_all(&fake_file[..5]).unwrap();
sw.write_all(&fake_file[5..]).unwrap();
mla.end_entry(id).unwrap();
let id = mla
.start_entry(EntryName::from_arbitrary_bytes(b"my_entry2").unwrap())
.unwrap();
let mut sw = StreamWriter::new(&mut mla, id);
assert_eq!(
io::copy(&mut fake_file.as_slice(), &mut sw).unwrap(),
fake_file.len() as u64
);
mla.end_entry(id).unwrap();
let dest = mla.finalize().unwrap();
let buf = Cursor::new(dest.as_slice());
let mut mla_read = ArchiveReader::from_config(
buf,
ArchiveReaderConfig::without_signature_verification().without_encryption(),
)
.unwrap()
.0;
let mut content1 = Vec::new();
mla_read
.get_entry(EntryName::from_arbitrary_bytes(b"my_file").unwrap())
.unwrap()
.unwrap()
.data
.read_to_end(&mut content1)
.unwrap();
assert_eq!(content1.as_slice(), fake_file.as_slice());
let mut content2 = Vec::new();
mla_read
.get_entry(EntryName::from_arbitrary_bytes(b"my_entry2").unwrap())
.unwrap()
.unwrap()
.data
.read_to_end(&mut content2)
.unwrap();
assert_eq!(content2.as_slice(), fake_file.as_slice());
}
#[test]
fn test_escape() {
assert_eq!(
b"%2f".as_slice(),
mla_percent_escape(b"/", &ENTRY_NAME_RAW_CONTENT_ALLOWED_BYTES).as_slice()
);
assert_eq!(
b"/".as_slice(),
mla_percent_unescape(b"%2f", &ENTRY_NAME_RAW_CONTENT_ALLOWED_BYTES)
.unwrap()
.as_slice()
);
}
}