rover_api/
rate_limiter.rs1use reqwest::Response;
2use std::collections::HashMap;
3use std::time::{Duration, Instant};
4use crate::{RoverApiError, RoverClient};
5
6#[derive(Debug, Clone)]
7pub struct RateLimitBucket {
8 pub remaining: u32,
9 pub reset_after: Duration,
10 pub reset_at: Instant,
11}
12
13#[derive(Debug, Clone)]
14pub struct RateLimitStatus {
15 pub bucket: Option<String>,
16 pub remaining: Option<u32>,
17 pub reset_after: Option<Duration>,
18 pub reset_at: Option<Instant>,
19 pub next_reset: Option<Instant>,
20}
21
22pub struct RateLimiter {
23 buckets: HashMap<String, RateLimitBucket>,
24 global_reset: Option<Instant>,
25}
26
27impl RateLimiter {
28 pub fn new() -> Self {
29 Self {
30 buckets: HashMap::new(),
31 global_reset: None,
32 }
33 }
34
35 pub fn check_rate_limit(&mut self) -> Result<(), RoverApiError> {
36 let now = Instant::now();
37
38 if let Some(global_reset) = self.global_reset {
40 if now < global_reset {
41 let retry_after = (global_reset - now).as_secs();
42 return Err(RoverApiError::RateLimit { retry_after });
43 } else {
44 self.global_reset = None;
45 }
46 }
47
48 self.buckets.retain(|_, bucket| now < bucket.reset_at);
50
51 Ok(())
52 }
53
54 pub fn update_from_headers(&mut self, response: &Response) {
55 let headers = response.headers();
56 let now = Instant::now();
57
58 let bucket = headers
60 .get("X-RateLimit-Bucket")
61 .and_then(|h| h.to_str().ok())
62 .map(|s| s.to_string());
63
64 let remaining = headers
65 .get("X-RateLimit-Remaining")
66 .and_then(|h| h.to_str().ok())
67 .and_then(|s| s.parse::<u32>().ok());
68
69 let reset_after = headers
70 .get("X-RateLimit-Reset-After")
71 .and_then(|h| h.to_str().ok())
72 .and_then(|s| s.parse::<f64>().ok())
73 .map(|secs| Duration::from_secs_f64(secs));
74
75 if response.status() == 429 {
77 if let Some(retry_after_header) = headers.get("Retry-After") {
78 if let Ok(retry_str) = retry_after_header.to_str() {
79 if let Ok(retry_secs) = retry_str.parse::<u64>() {
80 self.global_reset = Some(now + Duration::from_secs(retry_secs));
81 }
82 }
83 }
84 }
85
86 if let (Some(bucket_name), Some(remaining_count), Some(reset_duration)) =
88 (bucket, remaining, reset_after) {
89 let bucket = RateLimitBucket {
90 remaining: remaining_count,
91 reset_after: reset_duration,
92 reset_at: now + reset_duration,
93 };
94 self.buckets.insert(bucket_name, bucket);
95 }
96 }
97
98 pub fn get_status(&self) -> RateLimitStatus {
99 let now = Instant::now();
100
101 let most_restrictive = self.buckets
103 .iter()
104 .filter(|(_, bucket)| now < bucket.reset_at)
105 .min_by_key(|(_, bucket)| bucket.remaining);
106
107 if let Some((bucket_name, bucket)) = most_restrictive {
108 RateLimitStatus {
109 bucket: Some(bucket_name.clone()),
110 remaining: Some(bucket.remaining),
111 reset_after: Some(bucket.reset_after),
112 reset_at: Some(bucket.reset_at),
113 next_reset: Some(bucket.reset_at),
114 }
115 } else {
116 RateLimitStatus {
117 bucket: None,
118 remaining: None,
119 reset_after: None,
120 reset_at: None,
121 next_reset: self.global_reset,
122 }
123 }
124 }
125
126 pub async fn wait_for_reset(&self) {
128 let status = self.get_status();
129 if let Some(reset_time) = status.next_reset {
130 let now = Instant::now();
131 if now < reset_time {
132 let wait_duration = reset_time - now;
133 tokio::time::sleep(wait_duration).await;
134 }
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use tokio_test;
143
144 #[tokio::test]
145 async fn test_client_creation() {
146 let client = RoverClient::new("test_key".to_string());
147 assert!(!client.api_key.is_empty());
148 }
149
150 #[test]
151 fn test_rate_limiter() {
152 let mut limiter = RateLimiter::new();
153
154 assert!(limiter.check_rate_limit().is_ok());
156
157 let status = limiter.get_status();
158 assert!(status.bucket.is_none());
159 assert!(status.remaining.is_none());
160 }
161
162 #[test]
163 fn test_error_types() {
164 let api_error = RoverApiError::Api {
165 code: "bad_request".to_string(),
166 message: "Invalid request".to_string(),
167 detail: None,
168 context: None,
169 };
170
171 assert!(matches!(api_error, RoverApiError::Api { .. }));
172
173 let rate_limit_error = RoverApiError::RateLimit { retry_after: 60 };
174 assert!(matches!(rate_limit_error, RoverApiError::RateLimit { .. }));
175 }
176}