use crate::config::TruncatedReaderDecryptionMode;
use crate::crypto::aesgcm::{
AesGcm256, ConstantTimeEq, KEY_COMMITMENT_SIZE, Key, Nonce, TAG_LENGTH, Tag,
};
use crate::crypto::MaybeSeededRNG;
use crate::crypto::hpke::{compute_nonce, key_schedule_base_hybrid_kem};
use crate::crypto::hybrid::{
HybridKemSharedSecret, HybridMultiRecipientEncapsulatedKey, HybridMultiRecipientsPublicKeys,
MLADecryptionPrivateKey, MLAEncryptionPublicKey,
};
use crate::layers::traits::{
InnerWriterTrait, InnerWriterType, LayerFailSafeReader, LayerReader, LayerWriter,
};
use crate::{EMPTY_TAIL_OPTS_SERIALIZATION, Error, MLADeserialize, MLASerialize, Opts};
use std::io;
use std::io::{BufReader, Cursor, Read, Seek, SeekFrom, Write};
use crate::errors::ConfigError;
use kem::{Decapsulate, Encapsulate};
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use zeroize::{Zeroize, ZeroizeOnDrop};
use super::position::PositionLayerReader;
use super::strip_head_tail::StripHeadTailReader;
use super::traits::InnerReaderTrait;
const CIPHER_BUF_SIZE: u64 = 4096;
const NORMAL_CHUNK_PT_SIZE: u64 = 128 * 1024;
const NORMAL_CHUNK_PT_AND_TAG_SIZE: u64 = NORMAL_CHUNK_PT_SIZE + TAG_LENGTH as u64;
const NORMAL_CHUNK_PT_AND_TAG_USIZE: usize = NORMAL_CHUNK_PT_AND_TAG_SIZE as usize;
const ASSOCIATED_DATA: &[u8; 0] = b"";
const FINAL_ASSOCIATED_DATA: &[u8; 8] = b"FINALAAD";
const FINAL_BLOCK_CONTENT: &[u8; 10] = b"FINALBLOCK";
const FINAL_INFO_SIZE: u64 = FINAL_INFO_USIZE as u64;
const FINAL_INFO_USIZE: usize = FINAL_BLOCK_MAGIC.len() + FINAL_BLOCK_CONTENT.len() + TAG_LENGTH;
const FINAL_INFO_SIZE_WOM: usize = FINAL_BLOCK_CONTENT.len() + TAG_LENGTH;
const FINAL_BLOCK_MAGIC: &[u8] = b"M0FNLBLK";
pub const ENCRYPTION_LAYER_MAGIC: &[u8; 8] = b"ENCMLAAA";
const END_OF_ENCRYPTED_INNER_LAYER_MAGIC: &[u8; 8] = b"ENCMLAAB";
const CHUNK_MAGIC: &[u8; 8] = b"M0ENCCNK";
const CHUNK_HEAD_SIZE: u64 = CHUNK_HEAD_USIZE as u64;
const CHUNK_HEAD_USIZE: usize = CHUNK_MAGIC.len() + 8;
const M0_CHUNK_SIZE: u64 = M0_CHUNK_USIZE as u64;
const M0_CHUNK_USIZE: usize = CHUNK_MAGIC.len() + 8 + NORMAL_CHUNK_PT_AND_TAG_USIZE;
const KEY_COMMITMENT_CHAIN: &[u8; KEY_COMMITMENT_SIZE] =
b"-KEY COMMITMENT--KEY COMMITMENT--KEY COMMITMENT--KEY COMMITMENT-";
fn build_key_commitment_chain(key: &Key, nonce: &Nonce) -> Result<KeyCommitmentAndTag, Error> {
let mut key_commitment = [0u8; KEY_COMMITMENT_SIZE];
key_commitment.copy_from_slice(KEY_COMMITMENT_CHAIN);
let mut cipher = AesGcm256::new(key, &compute_nonce(nonce, 0), ASSOCIATED_DATA)?;
cipher.encrypt(&mut key_commitment);
let mut tag = [0u8; TAG_LENGTH];
tag.copy_from_slice(&cipher.into_tag());
Ok(KeyCommitmentAndTag {
key_commitment,
tag,
})
}
fn check_key_commitment(
key: &Key,
nonce: &Nonce,
commitment: &KeyCommitmentAndTag,
) -> Result<(), ConfigError> {
let mut key_commitment = commitment.key_commitment;
let mut cipher = AesGcm256::new(key, &compute_nonce(nonce, 0), ASSOCIATED_DATA)
.or(Err(ConfigError::KeyCommitmentCheckingError))?;
let tag = cipher.decrypt(&mut key_commitment);
if tag.ct_eq(&commitment.tag).unwrap_u8() == 1 {
Ok(())
} else {
Err(ConfigError::KeyCommitmentCheckingError)
}
}
const FIRST_DATA_CHUNK_NUMBER: u64 = 1;
const HPKE_INFO_LAYER: &[u8] = b"MLA Encrypt Layer";
struct KeyCommitmentAndTag {
key_commitment: [u8; KEY_COMMITMENT_SIZE],
tag: [u8; TAG_LENGTH],
}
impl<W: Write> MLASerialize<W> for KeyCommitmentAndTag {
fn serialize(&self, dest: &mut W) -> Result<u64, Error> {
let mut serialization_length = 0;
serialization_length += self.key_commitment.as_slice().serialize(dest)?;
serialization_length += self.tag.as_slice().serialize(dest)?;
Ok(serialization_length)
}
}
impl<R: Read> MLADeserialize<R> for KeyCommitmentAndTag {
fn deserialize(src: &mut R) -> Result<Self, Error> {
let key_commitment = MLADeserialize::deserialize(src)?;
let tag = MLADeserialize::deserialize(src)?;
Ok(Self {
key_commitment,
tag,
})
}
}
pub(crate) fn get_crypto_rng() -> Result<ChaCha20Rng, Error> {
Ok(ChaCha20Rng::from_entropy())
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub(crate) struct InternalEncryptionConfig {
pub(crate) key: Key,
pub(crate) nonce: Nonce,
}
impl InternalEncryptionConfig {
fn from(shared_secret: HybridKemSharedSecret) -> Result<Self, Error> {
let (key, nonce) = key_schedule_base_hybrid_kem(&shared_secret.0, HPKE_INFO_LAYER)?;
Ok(Self { key, nonce })
}
}
pub struct EncryptionPersistentConfig {
pub hybrid_multi_recipient_encapsulate_key: HybridMultiRecipientEncapsulatedKey,
key_commitment: KeyCommitmentAndTag,
}
impl<W: Write> MLASerialize<W> for EncryptionPersistentConfig {
fn serialize(&self, dest: &mut W) -> Result<u64, Error> {
let mut serialization_length = 0;
serialization_length += self
.hybrid_multi_recipient_encapsulate_key
.serialize(dest)?;
serialization_length += self.key_commitment.serialize(dest)?;
Ok(serialization_length)
}
}
impl<R: Read> MLADeserialize<R> for EncryptionPersistentConfig {
fn deserialize(src: &mut R) -> Result<Self, Error> {
let hybrid_multi_recipient_encapsulate_key = MLADeserialize::deserialize(src)?;
let key_commitment = MLADeserialize::deserialize(src)?;
Ok(Self {
hybrid_multi_recipient_encapsulate_key,
key_commitment,
})
}
}
pub(crate) struct EncryptionConfig {
public_keys: HybridMultiRecipientsPublicKeys,
pub(crate) rng: MaybeSeededRNG,
}
impl EncryptionConfig {
pub(crate) fn to_persistent(
&self,
) -> Result<(EncryptionPersistentConfig, InternalEncryptionConfig), Error> {
let (hybrid_multi_recipient_encapsulate_key, ss_hybrid) =
self.public_keys.encapsulate(&mut self.rng.get_rng()?)?;
let cryptographic_material = InternalEncryptionConfig::from(ss_hybrid).or(Err(
Error::ConfigError(ConfigError::KeyWrappingComputationError),
))?;
let key_commitment =
build_key_commitment_chain(&cryptographic_material.key, &cryptographic_material.nonce)
.or(Err(Error::ConfigError(
ConfigError::KeyCommitmentComputationError,
)))?;
Ok((
EncryptionPersistentConfig {
hybrid_multi_recipient_encapsulate_key,
key_commitment,
},
cryptographic_material,
))
}
pub(crate) fn new(
encryption_public_keys: &[MLAEncryptionPublicKey],
) -> Result<Self, ConfigError> {
if encryption_public_keys.is_empty() {
return Err(ConfigError::EncryptionKeyIsMissing);
}
let public_keys = HybridMultiRecipientsPublicKeys {
keys: encryption_public_keys.to_vec(),
};
Ok(Self {
public_keys,
rng: MaybeSeededRNG::default(),
})
}
}
#[derive(Default)]
pub struct EncryptionReaderConfig {
private_keys: Vec<MLADecryptionPrivateKey>,
encrypt_parameters: Option<(Key, Nonce)>,
}
impl EncryptionReaderConfig {
pub(crate) fn set_private_keys(&mut self, private_keys: &[MLADecryptionPrivateKey]) {
self.private_keys = private_keys.to_vec();
}
pub fn load_persistent(
&mut self,
config: EncryptionPersistentConfig,
) -> Result<(), ConfigError> {
if self.private_keys.is_empty() {
return Err(ConfigError::PrivateKeyNotSet);
}
for private_key in &self.private_keys {
if let Ok(ss_hybrid) =
private_key.decapsulate(&config.hybrid_multi_recipient_encapsulate_key)
{
let (key, nonce) = key_schedule_base_hybrid_kem(&ss_hybrid.0, HPKE_INFO_LAYER)
.or(Err(ConfigError::KeyWrappingComputationError))?;
self.encrypt_parameters = Some((key, nonce));
break;
};
}
let (key, nonce) = &self
.encrypt_parameters
.ok_or(ConfigError::PrivateKeyNotFound)?;
check_key_commitment(key, nonce, &config.key_commitment)
}
}
pub(crate) struct EncryptionLayerWriter<'a, W: 'a + InnerWriterTrait>(
InternalEncryptionLayerWriter<'a, W>,
);
impl<'a, W: 'a + InnerWriterTrait> EncryptionLayerWriter<'a, W> {
pub fn new(
mut inner: InnerWriterType<'a, W>,
encryption_config: &EncryptionConfig,
) -> Result<Self, Error> {
let (persistent_config, internal_config) =
EncryptionConfig::to_persistent(encryption_config)?;
inner.write_all(ENCRYPTION_LAYER_MAGIC)?;
let _ = Opts.dump(&mut inner)?;
let encryption_method_id = 0u16;
encryption_method_id.serialize(&mut inner)?;
persistent_config.serialize(&mut inner)?;
inner.write_all(CHUNK_MAGIC)?;
1u64.serialize(&mut inner)?;
Ok(Self(InternalEncryptionLayerWriter::new(
inner,
&internal_config,
)?))
}
}
impl<'a, W: 'a + InnerWriterTrait> Write for EncryptionLayerWriter<'a, W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
impl<'a, W: 'a + InnerWriterTrait> LayerWriter<'a, W> for EncryptionLayerWriter<'a, W> {
fn finalize(self: Box<Self>) -> Result<W, Error> {
Box::new(self.0).finalize()
}
}
struct InternalEncryptionLayerWriter<'a, W: 'a + InnerWriterTrait> {
inner: InnerWriterType<'a, W>,
cipher: AesGcm256,
key: Key,
base_nonce: Nonce,
current_chunk_offset: u64,
current_ctr: u64,
}
impl<'a, W: 'a + InnerWriterTrait> InternalEncryptionLayerWriter<'a, W> {
pub fn new(
inner: InnerWriterType<'a, W>,
internal_config: &InternalEncryptionConfig,
) -> Result<Self, Error> {
Ok(Self {
inner,
key: internal_config.key,
base_nonce: internal_config.nonce,
cipher: AesGcm256::new(
&internal_config.key,
&compute_nonce(&internal_config.nonce, FIRST_DATA_CHUNK_NUMBER),
ASSOCIATED_DATA,
)?,
current_chunk_offset: 0,
current_ctr: FIRST_DATA_CHUNK_NUMBER,
})
}
fn renew_cipher_aad(&mut self, aad: &[u8]) -> Result<Tag, Error> {
self.current_ctr = self.current_ctr.checked_add(1).ok_or(Error::HPKEError)?;
self.current_chunk_offset = 0;
let cipher = AesGcm256::new(
&self.key,
&compute_nonce(&self.base_nonce, self.current_ctr),
aad,
)?;
let old_cipher = std::mem::replace(&mut self.cipher, cipher);
Ok(old_cipher.into_tag())
}
fn renew_cipher(&mut self) -> Result<Tag, Error> {
self.renew_cipher_aad(ASSOCIATED_DATA)
}
fn last_renew_cipher(&mut self) -> Result<Tag, Error> {
self.renew_cipher_aad(FINAL_ASSOCIATED_DATA)
}
}
impl<'a, W: 'a + InnerWriterTrait> LayerWriter<'a, W> for InternalEncryptionLayerWriter<'a, W> {
fn finalize(mut self: Box<Self>) -> Result<W, Error> {
let last_content_tag = self.last_renew_cipher()?;
self.inner.write_all(&last_content_tag)?;
self.inner.write_all(FINAL_BLOCK_MAGIC)?;
self.write_all(FINAL_BLOCK_CONTENT)?;
let final_tag = self.renew_cipher()?;
self.inner.write_all(&final_tag)?;
self.inner.write_all(END_OF_ENCRYPTED_INNER_LAYER_MAGIC)?;
self.inner.write_all(EMPTY_TAIL_OPTS_SERIALIZATION)?;
self.inner.finalize()
}
}
impl<W: InnerWriterTrait> Write for InternalEncryptionLayerWriter<'_, W> {
#[allow(clippy::comparison_chain)]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.current_chunk_offset > NORMAL_CHUNK_PT_SIZE {
return Err(
Error::WrongWriterState("[EncryptWriter] Chunk too big".to_string()).into(),
);
} else if self.current_chunk_offset == NORMAL_CHUNK_PT_SIZE {
let tag = self.renew_cipher()?;
self.inner.write_all(&tag)?;
self.inner.write_all(CHUNK_MAGIC)?;
self.current_ctr.serialize(&mut self.inner)?;
}
let size = std::cmp::min(
std::cmp::min(CIPHER_BUF_SIZE, buf.len() as u64),
NORMAL_CHUNK_PT_SIZE - self.current_chunk_offset,
);
let mut buf_tmp = Vec::with_capacity(size as usize);
let buf_src = BufReader::new(buf);
io::copy(&mut buf_src.take(size), &mut buf_tmp)?;
self.cipher.encrypt(&mut buf_tmp);
self.inner.write_all(&buf_tmp)?;
self.current_chunk_offset += size;
Ok(size as usize)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
pub(crate) fn read_encryption_header_after_magic<R: Read>(
src: &mut R,
) -> Result<(EncryptionPersistentConfig, u64), Error> {
let mut src = PositionLayerReader::new(src);
let _ = Opts::from_reader(&mut src)?; let _encryption_method_id = u16::deserialize(&mut src)?;
let read_encryption_metadata = EncryptionPersistentConfig::deserialize(&mut src)?;
let encryption_header_length = src
.position()
.checked_add(8)
.ok_or(Error::DeserializationError)?;
Ok((read_encryption_metadata, encryption_header_length))
}
pub(crate) struct EncryptionLayerReader<'a, R: InnerReaderTrait>(
InternalEncryptionLayerReader<Box<dyn 'a + LayerReader<'a, R>>>,
);
impl<'a, R: 'a + InnerReaderTrait> EncryptionLayerReader<'a, R> {
pub(crate) fn new_skip_magic(
mut inner: Box<dyn 'a + LayerReader<'a, R>>,
mut reader_config: EncryptionReaderConfig,
persistent_config: Option<EncryptionPersistentConfig>,
) -> Result<Self, Error> {
let (read_encryption_metadata, encryption_header_length) =
read_encryption_header_after_magic(&mut inner)?;
let persistent_config = persistent_config.unwrap_or(read_encryption_metadata);
let raw_encryption_layer_length = inner.seek(SeekFrom::End(0))?;
inner.seek(SeekFrom::Current(-8))?;
let encryption_footer_options_length = u64::deserialize(&mut inner)?;
let encryption_footer_length = encryption_footer_options_length
.checked_add(8)
.ok_or(Error::DeserializationError)?;
inner.seek(SeekFrom::Start(encryption_header_length))?;
reader_config.load_persistent(persistent_config)?;
let inner: Box<dyn 'a + LayerReader<'a, R>> = Box::new(StripHeadTailReader::new(
inner,
encryption_header_length,
encryption_footer_length,
raw_encryption_layer_length,
0,
)?);
let inner = InternalEncryptionLayerReader::new(inner, reader_config, None)?;
Ok(Self(inner))
}
}
impl<'a, R: 'a + InnerReaderTrait> Read for EncryptionLayerReader<'a, R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl<'a, R: 'a + InnerReaderTrait> Seek for EncryptionLayerReader<'a, R> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
self.0.seek(pos)
}
}
impl<'a, R: 'a + InnerReaderTrait> LayerReader<'a, R> for EncryptionLayerReader<'a, R> {
fn into_raw(self: Box<Self>) -> R {
self.0.inner.into_raw()
}
fn initialize(&mut self) -> Result<(), Error> {
self.0.inner.initialize()?;
self.0.initialize()?;
Ok(())
}
}
struct InternalEncryptionLayerReader<R> {
inner: R,
cipher: AesGcm256,
key: Key,
nonce: Nonce,
chunk_cache: Cursor<Vec<u8>>,
next_chunk_cache: Vec<u8>,
all_plaintext_size: u64,
current_position_in_this_layer: u64,
_final_block: Option<Vec<u8>>,
truncated_decryption_mode: Option<TruncatedReaderDecryptionMode>,
}
impl<R: Read> InternalEncryptionLayerReader<R> {
fn new(
mut inner: R,
config: EncryptionReaderConfig,
truncated_decryption_mode: Option<TruncatedReaderDecryptionMode>,
) -> Result<Self, Error> {
let mut next_chunk_cache = Vec::with_capacity(M0_CHUNK_USIZE);
(&mut inner)
.take(M0_CHUNK_SIZE)
.read_to_end(&mut next_chunk_cache)?;
match config.encrypt_parameters {
Some((key, nonce)) => Ok(Self {
inner,
cipher: AesGcm256::new(
&key,
&compute_nonce(&nonce, FIRST_DATA_CHUNK_NUMBER),
ASSOCIATED_DATA,
)?,
key,
nonce,
chunk_cache: Cursor::new(Vec::with_capacity(M0_CHUNK_USIZE)),
next_chunk_cache,
all_plaintext_size: 0,
current_position_in_this_layer: 0,
_final_block: None,
truncated_decryption_mode,
}),
None => Err(Error::PrivateKeyNeeded),
}
}
fn is_at_least_in_last_data_chunk(&self) -> bool {
self.current_position_in_this_layer
>= (self.last_data_chunk_number() * NORMAL_CHUNK_PT_SIZE)
}
fn current_data_chunk_number(&self) -> u64 {
self.current_position_in_this_layer / NORMAL_CHUNK_PT_SIZE
}
fn last_data_chunk_number(&self) -> u64 {
if self.all_plaintext_size % NORMAL_CHUNK_PT_SIZE == 0 {
(self.all_plaintext_size / NORMAL_CHUNK_PT_SIZE).saturating_sub(1)
} else {
self.all_plaintext_size / NORMAL_CHUNK_PT_SIZE
}
}
fn _check_final(&mut self) -> Result<(), Error> {
if let Some(mut final_block) = self._final_block.take() {
let mut cipher = AesGcm256::new(
&self.key,
&compute_nonce(
&self.nonce,
self.current_data_chunk_number() + 1 + FIRST_DATA_CHUNK_NUMBER,
),
FINAL_ASSOCIATED_DATA,
)?;
let data_part = final_block
.get_mut(..FINAL_BLOCK_CONTENT.len())
.ok_or_else(|| {
Error::WrongReaderState("Invalid final block data part".to_owned())
})?;
let computed_tag = cipher.decrypt(data_part);
if data_part != FINAL_BLOCK_CONTENT {
return Err(Error::InvalidLastTag);
}
let tag_part = &final_block
.get(FINAL_BLOCK_CONTENT.len()..)
.ok_or_else(|| {
Error::WrongReaderState("Invalid final block data part".to_owned())
})?;
if computed_tag.ct_eq(tag_part).unwrap_u8() != 1 {
Err(Error::InvalidLastTag)
} else {
Ok(())
}
} else {
Err(Error::AssertionError(
"check_final should have Some(final_block)".to_owned(),
))
}
}
fn final_block_magic_scan(&self) -> Option<usize> {
if let Some(pos) = self
.chunk_cache
.get_ref()
.windows(FINAL_BLOCK_MAGIC.len())
.position(|candidate| candidate == FINAL_BLOCK_MAGIC)
{
Some(pos)
} else {
if let Some(pos) = self
.next_chunk_cache
.windows(FINAL_BLOCK_MAGIC.len())
.position(|candidate| candidate == FINAL_BLOCK_MAGIC)
{
Some(self.chunk_cache.get_ref().len() + pos)
} else {
(1..FINAL_BLOCK_MAGIC.len())
.find(|number_in_chunk_cache| {
let chunk_cache = self.chunk_cache.get_ref();
let number_in_next_chunk_cache =
FINAL_BLOCK_MAGIC.len() - *number_in_chunk_cache;
if chunk_cache.len() < *number_in_chunk_cache
|| self.next_chunk_cache.len() < number_in_next_chunk_cache
{
false
} else {
let magic_part1 = &FINAL_BLOCK_MAGIC[..*number_in_chunk_cache];
let magic_part2 = &FINAL_BLOCK_MAGIC[*number_in_chunk_cache..];
let end_of_chunk_cache =
&chunk_cache[(chunk_cache.len() - *number_in_chunk_cache)..];
let start_of_next_chunk_cache =
&self.next_chunk_cache[..number_in_next_chunk_cache];
end_of_chunk_cache == magic_part1
&& start_of_next_chunk_cache == magic_part2
}
})
.map(|number_in_chunk_cache| {
self.chunk_cache.get_ref().len() - number_in_chunk_cache
})
}
}
}
fn try_tag_pos_at_from_0_to(&self, _max_pos: usize) -> Option<usize> {
None
}
fn get_end_of_tag_pos(&self) -> Result<usize, Error> {
let chunk_cache_len = self.chunk_cache.get_ref().len();
let next_chunk_cache_len = self.next_chunk_cache.len();
let end_of_tag_pos =
if chunk_cache_len == M0_CHUNK_USIZE && next_chunk_cache_len > FINAL_INFO_USIZE {
M0_CHUNK_USIZE
} else if chunk_cache_len < TAG_LENGTH {
return Err(Error::TruncatedTag);
} else {
match self.final_block_magic_scan() {
None => {
if let Some(pos) = self
.try_tag_pos_at_from_0_to(chunk_cache_len.saturating_sub(TAG_LENGTH))
{
pos
} else if matches!(
self.truncated_decryption_mode,
Some(TruncatedReaderDecryptionMode::DataEvenUnauthenticated)
) {
chunk_cache_len + TAG_LENGTH } else {
return Err(Error::UnknownTagPosition);
}
}
Some(pos) => pos,
}
};
Ok(end_of_tag_pos)
}
fn load_in_cache(&mut self) -> Result<Option<()>, Error> {
std::mem::swap(self.chunk_cache.get_mut(), &mut self.next_chunk_cache);
self.chunk_cache.set_position(0);
self.next_chunk_cache.clear();
(&mut self.inner)
.take(M0_CHUNK_SIZE)
.read_to_end(&mut self.next_chunk_cache)?;
let end_of_tag_pos = if self.truncated_decryption_mode.is_none()
&& self.is_at_least_in_last_data_chunk()
{
let content_len =
self.all_plaintext_size
.saturating_sub(self.current_position_in_this_layer) as usize;
if content_len == 0 {
return Ok(None);
} else {
content_len + TAG_LENGTH + CHUNK_HEAD_USIZE
}
} else {
self.get_end_of_tag_pos()?
};
let end_of_ciphertext_pos = end_of_tag_pos.saturating_sub(TAG_LENGTH);
if end_of_ciphertext_pos == 0 {
return Ok(None);
}
self.cipher = AesGcm256::new(
&self.key,
&compute_nonce(
&self.nonce,
self.current_data_chunk_number() + FIRST_DATA_CHUNK_NUMBER,
),
ASSOCIATED_DATA,
)?;
let chunk_cache_slice = self.chunk_cache.get_mut().as_mut_slice();
let data_part = chunk_cache_slice
.get_mut(CHUNK_HEAD_USIZE..end_of_ciphertext_pos)
.ok_or_else(|| Error::WrongReaderState("Invalid chunk cache data part".to_owned()))?;
if matches!(
self.truncated_decryption_mode,
Some(TruncatedReaderDecryptionMode::DataEvenUnauthenticated)
) {
self.cipher.decrypt_unauthenticated(data_part);
} else {
let computed_tag = self.cipher.decrypt(data_part);
let tag_part = chunk_cache_slice
.get(end_of_ciphertext_pos..end_of_tag_pos)
.ok_or_else(|| {
Error::WrongReaderState("Invalid chunk cache tag part".to_owned())
})?;
if computed_tag.ct_eq(tag_part).unwrap_u8() != 1 {
return Err(Error::AuthenticatedDecryptionWrongTag);
}
}
self.chunk_cache.get_mut().truncate(end_of_ciphertext_pos);
self.chunk_cache.set_position(CHUNK_HEAD_SIZE);
Ok(Some(()))
}
fn read_internal(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
let cache_to_consume = NORMAL_CHUNK_PT_AND_TAG_SIZE - self.chunk_cache.position();
if cache_to_consume == 0 {
if self.load_in_cache()?.is_none() {
return Ok(0);
}
return self.read_internal(buf);
}
let size = std::cmp::min(cache_to_consume as usize, buf.len());
let chunk_cache_read_size = self.chunk_cache.read(&mut buf[..size])?;
self.current_position_in_this_layer += chunk_cache_read_size as u64;
if chunk_cache_read_size == 0 {
}
Ok(chunk_cache_read_size)
}
}
impl<R: Read + Seek> InternalEncryptionLayerReader<R> {
fn check_last_block(&mut self) -> Result<(), Error> {
self.inner.seek(SeekFrom::End(
-((FINAL_INFO_SIZE_WOM + END_OF_ENCRYPTED_INNER_LAYER_MAGIC.len()) as i64),
))?;
self.cipher = AesGcm256::new(
&self.key,
&compute_nonce(
&self.nonce,
self.last_data_chunk_number() + 1 + FIRST_DATA_CHUNK_NUMBER,
),
FINAL_ASSOCIATED_DATA,
)?;
let mut data_and_tag = Vec::with_capacity(FINAL_INFO_SIZE_WOM);
let data_and_tag_read = self.inner.read_to_end(&mut data_and_tag)?;
if data_and_tag_read < FINAL_INFO_SIZE_WOM {
return Err(Error::InvalidLastTag);
}
let mut tag = [0u8; TAG_LENGTH];
let tag_part = data_and_tag
.get(FINAL_BLOCK_CONTENT.len()..FINAL_INFO_SIZE_WOM)
.ok_or_else(|| {
Error::WrongReaderState(
"Invalid final block data tag part in check last block".to_owned(),
)
})?;
tag.copy_from_slice(tag_part);
data_and_tag.truncate(FINAL_BLOCK_CONTENT.len());
let mut data = data_and_tag;
let expected_tag = self.cipher.decrypt(data.as_mut_slice());
if expected_tag.ct_eq(&tag).unwrap_u8() != 1 || data != FINAL_BLOCK_CONTENT {
Err(Error::InvalidLastTag)
} else {
Ok(())
}
}
fn initialize(&mut self) -> Result<(), Error> {
self.set_all_plaintext_size()?;
self.check_last_block()?;
self.rewind()?;
Ok(())
}
fn set_all_plaintext_size(&mut self) -> Result<(), Error> {
let input_size = self.inner.seek(SeekFrom::End(0))?;
let input_size_without_final_nor_header = input_size.saturating_sub(
END_OF_ENCRYPTED_INNER_LAYER_MAGIC.len() as u64 + FINAL_INFO_SIZE + CHUNK_HEAD_SIZE,
);
let chunk_number_at_end_of_data = input_size_without_final_nor_header / M0_CHUNK_SIZE;
let last_chunk_size = input_size_without_final_nor_header % M0_CHUNK_SIZE;
self.all_plaintext_size = if last_chunk_size == 0 {
chunk_number_at_end_of_data * NORMAL_CHUNK_PT_SIZE
} else {
chunk_number_at_end_of_data * NORMAL_CHUNK_PT_SIZE + last_chunk_size - TAG_LENGTH as u64
};
Ok(())
}
}
impl<R: InnerReaderTrait> Read for InternalEncryptionLayerReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.read_internal(buf).map_err(mla_error_to_io_error)
}
}
fn this_layer_position_to_inner_position(position: u64) -> u64 {
let cur_chunk = position / NORMAL_CHUNK_PT_SIZE;
let cur_chunk_pos = position % NORMAL_CHUNK_PT_SIZE;
cur_chunk * M0_CHUNK_SIZE + cur_chunk_pos + CHUNK_HEAD_SIZE
}
fn _inner_position_in_plaintext_to_this_layer_position(position: u64) -> u64 {
let cur_chunk = position / M0_CHUNK_SIZE;
let cur_chunk_pos = position % M0_CHUNK_SIZE - CHUNK_HEAD_SIZE;
cur_chunk * NORMAL_CHUNK_PT_SIZE + cur_chunk_pos
}
impl<R: Read + Seek> Seek for InternalEncryptionLayerReader<R> {
fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
match pos {
SeekFrom::Start(asked_pos) => {
let inner_position_of_asked_pos = this_layer_position_to_inner_position(asked_pos);
let asked_pos_chunk_number = inner_position_of_asked_pos / M0_CHUNK_SIZE;
let inner_position_of_m0_chunk_start = asked_pos_chunk_number * M0_CHUNK_SIZE;
let asked_pos_in_chunk_plaintext =
(inner_position_of_asked_pos % M0_CHUNK_SIZE) - CHUNK_HEAD_SIZE;
self.inner
.seek(SeekFrom::Start(inner_position_of_m0_chunk_start))?;
self.current_position_in_this_layer = asked_pos_chunk_number * NORMAL_CHUNK_PT_SIZE;
self.next_chunk_cache.clear();
(&mut self.inner)
.take(M0_CHUNK_SIZE)
.read_to_end(&mut self.next_chunk_cache)?;
self.load_in_cache()?;
self.chunk_cache
.set_position(CHUNK_HEAD_SIZE + asked_pos_in_chunk_plaintext);
self.current_position_in_this_layer += asked_pos_in_chunk_plaintext;
Ok(asked_pos)
}
SeekFrom::Current(value) => {
if value == 0 {
Ok(self.current_position_in_this_layer)
} else {
self.seek(SeekFrom::Start(
(self.current_position_in_this_layer as i64 + value) as u64,
))
}
}
SeekFrom::End(pos) => {
if pos > 0 {
return Err(Error::EndOfStream.into());
}
self.seek(SeekFrom::Start(
(pos + self.all_plaintext_size as i64) as u64,
))
}
}
}
}
pub(crate) struct EncryptionLayerFailSafeReader<'a, R: Read> {
inner: InternalEncryptionLayerReader<Box<dyn 'a + LayerFailSafeReader<'a, R>>>,
}
impl<'a, R: 'a + Read> EncryptionLayerFailSafeReader<'a, R> {
fn new_skip_header(
inner: Box<dyn 'a + LayerFailSafeReader<'a, R>>,
config: EncryptionReaderConfig,
truncated_decryption_mode: TruncatedReaderDecryptionMode,
) -> Result<Self, Error> {
let mut inner =
InternalEncryptionLayerReader::new(inner, config, Some(truncated_decryption_mode))?;
inner.load_in_cache()?;
Ok(Self { inner })
}
pub(crate) fn new_skip_magic(
mut inner: Box<dyn 'a + LayerFailSafeReader<'a, R>>,
mut reader_config: EncryptionReaderConfig,
persistent_config: Option<EncryptionPersistentConfig>,
truncated_decryption_mode: TruncatedReaderDecryptionMode,
) -> Result<Self, Error> {
let (read_encryption_metadata, _) = read_encryption_header_after_magic(&mut inner)?;
let persistent_config = persistent_config.unwrap_or(read_encryption_metadata); reader_config.load_persistent(persistent_config)?;
Self::new_skip_header(inner, reader_config, truncated_decryption_mode)
}
}
impl<'a, R: 'a + Read> LayerFailSafeReader<'a, R> for EncryptionLayerFailSafeReader<'a, R> {}
impl<R: Read> Read for EncryptionLayerFailSafeReader<'_, R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read_internal(buf).map_err(mla_error_to_io_error)
}
}
fn mla_error_to_io_error(err: Error) -> io::Error {
if let Error::IOError(e) = err {
e
} else {
io::Error::other(err.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::SeedableRng;
use rand::distributions::{Alphanumeric, Distribution};
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
use crate::crypto::aesgcm::{KEY_SIZE, NONCE_AES_SIZE};
use crate::layers::encrypt::{InternalEncryptionLayerReader, InternalEncryptionLayerWriter};
use crate::layers::raw::{RawLayerFailSafeReader, RawLayerReader, RawLayerWriter};
static FAKE_FILE: [u8; 26] = *b"abcdefghijklmnopqrstuvwxyz";
static KEY: Key = [2u8; KEY_SIZE];
static NONCE: Nonce = [3u8; NONCE_AES_SIZE];
fn encrypt_write(mut file: Vec<u8>) -> Vec<u8> {
file.write_all(CHUNK_MAGIC).unwrap();
file.write_all(&[1, 0, 0, 0, 0, 0, 0, 0]).unwrap();
let mut encrypt_w = Box::new(
InternalEncryptionLayerWriter::new(
Box::new(RawLayerWriter::new(file)),
&InternalEncryptionConfig {
key: KEY,
nonce: NONCE,
},
)
.unwrap(),
);
encrypt_w.write_all(&FAKE_FILE[..21]).unwrap();
encrypt_w.write_all(&FAKE_FILE[21..]).unwrap();
let mut out = encrypt_w.finalize().unwrap();
out.resize(out.len() - EMPTY_TAIL_OPTS_SERIALIZATION.len(), 0);
assert_eq!(
out.len(),
FAKE_FILE.len()
+ CHUNK_MAGIC.len()
+ 8
+ TAG_LENGTH
+ FINAL_INFO_USIZE
+ END_OF_ENCRYPTED_INNER_LAYER_MAGIC.len()
);
assert_ne!(out[..FAKE_FILE.len()], FAKE_FILE);
out
}
#[test]
fn encrypt_layer() {
let file = Vec::new();
let out = encrypt_write(file);
let buf = Cursor::new(out.as_slice());
let config = EncryptionReaderConfig {
private_keys: Vec::new(),
encrypt_parameters: Some((KEY, NONCE)),
};
let mut encrypt_r =
InternalEncryptionLayerReader::new(Box::new(RawLayerReader::new(buf)), config, None)
.unwrap();
encrypt_r.initialize().unwrap();
let mut output = Vec::new();
encrypt_r.read_to_end(&mut output).unwrap();
assert_eq!(output, FAKE_FILE);
}
#[test]
fn encrypt_failsafe_layer() {
let file = Vec::new();
let out = encrypt_write(file);
let config = EncryptionReaderConfig {
private_keys: Vec::new(),
encrypt_parameters: Some((KEY, NONCE)),
};
let mut encrypt_r = EncryptionLayerFailSafeReader::new_skip_header(
Box::new(RawLayerFailSafeReader::new(out.as_slice())),
config,
TruncatedReaderDecryptionMode::OnlyAuthenticatedData,
)
.unwrap();
let mut output = Vec::new();
encrypt_r.read_to_end(&mut output).unwrap();
assert!(output.len() == FAKE_FILE.len());
assert_eq!(output[..FAKE_FILE.len()], FAKE_FILE);
}
#[test]
fn encrypt_failsafe_truncated() {
let file = Vec::new();
let out = encrypt_write(file);
let stop = CHUNK_HEAD_USIZE
+ (out.len()
- FINAL_INFO_USIZE
- END_OF_ENCRYPTED_INNER_LAYER_MAGIC.len()
- CHUNK_HEAD_USIZE)
/ 2;
let config = EncryptionReaderConfig {
private_keys: Vec::new(),
encrypt_parameters: Some((KEY, NONCE)),
};
let mut encrypt_r = EncryptionLayerFailSafeReader::new_skip_header(
Box::new(RawLayerFailSafeReader::new(&out[..stop])),
config,
TruncatedReaderDecryptionMode::DataEvenUnauthenticated,
)
.unwrap();
let mut output = Vec::new();
encrypt_r.read_to_end(&mut output).unwrap();
assert_eq!(output.as_slice(), &FAKE_FILE[..(stop - CHUNK_HEAD_USIZE)]);
}
#[test]
fn failsafe_auth_vs_unauth() {
let mut file = Vec::new();
file.write_all(CHUNK_MAGIC).unwrap();
file.write_all(&[1, 0, 0, 0, 0, 0, 0, 0]).unwrap();
let mut encrypt_w = Box::new(
InternalEncryptionLayerWriter::new(
Box::new(RawLayerWriter::new(file)),
&InternalEncryptionConfig {
key: KEY,
nonce: NONCE,
},
)
.unwrap(),
);
let length = (NORMAL_CHUNK_PT_SIZE * 2 + 128) as usize;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0);
let data: Vec<u8> = Alphanumeric.sample_iter(&mut rng).take(length).collect();
encrypt_w.write_all(&data).unwrap();
let mut out = encrypt_w.finalize().unwrap();
out.resize(out.len() - EMPTY_TAIL_OPTS_SERIALIZATION.len(), 0);
assert_eq!(
out.len(),
length
+ 3 * CHUNK_HEAD_USIZE
+ 3 * TAG_LENGTH
+ FINAL_INFO_USIZE
+ END_OF_ENCRYPTED_INNER_LAYER_MAGIC.len()
);
let trunc = &out[..out.len()
- TAG_LENGTH
- FINAL_INFO_USIZE
- END_OF_ENCRYPTED_INNER_LAYER_MAGIC.len()];
let config = EncryptionReaderConfig {
private_keys: Vec::new(),
encrypt_parameters: Some((KEY, NONCE)),
};
let mut encrypt_r = EncryptionLayerFailSafeReader::new_skip_header(
Box::new(RawLayerFailSafeReader::new(trunc)),
config,
TruncatedReaderDecryptionMode::OnlyAuthenticatedData,
)
.unwrap();
let mut output = Vec::new();
assert!(encrypt_r.read_to_end(&mut output).is_err());
assert_eq!(output.len(), 2 * NORMAL_CHUNK_PT_SIZE as usize);
assert_eq!(output, data[..output.len()]);
let config = EncryptionReaderConfig {
private_keys: Vec::new(),
encrypt_parameters: Some((KEY, NONCE)),
};
let mut encrypt_r = EncryptionLayerFailSafeReader::new_skip_header(
Box::new(RawLayerFailSafeReader::new(trunc)),
config,
TruncatedReaderDecryptionMode::DataEvenUnauthenticated,
)
.unwrap();
let mut output = Vec::new();
encrypt_r.read_to_end(&mut output).unwrap();
assert_eq!(output.len(), length);
assert_eq!(output, data);
}
#[test]
fn seek_encrypt() {
let file = Vec::new();
let out = encrypt_write(file);
let buf = Cursor::new(out.as_slice());
let config = EncryptionReaderConfig {
private_keys: Vec::new(),
encrypt_parameters: Some((KEY, NONCE)),
};
let mut encrypt_r =
InternalEncryptionLayerReader::new(Box::new(RawLayerReader::new(buf)), config, None)
.unwrap();
encrypt_r.initialize().unwrap();
let mut output = Vec::new();
encrypt_r.read_to_end(&mut output).unwrap();
assert_eq!(output, FAKE_FILE);
let pos = encrypt_r.stream_position().unwrap();
assert_eq!(pos, FAKE_FILE.len() as u64);
let pos = encrypt_r.seek(SeekFrom::Start(5)).unwrap();
assert_eq!(pos, 5);
let mut output = Vec::new();
encrypt_r.read_to_end(&mut output).unwrap();
println!("{output:?}");
assert_eq!(output.as_slice(), &FAKE_FILE[5..]);
}
#[test]
fn encrypt_op_chunk_size() {
let mut file = Vec::new();
file.write_all(CHUNK_MAGIC).unwrap();
file.write_all(&[1, 0, 0, 0, 0, 0, 0, 0]).unwrap();
let mut encrypt_w = Box::new(
InternalEncryptionLayerWriter::new(
Box::new(RawLayerWriter::new(file)),
&InternalEncryptionConfig {
key: KEY,
nonce: NONCE,
},
)
.unwrap(),
);
let length = (NORMAL_CHUNK_PT_SIZE * 2) as usize;
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(0);
let data: Vec<u8> = Alphanumeric.sample_iter(&mut rng).take(length).collect();
encrypt_w.write_all(&data).unwrap();
let mut out = encrypt_w.finalize().unwrap();
out.resize(out.len() - EMPTY_TAIL_OPTS_SERIALIZATION.len(), 0);
assert_eq!(
out.len(),
length
+ 2 * CHUNK_HEAD_USIZE
+ 2 * TAG_LENGTH
+ FINAL_INFO_USIZE
+ END_OF_ENCRYPTED_INNER_LAYER_MAGIC.len()
);
assert_ne!(&out[..length], data.as_slice());
let buf = Cursor::new(out.as_slice());
let config = EncryptionReaderConfig {
private_keys: Vec::new(),
encrypt_parameters: Some((KEY, NONCE)),
};
let mut encrypt_r =
InternalEncryptionLayerReader::new(Box::new(RawLayerReader::new(buf)), config, None)
.unwrap();
encrypt_r.initialize().unwrap();
let mut output = Vec::new();
encrypt_r.read_to_end(&mut output).unwrap();
assert_eq!(output, data);
let pos = encrypt_r
.seek(SeekFrom::Start(NORMAL_CHUNK_PT_SIZE))
.unwrap();
assert_eq!(pos, NORMAL_CHUNK_PT_SIZE);
let mut output = Vec::new();
encrypt_r.read_to_end(&mut output).unwrap();
assert_eq!(output.as_slice(), &data[NORMAL_CHUNK_PT_SIZE as usize..]);
}
#[test]
fn build_key_commitment_chain_test() {
let key: Key = [1u8; KEY_SIZE];
let nonce: Nonce = [2u8; NONCE_AES_SIZE];
let result = build_key_commitment_chain(&key, &nonce);
assert!(result.is_ok());
let key_commitment_and_tag = result.unwrap();
let mut cipher = AesGcm256::new(&key, &compute_nonce(&nonce, 0), ASSOCIATED_DATA).unwrap();
let mut decrypted_key_commitment = [0u8; KEY_COMMITMENT_SIZE];
decrypted_key_commitment.copy_from_slice(&key_commitment_and_tag.key_commitment);
let tag = cipher.decrypt(&mut decrypted_key_commitment);
assert_eq!(tag.ct_eq(&key_commitment_and_tag.tag).unwrap_u8(), 1);
assert_eq!(decrypted_key_commitment, *KEY_COMMITMENT_CHAIN);
}
#[test]
fn check_key_commitment_test() {
let key: Key = [1u8; KEY_SIZE];
let nonce: Nonce = [2u8; NONCE_AES_SIZE];
let result = build_key_commitment_chain(&key, &nonce);
assert!(result.is_ok());
let key_commitment_and_tag = result.unwrap();
let result = check_key_commitment(&key, &nonce, &key_commitment_and_tag);
assert!(result.is_ok());
let invalid_key_commitment_and_tag = KeyCommitmentAndTag {
key_commitment: [0u8; KEY_COMMITMENT_SIZE],
tag: [0u8; TAG_LENGTH],
};
let result = check_key_commitment(&key, &nonce, &invalid_key_commitment_and_tag);
assert!(result.is_err());
}
}