1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
//! # mpmc-scheduler
//!
//! A Fair, Per-Channel Cancellable, multi-mpmc task scheduler running on top of tokio.
//!
//! It bundles together multiple mpmc channels and schedules incoming work with fair rate limiting among the allowed maximum of workers.
//!
//! ## Example
//!
//! ```rust
//! use mpmc_scheduler;
//! use tokio::runtime::Runtime;
//!
//! let (controller, scheduler) = Scheduler::new(
//!     4,
//!     |v| {
//!         println!("Processing {}", v);
//!         v
//!     },
//!     Some(|r| println!("Finalizing {}", r)),
//!     true
//! );
//!
//! let mut runtime = Runtime::new();
//!
//! let tx = controller.channel(1,4);
//!
//! runtime.spawn(scheduler);
//!
//! for i in 0..4 {
//!     tx.try_send(i);
//! }
//!
//! drop(tx); // drop tx so scheduler & runtime shut down
//!
//! runtime.shutdown_on_idle().wait().unwrap();
//! ```
//!
//! ## Details
//!
//! You can think of it as a round-robin scheduler for rate limited workers which always run the same function.
//!
//! ```text
//! o-                  -x
//!   \                /
//! o--|--Scheduler --|--x
//!   /                \
//! o-                  -x
//! ```
//!
//! In this image we have an n amount of Producers `o` and m amount of Workers `x`
//! We want to handle all incoming work from `o` in a fair manner. Such that if
//! one producers has 20 jobs and another 2, both are going to get handled equally in a round robin manner.
//!
//! Each channel queue can be cleared such that all to-be-scheduled jobs are droppped.  
//! To allow also stopping currently running (extensive) options, operation can be split into two functions.  
//! For example of http requests whose result is stored. If we abort before the store operation we can prevent all outstanding  
//! worker operations of one channel plus the remaining jobs.  
//!
//! Closed channels are detected and removed from the scheduler when iterating.
//! You can manually trigger a schedule tick by calling `gc` on the controller.
//!
//! ## Limitations
//! - mpmc-scheduler can only be used with its own Producer channels due to missing traits for other channels. futures mpsc also doesn't work as they are not waking up the scheduler.
//!
//! - The channel bound has to be a power of two.
//!
//! - You can only define one work-handler function per `Scheduler` and it cannot be changed afterwards.
//!
use bus::Bus;
use futures::task::Task;
use futures::{task, Async, Future, Poll};
use npnc::bounded::mpmc;
use npnc::{ConsumeError, ProduceError};

use std::collections::HashMap;
use std::hash::Hash;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::TrySendError;
use std::sync::{Arc, Mutex, RwLock};
use std::thread;

type TaskStore = Arc<RwLock<Option<Task>>>;
type FutItem = ();
type FutError = ();

macro_rules! lock_c {
    ($x:expr) => {
        $x.lock().expect("Can't access channels!")
    };
}

/// Sender/Producer for one channel
#[derive(Clone)]
pub struct Sender<V> {
    queue: Arc<mpmc::Producer<V>>,
    task: TaskStore,
}

unsafe impl<V> Send for Sender<V> where V: Send {}

impl<V> Sender<V> {
    /// Try sending a new job
    /// Doesn't block in best-normal case, but can block for a guaranteed short amount.
    pub fn try_send(&self, value: V) -> Result<(), TrySendError<V>> {
        match self.queue.produce(value) {
            Err(ProduceError::Disconnected(v)) => return Err(TrySendError::Disconnected(v)),
            Err(ProduceError::Full(v)) => return Err(TrySendError::Full(v)),
            _ => (),
        }
        let task_l = self.task.read().expect("Can't lock task!");
        if let Some(task) = task_l.as_ref() {
            task.notify();
        }
        Ok(())
    }
}

/// Inner scheduler, shared accross controller & worker
#[doc(hidden)]
struct SchedulerInner<K, V, FB, FR, R>
where
    K: Sync + Send + Hash + Eq,
    V: Send + Sync + 'static,
    FB: Fn(V) -> R + Send + Sync + 'static,
    FR: Fn(R) + Send + Sync + 'static,
{
    position: AtomicUsize,
    // TODO: evaluate concurrent_hashmap
    channels: Arc<Mutex<HashMap<K, Channel<V>>>>,
    task: TaskStore,
    workers_active: Arc<AtomicUsize>,
    max_worker: usize,
    worker_fn: Arc<FB>,
    worker_fn_finalize: Arc<Option<FR>>,
    exit_on_idle: bool,
}

