Skip to main content

rover_api/
rate_limiter.rs

1use reqwest::Response;
2use std::collections::HashMap;
3use std::time::{Duration, Instant};
4use crate::{RoverApiError, RoverClient};
5
6#[derive(Debug, Clone)]
7pub struct RateLimitBucket {
8    pub remaining: u32,
9    pub reset_after: Duration,
10    pub reset_at: Instant,
11}
12
13#[derive(Debug, Clone)]
14pub struct RateLimitStatus {
15    pub bucket: Option<String>,
16    pub remaining: Option<u32>,
17    pub reset_after: Option<Duration>,
18    pub reset_at: Option<Instant>,
19    pub next_reset: Option<Instant>,
20}
21
22pub struct RateLimiter {
23    buckets: HashMap<String, RateLimitBucket>,
24    global_reset: Option<Instant>,
25}
26
27impl RateLimiter {
28    pub fn new() -> Self {
29        Self {
30            buckets: HashMap::new(),
31            global_reset: None,
32        }
33    }
34
35    pub fn check_rate_limit(&mut self) -> Result<(), RoverApiError> {
36        let now = Instant::now();
37
38        // Check global rate limit
39        if let Some(global_reset) = self.global_reset {
40            if now < global_reset {
41                let retry_after = (global_reset - now).as_secs();
42                return Err(RoverApiError::RateLimit { retry_after });
43            } else {
44                self.global_reset = None;
45            }
46        }
47
48        // Clean up expired buckets
49        self.buckets.retain(|_, bucket| now < bucket.reset_at);
50
51        Ok(())
52    }
53
54    pub fn update_from_headers(&mut self, response: &Response) {
55        let headers = response.headers();
56        let now = Instant::now();
57
58        // Extract rate limit headers
59        let bucket = headers
60            .get("X-RateLimit-Bucket")
61            .and_then(|h| h.to_str().ok())
62            .map(|s| s.to_string());
63
64        let remaining = headers
65            .get("X-RateLimit-Remaining")
66            .and_then(|h| h.to_str().ok())
67            .and_then(|s| s.parse::<u32>().ok());
68
69        let reset_after = headers
70            .get("X-RateLimit-Reset-After")
71            .and_then(|h| h.to_str().ok())
72            .and_then(|s| s.parse::<f64>().ok())
73            .map(|secs| Duration::from_secs_f64(secs));
74
75        // Handle global rate limit
76        if response.status() == 429 {
77            if let Some(retry_after_header) = headers.get("Retry-After") {
78                if let Ok(retry_str) = retry_after_header.to_str() {
79                    if let Ok(retry_secs) = retry_str.parse::<u64>() {
80                        self.global_reset = Some(now + Duration::from_secs(retry_secs));
81                    }
82                }
83            }
84        }
85
86        // Update bucket information
87        if let (Some(bucket_name), Some(remaining_count), Some(reset_duration)) = 
88            (bucket, remaining, reset_after) {
89            let bucket = RateLimitBucket {
90                remaining: remaining_count,
91                reset_after: reset_duration,
92                reset_at: now + reset_duration,
93            };
94            self.buckets.insert(bucket_name, bucket);
95        }
96    }
97
98    pub fn get_status(&self) -> RateLimitStatus {
99        let now = Instant::now();
100        
101        // Find the most restrictive bucket
102        let most_restrictive = self.buckets
103            .iter()
104            .filter(|(_, bucket)| now < bucket.reset_at)
105            .min_by_key(|(_, bucket)| bucket.remaining);
106
107        if let Some((bucket_name, bucket)) = most_restrictive {
108            RateLimitStatus {
109                bucket: Some(bucket_name.clone()),
110                remaining: Some(bucket.remaining),
111                reset_after: Some(bucket.reset_after),
112                reset_at: Some(bucket.reset_at),
113                next_reset: Some(bucket.reset_at),
114            }
115        } else {
116            RateLimitStatus {
117                bucket: None,
118                remaining: None,
119                reset_after: None,
120                reset_at: None,
121                next_reset: self.global_reset,
122            }
123        }
124    }
125
126    /// Wait until rate limits are reset
127    pub async fn wait_for_reset(&self) {
128        let status = self.get_status();
129        if let Some(reset_time) = status.next_reset {
130            let now = Instant::now();
131            if now < reset_time {
132                let wait_duration = reset_time - now;
133                tokio::time::sleep(wait_duration).await;
134            }
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use tokio_test;
143
144    #[tokio::test]
145    async fn test_client_creation() {
146        let client = RoverClient::new("test_key".to_string());
147        assert!(!client.api_key.is_empty());
148    }
149
150    #[test]
151    fn test_rate_limiter() {
152        let mut limiter = RateLimiter::new();
153        
154        // Should not be rate limited initially
155        assert!(limiter.check_rate_limit().is_ok());
156        
157        let status = limiter.get_status();
158        assert!(status.bucket.is_none());
159        assert!(status.remaining.is_none());
160    }
161
162    #[test]
163    fn test_error_types() {
164        let api_error = RoverApiError::Api {
165            code: "bad_request".to_string(),
166            message: "Invalid request".to_string(),
167            detail: None,
168            context: None,
169        };
170        
171        assert!(matches!(api_error, RoverApiError::Api { .. }));
172        
173        let rate_limit_error = RoverApiError::RateLimit { retry_after: 60 };
174        assert!(matches!(rate_limit_error, RoverApiError::RateLimit { .. }));
175    }
176}