use std::sync::atomic::{AtomicBool, AtomicU64, Ordering::*};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::time;
use crate::limiter::Limiter;
#[async_trait]
pub(crate) trait Dispatcher: Send + Sync {
type Limiter = Limiter;
fn is_canceled_or_done(&self) -> bool;
fn get_limiter(&self) -> &Option<Limiter>;
async fn apply_token(&self) -> bool {
if self.is_canceled_or_done() {
return false;
}
if let Some(limiter) = self.get_limiter() {
loop {
let result = limiter.allow_fast().await;
if result.is_ok() {
break;
}
if self.is_canceled_or_done() {
return false;
}
time::sleep(Duration::from_micros(1)).await;
}
}
if self.is_canceled_or_done() {
return false;
}
true
}
async fn try_apply_job(&self) -> bool;
fn complete_job(&self);
fn cancel(&mut self);
}
pub(crate) struct CountDispatcher {
total: u64,
applied: AtomicU64,
completed: AtomicU64,
is_canceled: AtomicBool,
is_done: AtomicBool,
limiter: Option<Limiter>,
}
fn new_limiter(rate: &Option<u16>) -> Option<Limiter> {
let mut limiter: Option<Limiter> = None;
if let Some(rate) = rate {
limiter = Some(Limiter::new(*rate));
limiter.as_ref().unwrap().allow_n(*rate as usize);
}
limiter
}
impl CountDispatcher {
pub(crate) fn new(total: u64, rate: &Option<u16>) -> Self {
Self {
total,
limiter: new_limiter(rate),
applied: AtomicU64::new(0),
completed: AtomicU64::new(0),
is_canceled: AtomicBool::new(false),
is_done: AtomicBool::new(false),
}
}
}
#[async_trait]
impl Dispatcher for CountDispatcher {
fn is_canceled_or_done(&self) -> bool {
self.is_done.load(Acquire) || self.is_canceled.load(Acquire)
}
fn get_limiter(&self) -> &Option<Limiter> {
&self.limiter
}
async fn try_apply_job(&self) -> bool {
if !self.apply_token().await {
return false;
}
if self.applied.load(Acquire) < self.total {
let previous = self.applied.fetch_add(1, SeqCst);
if previous >= self.total {
return false;
}
} else {
return false;
}
true
}
fn complete_job(&self) {
self.completed.fetch_add(1, SeqCst);
if self.completed.load(Acquire) >= self.total
&& !self.is_done.load(Acquire)
{
self.is_done.store(true, SeqCst);
}
}
fn cancel(&mut self) {
if !self.is_canceled.load(Acquire) {
self.is_canceled.store(true, SeqCst);
}
}
}
pub(crate) struct DurationDispatcher {
total: AtomicU64,
start: Instant,
duration: Duration,
limiter: Option<Limiter>,
is_canceled: AtomicBool,
canceled_at: Option<Instant>,
is_done: AtomicBool,
}
impl DurationDispatcher {
pub(crate) fn new(duration: Duration, rate: &Option<u16>) -> Self {
Self {
duration,
canceled_at: None,
start: Instant::now(),
limiter: new_limiter(rate),
total: AtomicU64::new(0),
is_canceled: AtomicBool::new(false),
is_done: AtomicBool::new(false),
}
}
}
#[async_trait]
impl Dispatcher for DurationDispatcher {
fn is_canceled_or_done(&self) -> bool {
self.is_done.load(Acquire) || self.is_canceled.load(Acquire)
}
fn get_limiter(&self) -> &Option<Limiter> {
&self.limiter
}
async fn try_apply_job(&self) -> bool {
if !self.apply_token().await {
return false;
}
if Instant::now() - self.start >= self.duration {
return false;
}
self.total.fetch_add(1, SeqCst);
true
}
fn complete_job(&self) {
if Instant::now() - self.start >= self.duration
&& !self.is_done.load(Acquire)
{
self.is_done.store(true, SeqCst);
}
}
fn cancel(&mut self) {
if !self.is_canceled.load(Acquire) {
self.is_canceled.store(true, SeqCst);
self.canceled_at = Some(Instant::now());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_limiter_with_rate() {
let rate = Some(10);
let limiter = new_limiter(&rate);
assert!(limiter.is_some());
}
#[test]
fn test_new_limiter_without_rate() {
let rate: Option<u16> = None;
let limiter = new_limiter(&rate);
assert!(limiter.is_none());
}
#[test]
fn test_count_dispatcher_new() {
let dispatcher = CountDispatcher::new(100, &Some(10));
assert_eq!(dispatcher.applied.load(Acquire), 0);
assert_eq!(dispatcher.completed.load(Acquire), 0);
assert!(!dispatcher.is_canceled.load(Acquire));
assert!(!dispatcher.is_done.load(Acquire));
}
#[test]
fn test_count_dispatcher_is_canceled_or_done() {
let dispatcher = CountDispatcher::new(100, &None);
assert!(!dispatcher.is_canceled_or_done());
dispatcher.is_done.store(true, SeqCst);
assert!(dispatcher.is_canceled_or_done());
}
#[tokio::test]
async fn test_count_dispatcher_try_apply_job() {
let dispatcher = CountDispatcher::new(5, &None);
for i in 0..5 {
let result = dispatcher.try_apply_job().await;
assert!(result, "Job {} should succeed", i);
}
let result = dispatcher.try_apply_job().await;
assert!(!result);
}
#[test]
fn test_count_dispatcher_complete_job() {
let dispatcher = CountDispatcher::new(5, &None);
for _ in 0..5 {
dispatcher.complete_job();
}
assert_eq!(dispatcher.completed.load(Acquire), 5);
assert!(dispatcher.is_done.load(Acquire));
}
#[test]
fn test_count_dispatcher_cancel() {
let mut dispatcher = CountDispatcher::new(100, &None);
assert!(!dispatcher.is_canceled.load(Acquire));
dispatcher.cancel();
assert!(dispatcher.is_canceled.load(Acquire));
}
#[test]
fn test_duration_dispatcher_new() {
let duration = Duration::from_secs(60);
let dispatcher = DurationDispatcher::new(duration, &Some(10));
assert_eq!(dispatcher.total.load(Acquire), 0);
assert!(!dispatcher.is_canceled.load(Acquire));
assert!(!dispatcher.is_done.load(Acquire));
}
#[test]
fn test_duration_dispatcher_is_canceled_or_done() {
let duration = Duration::from_secs(60);
let dispatcher = DurationDispatcher::new(duration, &None);
assert!(!dispatcher.is_canceled_or_done());
dispatcher.is_done.store(true, SeqCst);
assert!(dispatcher.is_canceled_or_done());
}
#[tokio::test]
async fn test_duration_dispatcher_try_apply_job() {
let duration = Duration::from_secs(1);
let dispatcher = DurationDispatcher::new(duration, &None);
let result = dispatcher.try_apply_job().await;
assert!(result);
assert_eq!(dispatcher.total.load(Acquire), 1);
}
#[test]
fn test_duration_dispatcher_complete_job() {
let duration = Duration::from_secs(60);
let dispatcher = DurationDispatcher::new(duration, &None);
dispatcher.complete_job();
assert!(!dispatcher.is_done.load(Acquire));
}
#[test]
fn test_duration_dispatcher_cancel() {
let duration = Duration::from_secs(60);
let mut dispatcher = DurationDispatcher::new(duration, &None);
assert!(!dispatcher.is_canceled.load(Acquire));
assert!(dispatcher.canceled_at.is_none());
dispatcher.cancel();
assert!(dispatcher.is_canceled.load(Acquire));
assert!(dispatcher.canceled_at.is_some());
}
}