use bytes::BytesMut;
use futures::task::Context;
use std::io::{BufReader, Read};
use std::path::PathBuf;
use tokio::macros::support::Pin;
use tokio::sync::mpsc::Sender as GroupChanneler;
use tokio::sync::oneshot::Receiver;
use crate::entropy_bank::{EntropyBank, SecurityLevel};
use crate::packet_vector::PacketVector;
use crate::scramble::crypt_splitter::{par_scramble_encrypt_group, GroupSenderDevice};
use crate::misc::{CryptError, TransferType};
use crate::stacked_ratchet::StackedRatchet;
use citadel_io::Mutex;
use citadel_io::{BlockingSpawn, BlockingSpawnError};
use futures::Future;
use num_integer::Integer;
use std::sync::Arc;
use std::task::Poll;
use tokio_stream::{Stream, StreamExt};
use zeroize::Zeroizing;
pub const MAX_BYTES_PER_GROUP: usize = crate::scramble::crypt_splitter::MAX_BYTES_PER_GROUP;
const DEFAULT_BYTES_PER_GROUP: usize = 1024 * 1024 * 3;
pub trait FixedSizedSource: Read + Send + 'static {
fn length(&self) -> std::io::Result<u64>;
}
#[cfg(feature = "filesystem")]
impl FixedSizedSource for std::fs::File {
fn length(&self) -> std::io::Result<u64> {
self.metadata().map(|r| r.len())
}
}
pub trait HeaderInscriberFn:
for<'a> Fn(&'a PacketVector, &'a EntropyBank, u64, u64, &'a mut BytesMut) + Send + Sync + 'static
{
}
impl<
T: for<'a> Fn(&'a PacketVector, &'a EntropyBank, u64, u64, &'a mut BytesMut)
+ Send
+ Sync
+ 'static,
> HeaderInscriberFn for T
{
}
#[auto_impl::auto_impl(Box)]
pub trait ObjectSource: Send + Sync + 'static {
fn try_get_stream(&mut self) -> Result<Box<dyn FixedSizedSource>, CryptError>;
fn get_source_name(&self) -> Result<String, CryptError>;
fn delete_path(&self) -> Option<PathBuf> {
None
}
}
macro_rules! impl_file_src {
($value:ty) => {
#[cfg(feature = "filesystem")]
impl ObjectSource for $value {
fn try_get_stream(&mut self) -> Result<Box<dyn FixedSizedSource>, CryptError> {
std::fs::File::open(self)
.map_err(|err| CryptError::Encrypt(err.to_string()))
.map(|r| Box::new(r) as Box<dyn FixedSizedSource>)
}
fn get_source_name(&self) -> Result<String, CryptError> {
let name = std::path::Path::new(self);
name.file_name()
.ok_or_else(|| CryptError::Encrypt("Unable to get filename".to_string()))?
.to_str()
.map(|r| r.to_string())
.ok_or_else(|| CryptError::Encrypt("Unable to get filename/2".to_string()))
}
fn delete_path(&self) -> Option<PathBuf> {
let path = std::path::PathBuf::from(self);
Some(path)
}
}
};
}
impl_file_src!(std::path::PathBuf);
impl_file_src!(&'static str);
impl_file_src!(String);
pub struct BytesSource {
pub inner: Option<Zeroizing<Vec<u8>>>,
}
impl Clone for BytesSource {
fn clone(&self) -> Self {
Self { inner: None }
}
}
impl ObjectSource for BytesSource {
fn try_get_stream(&mut self) -> Result<Box<dyn FixedSizedSource>, CryptError> {
struct VecReader {
len: usize,
cursor: std::io::Cursor<Zeroizing<Vec<u8>>>,
}
impl std::io::Read for VecReader {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.cursor.read(buf)
}
}
impl FixedSizedSource for VecReader {
fn length(&self) -> std::io::Result<u64> {
Ok(self.len as u64)
}
}
let inner = self
.inner
.take()
.ok_or_else(|| CryptError::Encrypt("Source has already been exhausted".into()))?;
let len = inner.len();
let cursor = std::io::Cursor::new(inner);
Ok(Box::new(VecReader { len, cursor }))
}
fn get_source_name(&self) -> Result<String, CryptError> {
let rand_id = rand::random::<u128>();
Ok(format!("{rand_id}.bin"))
}
}
impl<T: Into<Vec<u8>>> From<T> for BytesSource {
fn from(value: T) -> Self {
Self {
inner: Some(value.into().into()),
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn scramble_encrypt_source<S: ObjectSource, F: HeaderInscriberFn, const N: usize>(
mut source: S,
max_group_size: Option<usize>,
object_id: u64,
group_sender: GroupChanneler<Result<GroupSenderDevice<N>, CryptError>>,
stop: Receiver<()>,
security_level: SecurityLevel,
hyper_ratchet: StackedRatchet,
static_aux_ratchet: StackedRatchet,
header_size_bytes: usize,
target_cid: u64,
group_id: u64,
transfer_type: TransferType,
header_inscriber: F,
) -> Result<(usize, usize, usize), CryptError> {
let source = source.try_get_stream()?;
let object_len = source
.length()
.map_err(|err| CryptError::Encrypt(err.to_string()))? as usize;
let max_bytes_per_group = max_group_size.unwrap_or(DEFAULT_BYTES_PER_GROUP);
if max_bytes_per_group > MAX_BYTES_PER_GROUP {
return Err(CryptError::Encrypt(format!(
"Maximum group size cannot be larger than {MAX_BYTES_PER_GROUP} bytes",
)));
}
let total_groups = Integer::div_ceil(&object_len, &max_bytes_per_group);
log::trace!(target: "citadel", "Will parallel_scramble_encrypt file object {}, which is {} bytes or {} MB. {} groups total", object_id, object_len, (object_len as f32)/(1024f32*1024f32), total_groups);
let reader = BufReader::with_capacity(std::cmp::min(object_len, max_bytes_per_group), source);
let buffer = Arc::new(Mutex::new(vec![
0u8;
std::cmp::min(
object_len,
max_bytes_per_group
)
]));
let file_scrambler = AsyncCryptScrambler {
total_groups,
buffer,
groups_rendered: 0,
object_id,
header_size_bytes,
target_cid,
group_id,
security_level,
hyper_ratchet,
static_aux_ratchet,
reader,
transfer_type,
file_len: object_len,
max_bytes_per_group,
read_cursor: 0,
header_inscriber: Arc::new(header_inscriber),
poll_amt: 0,
cur_task: None,
};
let handle = citadel_io::spawn(async move {
let res = tokio::select! {
res0 = stopper(stop) => res0,
res1 = file_streamer(group_sender.clone(), file_scrambler) => res1
};
if let Err(err) = res {
let _ = group_sender.try_send(Err(err));
}
});
std::mem::drop(handle);
Ok((object_len, total_groups, max_bytes_per_group))
}
async fn stopper(stop: Receiver<()>) -> Result<(), CryptError> {
stop.await
.map_err(|err| CryptError::Encrypt(err.to_string()))
}
async fn file_streamer<F: HeaderInscriberFn, R: Read, const N: usize>(
group_sender: GroupChanneler<Result<GroupSenderDevice<N>, CryptError>>,
mut file_scrambler: AsyncCryptScrambler<F, R, N>,
) -> Result<(), CryptError> {
while let Some(val) = file_scrambler.next().await {
group_sender
.send(Ok(val))
.await
.map_err(|err| CryptError::Encrypt(err.to_string()))?;
}
Ok(())
}
#[allow(dead_code)]
struct AsyncCryptScrambler<F: HeaderInscriberFn, R: Read, const N: usize> {
reader: BufReader<R>,
hyper_ratchet: StackedRatchet,
static_aux_ratchet: StackedRatchet,
security_level: SecurityLevel,
transfer_type: TransferType,
file_len: usize,
read_cursor: usize,
object_id: u64,
header_size_bytes: usize,
target_cid: u64,
group_id: u64,
total_groups: usize,
groups_rendered: usize,
max_bytes_per_group: usize,
poll_amt: usize,
buffer: Arc<Mutex<Vec<u8>>>,
header_inscriber: Arc<F>,
cur_task: Option<BlockingSpawn<Result<GroupSenderDevice<N>, CryptError<String>>>>,
}
impl<F: HeaderInscriberFn, R: Read, const N: usize> AsyncCryptScrambler<F, R, N> {
fn poll_task(
groups_rendered: &mut usize,
read_cursor: &mut usize,
poll_amt: usize,
cur_task: &mut Option<BlockingSpawn<Result<GroupSenderDevice<N>, CryptError<String>>>>,
cx: &mut Context<'_>,
) -> Poll<Option<GroupSenderDevice<N>>> {
let res: Result<Result<GroupSenderDevice<N>, CryptError<String>>, BlockingSpawnError> =
futures::ready!(Pin::new(cur_task.as_mut().unwrap()).poll(cx));
if let Ok(Ok(sender)) = res {
*groups_rendered += 1;
*read_cursor += poll_amt;
*cur_task = None;
Poll::Ready(Some(sender))
} else {
log::error!(target: "citadel", "Unable to par_scramble_encrypt group");
Poll::Ready(None)
}
}
}
impl<F: HeaderInscriberFn, R: Read, const N: usize> Unpin for AsyncCryptScrambler<F, R, N> {}
impl<F: HeaderInscriberFn, R: Read, const N: usize> AsyncCryptScrambler<F, R, N> {
fn poll_scramble_next_group(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<GroupSenderDevice<N>>> {
let Self {
hyper_ratchet,
static_aux_ratchet,
file_len,
read_cursor,
buffer,
group_id,
groups_rendered,
header_size_bytes,
target_cid,
object_id,
header_inscriber,
reader,
security_level,
max_bytes_per_group,
cur_task,
transfer_type,
poll_amt,
..
} = &mut *self;
if cur_task.is_some() {
return Self::poll_task(groups_rendered, read_cursor, *poll_amt, cur_task, cx);
}
if *read_cursor != *file_len {
let remaining = *file_len - *read_cursor;
let poll_len = std::cmp::min(remaining, *max_bytes_per_group);
let mut lock = buffer.lock();
let bytes = &mut lock[..poll_len];
if reader.read_exact(bytes).is_ok() {
let group_id_input = *group_id + (*groups_rendered as u64);
std::mem::drop(lock);
let header_inscriber = header_inscriber.clone();
let buffer = buffer.clone();
let security_level = *security_level;
let hyper_ratchet = hyper_ratchet.clone();
let static_aux_ratchet = static_aux_ratchet.clone();
let header_size_bytes = *header_size_bytes;
let target_cid = *target_cid;
let object_id = *object_id;
let transfer_type = transfer_type.clone();
let task = citadel_io::spawn_blocking(move || {
par_scramble_encrypt_group(
&buffer.lock()[..poll_len],
security_level,
&hyper_ratchet,
&static_aux_ratchet,
header_size_bytes,
target_cid,
object_id,
group_id_input,
transfer_type,
|a, b, c, d, e| (header_inscriber)(a, b, c, d, e),
)
});
*cur_task = Some(task);
*poll_amt = poll_len;
Self::poll_task(groups_rendered, read_cursor, *poll_amt, cur_task, cx)
} else {
log::error!(target: "citadel", "Error polling exact amt {}", poll_len);
Poll::Ready(None)
}
} else {
log::trace!(target: "citadel", "Done rendering all groups!");
Poll::Ready(None)
}
}
}
impl<F: HeaderInscriberFn, R: Read, const N: usize> Stream for AsyncCryptScrambler<F, R, N> {
type Item = GroupSenderDevice<N>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_scramble_next_group(cx)
}
}