/// One Channel conisting of the receiver site and the cancel bus to
/// stop running jobs
#[doc(hidden)]
struct Channel<V> {
    recv: mpmc::Consumer<V>,
    cancel_bus: Bus<()>,
}

impl<K, V, FB, FR, R> SchedulerInner<K, V, FB, FR, R>
where
    K: Sync + Send + Hash + Eq,
    V: Sync + Send + 'static,
    FB: Fn(V) -> R + Send + Sync + 'static,
    FR: Fn(R) + Send + Sync + 'static,
{
    pub fn new(
        max_worker: usize,
        worker_fn: FB,
        worker_fn_finalize: Option<FR>,
        exit_on_idle: bool,
    ) -> SchedulerInner<K, V, FB, FR, R> {
        SchedulerInner {
            position: AtomicUsize::new(0),
            channels: Arc::new(Mutex::new(HashMap::new())),
            workers_active: Arc::new(AtomicUsize::new(0)),
            max_worker,
            task: Arc::new(RwLock::new(None)),
            worker_fn: Arc::new(worker_fn),
            worker_fn_finalize: Arc::new(worker_fn_finalize),
            exit_on_idle,
        }
    }

    /// Trigger polling wakeup, used by GC call
    fn schedule(&self) {
        let task_l = self.task.read().expect("Can't lock task!");
        if let Some(task) = task_l.as_ref() {
            task.notify();
        }
    }

    /// Clear queue for specific channel & cancel workers
    pub fn cancel_channel(&self, key: &K) -> Result<(), ()> {
        let mut map_l = lock_c!(self.channels);
        //TODO: handle missing queue
        if let Some(channel) = map_l.get_mut(key) {
            // if we're not able to send a cancel command then probably
            // no work is running and/or a stop command was already send
            let _ = channel.cancel_bus.try_broadcast(());
            loop {
                match channel.recv.consume() {
                    Ok(_) => (), // more messages
                    _ => return Ok(()),
                }
            }
        } else {
            Err(())
        }
    }

    /// Create channel
    pub fn create_channel(&self, key: K, bound: usize) -> Sender<V> {
        let mut map_l = lock_c!(self.channels);

        let (tx, rx) = mpmc::channel(bound);
        map_l.insert(
            key,
            Channel {
                recv: rx,
                cancel_bus: Bus::new(1),
            },
        );
        Sender {
            queue: Arc::new(tx),
            task: self.task.clone(),
        }
    }

    /// Inner poll method, only to be called by future handler
    fn poll(&self) -> Poll<FutItem, FutError> {
        let mut map_l = lock_c!(self.channels);
        if map_l.len() < self.position.load(Ordering::Relaxed) {
            self.position.store(0, Ordering::Relaxed);
        }

        let start_pos = self.position.load(Ordering::Relaxed);
        let mut pos = 0;

        let mut worker_counter = 0;
        let mut roundtrip = 0;
        let mut no_work = true;
        let mut idle = false;

        while self.workers_active.load(Ordering::Relaxed) < self.max_worker && !idle {
            map_l.retain(|_, channel| {
                // skip to postion from last poll
                if roundtrip == 0 && pos < start_pos {
                    return true;
                }
                let mut connected = true;
                match channel.recv.consume() {
                    Ok(w) => {
                        no_work = false;
                        self.workers_active.fetch_add(1, Ordering::SeqCst);
                        worker_counter += 1;
                        let worker_c = self.workers_active.clone();
                        let task = task::current();
                        let work_fn = self.worker_fn.clone();
                        let work_fn_final = self.worker_fn_finalize.clone();
                        let mut cancel_recv = channel.cancel_bus.add_rx();
                        thread::spawn(move || {
                            let result: R = work_fn(w);
                            if cancel_recv.try_recv().is_err() {
                                if let Some(finalizer) = work_fn_final.as_ref() {
                                    finalizer(result);
                                }
                            }
                            worker_c.fetch_sub(1, Ordering::SeqCst);
                            task.notify();
                        });
                    }
                    Err(ConsumeError::Empty) => (),
                    Err(ConsumeError::Disconnected) => connected = false,
                }
                pos += 1;
                connected
            });
            pos = 0;

            if no_work && roundtrip >= 1 {
                idle = true;
            }
            roundtrip += 1;

            no_work = true;
        }
        let mut task_l = self.task.write().expect("Can't lock task!");
        *task_l = Some(task::current());
        drop(task_l);
        self.position.store(pos, Ordering::Relaxed);
        if self.exit_on_idle && map_l.len() == 0 {
            Ok(Async::Ready(()))
        } else {
            Ok(Async::NotReady)
        }
    }
}

