use std::{
collections::{HashMap, VecDeque},
sync::mpsc::{channel, RecvTimeoutError},
time::{Duration, Instant},
};
use crate::task::{
event::TaskEvent, FinishedTask, RunningTask, Task, TaskId, TaskInfo, TaskStatus,
};
use super::{start_task, Executor};
use anyhow::{bail, Result};
pub(crate) struct TimeoutSituation {
tinfo: TaskInfo,
deadline: Instant,
}
pub struct TimeoutExecutor {
timeout_queue: VecDeque<TimeoutSituation>,
capacity: usize,
timeout: Option<Instant>,
}
impl TimeoutExecutor {
pub fn new_with_thread_count(thread_count: usize) -> Self {
debug_assert!(
thread_count > 0,
"At least one thread is required to execute the task."
);
TimeoutExecutor {
timeout_queue: VecDeque::default(),
capacity: thread_count,
timeout: Some(default_deadline_60_seconds()),
}
}
pub fn new_with_thread_count_and_timeout(thread_count: usize, timeout: Instant) -> Self {
debug_assert!(
thread_count > 0,
"At least one thread is required to execute the task."
);
TimeoutExecutor {
timeout_queue: VecDeque::default(),
capacity: thread_count,
timeout: Some(timeout),
}
}
fn all_timed_out_tasks_info(
&mut self,
running_tasks: &HashMap<TaskId, RunningTask>,
) -> Vec<TaskInfo> {
let now = Instant::now();
let mut timed_out = Vec::new();
while let Some(timeout_entry) = self.timeout_queue.front() {
if now < timeout_entry.deadline {
break;
}
let timeout_entry = self.timeout_queue.pop_front().unwrap();
if running_tasks.contains_key(&timeout_entry.tinfo.tid()) {
timed_out.push(timeout_entry.tinfo);
}
}
timed_out
}
fn deadline(&self) -> Option<Duration> {
self.timeout_queue
.front()
.map(|&TimeoutSituation { deadline, .. }| {
let now = Instant::now();
if deadline >= now {
deadline - now
} else {
Duration::new(0, 0)
}
})
}
}
impl Executor for TimeoutExecutor {
fn run_all_tasks<T, F>(mut self, tasks: &[T], notify_what_happened: F) -> Result<()>
where
T: Task + Clone + Sync + Send + 'static,
F: Fn(TaskEvent) -> Result<()>,
{
let (tx, rx) = channel::<FinishedTask>();
let mut waiting_tasks = VecDeque::from(tasks.to_vec());
let mut running_tasks = HashMap::<TaskId, RunningTask>::default();
let mut running_count = 0;
while running_count > 0 || !waiting_tasks.is_empty() {
while running_count < self.concurrency_capacity() && !waiting_tasks.is_empty() {
let task = waiting_tasks.pop_front().unwrap();
let tid = task.info().tid();
let tinfo = task.info();
let deadline = if let Some(timeout) = self.timeout {
timeout
} else {
default_deadline_60_seconds()
};
let event = TaskEvent::wait(task.info());
notify_what_happened(event)?;
let join_handle = start_task(task, tx.clone());
running_tasks.insert(tid, RunningTask { join_handle });
self.timeout_queue
.push_back(TimeoutSituation { tinfo, deadline });
running_count += 1;
}
let mut res;
loop {
if let Some(timeout) = self.deadline() {
res = rx.recv_timeout(timeout);
for tid in self.all_timed_out_tasks_info(&running_tasks) {
notify_what_happened(TaskEvent::time_out(tid, timeout))?;
}
if res.is_ok() {
break;
};
} else {
res = rx.recv().map_err(|_| RecvTimeoutError::Disconnected);
break;
}
}
let mut finished_task = res.unwrap();
let running_task = match running_tasks.remove(&finished_task.tinfo().tid()) {
Some(rs) => rs,
None => {
panic!(
"size id {}, {}",
running_tasks.len(),
finished_task.tinfo()
)
}
};
running_task.join(&mut finished_task);
let fail = match finished_task.status() {
TaskStatus::Bug(_) => {
std::mem::forget(rx);
bail!(
"The task {} has completed, but the thread has failed",
finished_task.tinfo()
);
}
_ => false,
};
let event = TaskEvent::finished(finished_task.tinfo(), finished_task);
notify_what_happened(event)?;
running_count -= 1;
if fail {
std::mem::forget(rx);
return Ok(());
}
}
Ok(())
}
fn concurrency_capacity(&self) -> usize {
self.capacity
}
}
pub(crate) fn default_deadline_60_seconds() -> Instant {
pub(crate) const _TIMEOUT_S: u64 = 60;
Instant::now() + Duration::from_secs(_TIMEOUT_S)
}