use crate::error::NetworkError;
use crate::proto::packet_processor::includes::Duration;
use crate::proto::session::SessionState;
use citadel_io::tokio::sync::broadcast::Sender;
use citadel_io::tokio::time::error::Error;
use citadel_io::tokio_util::time::{delay_queue, DelayQueue};
use futures::Stream;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
use crate::inner_arg::ExpectedInnerTargetMut;
use crate::proto::state_container::{StateContainer, StateContainerInner};
use citadel_crypt::ratchets::Ratchet;
use citadel_io::tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use std::sync::atomic::{AtomicUsize, Ordering};
pub const QUEUE_WORKER_RESERVED_INDEX: usize = 10;
pub const RESERVED_CID_IDX: u64 = 0;
pub const PROVISIONAL_CHECKER: usize = 0;
pub const DRILL_REKEY_WORKER: usize = 1;
pub const KEEP_ALIVE_CHECKER: usize = 2;
pub const FIREWALL_KEEP_ALIVE: usize = 3;
pub trait QueueFunction<R: Ratchet>:
Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner<R>>) -> QueueWorkerResult + Send + 'static
{
}
impl<
R: Ratchet,
T: Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner<R>>) -> QueueWorkerResult
+ Send
+ 'static,
> QueueFunction<R> for T
{
}
#[allow(clippy::type_complexity)]
pub struct SessionQueueWorker<R: Ratchet> {
entries: HashMap<QueueWorkerTicket, (Box<dyn QueueFunction<R>>, delay_queue::Key, Duration)>,
expirations: DelayQueue<QueueWorkerTicket>,
state_container: Option<StateContainer<R>>,
sess_shutdown: Sender<()>,
waker: Option<Waker>,
rx: UnboundedReceiver<ChannelInner<R>>,
rolling_idx: usize,
}
#[derive(Clone)]
pub struct SessionQueueWorkerHandle<R: Ratchet> {
tx: UnboundedSender<ChannelInner<R>>,
rolling_idx: Arc<AtomicUsize>,
}
type ChannelInner<R> = (
Option<QueueWorkerTicket>,
Duration,
Box<dyn QueueFunction<R>>,
);
impl<R: Ratchet> SessionQueueWorkerHandle<R> {
pub fn insert_reserved(
&self,
key: Option<QueueWorkerTicket>,
timeout: Duration,
on_timeout: impl Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner<R>>) -> QueueWorkerResult
+ Send
+ 'static,
) {
let _ = self.tx.send((key, timeout, Box::new(on_timeout)));
}
#[allow(dead_code)]
pub fn insert_oneshot(
&self,
call_in: Duration,
on_call: impl Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner<R>>) + Send + 'static,
) {
self.insert_reserved(None, call_in, move |sess| {
(on_call)(sess);
QueueWorkerResult::Complete
});
}
pub fn insert_ordinary(
&self,
idx: usize,
target_cid: u64,
timeout: Duration,
on_timeout: impl Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner<R>>) -> QueueWorkerResult
+ Send
+ 'static,
) {
self.insert_reserved(
Some(QueueWorkerTicket::Periodic(
idx + QUEUE_WORKER_RESERVED_INDEX + self.rolling_idx.fetch_add(1, Ordering::SeqCst),
target_cid,
)),
timeout,
on_timeout,
)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum QueueWorkerTicket {
Oneshot(usize, u64),
Periodic(usize, u64),
}
pub enum QueueWorkerResult {
Complete,
Incomplete,
EndSession,
AdjustPeriodicity(Duration),
}
impl<R: Ratchet> SessionQueueWorker<R> {
pub fn new(sess_shutdown: Sender<()>) -> (Self, SessionQueueWorkerHandle<R>) {
let (tx, rx) = citadel_io::tokio::sync::mpsc::unbounded_channel();
let handle = SessionQueueWorkerHandle {
tx,
rolling_idx: Arc::new(Default::default()),
};
(
Self {
rx,
waker: None,
sess_shutdown,
rolling_idx: 0,
entries: HashMap::new(),
expirations: DelayQueue::new(),
state_container: None,
},
handle,
)
}
pub fn load_state_container(&mut self, state_container: StateContainer<R>) {
self.state_container = Some(state_container);
}
#[allow(unused_results)]
pub fn insert_reserved(
&mut self,
key_orig: Option<QueueWorkerTicket>,
timeout: Duration,
on_timeout: Box<dyn QueueFunction<R>>,
) {
let key_new = key_orig.unwrap_or(QueueWorkerTicket::Oneshot(
self.rolling_idx + QUEUE_WORKER_RESERVED_INDEX + 1,
RESERVED_CID_IDX,
));
let delay = self.expirations.insert(key_new, timeout);
if let Some(key) = self.entries.insert(key_new, (on_timeout, delay, timeout)) {
log::warn!(target: "citadel", "Overwrote a session key: {:?} || Original: {key_orig:?}", key.1);
}
self.rolling_idx += 1;
}
pub fn insert_reserved_fn(
&mut self,
key: Option<QueueWorkerTicket>,
timeout: Duration,
on_timeout: impl Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner<R>>) -> QueueWorkerResult
+ Send
+ 'static,
) {
self.insert_reserved(key, timeout, Box::new(on_timeout))
}
fn register_waker(&mut self, waker: &Waker) {
self.waker = Some(waker.clone());
}
fn wake(&self) {
if let Some(waker) = self.waker.as_ref() {
waker.wake_by_ref();
}
}
#[allow(unused_results)]
fn poll_purge(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.register_waker(cx.waker());
let SessionQueueWorker {
expirations,
state_container,
entries,
..
} = &mut *self;
let mut state_container = inner_mut_state!(state_container.as_ref().unwrap());
if state_container.state.get() != SessionState::Disconnected {
while let Some(res) = futures::ready!(expirations.poll_expired(cx)) {
let entry: QueueWorkerTicket = res.into_inner();
match entry {
QueueWorkerTicket::Oneshot(_, _) => {
let (fx, _, _) = entries.remove(&entry).unwrap();
if let QueueWorkerResult::EndSession = (fx)(&mut state_container) {
return Poll::Ready(Err(Error::shutdown()));
}
}
QueueWorkerTicket::Periodic(_, _) => {
let (fx, _key, duration) = entries.get(&entry).unwrap();
let next_key = match fx(&mut state_container) {
QueueWorkerResult::Complete => {
entries.remove(&entry);
drop(state_container);
self.wake();
return Poll::Pending;
}
QueueWorkerResult::EndSession => {
return Poll::Ready(Err(Error::shutdown()));
}
QueueWorkerResult::AdjustPeriodicity(new_period) => {
expirations.insert(entry, new_period)
}
_ => {
let duration = *duration;
expirations.insert(entry, duration)
}
};
let (_fx, key, _duration) = entries.get_mut(&entry).unwrap();
*key = next_key;
}
}
}
Poll::Pending
} else {
Poll::Ready(Err(Error::shutdown()))
}
}
}
impl<R: Ratchet> Stream for SessionQueueWorker<R> {
type Item = ();
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
while let Poll::Ready(Some((key, timeout, on_timeout))) =
self.as_mut().get_mut().rx.poll_recv(cx)
{
self.as_mut()
.get_mut()
.insert_reserved(key, timeout, on_timeout);
}
match futures::ready!(self.poll_purge(cx)) {
Ok(_) => Poll::Pending,
Err(_) => Poll::Ready(None),
}
}
}
impl<R: Ratchet> futures::Future for SessionQueueWorker<R> {
type Output = Result<(), NetworkError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match futures::ready!(self.as_mut().poll_next(cx)) {
Some(_) => Poll::Pending,
None => {
if let Err(_err) = self.sess_shutdown.send(()) {
}
Poll::Ready(Err(NetworkError::InternalError(
"Queue handler signalled shutdown",
)))
}
}
}
}