use std::{cmp::Ordering, collections::HashMap, marker::PhantomData, ops::Add, time::Duration};
use prosa_utils::msg::tvf::Tvf;
use tokio::time::{Instant, Sleep, sleep_until};
use crate::core::msg::Msg;
#[derive(Debug)]
struct PendingTimer<T>
where
T: Copy,
{
timer_id: T,
timeout: Instant,
}
impl<T> PendingTimer<T>
where
T: Copy,
{
pub(crate) fn new(timer_id: T, timeout_duration: Duration) -> PendingTimer<T> {
PendingTimer {
timer_id,
timeout: Instant::now().add(timeout_duration),
}
}
pub(crate) fn new_at(timer_id: T, timeout: Instant) -> PendingTimer<T> {
PendingTimer { timer_id, timeout }
}
pub(crate) fn get_timer_id(&self) -> T {
self.timer_id
}
pub(crate) fn is_expired(&self) -> bool {
self.timeout <= Instant::now()
}
pub(crate) fn sleep(&self) -> Sleep {
sleep_until(self.timeout)
}
}
impl<T> Ord for PendingTimer<T>
where
T: Copy,
{
fn cmp(&self, other: &Self) -> Ordering {
self.timeout.cmp(&other.timeout)
}
}
impl<T> PartialOrd for PendingTimer<T>
where
T: Copy,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<T> PartialEq for PendingTimer<T>
where
T: Copy,
{
fn eq(&self, other: &Self) -> bool {
self.timeout == other.timeout
}
}
impl<T> Eq for PendingTimer<T> where T: Copy {}
#[derive(Debug, Default)]
pub struct Timers<T>
where
T: Copy,
{
timers: Vec<PendingTimer<T>>,
}
impl<T> Timers<T>
where
T: Copy,
{
pub fn len(&self) -> usize {
self.timers.len()
}
pub fn is_empty(&self) -> bool {
self.timers.is_empty()
}
fn push_timer(&mut self, timer: PendingTimer<T>) {
let mut timer_iter = self.timers.iter();
let index = loop {
if let Some(val) = timer_iter.next() {
if timer > *val {
break self.timers.len() - (timer_iter.count() + 1);
}
} else {
break self.timers.len();
}
};
self.timers.insert(index, timer);
}
pub fn push(&mut self, timer_id: T, timeout_duration: Duration) {
self.push_timer(PendingTimer::new(timer_id, timeout_duration));
}
pub fn push_at(&mut self, timer_id: T, timeout: Instant) {
self.push_timer(PendingTimer::new_at(timer_id, timeout));
}
pub async fn pull(&mut self) -> Option<T> {
if let Some(timer) = self.timers.last() {
if !timer.is_expired() {
timer.sleep().await;
}
self.timers.pop().map(|t| t.get_timer_id())
} else {
None
}
}
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(T) -> bool,
{
self.timers.retain(|t| f(t.timer_id));
}
fn pop(&mut self) -> Option<PendingTimer<T>> {
self.timers.pop()
}
fn last(&self) -> Option<&PendingTimer<T>> {
self.timers.last()
}
}
#[derive(Debug)]
pub struct PendingMsgs<T, M>
where
T: Msg<M>,
M: Sized + Clone + Tvf,
{
pending_messages: HashMap<u64, T>,
timers: Timers<u64>,
phantom: PhantomData<M>,
}
impl<T, M> PendingMsgs<T, M>
where
T: Msg<M>,
M: Sized + Clone + Tvf,
{
pub fn len(&self) -> usize {
self.pending_messages.len()
}
pub fn is_empty(&self) -> bool {
self.pending_messages.is_empty()
}
pub fn push(&mut self, msg: T, timeout: Duration) {
self.push_with_id(msg.get_id(), msg, timeout);
}
pub fn push_with_id(&mut self, id: u64, msg: T, timeout: Duration) {
self.timers.push(id, timeout);
self.pending_messages.insert(id, msg);
}
pub fn pull_msg(&mut self, msg_id: u64) -> Option<T> {
if let Some(msg) = self.pending_messages.remove(&msg_id) {
return Some(msg);
}
None
}
pub async fn pull(&mut self) -> Option<T> {
while let Some(timer) = self.timers.last() {
if self.pending_messages.contains_key(&timer.get_timer_id()) {
if !timer.is_expired() {
timer.sleep().await;
}
if let Some(time) = self.timers.pop() {
return self.pull_msg(time.get_timer_id());
} else {
return None;
}
} else {
self.timers.pop();
}
}
None
}
}
impl<T, M> Default for PendingMsgs<T, M>
where
T: Msg<M>,
M: Sized + Clone + Tvf,
{
fn default() -> Self {
PendingMsgs::<T, M> {
pending_messages: Default::default(),
timers: Default::default(),
phantom: PhantomData,
}
}
}
#[cfg(test)]
mod tests {
extern crate self as prosa;
use std::time::Duration;
use prosa_macros::{proc, settings};
use prosa_utils::msg::{simple_string_tvf::SimpleStringTvf, tvf::Tvf};
use serde::Serialize;
use tokio::time::timeout;
use crate::core::{
error::BusError,
main::{MainProc, MainRunnable},
msg::{InternalMsg, Msg, RequestMsg},
proc::{ProcBusParam, ProcConfig},
};
use super::{PendingMsgs, Timers};
#[proc]
pub(crate) struct TestProc {}
#[proc]
impl TestProc<SimpleStringTvf> {
async fn timers_run(&mut self) -> Result<(), BusError> {
self.proc.add_proc().await?;
self.proc
.add_service_proc(vec![String::from("TEST")])
.await?;
let mut pending_timer: Timers<u64> = Default::default();
loop {
tokio::select! {
Some(msg) = self.internal_rx_queue.recv() => {
match msg {
InternalMsg::Request(_) => {
assert_eq!(0, pending_timer.len());
pending_timer.push(1, Duration::from_millis(100));
assert_eq!(1, pending_timer.len());
},
InternalMsg::Service(table) => {
if let Some(service) = table.get_proc_service("TEST") {
service.proc_queue.send(InternalMsg::Request(RequestMsg::new(String::from("TEST"), Default::default(), self.proc.get_service_queue().clone()))).await.unwrap();
}
},
_ => return Err(BusError::ProcComm(self.get_proc_id(), 0, String::from("Wrong message"))),
}
},
Some(timer_id) = pending_timer.pull(), if !pending_timer.is_empty() => {
assert_eq!(0, pending_timer.len());
assert_eq!(1, timer_id);
self.proc.remove_proc(None).await?;
return Ok(())
},
}
}
}
async fn pending_msgs_run(&mut self) -> Result<(), BusError> {
self.proc.add_proc().await?;
self.proc
.add_service_proc(vec![String::from("TEST")])
.await?;
let mut pending_msg: PendingMsgs<RequestMsg<SimpleStringTvf>, SimpleStringTvf> =
Default::default();
loop {
tokio::select! {
Some(msg) = self.internal_rx_queue.recv() => {
match msg {
InternalMsg::Request(msg) => {
assert_eq!(0, pending_msg.len());
pending_msg.push(msg, Duration::from_millis(100));
assert_eq!(1, pending_msg.len());
},
InternalMsg::Service(table) => {
if let Some(service) = table.get_proc_service("TEST") {
let mut msg: SimpleStringTvf = Default::default();
msg.put_string(1, "good");
service.proc_queue.send(InternalMsg::Request(RequestMsg::new(String::from("TEST"), msg, self.proc.get_service_queue().clone()))).await.unwrap();
}
},
_ => return Err(BusError::ProcComm(self.get_proc_id(), 0, String::from("Wrong message"))),
}
},
Some(msg) = pending_msg.pull(), if !pending_msg.is_empty() => {
assert_eq!(0, pending_msg.len());
assert_eq!(String::from("good"), msg.get_data()?.get_string(1)?.into_owned());
self.proc.remove_proc(None).await?;
return Ok(())
},
}
}
}
pub(crate) async fn timers_timeout_run(&mut self) -> Result<(), BusError> {
if timeout(Duration::from_millis(200), self.timers_run())
.await
.is_err()
{
Err(BusError::InternalQueue(String::from(
"Timer is not working",
)))
} else {
Ok(())
}
}
pub(crate) async fn pending_msgs_timeout_run(&mut self) -> Result<(), BusError> {
if timeout(Duration::from_millis(200), self.pending_msgs_run())
.await
.is_err()
{
Err(BusError::InternalQueue(String::from(
"pending msgs is not working",
)))
} else {
Ok(())
}
}
}
#[tokio::test]
async fn test_pending() {
#[settings]
#[derive(Default, Debug, Serialize)]
struct DummySettings {}
let (bus, main) = MainProc::<SimpleStringTvf>::create(&DummySettings::default(), Some(2));
let main_task = tokio::spawn(main.run());
assert_eq!(
Ok(()),
TestProc::<SimpleStringTvf>::create_raw(1, "test1".to_string(), bus.clone())
.timers_timeout_run()
.await
);
assert_eq!(
Ok(()),
TestProc::<SimpleStringTvf>::create_raw(2, "test2".to_string(), bus.clone())
.pending_msgs_timeout_run()
.await
);
bus.stop("ProSA unit test end".into()).await.unwrap();
main_task.await.unwrap();
}
}