/// The Controller is a non-producing handle to the scheduler.
/// It allows creation of new channels as well as clearing of queues.
#[derive(Clone)]
pub struct Controller<K, V, FB, FR, R>
where
    K: Sync + Send + Hash + Eq,
    V: Sync + Send + 'static,
    FB: Fn(V) -> R + Send + Sync + 'static,
    FR: Fn(R) + Send + Sync + 'static,
{
    inner: Arc<SchedulerInner<K, V, FB, FR, R>>,
}

impl<K, V, FB, FR, R> Controller<K, V, FB, FR, R>
where
    K: Sync + Send + Hash + Eq,
    V: Sync + Send + 'static,
    FB: Fn(V) -> R + Send + Sync + 'static,
    FR: Fn(R) + Send + Sync + 'static,
{
    /// Create a new channel, returns the producer site.
    /// The channel bound has to be a power of 2 !
    /// May block if clearing or scheduling tick is currently running.
    pub fn channel(&self, key: K, bound: usize) -> Sender<V> {
        self.inner.create_channel(key, bound)
    }

    /// Clear queue for specific channel & running jobs if supported.
    ///
    /// May block if `channel` is called or a schedule is running.
    /// Note that for a queue with bounds n, it has a O(n) worst case complexity.
    ///
    /// Returns Err if the specified channel is invalid.
    pub fn cancel_channel(&self, key: &K) -> Result<(), ()> {
        self.inner.cancel_channel(key)
    }

    /// Manually trigger schedule. Normaly not required but if you should drop a lot of channels and
    /// don't insert/complete a job in the next time, you may call this.
    pub fn gc(&self) {
        self.inner.schedule();
    }
}

// no clone, don't allow for things such as 2x spawn()
/// Scheduler
pub struct Scheduler<K, V, FB, FR, R>
where
    K: Sync + Send + Hash + Eq,
    V: Sync + Send + 'static,
    FB: Fn(V) -> R + Send + Sync + 'static,
    FR: Fn(R) + Send + Sync + 'static,
{
    inner: Arc<SchedulerInner<K, V, FB, FR, R>>,
}

impl<K, V, FB, FR, R> Scheduler<K, V, FB, FR, R>
where
    K: Sync + Send + Hash + Eq,
    V: Sync + Send + 'static,
    FB: Fn(V) -> R + Send + Sync + 'static,
    FR: Fn(R) + Send + Sync + 'static,
{
    /// Create a new scheduler with specified amount of max workers.
    /// max_worker: specifies the amount of workers to be used
    /// * `worker_fn` - the function to execute that handles the "main" work
    /// * `worker_fn_finialize` - the "finish" function which is not called on job cancel
    /// * `finish_on_idle` - on true if no channels are left on the next schedule the scheduler will drop from the tokio Runtime
    ///
    /// You should create at least one channel before spawning the scheduler on the runtime when set to true.
    pub fn new(
        max_worker: usize,
        worker_fn: FB,
        worker_fn_finalize: Option<FR>,
        finish_on_idle: bool,
    ) -> (Controller<K, V, FB, FR, R>, Scheduler<K, V, FB, FR, R>) {
        let inner = Arc::new(SchedulerInner::new(
            max_worker,
            worker_fn,
            worker_fn_finalize,
            finish_on_idle,
        ));
        (
            Controller {
                inner: inner.clone(),
            },
            Scheduler { inner },
        )
    }
}

impl<K, V, FB, FR, R> Future for Scheduler<K, V, FB, FR, R>
where
    K: Sync + Send + Hash + Eq,
    V: Sync + Send + 'static,
    FB: Fn(V) -> R + Send + Sync + 'static,
    FR: Fn(R) + Send + Sync + 'static,
{
    // The stream will never yield an error
    type Error = FutError;
    type Item = FutItem;

    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
        self.inner.poll()
    }
}