use crate::error::NetworkError;
use crate::proto::packet_processor::includes::Duration;
use crate::proto::session::SessionState;
use futures::Stream;
use std::collections::HashMap;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use tokio::sync::broadcast::Sender;
use tokio::time::error::Error;
use tokio_util::time::{delay_queue, DelayQueue};
use crate::inner_arg::ExpectedInnerTargetMut;
use crate::proto::state_container::{StateContainer, StateContainerInner};
use std::sync::atomic::Ordering;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
pub const QUEUE_WORKER_RESERVED_INDEX: usize = 10;
pub const RESERVED_CID_IDX: u64 = 0;
pub const PROVISIONAL_CHECKER: usize = 0;
pub const DRILL_REKEY_WORKER: usize = 1;
pub const KEEP_ALIVE_CHECKER: usize = 2;
pub const FIREWALL_KEEP_ALIVE: usize = 3;
pub trait QueueFunction:
Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner>) -> QueueWorkerResult + Send + 'static
{
}
pub trait QueueOneshotFunction:
Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner>) + Send + 'static
{
}
impl<
T: Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner>) -> QueueWorkerResult
+ Send
+ 'static,
> QueueFunction for T
{
}
impl<T: Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner>) + Send + 'static>
QueueOneshotFunction for T
{
}
pub struct SessionQueueWorker {
entries: HashMap<QueueWorkerTicket, (Box<dyn QueueFunction>, delay_queue::Key, Duration)>,
expirations: DelayQueue<QueueWorkerTicket>,
state_container: Option<StateContainer>,
sess_shutdown: Sender<()>,
waker: Option<Waker>,
rx: UnboundedReceiver<ChannelInner>,
rolling_idx: usize,
}
#[derive(Clone)]
pub struct SessionQueueWorkerHandle {
tx: UnboundedSender<ChannelInner>,
}
type ChannelInner = (Option<QueueWorkerTicket>, Duration, Box<dyn QueueFunction>);
impl SessionQueueWorkerHandle {
pub fn insert_reserved(
&self,
key: Option<QueueWorkerTicket>,
timeout: Duration,
on_timeout: impl Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner>) -> QueueWorkerResult
+ Send
+ 'static,
) {
let _ = self.tx.send((key, timeout, Box::new(on_timeout)));
}
#[allow(dead_code)]
pub fn insert_oneshot(
&self,
call_in: Duration,
on_call: impl Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner>) + Send + 'static,
) {
self.insert_reserved(None, call_in, move |sess| {
(on_call)(sess);
QueueWorkerResult::Complete
});
}
pub fn insert_ordinary(
&self,
idx: usize,
target_cid: u64,
timeout: Duration,
on_timeout: impl Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner>) -> QueueWorkerResult
+ Send
+ 'static,
) {
self.insert_reserved(
Some(QueueWorkerTicket::Periodic(
idx + QUEUE_WORKER_RESERVED_INDEX,
target_cid,
)),
timeout,
on_timeout,
)
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum QueueWorkerTicket {
Oneshot(usize, u64),
Periodic(usize, u64),
}
pub enum QueueWorkerResult {
Complete,
Incomplete,
EndSession,
AdjustPeriodicity(Duration),
}
impl SessionQueueWorker {
pub fn new(sess_shutdown: Sender<()>) -> (Self, SessionQueueWorkerHandle) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let handle = SessionQueueWorkerHandle { tx };
(
Self {
rx,
waker: None,
sess_shutdown,
rolling_idx: 0,
entries: HashMap::new(),
expirations: DelayQueue::new(),
state_container: None,
},
handle,
)
}
pub fn load_state_container(&mut self, state_container: StateContainer) {
self.state_container = Some(state_container);
}
#[allow(unused_results)]
pub fn insert_reserved(
&mut self,
key: Option<QueueWorkerTicket>,
timeout: Duration,
on_timeout: Box<dyn QueueFunction>,
) {
let key = key.unwrap_or(QueueWorkerTicket::Oneshot(
self.rolling_idx + QUEUE_WORKER_RESERVED_INDEX + 1,
RESERVED_CID_IDX,
));
let delay = self.expirations.insert(key, timeout);
if let Some(key) = self.entries.insert(key, (on_timeout, delay, timeout)) {
log::error!(target: "citadel", "Overwrote a session key: {:?}", key.1);
}
self.rolling_idx += 1;
}
pub fn insert_reserved_fn(
&mut self,
key: Option<QueueWorkerTicket>,
timeout: Duration,
on_timeout: impl Fn(&mut dyn ExpectedInnerTargetMut<StateContainerInner>) -> QueueWorkerResult
+ Send
+ 'static,
) {
self.insert_reserved(key, timeout, Box::new(on_timeout))
}
fn register_waker(&mut self, waker: &futures::task::Waker) {
self.waker = Some(waker.clone());
}
fn wake(&self) {
if let Some(waker) = self.waker.as_ref() {
waker.wake_by_ref();
}
}
#[allow(unused_results)]
fn poll_purge(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.register_waker(cx.waker());
let SessionQueueWorker {
expirations,
state_container,
entries,
..
} = &mut *self;
let mut state_container = inner_mut_state!(state_container.as_ref().unwrap());
if state_container.state.load(Ordering::Relaxed) != SessionState::Disconnected {
while let Some(res) = futures::ready!(expirations.poll_expired(cx)) {
let entry: QueueWorkerTicket = res.into_inner();
match entry {
QueueWorkerTicket::Oneshot(_, _) => {
let (fx, _, _) = entries.remove(&entry).unwrap();
if let QueueWorkerResult::EndSession = (fx)(&mut state_container) {
return Poll::Ready(Err(Error::shutdown()));
}
}
QueueWorkerTicket::Periodic(_, _) => {
let (fx, _key, duration) = entries.get(&entry).unwrap();
let next_key = match (fx)(&mut state_container) {
QueueWorkerResult::Complete => {
entries.remove(&entry);
std::mem::drop(state_container);
self.wake();
return Poll::Pending;
}
QueueWorkerResult::EndSession => {
return Poll::Ready(Err(Error::shutdown()));
}
QueueWorkerResult::AdjustPeriodicity(new_period) => {
expirations.insert(entry, new_period)
}
_ => {
let duration = *duration;
expirations.insert(entry, duration)
}
};
let (_fx, key, _duration) = entries.get_mut(&entry).unwrap();
*key = next_key;
}
}
}
Poll::Pending
} else {
Poll::Ready(Err(Error::shutdown()))
}
}
}
impl Stream for SessionQueueWorker {
type Item = ();
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
while let Poll::Ready(Some((key, timeout, on_timeout))) =
self.as_mut().get_mut().rx.poll_recv(cx)
{
self.as_mut()
.get_mut()
.insert_reserved(key, timeout, on_timeout);
}
match futures::ready!(self.poll_purge(cx)) {
Ok(_) => Poll::Pending,
Err(_) => Poll::Ready(None),
}
}
}
impl futures::Future for SessionQueueWorker {
type Output = Result<(), NetworkError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match futures::ready!(self.as_mut().poll_next(cx)) {
Some(_) => Poll::Pending,
None => {
if let Err(_err) = self.sess_shutdown.send(()) {
}
Poll::Ready(Err(NetworkError::InternalError(
"Queue handler signalled shutdown",
)))
}
}
}
}