Skip to main content

lychee_lib/ratelimit/host/
stats.rs

1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use serde::Serialize;
5use serde::ser::SerializeStruct;
6
7/// A [`HashMap`] mapping hosts to their [`HostStats`]
8#[derive(Debug, Default, Serialize)]
9pub struct HostStatsMap(HashMap<String, HostStats>);
10
11impl HostStatsMap {
12    /// Sort host statistics by request count (descending order)
13    /// This matches the display order we want in the output
14    #[must_use]
15    pub fn sorted(&self) -> Vec<(String, HostStats)> {
16        let mut sorted_hosts: Vec<_> = self.0.clone().into_iter().collect();
17        sorted_hosts.sort_by_key(|(_, stats)| std::cmp::Reverse(stats.total_requests));
18        sorted_hosts
19    }
20}
21
22impl From<HashMap<String, HostStats>> for HostStatsMap {
23    fn from(value: HashMap<String, HostStats>) -> Self {
24        Self(value)
25    }
26}
27
28/// Record and report statistics for a [`crate::ratelimit::Host`]
29#[derive(Debug, Clone, Default)]
30pub struct HostStats {
31    /// Total number of requests made to this host
32    pub total_requests: u64,
33    /// Number of successful requests (2xx status)
34    pub successful_requests: u64,
35    /// Number of requests that received rate limit responses (429)
36    pub rate_limited: u64,
37    /// Number of server error responses (5xx)
38    pub server_errors: u64,
39    /// Number of client error responses (4xx, excluding 429)
40    pub client_errors: u64,
41    /// Timestamp of the last successful request
42    pub last_success: Option<Instant>,
43    /// Timestamp of the last rate limit response
44    pub last_rate_limit: Option<Instant>,
45    /// Request times for median calculation
46    pub request_times: Vec<Duration>,
47    /// Status code counts
48    pub status_codes: HashMap<u16, u64>,
49    /// Number of cache hits
50    pub cache_hits: u64,
51    /// Number of cache misses
52    pub cache_misses: u64,
53}
54
55impl HostStats {
56    /// Record a response with status code and request duration
57    pub fn record_response(&mut self, status_code: u16, request_time: Duration) {
58        self.total_requests += 1;
59
60        // Track status code
61        *self.status_codes.entry(status_code).or_insert(0) += 1;
62
63        // Categorize response
64        match status_code {
65            200..=299 => {
66                self.successful_requests += 1;
67                self.last_success = Some(Instant::now());
68            }
69            429 => {
70                self.rate_limited += 1;
71                self.last_rate_limit = Some(Instant::now());
72            }
73            400..=499 => {
74                self.client_errors += 1;
75            }
76            500..=599 => {
77                self.server_errors += 1;
78            }
79            _ => {} // Other status codes
80        }
81
82        self.request_times.push(request_time);
83    }
84
85    /// Get median request time
86    #[must_use]
87    pub fn median_request_time(&self) -> Option<Duration> {
88        if self.request_times.is_empty() {
89            return None;
90        }
91
92        let mut times = self.request_times.clone();
93        times.sort();
94        let mid = times.len() / 2;
95
96        if times.len().is_multiple_of(2) {
97            // Average of two middle values
98            Some((times[mid - 1] + times[mid]) / 2)
99        } else {
100            Some(times[mid])
101        }
102    }
103
104    /// Get error rate (percentage)
105    #[must_use]
106    pub fn error_rate(&self) -> f64 {
107        if self.total_requests == 0 {
108            return 0.0;
109        }
110        let errors = self.rate_limited + self.client_errors + self.server_errors;
111        #[allow(clippy::cast_precision_loss)]
112        let error_rate = errors as f64 / self.total_requests as f64;
113        error_rate * 100.0
114    }
115
116    /// Get the current success rate (0.0 to 1.0)
117    #[must_use]
118    pub fn success_rate(&self) -> f64 {
119        if self.total_requests == 0 {
120            1.0 // Assume success until proven otherwise
121        } else {
122            #[allow(clippy::cast_precision_loss)]
123            let success_rate = self.successful_requests as f64 / self.total_requests as f64;
124            success_rate
125        }
126    }
127
128    /// Get average request time
129    #[must_use]
130    pub fn average_request_time(&self) -> Option<Duration> {
131        if self.request_times.is_empty() {
132            return None;
133        }
134
135        let total: Duration = self.request_times.iter().sum();
136        #[allow(clippy::cast_possible_truncation)]
137        Some(total / (self.request_times.len() as u32))
138    }
139
140    /// Get the most recent request time
141    #[must_use]
142    pub fn latest_request_time(&self) -> Option<Duration> {
143        self.request_times.iter().last().copied()
144    }
145
146    /// Check if this host has been experiencing rate limiting recently
147    #[must_use]
148    pub fn is_currently_rate_limited(&self) -> bool {
149        if let Some(last_rate_limit) = self.last_rate_limit {
150            // Consider rate limited if we got a 429 in the last 60 seconds
151            last_rate_limit.elapsed() < Duration::from_secs(60)
152        } else {
153            false
154        }
155    }
156
157    /// Record a cache hit
158    pub const fn record_cache_hit(&mut self) {
159        self.cache_hits += 1;
160        // Cache hits should also count as total requests from user perspective
161        self.total_requests += 1;
162        // Cache hits are typically for successful previous requests, so count as successful
163        self.successful_requests += 1;
164    }
165
166    /// Record a cache miss
167    pub const fn record_cache_miss(&mut self) {
168        self.cache_misses += 1;
169        // Cache misses will be followed by actual requests that increment total_requests
170        // so we don't increment here to avoid double-counting
171    }
172
173    /// Get cache hit rate (0.0 to 1.0)
174    #[must_use]
175    pub fn cache_hit_rate(&self) -> f64 {
176        let total_cache_requests = self.cache_hits + self.cache_misses;
177        if total_cache_requests == 0 {
178            0.0
179        } else {
180            #[allow(clippy::cast_precision_loss)]
181            let hit_rate = self.cache_hits as f64 / total_cache_requests as f64;
182            hit_rate
183        }
184    }
185
186    /// Get human-readable summary of the stats
187    #[must_use]
188    pub fn summary(&self) -> String {
189        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
190        let success_pct = (self.success_rate() * 100.0) as u64;
191        #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
192        let error_pct = self.error_rate() as u64;
193
194        let avg_time = self
195            .average_request_time()
196            .map_or_else(|| "N/A".to_string(), |d| format!("{:.0}ms", d.as_millis()));
197
198        format!(
199            "{} requests ({}% success, {}% errors), avg: {}",
200            self.total_requests, success_pct, error_pct, avg_time
201        )
202    }
203}
204
205impl Serialize for HostStats {
206    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
207    where
208        S: serde::Serializer,
209    {
210        let median_request_time_ms = self.median_request_time().map(|d| d.as_millis());
211
212        let mut s = serializer.serialize_struct("HostStats", 11)?;
213        s.serialize_field("total_requests", &self.total_requests)?;
214        s.serialize_field("successful_requests", &self.successful_requests)?;
215        s.serialize_field("success_rate", &self.success_rate())?;
216        s.serialize_field("rate_limited", &self.rate_limited)?;
217        s.serialize_field("client_errors", &self.client_errors)?;
218        s.serialize_field("server_errors", &self.server_errors)?;
219        s.serialize_field("median_request_time_ms", &median_request_time_ms)?;
220        s.serialize_field("cache_hits", &self.cache_hits)?;
221        s.serialize_field("cache_misses", &self.cache_misses)?;
222        s.serialize_field("cache_hit_rate", &self.cache_hit_rate())?;
223        s.serialize_field("status_codes", &self.status_codes)?;
224        s.end()
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use std::time::Duration;
232
233    #[test]
234    fn test_host_stats_success_rate() {
235        let mut stats = HostStats::default();
236
237        // No requests yet - should assume success
238        assert!((stats.success_rate() - 1.0).abs() < f64::EPSILON);
239
240        // Record some successful requests
241        stats.record_response(200, Duration::from_millis(100));
242        stats.record_response(200, Duration::from_millis(120));
243        assert!((stats.success_rate() - 1.0).abs() < f64::EPSILON);
244
245        // Record a rate limited request
246        stats.record_response(429, Duration::from_millis(150));
247        assert!((stats.success_rate() - (2.0 / 3.0)).abs() < 0.001);
248
249        // Record a server error
250        stats.record_response(500, Duration::from_millis(200));
251        assert!((stats.success_rate() - 0.5).abs() < f64::EPSILON);
252    }
253
254    #[test]
255    fn test_host_stats_tracking() {
256        let mut stats = HostStats::default();
257
258        // Initially empty
259        assert_eq!(stats.total_requests, 0);
260        assert_eq!(stats.successful_requests, 0);
261        assert!(stats.error_rate().abs() < f64::EPSILON);
262
263        // Record a successful response
264        stats.record_response(200, Duration::from_millis(100));
265        assert_eq!(stats.total_requests, 1);
266        assert_eq!(stats.successful_requests, 1);
267        assert!(stats.error_rate().abs() < f64::EPSILON);
268        assert_eq!(stats.status_codes.get(&200), Some(&1));
269
270        // Record rate limited response
271        stats.record_response(429, Duration::from_millis(200));
272        assert_eq!(stats.total_requests, 2);
273        assert_eq!(stats.rate_limited, 1);
274        assert!((stats.error_rate() - 50.0).abs() < f64::EPSILON);
275
276        // Record server error
277        stats.record_response(500, Duration::from_millis(150));
278        assert_eq!(stats.total_requests, 3);
279        assert_eq!(stats.server_errors, 1);
280
281        // Check median request time
282        assert_eq!(
283            stats.median_request_time(),
284            Some(Duration::from_millis(150))
285        );
286    }
287
288    #[test]
289    fn test_summary_formatting() {
290        let mut stats = HostStats::default();
291        stats.record_response(200, Duration::from_millis(150));
292        stats.record_response(500, Duration::from_millis(200));
293
294        let summary = stats.summary();
295        assert!(summary.contains("2 requests"));
296        assert!(summary.contains("50% success"));
297        assert!(summary.contains("50% errors"));
298        assert!(summary.contains("175ms")); // average of 150 and 200
299    }
300}