Skip to main content

lychee_lib/ratelimit/
config.rs

1use http::{HeaderMap, HeaderName, HeaderValue};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::time::Duration;
5
6use crate::ratelimit::HostKey;
7
8/// Default number of concurrent requests per host
9const DEFAULT_CONCURRENCY: usize = 10;
10
11/// Default interval between requests to the same host
12const DEFAULT_REQUEST_INTERVAL: Duration = Duration::from_millis(50);
13
14/// Global rate limiting configuration that applies as defaults to all hosts
15#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
16pub struct RateLimitConfig {
17    /// Default maximum concurrent requests per host
18    #[serde(default = "default_concurrency")]
19    pub concurrency: usize,
20
21    /// Default minimum interval between requests to the same host
22    #[serde(default = "default_request_interval", with = "humantime_serde")]
23    pub request_interval: Duration,
24}
25
26impl Default for RateLimitConfig {
27    fn default() -> Self {
28        Self {
29            concurrency: default_concurrency(),
30            request_interval: default_request_interval(),
31        }
32    }
33}
34
35/// Default number of concurrent requests per host
36const fn default_concurrency() -> usize {
37    DEFAULT_CONCURRENCY
38}
39
40/// Default interval between requests to the same host
41const fn default_request_interval() -> Duration {
42    DEFAULT_REQUEST_INTERVAL
43}
44
45impl RateLimitConfig {
46    /// Create a `RateLimitConfig` from CLI options, using defaults for missing values
47    #[must_use]
48    pub fn from_options(concurrency: Option<usize>, request_interval: Option<Duration>) -> Self {
49        Self {
50            concurrency: concurrency.unwrap_or(DEFAULT_CONCURRENCY),
51            request_interval: request_interval.unwrap_or(DEFAULT_REQUEST_INTERVAL),
52        }
53    }
54}
55
56/// Per-host configuration overrides
57pub type HostConfigs = HashMap<HostKey, HostConfig>;
58
59/// Configuration for a specific host's rate limiting behavior
60#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
61#[serde(deny_unknown_fields)]
62pub struct HostConfig {
63    /// Maximum concurrent requests allowed to this host
64    pub concurrency: Option<usize>,
65
66    /// Minimum interval between requests to this host
67    #[serde(default, with = "humantime_serde")]
68    pub request_interval: Option<Duration>,
69
70    /// Custom headers to send with requests to this host
71    #[serde(default)]
72    #[serde(deserialize_with = "deserialize_headers")]
73    #[serde(serialize_with = "serialize_headers")]
74    pub headers: HeaderMap,
75}
76
77impl Default for HostConfig {
78    fn default() -> Self {
79        Self {
80            concurrency: None,
81            request_interval: None,
82            headers: HeaderMap::new(),
83        }
84    }
85}
86
87impl HostConfig {
88    /// Get the effective maximum concurrency, falling back to the global default
89    #[must_use]
90    pub fn effective_concurrency(&self, global_config: &RateLimitConfig) -> usize {
91        self.concurrency.unwrap_or(global_config.concurrency)
92    }
93
94    /// Get the effective request interval, falling back to the global default
95    #[must_use]
96    pub fn effective_request_interval(&self, global_config: &RateLimitConfig) -> Duration {
97        self.request_interval
98            .unwrap_or(global_config.request_interval)
99    }
100}
101
102/// Custom deserializer for headers from TOML config format
103fn deserialize_headers<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
104where
105    D: serde::Deserializer<'de>,
106{
107    let map = HashMap::<String, String>::deserialize(deserializer)?;
108    let mut header_map = HeaderMap::new();
109
110    for (name, value) in map {
111        let header_name = HeaderName::from_bytes(name.as_bytes())
112            .map_err(|e| serde::de::Error::custom(format!("Invalid header name '{name}': {e}")))?;
113        let header_value = HeaderValue::from_str(&value).map_err(|e| {
114            serde::de::Error::custom(format!("Invalid header value '{value}': {e}"))
115        })?;
116        header_map.insert(header_name, header_value);
117    }
118
119    Ok(header_map)
120}
121
122/// Custom serializer for headers to TOML config format
123fn serialize_headers<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
124where
125    S: serde::Serializer,
126{
127    let map: HashMap<String, String> = headers
128        .iter()
129        .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
130        .collect();
131    map.serialize(serializer)
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn test_default_rate_limit_config() {
140        let config = RateLimitConfig::default();
141        assert_eq!(config.concurrency, 10);
142        assert_eq!(config.request_interval, Duration::from_millis(50));
143    }
144
145    #[test]
146    fn test_host_config_effective_values() {
147        let global_config = RateLimitConfig::default();
148
149        // Test with no overrides
150        let host_config = HostConfig::default();
151        assert_eq!(host_config.effective_concurrency(&global_config), 10);
152        assert_eq!(
153            host_config.effective_request_interval(&global_config),
154            Duration::from_millis(50)
155        );
156
157        // Test with overrides
158        let host_config = HostConfig {
159            concurrency: Some(5),
160            request_interval: Some(Duration::from_millis(500)),
161            headers: HeaderMap::new(),
162        };
163        assert_eq!(host_config.effective_concurrency(&global_config), 5);
164        assert_eq!(
165            host_config.effective_request_interval(&global_config),
166            Duration::from_millis(500)
167        );
168    }
169
170    #[test]
171    fn test_config_serialization() {
172        let config = RateLimitConfig {
173            concurrency: 15,
174            request_interval: Duration::from_millis(200),
175        };
176
177        let toml = toml::to_string(&config).unwrap();
178        let deserialized: RateLimitConfig = toml::from_str(&toml).unwrap();
179
180        assert_eq!(config.concurrency, deserialized.concurrency);
181        assert_eq!(config.request_interval, deserialized.request_interval);
182    }
183
184    #[test]
185    fn test_headers_serialization() {
186        let mut headers = HeaderMap::new();
187        headers.insert("Authorization", "Bearer token123".parse().unwrap());
188        headers.insert("User-Agent", "test-agent".parse().unwrap());
189
190        let host_config = HostConfig {
191            concurrency: Some(5),
192            request_interval: Some(Duration::from_millis(500)),
193            headers,
194        };
195
196        let toml = toml::to_string(&host_config).unwrap();
197        let deserialized: HostConfig = toml::from_str(&toml).unwrap();
198
199        assert_eq!(deserialized.concurrency, Some(5));
200        assert_eq!(
201            deserialized.request_interval,
202            Some(Duration::from_millis(500))
203        );
204        assert_eq!(deserialized.headers.len(), 2);
205        assert!(deserialized.headers.contains_key("authorization"));
206        assert!(deserialized.headers.contains_key("user-agent"));
207    }
208}