Skip to main content

lychee_lib/ratelimit/
config.rs

1use http::{HeaderMap, HeaderName, HeaderValue};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::collections::hash_map::Iter;
5use std::time::Duration;
6
7use crate::ratelimit::HostKey;
8
9/// Default number of concurrent requests per host
10const DEFAULT_CONCURRENCY: usize = 10;
11
12/// Default interval between requests to the same host
13const DEFAULT_REQUEST_INTERVAL: Duration = Duration::from_millis(50);
14
15/// Global rate limiting configuration that applies as defaults to all hosts
16#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
17pub struct RateLimitConfig {
18    /// Default maximum concurrent requests per host
19    #[serde(default = "default_concurrency")]
20    pub concurrency: usize,
21
22    /// Default minimum interval between requests to the same host
23    #[serde(default = "default_request_interval", with = "humantime_serde")]
24    pub request_interval: Duration,
25}
26
27impl Default for RateLimitConfig {
28    fn default() -> Self {
29        Self {
30            concurrency: default_concurrency(),
31            request_interval: default_request_interval(),
32        }
33    }
34}
35
36/// Default number of concurrent requests per host
37const fn default_concurrency() -> usize {
38    DEFAULT_CONCURRENCY
39}
40
41/// Default interval between requests to the same host
42const fn default_request_interval() -> Duration {
43    DEFAULT_REQUEST_INTERVAL
44}
45
46impl RateLimitConfig {
47    /// Create a `RateLimitConfig` from CLI options, using defaults for missing values
48    #[must_use]
49    pub fn from_options(concurrency: Option<usize>, request_interval: Option<Duration>) -> Self {
50        Self {
51            concurrency: concurrency.unwrap_or(DEFAULT_CONCURRENCY),
52            request_interval: request_interval.unwrap_or(DEFAULT_REQUEST_INTERVAL),
53        }
54    }
55}
56
57/// Per-host configuration overrides
58#[derive(Debug, Clone, Default, PartialEq, Deserialize)]
59pub struct HostConfigs(HashMap<HostKey, HostConfig>);
60
61impl HostConfigs {
62    /// Get a reference to the [`HostConfig`] associated to the [`HostKey`]
63    pub(crate) fn get(&self, key: &HostKey) -> Option<&HostConfig> {
64        self.0.get(key)
65    }
66
67    /// Get the number of [`HostConfig`]s
68    #[must_use]
69    pub fn len(&self) -> usize {
70        self.0.len()
71    }
72
73    /// Returns `true` if if there are no [`HostConfig`]s
74    #[must_use]
75    pub fn is_empty(&self) -> bool {
76        self.0.is_empty()
77    }
78
79    /// Get the iterator over all elements
80    pub(crate) fn iter(&self) -> Iter<'_, HostKey, HostConfig> {
81        self.0.iter()
82    }
83
84    /// Merge `self` with another `HostConfigs`
85    #[must_use]
86    pub fn merge(mut self, other: HostConfigs) -> HostConfigs {
87        for (key, value) in other.0 {
88            let value = if let Some(s) = self.0.remove(&key) {
89                s.merge(value)
90            } else {
91                value
92            };
93
94            self.0.insert(key, value);
95        }
96
97        self
98    }
99}
100
101impl<'a> IntoIterator for &'a HostConfigs {
102    type Item = (&'a HostKey, &'a HostConfig);
103    type IntoIter = Iter<'a, HostKey, HostConfig>;
104    fn into_iter(self) -> Self::IntoIter {
105        self.0.iter()
106    }
107}
108
109impl<const N: usize> From<[(HostKey, HostConfig); N]> for HostConfigs {
110    fn from(arr: [(HostKey, HostConfig); N]) -> Self {
111        HostConfigs(HashMap::<HostKey, HostConfig>::from_iter(arr))
112    }
113}
114
115/// Configuration for a specific host's rate limiting behavior
116#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
117#[serde(deny_unknown_fields)]
118pub struct HostConfig {
119    /// Maximum concurrent requests allowed to this host
120    pub concurrency: Option<usize>,
121
122    /// Minimum interval between requests to this host
123    #[serde(default, with = "humantime_serde")]
124    pub request_interval: Option<Duration>,
125
126    /// Custom headers to send with requests to this host
127    #[serde(default)]
128    #[serde(deserialize_with = "deserialize_headers")]
129    #[serde(serialize_with = "serialize_headers")]
130    pub headers: HeaderMap,
131}
132
133impl Default for HostConfig {
134    fn default() -> Self {
135        Self {
136            concurrency: None,
137            request_interval: None,
138            headers: HeaderMap::new(),
139        }
140    }
141}
142
143impl HostConfig {
144    /// Get the effective maximum concurrency, falling back to the global default
145    #[must_use]
146    pub fn effective_concurrency(&self, global_config: &RateLimitConfig) -> usize {
147        self.concurrency.unwrap_or(global_config.concurrency)
148    }
149
150    /// Get the effective request interval, falling back to the global default
151    #[must_use]
152    pub fn effective_request_interval(&self, global_config: &RateLimitConfig) -> Duration {
153        self.request_interval
154            .unwrap_or(global_config.request_interval)
155    }
156
157    #[must_use]
158    pub(crate) fn merge(mut self, other: Self) -> Self {
159        for (k, v) in other.headers {
160            if let Some(k) = k {
161                self.headers.append(k, v);
162            }
163        }
164
165        Self {
166            concurrency: self.concurrency.or(other.concurrency),
167            request_interval: self.request_interval.or(other.request_interval),
168            headers: self.headers,
169        }
170    }
171}
172
173/// Custom deserializer for headers from TOML config format
174fn deserialize_headers<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
175where
176    D: serde::Deserializer<'de>,
177{
178    let map = HashMap::<String, String>::deserialize(deserializer)?;
179    let mut header_map = HeaderMap::new();
180
181    for (name, value) in map {
182        let header_name = HeaderName::from_bytes(name.as_bytes())
183            .map_err(|e| serde::de::Error::custom(format!("Invalid header name '{name}': {e}")))?;
184        let header_value = HeaderValue::from_str(&value).map_err(|e| {
185            serde::de::Error::custom(format!("Invalid header value '{value}': {e}"))
186        })?;
187        header_map.insert(header_name, header_value);
188    }
189
190    Ok(header_map)
191}
192
193/// Custom serializer for headers to TOML config format
194fn serialize_headers<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
195where
196    S: serde::Serializer,
197{
198    let map: HashMap<String, String> = headers
199        .iter()
200        .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
201        .collect();
202    map.serialize(serializer)
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn test_default_rate_limit_config() {
211        let config = RateLimitConfig::default();
212        assert_eq!(config.concurrency, 10);
213        assert_eq!(config.request_interval, Duration::from_millis(50));
214    }
215
216    #[test]
217    fn test_host_config_effective_values() {
218        let global_config = RateLimitConfig::default();
219
220        // Test with no overrides
221        let host_config = HostConfig::default();
222        assert_eq!(host_config.effective_concurrency(&global_config), 10);
223        assert_eq!(
224            host_config.effective_request_interval(&global_config),
225            Duration::from_millis(50)
226        );
227
228        // Test with overrides
229        let host_config = HostConfig {
230            concurrency: Some(5),
231            request_interval: Some(Duration::from_millis(500)),
232            headers: HeaderMap::new(),
233        };
234        assert_eq!(host_config.effective_concurrency(&global_config), 5);
235        assert_eq!(
236            host_config.effective_request_interval(&global_config),
237            Duration::from_millis(500)
238        );
239    }
240
241    #[test]
242    fn test_config_serialization() {
243        let config = RateLimitConfig {
244            concurrency: 15,
245            request_interval: Duration::from_millis(200),
246        };
247
248        let toml = toml::to_string(&config).unwrap();
249        let deserialized: RateLimitConfig = toml::from_str(&toml).unwrap();
250
251        assert_eq!(config.concurrency, deserialized.concurrency);
252        assert_eq!(config.request_interval, deserialized.request_interval);
253    }
254
255    #[test]
256    fn test_headers_serialization() {
257        let mut headers = HeaderMap::new();
258        headers.insert("Authorization", "Bearer token123".parse().unwrap());
259        headers.insert("User-Agent", "test-agent".parse().unwrap());
260
261        let host_config = HostConfig {
262            concurrency: Some(5),
263            request_interval: Some(Duration::from_millis(500)),
264            headers,
265        };
266
267        let toml = toml::to_string(&host_config).unwrap();
268        let deserialized: HostConfig = toml::from_str(&toml).unwrap();
269
270        assert_eq!(deserialized.concurrency, Some(5));
271        assert_eq!(
272            deserialized.request_interval,
273            Some(Duration::from_millis(500))
274        );
275        assert_eq!(deserialized.headers.len(), 2);
276        assert!(deserialized.headers.contains_key("authorization"));
277        assert!(deserialized.headers.contains_key("user-agent"));
278    }
279}