#![allow(dead_code)]
use crate::{CrossbeamReceiver, CrossbeamSender, WorkerPoolStatus};
use async_std::{
prelude::*,
sync::{channel, Receiver, Sender, TryRecvError, TrySendError},
};
use std::{collections::VecDeque, fmt::Debug};
const FIX_ME: usize = 128;
pub struct WorkerPool<In, Out, F> {
task: fn(Job<In, Out>) -> F,
config: WorkerPoolConfig<In>,
cur_workers: usize,
queue: VecDeque<In>,
workers_channel: (Sender<Out>, Receiver<Out>),
close_channel: (Sender<()>, Receiver<()>),
worker_events: (CrossbeamSender<WorkerEvent>, CrossbeamReceiver<WorkerEvent>),
command_events: (CrossbeamSender<WorkerPoolCommand>, CrossbeamReceiver<WorkerPoolCommand>),
outstanding_stops: usize,
}
#[derive(Debug, Copy, Clone)]
pub struct WorkerPoolConfig<In> {
target_workers: usize,
default_job: Option<In>,
max_workers: usize,
}
#[derive(Debug, Copy, Clone)]
enum WorkerEvent {
WorkerDone,
WorkerStopped,
}
#[derive(Debug, Copy, Clone)]
pub enum WorkerPoolCommand {
Stop,
SetWorkerCount(usize),
}
pub struct Job<In, Out> {
pub task: In,
pub close: Receiver<()>,
pub results: Sender<Out>,
}
#[derive(Copy, Clone)]
pub enum JobStatus {
Done,
Stopped,
Working,
}
impl<In, Out, F> WorkerPool<In, Out, F>
where
In: Send + Sync + Clone + 'static,
Out: Send + Sync + 'static,
F: Future<Output = JobStatus> + Send + 'static,
{
pub fn new(task: fn(Job<In, Out>) -> F) -> Self {
Self::new_with_config(task, WorkerPoolConfig::default())
}
pub fn new_with_config(task: fn(Job<In, Out>) -> F, config: WorkerPoolConfig<In>) -> Self {
Self {
workers_channel: channel(config.max_workers),
close_channel: channel(config.max_workers),
command_events: crossbeam_channel::unbounded(),
worker_events: crossbeam_channel::unbounded(),
queue: Default::default(),
outstanding_stops: 0,
cur_workers: 0,
config,
task,
}
}
pub fn cur_workers(&self) -> usize {
self.cur_workers - self.outstanding_stops
}
pub fn target_workers(&self) -> usize {
self.config.target_workers
}
pub fn at_target_worker_count(&self) -> bool {
self.cur_workers() == self.target_workers()
}
pub fn working(&self) -> bool {
self.cur_workers() > 0
}
pub fn set_target_workers(&mut self, n: usize) {
self.config.target_workers = n;
}
pub fn push(&mut self, task: In) {
self.queue.push_back(task);
}
pub fn command(&mut self, command: WorkerPoolCommand) {
self.command_events.0.send(command).expect("failed to send command");
}
pub fn work(&mut self) -> WorkerPoolStatus<Out> {
self.process_pool_commands();
self.process_worker_events();
self.balance_workers();
match self.workers_channel.1.try_recv() {
Ok(out) => WorkerPoolStatus::Ready(out),
Err(e) => match e {
TryRecvError::Empty => WorkerPoolStatus::Working,
TryRecvError::Disconnected => WorkerPoolStatus::Done,
},
};
WorkerPoolStatus::Working
}
fn process_pool_commands(&mut self) {
while let Ok(command) = self.command_events.1.try_recv() {
match command {
WorkerPoolCommand::Stop => {
for _ in 0..self.config.target_workers {
self.send_stop_work_message();
}
}
WorkerPoolCommand::SetWorkerCount(n) => {
let n = match n {
0 => 1,
n => n,
};
self.config.target_workers = n;
}
}
}
}
fn process_worker_events(&mut self) {
while let Ok(event) = self.worker_events.1.try_recv() {
match event {
WorkerEvent::WorkerDone => {
self.cur_workers -= 1;
}
WorkerEvent::WorkerStopped => {
self.cur_workers -= 1;
self.outstanding_stops -= 1;
}
}
}
}
fn start_worker(&mut self) {
let task = self.get_task();
if task.is_none() {
return;
}
let work_send = self.workers_channel.0.clone();
let close_recv = self.close_channel.1.clone();
let event_send = self.worker_events.0.clone();
let job = Job::new(task.unwrap(), close_recv, work_send);
let fut = (self.task)(job);
async_std::task::spawn(async move {
let status = fut.await;
let message = match status {
JobStatus::Done => WorkerEvent::WorkerDone,
JobStatus::Stopped => WorkerEvent::WorkerStopped,
JobStatus::Working => panic!("worker stopped while running, unexpected state"),
};
event_send.send(message).expect("failed to send WorkerEvent");
});
self.cur_workers += 1;
}
fn get_task(&mut self) -> Option<In> {
if self.queue.is_empty() {
return match &self.config.default_job {
None => None,
Some(default) => Some(default.clone()),
};
} else {
Some(self.queue.pop_front().unwrap())
}
}
fn balance_workers(&mut self) {
if self.cur_workers() < self.target_workers() {
self.start_worker();
} else if self.cur_workers() > self.target_workers() {
self.send_stop_work_message();
}
}
fn send_stop_work_message(&mut self) {
loop {
match self.close_channel.0.try_send(()) {
Ok(_) => break,
Err(e) => match e {
TrySendError::Full(_) => {}
TrySendError::Disconnected(_) => panic!("foo"),
},
}
}
}
}
impl<In, Out> Job<In, Out> {
pub fn new(task: In, close: Receiver<()>, results: Sender<Out>) -> Self {
Self { task, close, results }
}
pub fn stop_requested(&self) -> bool {
match self.close.try_recv() {
Ok(_) => true,
Err(_) => false,
}
}
}
impl<In> Default for WorkerPoolConfig<In> {
fn default() -> Self {
Self { target_workers: 8, default_job: None, max_workers: 1024 }
}
}
impl<In> WorkerPoolConfig<In> {
pub fn new() -> Self {
Self::default()
}
pub fn default_job(&mut self, job: In) -> &mut Self {
self.default_job = Some(job);
self
}
pub fn target_workers(&mut self, n: usize) -> &mut Self {
self.target_workers = n;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_std::task;
use futures_await_test::async_test;
use std::time::Duration;
async fn double(job: Job<(usize, usize), usize>) -> JobStatus {
let (mut i, n) = job.task;
for _ in 0..n {
if job.stop_requested() {
break;
}
i *= 2;
job.results.send(i).await;
task::sleep(Duration::from_millis(100)).await;
}
JobStatus::Done
}
#[async_test]
async fn pool_new() {
let mut _pool = WorkerPool::new(double);
}
#[async_test]
async fn pool_new_with_config() {
let _pool = WorkerPool::new_with_config(
double,
*WorkerPoolConfig::new().target_workers(4).default_job((2, 10)),
);
}
}