use futures::{Async, Future};
use nbchan::mpsc as nb_mpsc;
use num_cpus;
use std::io;
use std::sync::mpsc::TryRecvError;
use std::thread;
use std::time;
use super::Executor;
use fiber::Task;
use fiber::{self, Spawn};
use io::poll;
use sync::oneshot::{self, Link};
#[derive(Debug)]
pub struct ThreadPoolExecutor {
pool: SchedulerPool,
pollers: PollerPool,
spawn_rx: nb_mpsc::Receiver<Task>,
spawn_tx: nb_mpsc::Sender<Task>,
round: usize,
steps: usize,
}
impl ThreadPoolExecutor {
pub fn new() -> io::Result<Self> {
Self::with_thread_count(num_cpus::get() * 2)
}
pub fn with_thread_count(count: usize) -> io::Result<Self> {
assert!(count > 0);
let pollers = PollerPool::new(count)?;
let schedulers = SchedulerPool::new(&pollers);
let (tx, rx) = nb_mpsc::channel();
Ok(ThreadPoolExecutor {
pool: schedulers,
pollers,
spawn_tx: tx,
spawn_rx: rx,
round: 0,
steps: 0,
})
}
}
impl Executor for ThreadPoolExecutor {
type Handle = ThreadPoolExecutorHandle;
fn handle(&self) -> Self::Handle {
ThreadPoolExecutorHandle {
spawn_tx: self.spawn_tx.clone(),
}
}
fn run_once(&mut self) -> io::Result<()> {
match self.spawn_rx.try_recv() {
Err(TryRecvError::Empty) => {
thread::sleep(time::Duration::from_millis(1));
}
Err(TryRecvError::Disconnected) => unreachable!(),
Ok(task) => {
let i = self.round % self.pool.schedulers.len();
self.pool.schedulers[i].spawn_boxed(task.0);
self.round = self.round.wrapping_add(1);
}
}
self.steps = self.steps.wrapping_add(1);
let i = self.steps % self.pool.schedulers.len();
if self.pool.links[i].poll().is_err() {
Err(io::Error::new(
io::ErrorKind::Other,
format!("The {}-th scheduler thread is aborted", i),
))
} else {
Ok(())
}
}
}
impl Spawn for ThreadPoolExecutor {
fn spawn_boxed(&self, fiber: Box<dyn Future<Item = (), Error = ()> + Send>) {
self.handle().spawn_boxed(fiber)
}
}
#[derive(Debug, Clone)]
pub struct ThreadPoolExecutorHandle {
spawn_tx: nb_mpsc::Sender<Task>,
}
impl Spawn for ThreadPoolExecutorHandle {
fn spawn_boxed(&self, fiber: Box<dyn Future<Item = (), Error = ()> + Send>) {
let _ = self.spawn_tx.send(Task(fiber));
}
}
#[derive(Debug)]
struct PollerPool {
pollers: Vec<poll::PollerHandle>,
links: Vec<Link<(), io::Error>>,
}
impl PollerPool {
pub fn new(pool_size: usize) -> io::Result<Self> {
let mut pollers = Vec::new();
let mut links = Vec::new();
for _ in 0..pool_size {
let (link0, mut link1) = oneshot::link();
let mut poller = poll::Poller::new()?;
links.push(link0);
pollers.push(poller.handle());
thread::spawn(move || {
while let Ok(Async::NotReady) = link1.poll() {
let timeout = time::Duration::from_millis(1);
if let Err(e) = poller.poll(Some(timeout)) {
link1.exit(Err(e));
return;
}
}
});
}
Ok(PollerPool { pollers, links })
}
}
#[derive(Debug)]
struct SchedulerPool {
schedulers: Vec<fiber::SchedulerHandle>,
links: Vec<Link<(), ()>>,
}
impl SchedulerPool {
pub fn new(poller_pool: &PollerPool) -> Self {
let mut schedulers = Vec::new();
let mut links = Vec::new();
for poller in &poller_pool.pollers {
let (link0, mut link1) = oneshot::link();
let mut scheduler = fiber::Scheduler::new(poller.clone());
links.push(link0);
schedulers.push(scheduler.handle());
thread::spawn(move || {
while let Ok(Async::NotReady) = link1.poll() {
scheduler.run_once(true);
}
});
}
SchedulerPool { schedulers, links }
}
}