use crate::{
events::{send_event, Event},
Result, TenxError,
};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::time::sleep;
const BACKOFF_MULTIPLIER: f64 = 2.0;
const MAX_BACKOFF_SECS: u64 = 60;
#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
pub enum Throttle {
RetryAfter(u64),
Backoff,
}
impl std::fmt::Display for Throttle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Throttle::RetryAfter(secs) => write!(f, "retry after {} seconds", secs),
Throttle::Backoff => write!(f, "rate limited"),
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Throttler {
retries: u32,
max_retries: u32,
}
impl Throttler {
pub fn new(max_retries: u32) -> Self {
Throttler {
retries: 0,
max_retries,
}
}
pub fn throttle_time(&mut self, t: &Throttle) -> Result<Duration> {
if self.retries >= self.max_retries {
return Err(TenxError::MaxRetries(self.retries as u64));
}
Ok(match t {
Throttle::RetryAfter(seconds) => {
self.retries = 0;
Duration::from_secs(*seconds)
}
Throttle::Backoff => {
let backoff =
(BACKOFF_MULTIPLIER.powi(self.retries as i32) as u64).min(MAX_BACKOFF_SECS);
self.retries = self.retries.saturating_add(1);
Duration::from_secs(backoff)
}
})
}
pub fn reset(&mut self) {
self.retries = 0;
}
pub async fn throttle(
&mut self,
t: &Throttle,
sender: &Option<mpsc::Sender<Event>>,
) -> Result<()> {
let duration = self.throttle_time(t)?;
send_event(sender, Event::Throttled(duration.as_millis() as u64))?;
sleep(duration).await;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_after() {
let mut throttler = Throttler::new(20);
let duration = throttler.throttle_time(&Throttle::RetryAfter(10)).unwrap();
assert_eq!(duration, Duration::from_secs(10));
assert_eq!(throttler.retries, 0); }
#[test]
fn test_exponential_backoff() {
let mut throttler = Throttler::new(20);
let duration = throttler.throttle_time(&Throttle::Backoff).unwrap();
assert_eq!(duration, Duration::from_secs(1)); assert_eq!(throttler.retries, 1);
let duration = throttler.throttle_time(&Throttle::Backoff).unwrap();
assert_eq!(duration, Duration::from_secs(2)); assert_eq!(throttler.retries, 2);
let duration = throttler.throttle_time(&Throttle::Backoff).unwrap();
assert_eq!(duration, Duration::from_secs(4)); assert_eq!(throttler.retries, 3);
}
#[test]
fn test_backoff_cap() {
let mut throttler = Throttler::new(20);
throttler.retries = 10; let duration = throttler.throttle_time(&Throttle::Backoff).unwrap();
assert_eq!(duration, Duration::from_secs(MAX_BACKOFF_SECS));
assert_eq!(throttler.retries, 11);
}
#[test]
fn test_max_retries() {
let mut throttler = Throttler::new(3);
assert!(throttler.throttle_time(&Throttle::Backoff).is_ok());
assert!(throttler.throttle_time(&Throttle::Backoff).is_ok());
assert!(throttler.throttle_time(&Throttle::Backoff).is_ok());
match throttler.throttle_time(&Throttle::Backoff) {
Err(TenxError::MaxRetries(3)) => (),
other => panic!("Expected MaxRetries error, got {:?}", other),
}
}
}