use std::
{
cell::OnceCell,
cmp,
collections::HashMap,
fmt,
num::NonZeroUsize,
sync::
{
Arc, Mutex, TryLockError, Weak, atomic::{AtomicBool, Ordering}, mpsc::{self, Receiver, Sender}
},
thread::JoinHandle,
time::Duration
};
use crossbeam_deque::{Injector, Steal};
use rand::random_range;
use crate::
{
AbsoluteTime,
FdTimerCom,
RelativeTime,
TimerPoll,
TimerReadRes,
error::{TimerError, TimerErrorType, TimerResult},
map_timer_err, timer_err,
timer_portable::
{
AsTimerId, PollEventType, PolledTimerFd, TimerExpMode, TimerFlags, TimerId, TimerType, poll::PollInterrupt, timer::TimerFd
}
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PeriodicTaskResult
{
Ok,
CancelTask,
TaskReSchedule(PeriodicTaskTime)
}
pub trait PeriodicTask: Send + 'static
{
fn exec(&mut self) -> PeriodicTaskResult;
}
pub type PeriodicTaskHndl = Box<dyn PeriodicTask>;
#[derive(Debug)]
pub(crate) enum GlobalTasks
{
AddTask( String, PeriodicTaskTime, Arc<PeriodicTaskGuardInner>, Option<mpsc::Sender<TimerResult<()>>> ),
RemoveTask( Arc<PeriodicTaskGuardInner>, Option<mpsc::Sender<TimerResult<()>>> ),
ReschedTask( Weak<PeriodicTaskGuardInner>, PeriodicTaskTime, Option<mpsc::Sender<TimerResult<()>>> ),
SuspendTask( Weak<PeriodicTaskGuardInner>, Option<mpsc::Sender<TimerResult<()>>> ),
ResumeTask( Weak<PeriodicTaskGuardInner>, Option<mpsc::Sender<TimerResult<()>>> ),
}
#[derive(Debug)]
pub enum ThreadTask
{
TaskExec( Arc<NewTaskTicket> )
}
#[derive(Debug)]
pub struct NewTaskTicket
{
task_thread: Arc<ThreadHandler>,
ptgi: Weak<PeriodicTaskGuardInner>,
}
impl NewTaskTicket
{
fn new(task_thread: Arc<ThreadHandler>, ptgi: Weak<PeriodicTaskGuardInner>) -> Self
{
return
Self
{
task_thread:
task_thread,
ptgi:
ptgi
};
}
fn send_task(this: Arc<NewTaskTicket>, task_rep_count: u64, thread_hndl_cnt: usize)
{
let strong_cnt = Arc::strong_count(&this);
let thread = this.task_thread.clone();
for _ in 0..task_rep_count
{
thread.send_task(this.clone(), strong_cnt < 2 && thread_hndl_cnt > 1);
}
}
}
#[derive(Debug)]
pub(crate) struct PeriodicTaskTicket
{
task_name: String,
sync_timer: PolledTimerFd<TimerFd>,
ptt: PeriodicTaskTime,
weak_ticket: Weak<NewTaskTicket>,
ptg: Weak<PeriodicTaskGuardInner>,
}
impl Ord for PeriodicTaskTicket
{
fn cmp(&self, other: &Self) -> cmp::Ordering
{
return
self
.sync_timer
.get_inner()
.as_timer_id()
.cmp(&other.sync_timer.get_inner().as_timer_id());
}
}
impl PartialOrd for PeriodicTaskTicket
{
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering>
{
return Some(self.cmp(other));
}
}
impl PartialEq for PeriodicTaskTicket
{
fn eq(&self, other: &Self) -> bool
{
return
self.task_name == other.task_name &&
self.sync_timer.get_inner().as_timer_id() == other.sync_timer.get_inner().as_timer_id();
}
}
impl Eq for PeriodicTaskTicket {}
impl PeriodicTaskTicket
{
fn new(task_name: String, ptt: PeriodicTaskTime, ptg: Arc<PeriodicTaskGuardInner>, poll: &TimerPoll) -> TimerResult<Self>
{
let sync_timer =
TimerFd::new(task_name.clone().into(), TimerType::CLOCK_REALTIME,
TimerFlags::TFD_CLOEXEC | TimerFlags::TFD_NONBLOCK)
.map_err(|e|
map_timer_err!(TimerErrorType::TimerError(e.get_errno()), "cannot setup timer for task: '{}'", task_name)
)
.and_then(|timer|
poll.add(timer)
)?;
let ptt =
Self
{
task_name: task_name,
sync_timer: sync_timer,
ptt: ptt,
weak_ticket: Weak::new(),
ptg: Arc::downgrade(&ptg),
};
ptt.set_timer()?;
return Ok(ptt);
}
#[inline]
fn set_timer(&self) -> TimerResult<()>
{
return self.get_timer_time().set_timer(self.sync_timer.get_inner());
}
#[inline]
fn unset_timer(&self) -> TimerResult<()>
{
return
self
.sync_timer
.get_inner()
.get_timer()
.unset_time()
.map_err(|e|
map_timer_err!(TimerErrorType::TimerError(e.get_errno()), "unsetting timer '{}' returned error: {}", self.sync_timer, e)
);
}
fn update_task_time(&mut self, ptt_new: PeriodicTaskTime) -> TimerResult<()>
{
self.ptt = ptt_new;
return self.set_timer();
}
fn get_task_guard(&self) -> TimerResult<Arc<PeriodicTaskGuardInner>>
{
return
self
.ptg
.upgrade()
.ok_or_else(||
map_timer_err!(TimerErrorType::ReferenceGone, "task: '{}' reference to timer has gone",
self.task_name)
);
}
#[inline]
fn get_timer_time(&self) -> &PeriodicTaskTime
{
return &self.ptt;
}
}
#[derive(Debug)]
pub struct PeriodicTaskGuard
{
task_name: String,
guard: Option<Arc<PeriodicTaskGuardInner>>,
spt: Arc<SyncPeriodicTasksInner>
}
impl Drop for PeriodicTaskGuard
{
fn drop(&mut self)
{
let guard = self.guard.take().unwrap();
let _ = self.spt.send_global_cmd(GlobalTasks::RemoveTask(guard, None));
return;
}
}
impl PeriodicTaskGuard
{
pub
fn reschedule_task(&self, ptt: PeriodicTaskTime) -> TimerResult<()>
{
let weak_ptgi = Arc::downgrade(self.guard.as_ref().unwrap());
let (snd, rcv) = mpsc::channel::<Result<(), TimerError>>();
self.spt.send_global_cmd(GlobalTasks::ReschedTask(weak_ptgi, ptt, Some(snd)))?;
return
rcv
.recv_timeout(Duration::from_secs(10))
.map_err(|e|
map_timer_err!(TimerErrorType::MpscTimeout, "reschedule_task(), task name: '{}', MPSC rcv timeout error: '{}'",
self.task_name, e)
)?;
}
pub
fn suspend_task(&self) -> TimerResult<()>
{
let weak_ptgi = Arc::downgrade(self.guard.as_ref().unwrap());
let (snd, rcv) = mpsc::channel::<Result<(), TimerError>>();
self.spt.send_global_cmd(GlobalTasks::SuspendTask(weak_ptgi, Some(snd)))?;
return
rcv
.recv_timeout(Duration::from_secs(10))
.map_err(|e|
map_timer_err!(TimerErrorType::MpscTimeout, "suspend_task(), task name: '{}', MPSC rcv timeout error: '{}'",
self.task_name, e)
)?;
}
pub
fn resume_task(&self) -> TimerResult<()>
{
let weak_ptgi = Arc::downgrade(self.guard.as_ref().unwrap());
let (snd, rcv) = mpsc::channel::<Result<(), TimerError>>();
self.spt.send_global_cmd(GlobalTasks::ResumeTask(weak_ptgi, Some(snd)))?;
return
rcv
.recv_timeout(Duration::from_secs(10))
.map_err(|e|
map_timer_err!(TimerErrorType::MpscTimeout, "resume_task(), task name: '{}', MPSC rcv timeout error: '{}'",
self.task_name, e)
)?;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PeriodicTaskTime
{
Absolute(TimerExpMode<AbsoluteTime>),
Relative(TimerExpMode<RelativeTime>),
}
impl From<TimerExpMode<AbsoluteTime>> for PeriodicTaskTime
{
fn from(value: TimerExpMode<AbsoluteTime>) -> Self
{
return Self::Absolute(value);
}
}
impl From<TimerExpMode<RelativeTime>> for PeriodicTaskTime
{
fn from(value: TimerExpMode<RelativeTime>) -> Self
{
return Self::Relative(value);
}
}
impl fmt::Display for PeriodicTaskTime
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
match self
{
Self::Absolute(t) =>
write!(f, "{}", t),
Self::Relative(t) =>
write!(f, "{}", t),
}
}
}
impl PeriodicTaskTime
{
#[inline]
pub
fn exact_time(abs_time: AbsoluteTime) -> Self
{
return Self::Absolute(TimerExpMode::<AbsoluteTime>::new_oneshot(abs_time));
}
#[inline]
pub
fn interval(rel_time: RelativeTime) -> Self
{
return Self::Relative(TimerExpMode::<RelativeTime>::new_interval(rel_time));
}
#[inline]
pub
fn interval_with_start_delay(start_del_time: RelativeTime, rel_int_time: RelativeTime) -> Self
{
return Self::Relative(TimerExpMode::<RelativeTime>::new_interval_with_init_delay(start_del_time, rel_int_time));
}
fn set_timer(&self, timer_fd: &TimerFd) -> TimerResult<()>
{
match *self
{
Self::Absolute(timer_exp_mode) =>
return
timer_fd
.get_timer()
.set_time(timer_exp_mode)
.map_err(|e|
map_timer_err!(TimerErrorType::TimerError(e.get_errno()), "cannot set time '{}' for timer: '{}'", timer_exp_mode, timer_fd )
),
Self::Relative(timer_exp_mode) =>
return
timer_fd
.get_timer()
.set_time(timer_exp_mode)
.map_err(|e|
map_timer_err!(TimerErrorType::TimerError(e.get_errno()), "cannot set time '{}' for timer: '{}'", timer_exp_mode, timer_fd )
),
}
}
}
pub(crate) struct PeriodicTaskGuardInner
{
task_name: String,
task: Mutex<PeriodicTaskHndl>,
}
impl fmt::Debug for PeriodicTaskGuardInner
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
f.debug_struct("PeriodicTaskGuardInner").field("task_name", &self.task_name).finish()
}
}
impl fmt::Display for PeriodicTaskGuardInner
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
{
write!(f, "task name: '{}'", self.task_name)
}
}
impl PeriodicTaskGuardInner
{
fn new(task_name: String, task_inst: PeriodicTaskHndl) -> TimerResult<Self>
{
return Ok(Self { task_name: task_name, task: Mutex::new(task_inst) });
}
}
struct ThreadWorker
{
thread_name: String,
global_task_injector: Arc<Injector<GlobalTasks>>,
local_thread_inj: Arc<Injector<ThreadTask>>,
thread_run_flag: Arc<AtomicBool>,
spti: Arc<Mutex<SharedPeriodicTasks>>,
thread_last_id: usize,
poll_int: PollInterrupt,
tx_err: Option<Sender<TimerError>>,
}
impl ThreadWorker
{
fn new(
thread_name: String,
global_task_injector: Arc<Injector<GlobalTasks>>,
spti: Arc<Mutex<SharedPeriodicTasks>>,
poll_int: PollInterrupt,
tx_err: Option<Sender<TimerError>>
) -> TimerResult<ThreadHandler>
{
let local_thread_inj = Arc::new(Injector::<ThreadTask>::new());
let thread_run_flag = Arc::new(AtomicBool::new(true));
let thread_run_flag_weak = Arc::downgrade(&thread_run_flag);
let worker =
ThreadWorker
{
thread_name:
thread_name.clone(),
global_task_injector:
global_task_injector,
local_thread_inj:
local_thread_inj.clone(),
thread_run_flag:
thread_run_flag,
spti:
spti,
thread_last_id:
0,
poll_int:
poll_int,
tx_err:
tx_err
};
let thread_hndl =
std::thread::Builder::new()
.name(thread_name)
.spawn(|| worker.worker())
.map_err(|e|
map_timer_err!(TimerErrorType::SpawnError(e.kind()), "{}", e)
)?;
return Ok( ThreadHandler::new(thread_hndl, local_thread_inj, thread_run_flag_weak) );
}
fn worker(mut self) -> TimerResult<()>
{
std::thread::park();
while self.thread_run_flag.load(Ordering::Acquire) == true
{
while let Steal::Success(task) = self.local_thread_inj.steal()
{
match task
{
ThreadTask::TaskExec(task_exec) =>
{
let Some(ptgi) = task_exec.ptgi.upgrade()
else
{
continue;
};
match ptgi.task.lock().unwrap().exec()
{
PeriodicTaskResult::Ok =>
{},
PeriodicTaskResult::CancelTask =>
{
self
.global_task_injector
.push(GlobalTasks::SuspendTask(task_exec.ptgi.clone(), None));
let _ = self.poll_int.aquire().map(|v| v.interrupt_drop());
},
PeriodicTaskResult::TaskReSchedule(ptt) =>
{
self
.global_task_injector
.push(GlobalTasks::ReschedTask(task_exec.ptgi.clone(), ptt, None));
let _ = self.poll_int.aquire().map(|v| v.interrupt_drop());
}
}
drop(task_exec);
}
}
}
let spti_lock_res = self.spti.try_lock();
if let Ok(mut task_token) = spti_lock_res
{
let thread_hndl_cnt = task_token.thread_pool.get().unwrap().len();
while let Steal::Success(task) = self.global_task_injector.steal()
{
match task
{
GlobalTasks::AddTask(task_name, ptt, ptg, opt_err_ret) =>
{
if task_token.contains_task(&task_name) == true
{
if let Some(err_ret) = opt_err_ret
{
let err_msg =
map_timer_err!(TimerErrorType::Duplicate,
"thread: '{}', task: '{}' already exists", self.thread_name, task_name);
let _ = err_ret.send(Err(err_msg));
}
continue;
}
let period_task_ticket =
PeriodicTaskTicket::new(task_name.clone(), ptt, ptg, &task_token.timers_poll);
if let Err(e) = period_task_ticket
{
if let Some(err_ret) = opt_err_ret
{
let _ = err_ret.send(Err(e));
}
continue;
}
let period_task_ticket = period_task_ticket.unwrap();
task_token.insert_task(period_task_ticket);
if let Some(err_ret) = opt_err_ret
{
let _ = err_ret.send(Ok(()));
}
},
GlobalTasks::RemoveTask(ptg_arc, opt_err_ret) =>
{
let res = task_token.remove_task(&ptg_arc.task_name);
if let Err(e) = res
{
if let Some(err_ret) = opt_err_ret
{
let err_msg =
map_timer_err!(e.get_error_type(),
"thread: '{}', {}", self.thread_name, e.get_error_msg());
let _ = err_ret.send(Err(err_msg));
}
continue;
}
let ptt_ref = res.unwrap();
if let Err(e) = ptt_ref.unset_timer()
{
if let Some(err_ret) = opt_err_ret.as_ref()
{
let _ = err_ret.send(Err(e));
}
continue;
}
drop(ptt_ref);
drop(ptg_arc);
},
GlobalTasks::ReschedTask( ptg_weak, ptt, opt_err_ret ) =>
{
let Some(ptg_arc) = ptg_weak.upgrade()
else
{
continue;
};
let res_task = task_token.get_task_by_name(&ptg_arc.task_name);
let res =
match res_task
{
Ok(task) =>
{
let _ = task.unset_timer();
let res = task.update_task_time(ptt);
if let Err(e) = res
{
if let Some(err_ret) = opt_err_ret.as_ref()
{
let _ = err_ret.send(Err(e));
}
continue;
}
Ok(())
},
Err(err) =>
{
Err(err)
}
};
if let Some(err_ret) = opt_err_ret
{
let _ = err_ret.send(res);
}
},
GlobalTasks::SuspendTask(ptg_weak, opt_err_ret) =>
{
let Some(ptg_arc) = ptg_weak.upgrade()
else
{
continue;
};
let res_task = task_token.get_task_by_name(&ptg_arc.task_name);
if let Err(e) = res_task
{
if let Some(err_ret) = opt_err_ret.as_ref()
{
let _ = err_ret.send(Err(e));
}
continue;
}
let task = res_task.unwrap();
let res = task.unset_timer();
if let Some(err_ret) = opt_err_ret
{
let _ = err_ret.send(res);
}
},
GlobalTasks::ResumeTask(ptg_weak, opt_err_ret) =>
{
let Some(ptg_arc) = ptg_weak.upgrade()
else
{
continue;
};
let res_task = task_token.get_task_by_name(&ptg_arc.task_name);
if let Err(e) = res_task
{
if let Some(err_ret) = opt_err_ret.as_ref()
{
let _ = err_ret.send(Err(e));
}
continue;
}
let res = res_task.unwrap().set_timer();
if let Some(err_ret) = opt_err_ret.as_ref()
{
let _ = err_ret.send(res);
}
}
}
}
let Some(res) =
task_token
.timers_poll
.poll(Some(5000))?
else
{
continue;
};
for event in res
{
match event
{
PollEventType::TimerRes(timer_fd, timer_res) =>
{
let task_by_id =
task_token
.get_task_by_timer_id(timer_fd)
.map_err(|e|
map_timer_err!(e.get_error_type(), "thread: '{}', {}", self.thread_name, e.get_error_msg())
)
.map(|task|
(task.ptg.clone(), task.ptt.clone(), task.weak_ticket.upgrade(), task.task_name.clone())
);
let Ok((ptg, ptt, ticket_arc_opt, task_name)) = task_by_id
else
{
if let Some(tx_err) = self.tx_err.as_ref()
{
let _ = tx_err.send(task_by_id.err().unwrap());
}
continue;
};
let Some(ptg_arc) = ptg.upgrade()
else
{
if let Err(e) = task_token.remove_task(&task_name)
{
if let Some(tx_err) = self.tx_err.as_ref()
{
let _ = tx_err.send(e);
}
}
continue;
};
let overflow_cnt: u64 =
match timer_res
{
TimerReadRes::Ok(overfl) =>
{
overfl
},
TimerReadRes::Cancelled =>
{
self
.global_task_injector
.push(
GlobalTasks::ReschedTask(ptg.clone(), ptt.clone(), None)
);
continue;
},
TimerReadRes::WouldBlock =>
{
panic!("assertion trap: timer retuned WouldBlock, {}", ptg_arc);
}
};
let ticket =
match ticket_arc_opt
{
Some(ticket) =>
ticket,
None =>
{
let task_thread =
{
self.thread_last_id = (self.thread_last_id + 1) % thread_hndl_cnt;
task_token.clone_thread_handler(self.thread_last_id)
};
let ticket =
Arc::new(NewTaskTicket::new(task_thread, ptg.clone()));
let task =
task_token
.get_task_by_timer_id(timer_fd)
.map_err(|e|
map_timer_err!(e.get_error_type(), "thread: '{}', {}", self.thread_name, e.get_error_msg())
)?;
task.weak_ticket = Arc::downgrade(&ticket);
ticket
}
};
NewTaskTicket::send_task(ticket, overflow_cnt, thread_hndl_cnt);
},
_ =>
{
},
}
} }
else if let Err(TryLockError::WouldBlock) = spti_lock_res
{
if self.thread_run_flag.load(Ordering::Acquire) == false
{
return Ok(());
}
if self.local_thread_inj.is_empty() == true
{
std::thread::park_timeout(Duration::from_secs(2));
}
}
}
return Ok(());
}
}
#[derive(Debug)]
struct ThreadHandler
{
hndl: JoinHandle<TimerResult<()>>,
task_injector: Arc<Injector<ThreadTask>>,
thread_flag: Weak<AtomicBool>,
}
impl ThreadHandler
{
fn new(hndl: JoinHandle<TimerResult<()>>, task_injector: Arc<Injector<ThreadTask>>, thread_flag: Weak<AtomicBool>) -> Self
{
return
Self
{
hndl,
task_injector,
thread_flag: thread_flag
};
}
fn stop(&self)
{
if let Some(v) = self.thread_flag.upgrade()
{
v.store(false, Ordering::Release);
}
}
fn unpark(&self)
{
self.hndl.thread().unpark();
}
fn send_task(&self, task: Arc<NewTaskTicket>, unpark: bool)
{
self.task_injector.push(ThreadTask::TaskExec(task));
if unpark == true
{
self.hndl.thread().unpark();
}
return;
}
fn clean_local_queue(&self)
{
while let Steal::Success(_) = self.task_injector.steal() {}
return;
}
}
#[derive(Debug)]
pub struct SharedPeriodicTasks
{
thread_pool: OnceCell<Arc<Vec<Arc<ThreadHandler>>>>,
tasks_by_timer_fd: HashMap<TimerId, PeriodicTaskTicket>,
tasks_name_to_timer_id: HashMap<String, TimerId>,
timers_poll: TimerPoll,
}
impl SharedPeriodicTasks
{
fn new() -> TimerResult<Self>
{
return Ok(
Self
{
thread_pool:
OnceCell::default(),
tasks_by_timer_fd:
HashMap::new(),
tasks_name_to_timer_id:
HashMap::new(),
timers_poll:
TimerPoll::new()?
}
);
}
fn get_task_by_timer_id(&mut self, timer_fd: TimerId) -> TimerResult<&mut PeriodicTaskTicket>
{
return
self
.tasks_by_timer_fd
.get_mut(&timer_fd)
.ok_or_else(||
map_timer_err!(TimerErrorType::NotFound, "task fd: '{}' was found but task was not found",
timer_fd)
);
}
fn get_task_by_name(&mut self, task_name: &str) -> TimerResult<&mut PeriodicTaskTicket>
{
let res =
self
.tasks_name_to_timer_id
.get(task_name)
.ok_or_else(||
map_timer_err!(TimerErrorType::NotFound, "task: '{}' was not found", task_name)
)
.map(|v|
self
.tasks_by_timer_fd
.get_mut(v)
.ok_or_else(||
map_timer_err!(TimerErrorType::NotFound, "task: '{}' fd: '{}' was found but task was not found",
task_name, v)
)
)??;
return Ok(res);
}
fn clone_thread_handler(&self, thread_last_id: usize) -> Arc<ThreadHandler>
{
let thread_local_hnd =
self
.thread_pool
.get()
.unwrap()[thread_last_id]
.clone();
return thread_local_hnd;
}
fn remove_task(&mut self, task_name: &str) -> TimerResult<PeriodicTaskTicket>
{
let Some(task_timer_fd) = self.tasks_name_to_timer_id.remove(task_name)
else
{
timer_err!(TimerErrorType::NotFound, "task: '{}' was not found", task_name);
};
let Some(ptt) = self.tasks_by_timer_fd.remove(&task_timer_fd)
else
{
timer_err!(TimerErrorType::NotFound, "task: '{}' fd: '{}' was found but task was not found",
task_name, task_timer_fd);
};
return Ok(ptt);
}
fn contains_task(&self, task_name: &str) -> bool
{
return self.tasks_name_to_timer_id.contains_key(task_name);
}
fn insert_task(&mut self, period_task_ticket: PeriodicTaskTicket)
{
let timer_id = period_task_ticket.sync_timer.get_inner().as_timer_id();
self.tasks_name_to_timer_id.insert(period_task_ticket.task_name.clone(), timer_id);
self.tasks_by_timer_fd.insert(timer_id, period_task_ticket);
return;
}
}
#[derive(Debug)]
pub struct SyncPeriodicTasksInner
{
poll_int: PollInterrupt,
task_injector: Arc<Injector<GlobalTasks>>,
error_journal: Option<Mutex<Receiver<TimerError>>>,
}
impl SyncPeriodicTasksInner
{
fn send_global_cmd(&self, glob: GlobalTasks) -> TimerResult<()>
{
let poll_int =
self.poll_int.aquire()?;
self.task_injector.push(glob);
poll_int.interrupt_drop()?;
return Ok(());
}
fn clear_global_queue(&self)
{
while let Steal::Success(_) = self.task_injector.steal() {}
return;
}
}
struct SyncCallableOper
{
op: Box<dyn FnMut() -> PeriodicTaskResult>,
}
unsafe impl Send for SyncCallableOper {}
impl PeriodicTask for SyncCallableOper
{
fn exec(&mut self) -> PeriodicTaskResult
{
return (self.op)();
}
}
#[derive(Debug, Clone)]
pub struct SyncPeriodicTasks
{
threads: Option<Arc<Vec<Arc<ThreadHandler>>>>,
inner: Arc<SyncPeriodicTasksInner>,
}
impl Drop for SyncPeriodicTasks
{
fn drop(&mut self)
{
self.inner.clear_global_queue();
let mut threads = self.threads.take().unwrap();
for thread in threads.iter()
{
thread.stop();
thread.unpark();
}
let _ = self.inner.poll_int.aquire().map(|v| v.interrupt_drop());
for _ in 0..5
{
let threads_unwr =
match Arc::try_unwrap(threads)
{
Ok(r) => r,
Err(e) =>
{
threads = e;
std::thread::sleep(Duration::from_millis(500));
continue;
}
};
for thread in threads_unwr
{
thread.clean_local_queue();
let Some(thread) = Arc::into_inner(thread)
else
{
panic!("assertion trap: ~SyncPeriodicTasks, a reference to ThreadHandler left somewhere");
};
let _ = thread.hndl.join();
}
break;
}
}
}
impl SyncPeriodicTasks
{
pub
fn new(threads_cnt: NonZeroUsize, error_report: bool) -> TimerResult<Self>
{
let spti = SharedPeriodicTasks::new()?;
let poll_int = spti.timers_poll.get_poll_interruptor();
let spti = Arc::new(Mutex::new(spti));
let task_injector = Arc::new(Injector::<GlobalTasks>::new());
let mut thread_hndls: Vec<Arc<ThreadHandler>> = Vec::with_capacity(threads_cnt.get());
let err_journal =
if error_report == true
{
Some(mpsc::channel::<TimerError>())
}
else
{
None
};
for i in 0..threads_cnt.get()
{
let handler =
ThreadWorker::new(
format!("timer_exec/{}s", i),
task_injector.clone(),
spti.clone(),
poll_int.clone(),
err_journal.as_ref().map(|c| c.0.clone())
)?;
thread_hndls.push(Arc::new(handler));
}
let thread_hndls = Arc::new(thread_hndls);
let spti_lock = spti.lock().unwrap();
spti_lock.thread_pool.get_or_init(|| thread_hndls.clone());
let thread =
spti_lock
.thread_pool
.get()
.unwrap()
.get(random_range(0..threads_cnt.get()))
.unwrap()
.clone();
drop(spti_lock);
thread.unpark();
let inner =
SyncPeriodicTasksInner
{
poll_int:
poll_int,
task_injector:
task_injector,
error_journal:
err_journal.map(|j| Mutex::new(j.1))
};
return Ok(
Self
{
threads: Some(thread_hndls),
inner: Arc::new(inner),
}
);
}
pub
fn add<T>(&self, task_name: impl Into<String>, task: T, task_time: PeriodicTaskTime) -> TimerResult<PeriodicTaskGuard>
where T: PeriodicTask
{
let task_int: PeriodicTaskHndl = Box::new(task);
let task_name_str: String = task_name.into();
let period_task_guard =
Arc::new(PeriodicTaskGuardInner::new(task_name_str.clone(), task_int)?);
let (mpsc_send, mpsc_recv) = mpsc::channel();
self.inner.send_global_cmd(GlobalTasks::AddTask(task_name_str.clone(), task_time, period_task_guard.clone(), Some(mpsc_send)) )?;
let _ =
mpsc_recv
.recv()
.map_err(|e|
map_timer_err!(TimerErrorType::ExternalError, "mpsc error: {}", e)
)??;
let ret =
PeriodicTaskGuard
{
task_name: task_name_str,
guard: Some(period_task_guard),
spt: self.inner.clone()
};
return Ok(ret);
}
pub
fn add_closure<F>(&self, task_name: impl Into<String>, task_time: PeriodicTaskTime, clo: F) -> TimerResult<PeriodicTaskGuard>
where F: 'static + FnMut() -> PeriodicTaskResult + Send
{
let closure_task = SyncCallableOper{ op: Box::new(clo) };
return self.add(task_name, closure_task, task_time);
}
pub
fn check_thread_status(&self) -> Option<String>
{
for thread in self.threads.as_ref().unwrap().iter()
{
if let None = thread.thread_flag.upgrade()
{
return Some(thread.hndl.thread().name().unwrap().to_string());
}
}
return None;
}
pub
fn read_error(&self) -> Option<TimerError>
{
let Some(rx) = self.inner.error_journal.as_ref()
else { return None };
let Ok(err) =
rx
.lock()
.unwrap_or_else(|e| e.into_inner())
.recv_timeout(Duration::from_secs(0))
else { return None };
return Some(err);
}
}
#[cfg(test)]
mod tests
{
use core::fmt;
use std::{sync::mpsc::{self, RecvTimeoutError, Sender}, time::{Duration, Instant}};
use crate::{periodic_task::sync_tasks::{PeriodicTask, PeriodicTaskResult, PeriodicTaskTime, SyncPeriodicTasks}, AbsoluteTime, RelativeTime};
#[derive(Debug)]
struct TaskStruct1
{
a1: u64,
s: Sender<u64>,
}
impl TaskStruct1
{
fn new(a1: u64, s: Sender<u64>) -> Self
{
return Self{ a1: a1, s };
}
}
impl PeriodicTask for TaskStruct1
{
fn exec(&mut self) -> PeriodicTaskResult
{
println!("taskstruct1 val: {}", self.a1);
let _ = self.s.send(self.a1);
return PeriodicTaskResult::Ok;
}
}
#[derive(Debug)]
struct TaskStruct2
{
a1: u64,
s: Sender<u64>,
}
impl TaskStruct2
{
fn new(a1: u64, s: Sender<u64>) -> Self
{
return Self{ a1: a1, s };
}
}
impl PeriodicTask for TaskStruct2
{
fn exec(&mut self) -> PeriodicTaskResult
{
println!("taskstruct2 val: {}", self.a1);
self.s.send(self.a1).unwrap();
return PeriodicTaskResult::TaskReSchedule(PeriodicTaskTime::exact_time(AbsoluteTime::now() + RelativeTime::new_time(2, 0)));
}
}
impl<F> PeriodicTask for F
where F: 'static + FnMut() + Send + fmt::Debug
{
fn exec(&mut self) -> PeriodicTaskResult
{
(self)();
return PeriodicTaskResult::Ok;
}
}
#[test]
fn ttt()
{
let s =
SyncPeriodicTasks::new(1.try_into().unwrap(), true).unwrap();
let task1_ptt =
PeriodicTaskTime
::exact_time(AbsoluteTime::now() + RelativeTime::new_time(3, 0));
let (send, recv) = mpsc::channel::<u64>();
let task1_guard =
s.add_closure("task2", task1_ptt,
move ||
{
println!("test output");
send.send(2).unwrap();
return PeriodicTaskResult::Ok;
}
).unwrap();
println!("added");
let val = recv.recv_timeout(Duration::from_millis(4000));
if val.is_err() == true
{
let e = s.read_error();
println!("ERROR, {:?}",e);
assert_eq!(true, false, "{:?}", e);
}
println!("val: {:?}", val);
assert_eq!(Ok(2), val);
drop(task1_guard);
}
#[test]
fn test1_absolute_simple()
{
let s = SyncPeriodicTasks::new(1.try_into().unwrap(), false).unwrap();
let (send, recv) = mpsc::channel::<u64>();
let task1 = TaskStruct1::new(2, send);
let task1_ptt = PeriodicTaskTime::exact_time(AbsoluteTime::now() + RelativeTime::new_time(3, 0));
let task1_guard = s.add("task1", task1, task1_ptt).unwrap();
println!("added");
let val = recv.recv();
println!("{:?}", val);
drop(task1_guard);
}
#[test]
fn test1_relative_simple()
{
let s = SyncPeriodicTasks::new(1.try_into().unwrap(), false).unwrap();
let (send, recv) = mpsc::channel::<u64>();
let task1 = TaskStruct1::new(2, send);
let task1_ptt = PeriodicTaskTime::interval(RelativeTime::new_time(1, 0));
let task1_guard = s.add("task1", task1, task1_ptt).unwrap();
let mut s = Instant::now();
for i in 0..3
{
let val = recv.recv().unwrap();
let e = s.elapsed();
s = Instant::now();
println!("{}: {:?} {:?} {}", i, val, e, e.as_micros());
assert!(999000 < e.as_micros() && e.as_micros() < 10001200);
assert_eq!(val, 2);
}
drop(task1_guard);
std::thread::sleep(Duration::from_millis(100));
return;
}
#[test]
fn test1_relative_resched_to_abs()
{
let s = SyncPeriodicTasks::new(1.try_into().unwrap(), false).unwrap();
let (send, recv) = mpsc::channel::<u64>();
let task1 = TaskStruct2::new(2, send);
let task1_ptt = PeriodicTaskTime::interval(RelativeTime::new_time(1, 0));
let task1_guard = s.add("task1", task1, task1_ptt).unwrap();
let s = Instant::now();
match recv.recv_timeout(Duration::from_millis(1150))
{
Ok(rcv_a) =>
{
let e = s.elapsed();
println!("{:?} {}", e, e.as_micros());
assert_eq!(rcv_a, 2);
assert!(990051 < e.as_micros() && e.as_micros() < 1020551);
},
Err(RecvTimeoutError::Timeout) =>
panic!("tineout"),
Err(e) =>
panic!("{}", e),
}
let s = Instant::now();
match recv.recv_timeout(Duration::from_millis(2100))
{
Ok(rcv_a) =>
{
let e = s.elapsed();
println!("{:?} {}", e, e.as_micros());
assert_eq!(rcv_a, 2);
assert!(1999642 < e.as_micros() && e.as_micros() < 2008342);
},
Err(RecvTimeoutError::Timeout) =>
panic!("tineout"),
Err(e) =>
panic!("{}", e),
}
let s = Instant::now();
match recv.recv_timeout(Duration::from_millis(2100))
{
Ok(rcv_a) =>
{
let e = s.elapsed();
println!("{:?} {}", e, e.as_micros());
assert_eq!(rcv_a, 2);
assert!(1999642 < e.as_micros() && e.as_micros() < 2003342);
},
Err(RecvTimeoutError::Timeout) =>
panic!("tineout"),
Err(e) =>
panic!("{}", e),
}
drop(task1_guard);
std::thread::sleep(Duration::from_millis(100));
return;
}
#[test]
fn test1_relative_simple_resched()
{
let s = SyncPeriodicTasks::new(1.try_into().unwrap(), false).unwrap();
let (send, recv) = mpsc::channel::<u64>();
let task1 = TaskStruct1::new(2, send);
let task1_ptt =
PeriodicTaskTime::interval(
RelativeTime::new_time(1, 0)
);
let task1_guard = s.add("task1", task1, task1_ptt).unwrap();
let mut s = Instant::now();
for i in 0..3
{
let val = recv.recv().unwrap();
let e = s.elapsed();
s = Instant::now();
println!("{}: {:?} {:?} {}", i, val, e, e.as_micros());
assert!(990000 < e.as_micros() && e.as_micros() < 10001200);
assert_eq!(val, 2);
}
task1_guard
.reschedule_task(
PeriodicTaskTime::exact_time(AbsoluteTime::now() + RelativeTime::new_time(2, 0))
)
.unwrap();
s = Instant::now();
let val = recv.recv().unwrap();
let e = s.elapsed();
println!("resched: {:?} {:?} {}", val, e, e.as_micros());
assert!(1990000 < e.as_micros() && e.as_micros() < 2003560);
let val = recv.recv_timeout(Duration::from_secs(3));
assert_eq!(val.is_err(), true);
assert_eq!(val.err().unwrap(), RecvTimeoutError::Timeout);
drop(task1_guard);
std::thread::sleep(Duration::from_millis(100));
return;
}
#[test]
fn test1_relative_simple_cancel()
{
let s = SyncPeriodicTasks::new(1.try_into().unwrap(), false).unwrap();
let (send, recv) = mpsc::channel::<u64>();
let task1 = TaskStruct1::new(0, send.clone());
let task1_ptt =
PeriodicTaskTime::interval(RelativeTime::new_time(1, 0));
let task2 = TaskStruct1::new(1, send.clone());
let task2_ptt =
PeriodicTaskTime::interval(RelativeTime::new_time(2, 0));
let task3 = TaskStruct1::new(2, send);
let task3_ptt =
PeriodicTaskTime::interval(RelativeTime::new_time(0, 500_000_000));
let task1_guard = s.add("task1", task1, task1_ptt).unwrap();
let task2_guard = s.add("task2", task2, task2_ptt).unwrap();
let task3_guard = s.add("task3", task3, task3_ptt).unwrap();
let mut a_cnt: [u8; 3] = [0_u8; 3];
let end = AbsoluteTime::now() + RelativeTime::new_time(5, 100_000_000);
while AbsoluteTime::now() < end
{
match recv.recv_timeout(Duration::from_millis(1))
{
Ok(rcv_a) =>
a_cnt[rcv_a as usize] += 1,
Err(RecvTimeoutError::Timeout) =>
continue,
Err(e) =>
panic!("{}", e),
}
}
assert_eq!(a_cnt[0], 5);
assert_eq!(a_cnt[1], 2);
assert_eq!(a_cnt[2], 10);
task3_guard.suspend_task().unwrap();
let end = AbsoluteTime::now() + RelativeTime::new_time(5, 100_000_000);
while AbsoluteTime::now() < end
{
match recv.recv_timeout(Duration::from_millis(1))
{
Ok(rcv_a) =>
a_cnt[rcv_a as usize] += 1,
Err(RecvTimeoutError::Timeout) =>
continue,
Err(e) =>
panic!("{}", e),
}
}
assert_eq!(a_cnt[0] > 5, true);
assert_eq!(a_cnt[1] > 2, true);
assert!((a_cnt[2] == 10 || a_cnt[2] == 11));
drop(task1_guard);
drop(task2_guard);
drop(task3_guard);
let end = AbsoluteTime::now() + RelativeTime::new_time(5, 100_000_000);
while AbsoluteTime::now() < end
{
match recv.recv_timeout(Duration::from_millis(1))
{
Ok(rcv_a) =>
a_cnt[rcv_a as usize] += 1,
Err(RecvTimeoutError::Timeout) =>
continue,
Err(_) =>
break,
}
}
assert_eq!(AbsoluteTime::now() < end, true);
return;
}
#[test]
fn test2_multithread_1()
{
let s = SyncPeriodicTasks::new(2.try_into().unwrap(), false).unwrap();
let (send, recv) = mpsc::channel::<u64>();
let task1 = TaskStruct1::new(0, send.clone());
let task1_ptt =
PeriodicTaskTime::interval(RelativeTime::new_time(1, 0));
let task2 = TaskStruct1::new(1, send.clone());
let task2_ptt =
PeriodicTaskTime::interval(RelativeTime::new_time(2, 0));
let task3 = TaskStruct1::new(2, send.clone());
let task3_ptt =
PeriodicTaskTime::interval(RelativeTime::new_time(0, 500_000_000));
let task4 = TaskStruct1::new(3, send.clone());
let task4_ptt =
PeriodicTaskTime::interval(RelativeTime::new_time(0, 200_000_000));
let task5 = TaskStruct1::new(4, send.clone());
let task5_ptt =
PeriodicTaskTime::exact_time(AbsoluteTime::now() + RelativeTime::new_time(5, 0));
let task1_guard = s.add("task1", task1, task1_ptt).unwrap();
let task2_guard = s.add("task2", task2, task2_ptt).unwrap();
let task3_guard = s.add("task3", task3, task3_ptt).unwrap();
let task4_guard = s.add("task4", task4, task4_ptt).unwrap();
let task5_guard = s.add("task5", task5, task5_ptt).unwrap();
let mut a_cnt: [u8; 5] = [0_u8; 5];
let end = AbsoluteTime::now() + RelativeTime::new_time(5, 500_000_000);
while AbsoluteTime::now() < end
{
match recv.recv_timeout(Duration::from_millis(1))
{
Ok(rcv_a) =>
a_cnt[rcv_a as usize] += 1,
Err(RecvTimeoutError::Timeout) =>
continue,
Err(e) =>
panic!("{}", e),
}
}
println!("{:?}", a_cnt);
assert!(a_cnt[0] == 5);
assert!(a_cnt[1] == 2);
assert!((a_cnt[2] == 10 || a_cnt[2] == 11));
assert!(a_cnt[3] == 27);
assert!(a_cnt[4] == 1);
task5_guard.reschedule_task(PeriodicTaskTime::exact_time(AbsoluteTime::now() + RelativeTime::new_time(0, 500_000_000))).unwrap();
let end = AbsoluteTime::now() + RelativeTime::new_time(0, 600_000_000);
while AbsoluteTime::now() < end
{
match recv.recv_timeout(Duration::from_millis(1))
{
Ok(rcv_a) =>
a_cnt[rcv_a as usize] += 1,
Err(RecvTimeoutError::Timeout) =>
continue,
Err(e) =>
panic!("{}", e),
}
}
println!("{:?}", a_cnt);
assert!(a_cnt[4] == 2);
drop(task5_guard);
drop(task4_guard);
drop(task3_guard);
drop(task2_guard);
let end = AbsoluteTime::now() + RelativeTime::new_time(2, 1000);
while AbsoluteTime::now() < end
{
match recv.recv_timeout(Duration::from_millis(1))
{
Ok(rcv_a) =>
a_cnt[rcv_a as usize] += 1,
Err(RecvTimeoutError::Timeout) =>
continue,
Err(e) =>
panic!("{}", e),
}
}
println!("{:?}", a_cnt);
assert_eq!(a_cnt[4], 2);
assert_eq!(a_cnt[0], 8);
drop(task1_guard);
std::thread::sleep(Duration::from_millis(10));
return;
}
}