1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
//! The library provides a bounded executor for tokio for a convenient way to run a fixed number of tasks concurrently
//!
//! See: [`Rider::new`], [`Rider::spawn`] and [`Rider::shutdown`]

use std::future::Future;
use std::sync::Arc;

use tokio::sync::Semaphore;
use tokio::task::JoinSet;

/// Error returned from [`Rider::spawn`] function.
///
/// A spawn operation can fail only if the rider is off.
#[derive(Debug)]
pub struct RiderError(());

/// Task executor that maintains a maximum number of tasks running concurrently
///
/// # Example
///
/// ```rust
/// use rider::{Rider, RiderError};
///
/// #[tokio::main]
/// async fn main() -> Result<(), RiderError> {
///     let mut rider = Rider::new(10);
///     for _ in 0..100 {
///         rider
///             .spawn(async { /* do whatever you want */ })
///             .await?;
///     }
///     rider.shutdown().await;
///     Ok(())
/// }
/// ```
#[derive(Debug)]
pub struct Rider {
    sem: Arc<Semaphore>,
    set: JoinSet<()>,
}

impl RiderError {
    /// Instantiate [`RiderError`]
    fn closed() -> RiderError {
        RiderError(())
    }
}

impl Rider {
    /// Maximum number of tasks which a rider can hold. It is `usize::MAX >> 3`.
    ///
    /// Exceeding this limit typically results in a panic.
    pub const MAX_CAPACITY: usize = Semaphore::MAX_PERMITS;

    /// Creates a new rider with the given capacity.
    ///
    /// # Panics
    ///
    /// Panic if `capacity` is greater than [`Rider::MAX_CAPACITY`]
    ///
    /// # Example
    ///
    /// ```rust
    /// use rider::Rider;
    ///
    /// fn main() {
    ///     let rider = Rider::new(10);
    ///     // ...
    /// }
    /// ```
    pub fn new(capacity: usize) -> Rider {
        let sem = Arc::new(Semaphore::new(capacity));
        let set = JoinSet::new();
        Rider { sem, set }
    }

    /// Suspends until a seat is available and spawn the provided task on this [`Rider`].
    ///
    /// The provided future will start running in the background once the function returns.
    ///
    /// # Cancel safety
    ///
    /// A [`Semaphore`] is used under the hood, which itself uses a queue to fairly distribute permits in the order they were requested.
    /// Cancelling a call to acquire_owned makes you lose your place in the queue.
    ///
    /// # Panics
    ///
    /// This method panics if called outside a Tokio runtime.
    ///
    /// # Example
    ///
    /// ```rust
    /// use rider::{Rider, RiderError};
    ///
    /// #[tokio::main]
    /// async fn main() -> Result<(), RiderError> {
    ///     let mut rider = Rider::new(10);
    ///     for _ in 0..100 {
    ///         rider.spawn(async move {
    ///             // Distribute your work
    ///         }).await?;
    ///     }
    ///     rider.shutdown().await;
    ///     Ok(())
    /// }
    /// ```
    pub async fn spawn<F>(&mut self, task: F) -> Result<(), RiderError>
    where
        F: Future<Output = ()>,
        F: Send + 'static,
    {
        let permit = self
            .sem
            .clone()
            .acquire_owned()
            .await
            .map_err(|_| RiderError::closed())?;

        self.set.spawn(async move {
            task.await;
            drop(permit);
        });

        Ok(())
    }

    /// Closes the rider.
    /// This prevents calls to further [`Rider::spawn`] calls, and it waits for remaining tasks to complete.
    ///
    /// # Example
    ///
    /// ```rust
    /// use rider::Rider;
    ///
    /// #[tokio::main]
    /// async fn main() {
    ///     let mut rider = Rider::new(10);
    ///     // ...
    ///     rider.shutdown().await;
    /// }
    /// ```
    pub async fn shutdown(mut self) {
        self.sem.close();
        while let Some(handle) = self.set.join_next().await {
            handle.expect("task in rider failed");
        }
    }
}