use reqwest::Response;
use std::collections::HashMap;
use std::time::{Duration, Instant};
use crate::{RoverApiError, RoverClient};
#[derive(Debug, Clone)]
pub struct RateLimitBucket {
pub remaining: u32,
pub reset_after: Duration,
pub reset_at: Instant,
}
#[derive(Debug, Clone)]
pub struct RateLimitStatus {
pub bucket: Option<String>,
pub remaining: Option<u32>,
pub reset_after: Option<Duration>,
pub reset_at: Option<Instant>,
pub next_reset: Option<Instant>,
}
pub struct RateLimiter {
buckets: HashMap<String, RateLimitBucket>,
global_reset: Option<Instant>,
}
impl RateLimiter {
pub fn new() -> Self {
Self {
buckets: HashMap::new(),
global_reset: None,
}
}
pub fn check_rate_limit(&mut self) -> Result<(), RoverApiError> {
let now = Instant::now();
if let Some(global_reset) = self.global_reset {
if now < global_reset {
let retry_after = (global_reset - now).as_secs();
return Err(RoverApiError::RateLimit { retry_after });
} else {
self.global_reset = None;
}
}
self.buckets.retain(|_, bucket| now < bucket.reset_at);
Ok(())
}
pub fn update_from_headers(&mut self, response: &Response) {
let headers = response.headers();
let now = Instant::now();
let bucket = headers
.get("X-RateLimit-Bucket")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
let remaining = headers
.get("X-RateLimit-Remaining")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<u32>().ok());
let reset_after = headers
.get("X-RateLimit-Reset-After")
.and_then(|h| h.to_str().ok())
.and_then(|s| s.parse::<f64>().ok())
.map(|secs| Duration::from_secs_f64(secs));
if response.status() == 429 {
if let Some(retry_after_header) = headers.get("Retry-After") {
if let Ok(retry_str) = retry_after_header.to_str() {
if let Ok(retry_secs) = retry_str.parse::<u64>() {
self.global_reset = Some(now + Duration::from_secs(retry_secs));
}
}
}
}
if let (Some(bucket_name), Some(remaining_count), Some(reset_duration)) =
(bucket, remaining, reset_after) {
let bucket = RateLimitBucket {
remaining: remaining_count,
reset_after: reset_duration,
reset_at: now + reset_duration,
};
self.buckets.insert(bucket_name, bucket);
}
}
pub fn get_status(&self) -> RateLimitStatus {
let now = Instant::now();
let most_restrictive = self.buckets
.iter()
.filter(|(_, bucket)| now < bucket.reset_at)
.min_by_key(|(_, bucket)| bucket.remaining);
if let Some((bucket_name, bucket)) = most_restrictive {
RateLimitStatus {
bucket: Some(bucket_name.clone()),
remaining: Some(bucket.remaining),
reset_after: Some(bucket.reset_after),
reset_at: Some(bucket.reset_at),
next_reset: Some(bucket.reset_at),
}
} else {
RateLimitStatus {
bucket: None,
remaining: None,
reset_after: None,
reset_at: None,
next_reset: self.global_reset,
}
}
}
pub async fn wait_for_reset(&self) {
let status = self.get_status();
if let Some(reset_time) = status.next_reset {
let now = Instant::now();
if now < reset_time {
let wait_duration = reset_time - now;
tokio::time::sleep(wait_duration).await;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio_test;
#[tokio::test]
async fn test_client_creation() {
let client = RoverClient::new("test_key".to_string());
assert!(!client.api_key.is_empty());
}
#[test]
fn test_rate_limiter() {
let mut limiter = RateLimiter::new();
assert!(limiter.check_rate_limit().is_ok());
let status = limiter.get_status();
assert!(status.bucket.is_none());
assert!(status.remaining.is_none());
}
#[test]
fn test_error_types() {
let api_error = RoverApiError::Api {
code: "bad_request".to_string(),
message: "Invalid request".to_string(),
detail: None,
context: None,
};
assert!(matches!(api_error, RoverApiError::Api { .. }));
let rate_limit_error = RoverApiError::RateLimit { retry_after: 60 };
assert!(matches!(rate_limit_error, RoverApiError::RateLimit { .. }));
}
}