#[cfg(not(target_arch = "wasm32"))]
mod std_task;
mod ticket;
#[cfg(target_arch = "wasm32")]
mod wasm_task;
#[cfg(target_arch = "wasm32")]
pub use gloo_worker;
use std::collections::{HashMap, VecDeque};
pub use ticket::Ticket;
#[cfg(target_arch = "wasm32")]
pub use wasm_task::WebWorker;
pub trait Function: 'static + Default + Sized {
type Input: serde::Serialize + serde::de::DeserializeOwned + Send;
type Output: serde::Serialize + serde::de::DeserializeOwned + Send;
fn call(&mut self, input: Self::Input) -> Self::Output;
}
pub struct Task<F: Function> {
task_count: usize,
#[cfg(not(target_arch = "wasm32"))]
task: std_task::TaskStd<F>,
#[cfg(target_arch = "wasm32")]
task: wasm_task::TaskWasm<F>,
}
impl<F: Function> std::fmt::Debug for Task<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Task")
.field("task_count", &self.task_count)
.field("task", &"[Cannot be shown]")
.finish()
}
}
impl<F: Function> Task<F> {
#[must_use]
pub fn new(task_name: &str) -> Self {
Self {
task_count: 0,
task: {
#[cfg(not(target_arch = "wasm32"))]
{
std_task::TaskStd::new(task_name)
}
#[cfg(target_arch = "wasm32")]
{
wasm_task::TaskWasm::new(task_name)
}
},
}
}
pub fn enqueue(&mut self, msg: F::Input) {
self.task_count += 1;
self.task.enqueue(msg);
}
#[must_use]
pub fn task_is_ongoing(&self) -> bool {
self.task_count > 0
}
#[must_use]
pub fn check(&mut self) -> Option<F::Output> {
let output = self.task.check();
if output.is_some() {
self.task_count -= 1;
}
output
}
}
#[derive(Debug)]
pub struct TaskPool<F: Function> {
tasks: Vec<(Option<Ticket>, Task<F>)>,
to_start: VecDeque<(Ticket, F::Input)>,
done: HashMap<Ticket, F::Output>,
ticket_generator: ticket::TicketGenerator,
}
impl<F: Function> TaskPool<F> {
#[must_use]
pub fn new(task_name: &str, task_count: usize) -> Self {
Self {
tasks: (0..task_count)
.map(|_| (None, Task::new(task_name)))
.collect(),
to_start: Default::default(),
done: Default::default(),
ticket_generator: Default::default(),
}
}
pub fn progress(&mut self) {
for (ongoing, task) in self.tasks.iter_mut() {
if ongoing.is_some() {
if let Some(output) = task.check() {
let ticket = std::mem::take(ongoing).unwrap();
let r = self.done.insert(ticket, output);
if r.is_some() {
panic!("Ticket is already in list of done jobs")
}
}
}
if ongoing.is_none() {
if let Some((ticket, input)) = self.to_start.pop_front() {
*ongoing = Some(ticket);
task.enqueue(input);
}
}
}
}
#[must_use]
pub fn enqueue(&mut self, input: F::Input) -> Ticket {
let (ticket, ticket_internal) = self.ticket_generator.next();
self.to_start.push_back((ticket_internal, input));
self.progress();
ticket
}
#[must_use]
pub fn check(&mut self, ticket: Ticket) -> JobState<F::Output> {
self.progress();
if let Some(output) = self.done.remove(&ticket) {
JobState::Done(output)
} else {
JobState::Ongoing(ticket)
}
}
#[must_use]
pub fn wait_for(&mut self, ticket: Ticket) -> F::Output {
match self.check(ticket) {
JobState::Ongoing(ticket) => self.wait_for(ticket),
JobState::Done(output) => output,
}
}
}
#[derive(Debug)]
pub enum JobState<Output> {
Ongoing(Ticket),
Done(Output),
}
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_task_pool_std() {
#[derive(Default)]
struct DummyFunction;
impl Function for DummyFunction {
type Input = u32;
type Output = u64;
fn call(&mut self, input: Self::Input) -> Self::Output {
doubling(input)
}
}
fn doubling(x: u32) -> u64 {
(x + 1) as _
}
let mut task_pool = TaskPool::<DummyFunction>::new("dummy_thread", 3);
let n = 10;
let mut tickets = Vec::new();
for i in 0..n {
tickets.push(task_pool.enqueue(i));
}
for (i, ticket) in tickets.into_iter().enumerate() {
let i = (i + 1) as u64;
let v = task_pool.wait_for(ticket);
assert_eq!(i, v);
}
}