use bus::Bus;
use futures::task::Task;
use futures::{task, Async, Future, Poll};
use npnc::bounded::mpmc;
use npnc::{ConsumeError, ProduceError};
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::TrySendError;
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
type TaskStore = Arc<RwLock<Option<Task>>>;
type FutItem = ();
type FutError = ();
macro_rules! lock_c {
($x:expr) => {
$x.lock().expect("Can't access channels!")
};
}
#[derive(Clone)]
pub struct Sender<V> {
queue: Arc<mpmc::Producer<V>>,
task: TaskStore,
}
unsafe impl<V> Send for Sender<V> where V: Send {}
impl<V> Sender<V> {
pub fn try_send(&self, value: V) -> Result<(), TrySendError<V>> {
match self.queue.produce(value) {
Err(ProduceError::Disconnected(v)) => return Err(TrySendError::Disconnected(v)),
Err(ProduceError::Full(v)) => return Err(TrySendError::Full(v)),
_ => (),
}
let task_l = self.task.read().expect("Can't lock task!");
if let Some(task) = task_l.as_ref() {
task.notify();
}
Ok(())
}
}
#[doc(hidden)]
struct SchedulerInner<K, V, FB, FR, R>
where
K: Sync + Send + Hash + Eq,
V: Send + Sync + 'static,
FB: Fn(V) -> R + Send + Sync + 'static,
FR: Fn(R) + Send + Sync + 'static,
{
position: AtomicUsize,
channels: Arc<Mutex<HashMap<K, Channel<V>>>>,
task: TaskStore,
workers_active: Arc<AtomicUsize>,
max_worker: usize,
worker_fn: Arc<FB>,
worker_fn_finalize: Arc<Option<FR>>,
exit_on_idle: bool,
}
#[doc(hidden)]
struct Channel<V> {
recv: mpmc::Consumer<V>,
cancel_bus: Bus<()>,
}
impl<K, V, FB, FR, R> SchedulerInner<K, V, FB, FR, R>
where
K: Sync + Send + Hash + Eq,
V: Sync + Send + 'static,
FB: Fn(V) -> R + Send + Sync + 'static,
FR: Fn(R) + Send + Sync + 'static,
{
pub fn new(
max_worker: usize,
worker_fn: FB,
worker_fn_finalize: Option<FR>,
exit_on_idle: bool,
) -> SchedulerInner<K, V, FB, FR, R> {
SchedulerInner {
position: AtomicUsize::new(0),
channels: Arc::new(Mutex::new(HashMap::new())),
workers_active: Arc::new(AtomicUsize::new(0)),
max_worker,
task: Arc::new(RwLock::new(None)),
worker_fn: Arc::new(worker_fn),
worker_fn_finalize: Arc::new(worker_fn_finalize),
exit_on_idle,
}
}
fn schedule(&self) {
let task_l = self.task.read().expect("Can't lock task!");
if let Some(task) = task_l.as_ref() {
task.notify();
}
}
pub fn cancel_channel(&self, key: &K) -> Result<(), ()> {
let mut map_l = lock_c!(self.channels);
if let Some(channel) = map_l.get_mut(key) {
let _ = channel.cancel_bus.try_broadcast(());
loop {
match channel.recv.consume() {
Ok(_) => (),
_ => return Ok(()),
}
}
} else {
Err(())
}
}
pub fn create_channel(&self, key: K, bound: usize) -> Sender<V> {
let mut map_l = lock_c!(self.channels);
let (tx, rx) = mpmc::channel(bound);
map_l.insert(
key,
Channel {
recv: rx,
cancel_bus: Bus::new(1),
},
);
Sender {
queue: Arc::new(tx),
task: self.task.clone(),
}
}
fn poll(&self) -> Poll<FutItem, FutError> {
let mut map_l = lock_c!(self.channels);
if map_l.len() < self.position.load(Ordering::Relaxed) {
self.position.store(0, Ordering::Relaxed);
}
let start_pos = self.position.load(Ordering::Relaxed);
let mut pos = 0;
let mut worker_counter = 0;
let mut roundtrip = 0;
let mut no_work = true;
let mut idle = false;
while self.workers_active.load(Ordering::Relaxed) < self.max_worker && !idle {
map_l.retain(|_, channel| {
if roundtrip == 0 && pos < start_pos {
return true;
}
let mut connected = true;
match channel.recv.consume() {
Ok(w) => {
no_work = false;
self.workers_active.fetch_add(1, Ordering::SeqCst);
worker_counter += 1;
let worker_c = self.workers_active.clone();
let task = task::current();
let work_fn = self.worker_fn.clone();
let work_fn_final = self.worker_fn_finalize.clone();
let mut cancel_recv = channel.cancel_bus.add_rx();
thread::spawn(move || {
let result: R = work_fn(w);
if cancel_recv.try_recv().is_err() {
if let Some(finalizer) = work_fn_final.as_ref() {
finalizer(result);
}
}
worker_c.fetch_sub(1, Ordering::SeqCst);
task.notify();
});
}
Err(ConsumeError::Empty) => (),
Err(ConsumeError::Disconnected) => connected = false,
}
pos += 1;
connected
});
pos = 0;
if no_work && roundtrip >= 1 {
idle = true;
}
roundtrip += 1;
no_work = true;
}
let mut task_l = self.task.write().expect("Can't lock task!");
*task_l = Some(task::current());
drop(task_l);
self.position.store(pos, Ordering::Relaxed);
if self.exit_on_idle && map_l.len() == 0 {
Ok(Async::Ready(()))
} else {
Ok(Async::NotReady)
}
}
}
#[derive(Clone)]
pub struct Controller<K, V, FB, FR, R>
where
K: Sync + Send + Hash + Eq,
V: Sync + Send + 'static,
FB: Fn(V) -> R + Send + Sync + 'static,
FR: Fn(R) + Send + Sync + 'static,
{
inner: Arc<SchedulerInner<K, V, FB, FR, R>>,
}
impl<K, V, FB, FR, R> Controller<K, V, FB, FR, R>
where
K: Sync + Send + Hash + Eq,
V: Sync + Send + 'static,
FB: Fn(V) -> R + Send + Sync + 'static,
FR: Fn(R) + Send + Sync + 'static,
{
pub fn channel(&self, key: K, bound: usize) -> Sender<V> {
self.inner.create_channel(key, bound)
}
pub fn cancel_channel(&self, key: &K) -> Result<(), ()> {
self.inner.cancel_channel(key)
}
pub fn gc(&self) {
self.inner.schedule();
}
}
pub struct Scheduler<K, V, FB, FR, R>
where
K: Sync + Send + Hash + Eq,
V: Sync + Send + 'static,
FB: Fn(V) -> R + Send + Sync + 'static,
FR: Fn(R) + Send + Sync + 'static,
{
inner: Arc<SchedulerInner<K, V, FB, FR, R>>,
}
impl<K, V, FB, FR, R> Scheduler<K, V, FB, FR, R>
where
K: Sync + Send + Hash + Eq,
V: Sync + Send + 'static,
FB: Fn(V) -> R + Send + Sync + 'static,
FR: Fn(R) + Send + Sync + 'static,
{
pub fn new(
max_worker: usize,
worker_fn: FB,
worker_fn_finalize: Option<FR>,
finish_on_idle: bool,
) -> (Controller<K, V, FB, FR, R>, Scheduler<K, V, FB, FR, R>) {
let inner = Arc::new(SchedulerInner::new(
max_worker,
worker_fn,
worker_fn_finalize,
finish_on_idle,
));
(
Controller {
inner: inner.clone(),
},
Scheduler { inner },
)
}
}
impl<K, V, FB, FR, R> Future for Scheduler<K, V, FB, FR, R>
where
K: Sync + Send + Hash + Eq,
V: Sync + Send + 'static,
FB: Fn(V) -> R + Send + Sync + 'static,
FR: Fn(R) + Send + Sync + 'static,
{
type Error = FutError;
type Item = FutItem;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.inner.poll()
}
}