use std::fmt::Debug;
use bytes::Bytes;
use crate::bmt::SPAN_SIZE;
use crate::chunk::encryption::{EncryptedChunkRef, EncryptionKey, decrypt_chunk_data};
use crate::chunk::{BmtChunk, Chunk, ChunkAddress, ContentChunk};
use crate::store::{SyncChunkGet, SyncChunkPut};
use super::constants::{ENCRYPTED_REF_SIZE, REF_SIZE, compute_spans_inline, subspan_for_spans};
use super::error::{FileError, Result};
fn chunk_creation_error(e: crate::error::PrimitivesError) -> FileError {
match e {
crate::error::PrimitivesError::Chunk(c) => FileError::Chunk(c),
other => FileError::Store(Box::new(other)),
}
}
#[inline]
fn create_chunk<const BS: usize>(data: Bytes) -> Result<ContentChunk<BS>> {
ContentChunk::<BS>::try_from(data).map_err(chunk_creation_error)
}
fn store_chunk<const BS: usize, S: SyncChunkPut<BS>>(
chunk: ContentChunk<BS>,
store: &S,
) -> Result<ChunkAddress> {
let address = *chunk.address();
store.put(chunk.into()).map_err(FileError::store)?;
Ok(address)
}
pub trait JoinMode: Sized + 'static {
const REF_SIZE: usize;
type RootRef: Clone + Debug + Send + Sync;
type JoinerContext: Clone + Debug + Send + Sync;
#[inline]
fn refs_per_chunk(body_size: usize) -> usize {
body_size / Self::REF_SIZE
}
#[inline]
fn levels(length: u64, chunk_size: usize) -> usize {
super::constants::tree_depth(length, chunk_size, Self::REF_SIZE)
}
#[inline]
fn subspan_size<const BS: usize>(span: u64) -> u64 {
let spans = compute_spans_inline(BS / Self::REF_SIZE);
subspan_for_spans::<BS>(span, &spans)
}
#[inline]
fn child_span<const BS: usize>(parent_span: u64, subspan: u64, child_index: usize) -> u64 {
let branches = Self::refs_per_chunk(BS);
if child_index == branches - 1 {
let preceding = child_index as u64 * subspan;
parent_span.saturating_sub(preceding)
} else {
subspan.min(parent_span.saturating_sub(child_index as u64 * subspan))
}
}
fn root_address(input: &Self::RootRef) -> ChunkAddress;
fn init_from_chunk<const BS: usize>(
input: Self::RootRef,
chunk: ContentChunk<BS>,
) -> Result<(ChunkAddress, u64, Self::JoinerContext)>;
fn decode_body<const BS: usize>(
chunk: ContentChunk<BS>,
context: &Self::JoinerContext,
span: u64,
) -> Result<Bytes>;
fn parse_child_ref(
body: &[u8],
ref_start: usize,
) -> Result<(ChunkAddress, Self::JoinerContext)>;
}
pub(crate) fn joiner_init<M: JoinMode, G: SyncChunkGet<BS>, const BS: usize>(
getter: &G,
input: M::RootRef,
) -> Result<(ChunkAddress, u64, M::JoinerContext)> {
let addr = M::root_address(&input);
let any = getter.get(&addr).map_err(FileError::getter)?;
let chunk = any.into_content().ok_or(FileError::InvalidChunkType {
type_name: "non-content",
})?;
M::init_from_chunk::<BS>(input, chunk)
}
pub(crate) async fn joiner_init_async<
M: JoinMode + Send + Sync,
G: crate::store::ChunkGet<BS>,
const BS: usize,
>(
getter: &G,
input: M::RootRef,
) -> Result<(ChunkAddress, u64, M::JoinerContext)> {
let addr = M::root_address(&input);
let any = getter.get(&addr).await.map_err(FileError::getter)?;
let chunk = any.into_content().ok_or(FileError::InvalidChunkType {
type_name: "non-content",
})?;
M::init_from_chunk::<BS>(input, chunk)
}
#[inline]
pub(crate) fn read_chunk_body<M: JoinMode, G: SyncChunkGet<BS>, const BS: usize>(
getter: &G,
address: &ChunkAddress,
context: &M::JoinerContext,
span: u64,
) -> Result<Bytes> {
let any = getter.get(address).map_err(FileError::getter)?;
let chunk = any.into_content().ok_or(FileError::InvalidChunkType {
type_name: "non-content",
})?;
M::decode_body::<BS>(chunk, context, span)
}
pub(crate) async fn read_chunk_body_async<
M: JoinMode + Send + Sync,
G: crate::store::ChunkGet<BS>,
const BS: usize,
>(
getter: &G,
address: &ChunkAddress,
context: &M::JoinerContext,
span: u64,
) -> Result<Bytes> {
let address = *address;
let context = context.clone();
let any = getter.get(&address).await.map_err(FileError::getter)?;
let chunk = any.into_content().ok_or(FileError::InvalidChunkType {
type_name: "non-content",
})?;
M::decode_body::<BS>(chunk, &context, span)
}
pub trait SplitMode: JoinMode {
type RefBytes: AsRef<[u8]> + AsMut<[u8]> + Clone + Debug + Send + Sync;
fn prepare_chunk<const BS: usize>(data: Vec<u8>) -> Result<(ContentChunk<BS>, Self::RefBytes)>;
#[inline]
fn process_chunk<const BS: usize, S: SyncChunkPut<BS>>(
data: Vec<u8>,
store: &S,
) -> Result<Self::RefBytes> {
let (chunk, ref_bytes) = Self::prepare_chunk::<BS>(data)?;
store.put(chunk.into()).map_err(FileError::store)?;
Ok(ref_bytes)
}
fn process_empty<const BS: usize, S: SyncChunkPut<BS>>(store: &S) -> Result<Self::RootRef>;
fn extract_root(buffer: &[u8]) -> Result<Self::RootRef>;
}
#[derive(Debug)]
pub struct PlainMode;
impl JoinMode for PlainMode {
const REF_SIZE: usize = REF_SIZE;
type RootRef = ChunkAddress;
type JoinerContext = ();
#[inline]
fn root_address(input: &ChunkAddress) -> ChunkAddress {
*input
}
fn init_from_chunk<const BS: usize>(
root: ChunkAddress,
chunk: ContentChunk<BS>,
) -> Result<(ChunkAddress, u64, ())> {
let span = chunk.span();
Ok((root, span, ()))
}
#[inline]
fn decode_body<const BS: usize>(
chunk: ContentChunk<BS>,
_context: &(),
_span: u64,
) -> Result<Bytes> {
Ok(chunk.data().clone())
}
#[inline]
fn parse_child_ref(body: &[u8], ref_start: usize) -> Result<(ChunkAddress, ())> {
let ref_end = ref_start + REF_SIZE;
let child_addr_bytes: [u8; 32] = body[ref_start..ref_end]
.try_into()
.map_err(|_| FileError::InvalidReference { level: 0 })?;
Ok((ChunkAddress::from(child_addr_bytes), ()))
}
}
impl SplitMode for PlainMode {
type RefBytes = [u8; REF_SIZE];
#[inline]
fn prepare_chunk<const BS: usize>(data: Vec<u8>) -> Result<(ContentChunk<BS>, [u8; REF_SIZE])> {
let chunk = create_chunk::<BS>(Bytes::from(data))?;
let ref_bytes = (*chunk.address()).into();
Ok((chunk, ref_bytes))
}
fn process_empty<const BS: usize, S: SyncChunkPut<BS>>(store: &S) -> Result<ChunkAddress> {
let chunk = ContentChunk::<BS>::new(Bytes::new()).map_err(chunk_creation_error)?;
store_chunk::<BS, S>(chunk, store)
}
fn extract_root(buffer: &[u8]) -> Result<ChunkAddress> {
let root_bytes: [u8; 32] = buffer
.get(..REF_SIZE)
.and_then(|s| s.try_into().ok())
.ok_or(FileError::InvalidReference { level: 0 })?;
Ok(ChunkAddress::from(root_bytes))
}
}
#[derive(Debug)]
pub struct EncryptedMode;
impl EncryptedMode {
fn decrypt_data_length<const BS: usize>(span: u64) -> usize {
if span <= BS as u64 {
span as usize
} else {
let sub = Self::subspan_size::<BS>(span);
let num_children = span.div_ceil(sub) as usize;
let raw = num_children * ENCRYPTED_REF_SIZE;
raw.min(BS)
}
}
}
impl JoinMode for EncryptedMode {
const REF_SIZE: usize = ENCRYPTED_REF_SIZE;
type RootRef = EncryptedChunkRef;
type JoinerContext = EncryptionKey;
fn root_address(input: &EncryptedChunkRef) -> ChunkAddress {
*input.address()
}
fn init_from_chunk<const BS: usize>(
root_ref: EncryptedChunkRef,
chunk: ContentChunk<BS>,
) -> Result<(ChunkAddress, u64, EncryptionKey)> {
let encrypted_data: Bytes = chunk.into();
let span_buf = decrypt_span::<BS>(&encrypted_data, root_ref.key())?;
let span = u64::from_le_bytes(span_buf);
let (address, key) = root_ref.into_parts();
Ok((address, span, key))
}
fn decode_body<const BS: usize>(
chunk: ContentChunk<BS>,
key: &EncryptionKey,
span: u64,
) -> Result<Bytes> {
let encrypted_data: Bytes = chunk.into();
let data_length = Self::decrypt_data_length::<BS>(span);
let decrypted = decrypt_chunk_data::<BS>(&encrypted_data, key, data_length)?;
Ok(Bytes::from(decrypted).slice(SPAN_SIZE..))
}
fn parse_child_ref(body: &[u8], ref_start: usize) -> Result<(ChunkAddress, EncryptionKey)> {
let ref_end = ref_start + ENCRYPTED_REF_SIZE;
let child_addr_bytes: [u8; 32] = body[ref_start..ref_start + 32]
.try_into()
.map_err(|_| FileError::InvalidReference { level: 0 })?;
let child_key = EncryptionKey::try_from(&body[ref_start + 32..ref_end])?;
Ok((ChunkAddress::from(child_addr_bytes), child_key))
}
}
#[cfg(feature = "encryption")]
impl SplitMode for EncryptedMode {
type RefBytes = [u8; ENCRYPTED_REF_SIZE];
fn prepare_chunk<const BS: usize>(
data: Vec<u8>,
) -> Result<(ContentChunk<BS>, [u8; ENCRYPTED_REF_SIZE])> {
use crate::chunk::encryption::encrypt_chunk;
let key = EncryptionKey::generate();
let ciphertext = encrypt_chunk::<BS>(&data, &key)?;
let chunk = create_chunk::<BS>(Bytes::from(ciphertext))?;
let mut ref_bytes = [0u8; ENCRYPTED_REF_SIZE];
ref_bytes[..32].copy_from_slice(chunk.address().as_bytes());
ref_bytes[32..].copy_from_slice(key.as_bytes());
Ok((chunk, ref_bytes))
}
fn process_empty<const BS: usize, S: SyncChunkPut<BS>>(store: &S) -> Result<EncryptedChunkRef> {
use crate::chunk::encryption::encrypt_chunk;
let key = EncryptionKey::generate();
let chunk_bytes = 0u64.to_le_bytes().to_vec();
let ciphertext = encrypt_chunk::<BS>(&chunk_bytes, &key)?;
let chunk = create_chunk::<BS>(Bytes::from(ciphertext))?;
let address = store_chunk::<BS, S>(chunk, store)?;
Ok(EncryptedChunkRef::new(address, key))
}
fn extract_root(buffer: &[u8]) -> Result<EncryptedChunkRef> {
let root_ref_bytes = buffer
.get(..ENCRYPTED_REF_SIZE)
.ok_or(FileError::InvalidReference { level: 0 })?;
EncryptedChunkRef::try_from(root_ref_bytes)
.map_err(|_| FileError::InvalidReference { level: 0 })
}
}
fn decrypt_span<const BODY_SIZE: usize>(
encrypted_data: &[u8],
key: &EncryptionKey,
) -> Result<[u8; SPAN_SIZE]> {
use crate::chunk::encryption::transcrypt;
let expected_len = SPAN_SIZE + BODY_SIZE;
if encrypted_data.len() != expected_len {
return Err(FileError::Encryption(
crate::chunk::encryption::EncryptionError::DataTooShort {
len: encrypted_data.len(),
min: expected_len,
},
));
}
let span_ctr = (BODY_SIZE / EncryptionKey::SIZE) as u32;
let mut span_buf = [0u8; SPAN_SIZE];
transcrypt(key, span_ctr, &encrypted_data[..SPAN_SIZE], &mut span_buf)
.map_err(FileError::Encryption)?;
Ok(span_buf)
}