use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::Arc;
use chrono::{DateTime, Utc};
use tokio::sync::Mutex;
#[derive(Default)]
pub struct APIKeyPool {
api_keys: Arc<Mutex<Vec<APIKey>>>,
}
impl APIKeyPool {
pub fn new() -> Self {
Self {
api_keys: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn add_key(&mut self, key: APIKey) {
self.api_keys.lock().await.push(key);
}
pub async fn poll_for_key(&mut self) -> Option<String> {
for key in &mut self.api_keys.lock().await.iter_mut() {
if key.is_ready().await {
return Some(key.use_key().await);
}
}
None
}
}
pub struct APIKey {
key: String,
policy: RateLimitPolicy,
times: Arc<Mutex<BinaryHeap<Reverse<DateTime<Utc>>>>>,
}
impl APIKey {
pub fn new(key: &str, policy: RateLimitPolicy) -> Self {
let mut _times = BinaryHeap::new();
_times.reserve(policy.count);
let times = Arc::new(Mutex::new(_times));
Self {
key: String::from(key),
policy,
times,
}
}
fn get_key(&self) -> String {
self.key.clone()
}
async fn is_ready(&self) -> bool {
if self.times.lock().await.len() < self.policy.count {
return true;
}
if let Some(oldest) = self.times.lock().await.peek() {
if oldest.0 < Utc::now() - self.policy.per {
return true;
}
}
false
}
async fn use_key(&mut self) -> String {
if self.times.lock().await.len() >= self.policy.count {
self.times.lock().await.pop();
}
self.times.lock().await.push(Reverse(Utc::now()));
self.get_key().clone()
}
}
#[derive(Clone, Copy)]
pub struct RateLimitPolicy {
pub count: usize,
pub per: chrono::Duration,
}
impl RateLimitPolicy {
pub fn new(count: usize, per: chrono::Duration) -> Self {
Self { count, per }
}
}