use bus::Bus;
use futures::task::Task;
use futures::{task, Async, Future, Poll};
use npnc::bounded::mpmc;
use npnc::{ConsumeError, ProduceError};
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::atomic::{AtomicUsize, Ordering};
pub use std::sync::mpsc::TrySendError;
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
type TaskStore = Arc<RwLock<Option<Task>>>;
type FutItem = ();
type FutError = ();
macro_rules! lock_c {
($x:expr) => {
$x.lock().expect("Can't access channels!")
};
}
pub struct Sender<V> {
queue: Arc<mpmc::Producer<V>>,
task: TaskStore,
len: Arc<AtomicUsize>,
}
impl<V> Clone for Sender<V> {
fn clone(&self) -> Self {
Sender {
queue: self.queue.clone(),
task: self.task.clone(),
len: self.len.clone(),
}
}
}
unsafe impl<V> Send for Sender<V> where V: Send {}
unsafe impl<V> Sync for Sender<V> {}
impl<V> Sender<V> {
pub fn try_send(&self, value: V) -> Result<(), TrySendError<V>> {
match self.queue.produce(value) {
Err(ProduceError::Disconnected(v)) => return Err(TrySendError::Disconnected(v)),
Err(ProduceError::Full(v)) => return Err(TrySendError::Full(v)),
Ok(_) => {
self.len.fetch_add(1, Ordering::SeqCst);
}
}
let task_l = self.task.read().expect("Can't lock task!");
if let Some(task) = task_l.as_ref() {
task.notify();
}
Ok(())
}
pub fn len(&self) -> usize {
return self.len.load(Ordering::Relaxed);
}
}
#[doc(hidden)]
struct SchedulerInner<K, V, R>
where
K: Sync + Send + Hash + Eq,
V: Send + Sync + 'static,
R: 'static,
{
position: AtomicUsize,
channels: Mutex<HashMap<K, Channel<V>>>,
task: TaskStore,
workers_active: Arc<AtomicUsize>,
max_worker: AtomicUsize,
worker_fn: Arc<Box<dyn Fn(V) -> R + Send + Sync + 'static>>,
worker_fn_finalize: Arc<Option<Box<dyn Fn(R) + Send + Sync + 'static>>>,
exit_on_idle: bool,
}
#[doc(hidden)]
struct Channel<V> {
recv: mpmc::Consumer<V>,
cancel_bus: Bus<()>,
len: Arc<AtomicUsize>,
}
impl<K, V, R> SchedulerInner<K, V, R>
where
K: Sync + Send + Hash + Eq,
V: Sync + Send + 'static,
R: 'static,
{
pub fn new(
max_worker: usize,
worker_fn: Box<dyn Fn(V) -> R + Send + Sync + 'static>,
worker_fn_finalize: Option<Box<dyn Fn(R) + Send + Sync + 'static>>,
exit_on_idle: bool,
) -> SchedulerInner<K, V, R> {
SchedulerInner {
position: AtomicUsize::new(0),
channels: Mutex::new(HashMap::new()),
workers_active: Arc::new(AtomicUsize::new(0)),
max_worker: AtomicUsize::new(max_worker),
task: Arc::new(RwLock::new(None)),
worker_fn: Arc::new(worker_fn),
worker_fn_finalize: Arc::new(
worker_fn_finalize as Option<Box<dyn Fn(R) + Send + Sync + 'static>>,
),
exit_on_idle,
}
}
fn schedule(&self) {
let task_l = self.task.read().expect("Can't lock task!");
if let Some(task) = task_l.as_ref() {
task.notify();
}
}
pub fn cancel_channel(&self, key: &K) -> Result<(), ()> {
let mut map_l = lock_c!(self.channels);
if let Some(channel) = map_l.get_mut(key) {
let _ = channel.cancel_bus.try_broadcast(());
loop {
match channel.recv.consume() {
Ok(_) => (),
_ => return Ok(()),
}
}
} else {
Err(())
}
}
pub fn set_workers_max(&self, max: usize) {
self.max_worker.store(max, Ordering::Relaxed);
}
pub fn get_workers_max(&self) -> usize {
self.max_worker.load(Ordering::Relaxed)
}
pub fn create_channel(&self, key: K, bound: usize) -> Sender<V> {
let mut map_l = lock_c!(self.channels);
let (tx, rx) = mpmc::channel(bound);
let len = Arc::new(AtomicUsize::new(0));
map_l.insert(
key,
Channel {
recv: rx,
cancel_bus: Bus::new(1),
len: len.clone(),
},
);
Sender {
queue: Arc::new(tx),
task: self.task.clone(),
len,
}
}
fn poll(&self) -> Poll<FutItem, FutError> {
let mut map_l = lock_c!(self.channels);
if map_l.len() < self.position.load(Ordering::Relaxed) {
self.position.store(0, Ordering::Relaxed);
}
let start_pos = self.position.load(Ordering::Relaxed);
let mut pos = 0;
let mut worker_counter = 0;
let mut roundtrip = 0;
let mut no_work = true;
let mut idle = false;
while self.workers_active.load(Ordering::Relaxed) < self.max_worker.load(Ordering::Relaxed)
&& !idle
{
map_l.retain(|_, channel| {
if roundtrip == 0 && pos < start_pos {
return true;
}
let mut connected = true;
match channel.recv.consume() {
Ok(w) => {
channel.len.fetch_sub(1, Ordering::SeqCst);
no_work = false;
self.workers_active.fetch_add(1, Ordering::SeqCst);
worker_counter += 1;
let worker_c = self.workers_active.clone();
let task = task::current();
let work_fn = self.worker_fn.clone();
let work_fn_final = self.worker_fn_finalize.clone();
let mut cancel_recv = channel.cancel_bus.add_rx();
thread::spawn(move || {
let result: R = work_fn(w);
if cancel_recv.try_recv().is_err() {
if let Some(finalizer) = work_fn_final.as_ref() {
finalizer(result);
}
}
worker_c.fetch_sub(1, Ordering::SeqCst);
task.notify();
});
}
Err(ConsumeError::Empty) => (),
Err(ConsumeError::Disconnected) => connected = false,
}
pos += 1;
connected
});
pos = 0;
if no_work && roundtrip >= 1 {
idle = true;
}
roundtrip += 1;
no_work = true;
}
let mut task_l = self.task.write().expect("Can't lock task!");
*task_l = Some(task::current());
drop(task_l);
self.position.store(pos, Ordering::Relaxed);
if self.exit_on_idle && map_l.len() == 0 && self.workers_active.load(Ordering::Relaxed) == 0
{
Ok(Async::Ready(()))
} else {
Ok(Async::NotReady)
}
}
}
#[derive(Clone)]
pub struct Controller<K, V, R>
where
K: Sync + Send + Hash + Eq,
V: Sync + Send + 'static,
R: 'static,
{
inner: Arc<SchedulerInner<K, V, R>>,
}
impl<K, V, R> Controller<K, V, R>
where
K: Sync + Send + Hash + Eq,
V: Sync + Send + 'static,
R: 'static,
{
pub fn channel(&self, key: K, bound: usize) -> Sender<V> {
self.inner.create_channel(key, bound.next_power_of_two())
}
pub fn cancel_channel(&self, key: &K) -> Result<(), ()> {
self.inner.cancel_channel(key)
}
pub fn gc(&self) {
self.inner.schedule();
}
pub fn set_worker_max(&self, max_workers: usize) {
self.inner.set_workers_max(max_workers);
}
pub fn get_worker_max(&self) -> usize {
self.inner.get_workers_max()
}
}
#[must_use = "schedulers do nothing unless polled"]
pub struct Scheduler<K, V, R>
where
K: Sync + Send + Hash + Eq,
V: Sync + Send + 'static,
R: 'static,
{
inner: Arc<SchedulerInner<K, V, R>>,
}
impl<K, V, R> Scheduler<K, V, R>
where
K: Sync + Send + Hash + Eq,
V: Sync + Send + 'static,
R: 'static,
{
pub fn new(
max_worker: usize,
worker_fn: impl Fn(V) -> R + Send + Sync + 'static,
worker_fn_finalize: Option<impl Fn(R) + Send + Sync + 'static>,
finish_on_idle: bool,
) -> (Controller<K, V, R>, Scheduler<K, V, R>) {
let inner = Arc::new(SchedulerInner::new(
max_worker,
Box::new(worker_fn) as Box<dyn Fn(V) -> R + Send + Sync + 'static>,
worker_fn_finalize.map(|v| Box::new(v) as Box<dyn Fn(R) + Send + Sync + 'static>),
finish_on_idle,
));
(
Controller {
inner: inner.clone(),
},
Scheduler { inner },
)
}
}
impl<K, V, R> Future for Scheduler<K, V, R>
where
K: Sync + Send + Hash + Eq,
V: Sync + Send + 'static,
R: 'static,
{
type Error = FutError;
type Item = FutItem;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
self.inner.poll()
}
}
#[cfg(test)]
mod tests {
use super::Scheduler;
use super::*;
use std::thread;
use std::time::{Duration, Instant};
use tokio::runtime::Runtime;
fn run_mpmc(producers: usize, amount: usize, workers: usize, channel_size: usize) {
let collector = Arc::new(Mutex::new(Vec::new()));
let collectorc = collector.clone();
let (controller, scheduler) = Scheduler::new(
workers,
|v| {
v
},
Some(move |v| {
let mut lock = collectorc.lock().unwrap();
lock.push(v);
}),
true,
);
let mut runtime = Runtime::new().unwrap();
let tx = controller.channel(1, channel_size);
runtime.spawn(scheduler);
for t in 0..producers {
let txc = controller.channel(t, channel_size);
let start = t * (amount / producers);
let mut end = start + (amount / producers);
if t == producers - 1 {
end = amount;
}
println!("start: {} end {}", start, end);
thread::spawn(move || {
for i in start..end {
loop {
if txc.try_send(i).is_ok() {
break;
}
thread::sleep(Duration::from_micros(10))
}
}
drop(txc);
println!("{} finished insertion", t);
});
}
drop(tx);
runtime.shutdown_on_idle().wait().unwrap();
let lock = collector.lock().unwrap();
println!(
"Verifying for {} inserter {} workers {} amount",
producers, workers, amount
);
assert_eq!(amount, lock.len());
for i in 0..amount {
assert!(lock.contains(&i));
}
}
#[test]
#[ignore]
fn bench_mpsw() {
let producers = 8;
let amount = 1_000_000;
let workers = 8;
let channel_size = 1024;
let (controller, scheduler) = Scheduler::new(
workers,
|v| {
v
},
Some(move |v| {
}),
true,
);
let mut runtime = Runtime::new().unwrap();
let tx = controller.channel(1, channel_size);
runtime.spawn(scheduler);
let start = Instant::now();
for t in 0..producers {
let txc = controller.channel(t, channel_size);
let start = t * (amount / producers);
let mut end = start + (amount / producers);
if t == producers - 1 {
end = amount;
}
println!("start: {} end {}", start, end);
thread::spawn(move || {
for i in start..end {
loop {
if txc.try_send(i).is_ok() {
break;
}
thread::sleep(Duration::from_nanos(1));
}
}
drop(txc);
println!("{} finished insertion", t);
});
}
drop(tx);
runtime.shutdown_on_idle().wait().unwrap();
let end_d = start.elapsed();
let end = end_d.subsec_millis() as u64 + (end_d.as_secs() * 1_000);
println!(
"Took {} ms for {} entries: {}ms per job",
end,
amount,
amount / end
);
}
#[test]
fn verify_mpsw() {
run_mpmc(8, 10_000, 4, 32);
}
#[test]
fn verify_spsw() {
run_mpmc(1, 1_000, 1, 32);
}
#[test]
fn verify_spmw_overload() {
for i in 2..30 {
run_mpmc(1, 1_000, i, 2);
}
}
#[test]
fn verify_spmw() {
for i in 2..30 {
run_mpmc(1, 1_000, i, 1024);
}
}
#[test]
fn verify_spmw_underload() {
for i in 2..30 {
run_mpmc(1, 30, i, 1024);
}
}
#[allow(dead_code)]
struct TestNonClonable {
a: Option<Mutex<()>>,
}
impl TestNonClonable {
pub fn new() -> TestNonClonable {
TestNonClonable { a: None }
}
}
#[test]
fn verify_non_clone() {
let workers = 2;
let producers = 2;
let amount = 100;
let channel_size = 32;
let collector = Arc::new(Mutex::new(Vec::new()));
let collectorc = collector.clone();
let (controller, scheduler) = Scheduler::new(
workers,
|v| {
v
},
Some(move |v| {
let mut lock = collectorc.lock().unwrap();
lock.push(v);
}),
true,
);
let mut runtime = Runtime::new().unwrap();
let tx = controller.channel(1, channel_size);
let _ = tx.clone();
runtime.spawn(scheduler);
for t in 0..producers {
let txc = controller.channel(t, channel_size);
let start = t * (amount / producers);
let mut end = start + (amount / producers);
if t == producers - 1 {
end = amount;
}
println!("start: {} end {}", start, end);
thread::spawn(move || {
for i in start..end {
loop {
if txc.try_send(TestNonClonable::new()).is_ok() {
break;
}
thread::sleep(Duration::from_micros(10))
}
}
drop(txc);
println!("{} finished insertion", t);
});
}
drop(tx);
runtime.shutdown_on_idle().wait().unwrap();
}
#[test]
fn verify_sync_sender() {
let workers = 2;
let producers = 4;
let amount = 1_000;
let channel_size = 32;
let collector = Arc::new(Mutex::new(Vec::new()));
let collectorc = collector.clone();
let (controller, scheduler) = Scheduler::new(
workers,
|v| v,
Some(move |v| {
let mut lock = collectorc.lock().unwrap();
lock.push(v);
}),
true,
);
let mut runtime = Runtime::new().unwrap();
let tx = controller.channel(1, channel_size);
runtime.spawn(scheduler);
for t in 0..producers {
let txc = controller.channel(t, channel_size);
let start = t * (amount / producers);
let mut end = start + (amount / producers);
if t == producers - 1 {
end = amount;
}
println!("start: {} end {}", start, end);
thread::spawn(move || {
for i in start..end {
loop {
if txc.try_send(i).is_ok() {
break;
}
thread::sleep(Duration::from_micros(10))
}
}
drop(txc);
println!("{} finished insertion", t);
});
}
drop(tx);
runtime.shutdown_on_idle().wait().unwrap();
let lock = collector.lock().unwrap();
println!(
"Verifying for {} inserter {} workers {} amount",
producers, workers, amount
);
assert_eq!(amount, lock.len());
for i in 0..amount {
assert!(lock.contains(&i));
}
}
#[test]
fn verify_length() {
let exit = Arc::new(AtomicUsize::new(0));
let exit_c = exit.clone();
const LIMIT: isize = 1000;
let (controller, scheduler) = Scheduler::new(
1,
|v| v,
Some(move |v| {
println!("{}", v);
if v == LIMIT {
println!("Killing..");
exit_c.fetch_add(1, Ordering::SeqCst);
}
}),
true,
);
let tx = controller.channel(1, 1024);
assert_eq!(0, tx.len());
tx.try_send(1).unwrap();
assert_eq!(1, tx.len());
let mut runtime = Runtime::new().unwrap();
runtime.spawn(scheduler);
for i in 0..=LIMIT {
tx.try_send(i).unwrap();
}
while exit.load(Ordering::Relaxed) == 0 {
thread::sleep(Duration::from_millis(15));
}
assert_eq!(0, tx.len());
drop(tx);
}
}