raliguard/
semaphore.rs

1/// Implementation of rate limi semaphore
2use std::time::{Duration, Instant};
3
4
5/// Use it to control execution frequency
6///
7/// # Examples:
8/// ```rust
9/// use std::{thread, sync, time};
10/// use raliguard::Semaphore;
11///
12///
13/// // Create a semaphore with restriction `5 tasks per 1 second`
14/// let originl_sem = Semaphore::new(5, time::Duration::from_secs(1));
15///
16/// // Make it sharable between treads (or you can share between tasks)
17/// let shared_sem = sync::Arc::new(
18///     sync::Mutex::new(originl_sem)
19/// );
20///
21/// // This is a counter that increments when a thread completed
22/// let shared_done_count = sync::Arc::new(sync::Mutex::new(0));
23///
24/// // Spawn 15 threads
25/// for _ in 0..15 {
26///     let cloned_sem = shared_sem.clone();
27///     let cloned_done_state = shared_done_count.clone();
28///     let thread = thread::spawn(move || {
29///         let mut local_sem = cloned_sem.lock().unwrap();
30///
31///         // Get required delay
32///         let calculated_delay = local_sem.calc_delay();
33///         drop(local_sem);
34///
35///         // If delay exists, sleep it
36///         if let Some(delay) = calculated_delay {
37///             dbg!(&delay);
38///             thread::sleep(delay);
39///         }
40///
41///         // Mark the thread is done
42///         let mut local_done_count = cloned_done_state.lock().unwrap();
43///         *local_done_count += 1;
44///
45///     });
46/// }
47///
48/// // So sleep 1 second (add some millis to let threads complete incrementing)
49/// thread::sleep(time::Duration::from_secs(1) + time::Duration::from_millis(50));
50/// let cloned_done_count = shared_done_count.clone();
51/// let current_done = cloned_done_count.lock().unwrap();
52///
53/// // And then maximum 10 threads should be completed
54/// // after 1 second sleeping
55/// // (the first 5 with no delay and the another 5 after 1 second)
56/// assert_eq!(*current_done, 10);
57/// ```
58#[derive(Debug, Clone)]
59pub struct Semaphore {
60    pub access_times: u64,
61    pub per_period: Duration,
62
63    boundary: Duration,
64    current_block_access: u64,
65    benchmark_stamp: Instant
66}
67
68
69impl Semaphore {
70    /// Create a new semaphore
71    ///
72    /// # Arguments:
73    /// * `access_times` - how many times a code allowed to be executed
74    /// * `per_period` - in which period code allowed to be executed
75    ///
76    /// # Returns:
77    /// Duration you need to sleep
78    ///
79    /// # Examples:
80    ///
81    /// ```rust
82    /// use std::time::Duration;
83    /// use raliguard::Semaphore;
84    ///
85    /// // Allows 5 executions per 1 second
86    /// let semaphore = Semaphore::new(5, Duration::from_secs(1));
87    ///
88    /// // Allows 2 executions per 7 seconds
89    /// let semaphore = Semaphore::new(2, Duration::from_secs(7));
90    /// ```
91    pub fn new(access_times: u64, per_period: Duration) -> Self {
92        Semaphore {
93            access_times,
94            per_period,
95            boundary: Duration::from_secs(0),
96            current_block_access: 0,
97            benchmark_stamp: Instant::now(),
98        }
99    }
100
101    /// Calculate delay the task/thread should sleep
102    ///
103    /// Use with `std::sync::Arc` and `std::sync::Mutex`
104    /// (or `tokio::sync::Mutex` in async style)
105    pub fn calc_delay(&mut self) -> Option<Duration> {
106        let stamp = self.benchmark_stamp.elapsed();
107
108        // Boundary second should be moved forward if it's outdated
109        if stamp >= self.boundary {
110            self.boundary = stamp + self.per_period;
111            self.current_block_access = 1;
112            return None;
113        }
114
115        // Add new hit
116        self.current_block_access += 1;
117
118        // Calc delay, should not be at all if it's the first block
119        let delay = (self.boundary - stamp).checked_sub(self.per_period);
120
121        // Allowed access for current block gets it's maximum,
122        // should move block forward
123        if self.current_block_access == self.access_times {
124            self.boundary += self.per_period;
125            self.current_block_access = 0;
126        }
127
128        delay
129    }
130}