lychee_lib/ratelimit/
config.rs1use http::{HeaderMap, HeaderName, HeaderValue};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::collections::hash_map::Iter;
5use std::time::Duration;
6
7use crate::ratelimit::HostKey;
8
9const DEFAULT_CONCURRENCY: usize = 10;
11
12const DEFAULT_REQUEST_INTERVAL: Duration = Duration::from_millis(50);
14
15#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
17pub struct RateLimitConfig {
18 #[serde(default = "default_concurrency")]
20 pub concurrency: usize,
21
22 #[serde(default = "default_request_interval", with = "humantime_serde")]
24 pub request_interval: Duration,
25}
26
27impl Default for RateLimitConfig {
28 fn default() -> Self {
29 Self {
30 concurrency: default_concurrency(),
31 request_interval: default_request_interval(),
32 }
33 }
34}
35
36const fn default_concurrency() -> usize {
38 DEFAULT_CONCURRENCY
39}
40
41const fn default_request_interval() -> Duration {
43 DEFAULT_REQUEST_INTERVAL
44}
45
46impl RateLimitConfig {
47 #[must_use]
49 pub fn from_options(concurrency: Option<usize>, request_interval: Option<Duration>) -> Self {
50 Self {
51 concurrency: concurrency.unwrap_or(DEFAULT_CONCURRENCY),
52 request_interval: request_interval.unwrap_or(DEFAULT_REQUEST_INTERVAL),
53 }
54 }
55}
56
57#[derive(Debug, Clone, Default, PartialEq, Deserialize)]
59pub struct HostConfigs(HashMap<HostKey, HostConfig>);
60
61impl HostConfigs {
62 pub(crate) fn get(&self, key: &HostKey) -> Option<&HostConfig> {
64 self.0.get(key)
65 }
66
67 #[must_use]
69 pub fn len(&self) -> usize {
70 self.0.len()
71 }
72
73 #[must_use]
75 pub fn is_empty(&self) -> bool {
76 self.0.is_empty()
77 }
78
79 pub(crate) fn iter(&self) -> Iter<'_, HostKey, HostConfig> {
81 self.0.iter()
82 }
83
84 #[must_use]
86 pub fn merge(mut self, other: HostConfigs) -> HostConfigs {
87 for (key, value) in other.0 {
88 let value = if let Some(s) = self.0.remove(&key) {
89 s.merge(value)
90 } else {
91 value
92 };
93
94 self.0.insert(key, value);
95 }
96
97 self
98 }
99}
100
101impl<'a> IntoIterator for &'a HostConfigs {
102 type Item = (&'a HostKey, &'a HostConfig);
103 type IntoIter = Iter<'a, HostKey, HostConfig>;
104 fn into_iter(self) -> Self::IntoIter {
105 self.0.iter()
106 }
107}
108
109impl<const N: usize> From<[(HostKey, HostConfig); N]> for HostConfigs {
110 fn from(arr: [(HostKey, HostConfig); N]) -> Self {
111 HostConfigs(HashMap::<HostKey, HostConfig>::from_iter(arr))
112 }
113}
114
115#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
117#[serde(deny_unknown_fields)]
118pub struct HostConfig {
119 pub concurrency: Option<usize>,
121
122 #[serde(default, with = "humantime_serde")]
124 pub request_interval: Option<Duration>,
125
126 #[serde(default)]
128 #[serde(deserialize_with = "deserialize_headers")]
129 #[serde(serialize_with = "serialize_headers")]
130 pub headers: HeaderMap,
131}
132
133impl Default for HostConfig {
134 fn default() -> Self {
135 Self {
136 concurrency: None,
137 request_interval: None,
138 headers: HeaderMap::new(),
139 }
140 }
141}
142
143impl HostConfig {
144 #[must_use]
146 pub fn effective_concurrency(&self, global_config: &RateLimitConfig) -> usize {
147 self.concurrency.unwrap_or(global_config.concurrency)
148 }
149
150 #[must_use]
152 pub fn effective_request_interval(&self, global_config: &RateLimitConfig) -> Duration {
153 self.request_interval
154 .unwrap_or(global_config.request_interval)
155 }
156
157 #[must_use]
158 pub(crate) fn merge(mut self, other: Self) -> Self {
159 for (k, v) in other.headers {
160 if let Some(k) = k {
161 self.headers.append(k, v);
162 }
163 }
164
165 Self {
166 concurrency: self.concurrency.or(other.concurrency),
167 request_interval: self.request_interval.or(other.request_interval),
168 headers: self.headers,
169 }
170 }
171}
172
173fn deserialize_headers<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
175where
176 D: serde::Deserializer<'de>,
177{
178 let map = HashMap::<String, String>::deserialize(deserializer)?;
179 let mut header_map = HeaderMap::new();
180
181 for (name, value) in map {
182 let header_name = HeaderName::from_bytes(name.as_bytes())
183 .map_err(|e| serde::de::Error::custom(format!("Invalid header name '{name}': {e}")))?;
184 let header_value = HeaderValue::from_str(&value).map_err(|e| {
185 serde::de::Error::custom(format!("Invalid header value '{value}': {e}"))
186 })?;
187 header_map.insert(header_name, header_value);
188 }
189
190 Ok(header_map)
191}
192
193fn serialize_headers<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
195where
196 S: serde::Serializer,
197{
198 let map: HashMap<String, String> = headers
199 .iter()
200 .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
201 .collect();
202 map.serialize(serializer)
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn test_default_rate_limit_config() {
211 let config = RateLimitConfig::default();
212 assert_eq!(config.concurrency, 10);
213 assert_eq!(config.request_interval, Duration::from_millis(50));
214 }
215
216 #[test]
217 fn test_host_config_effective_values() {
218 let global_config = RateLimitConfig::default();
219
220 let host_config = HostConfig::default();
222 assert_eq!(host_config.effective_concurrency(&global_config), 10);
223 assert_eq!(
224 host_config.effective_request_interval(&global_config),
225 Duration::from_millis(50)
226 );
227
228 let host_config = HostConfig {
230 concurrency: Some(5),
231 request_interval: Some(Duration::from_millis(500)),
232 headers: HeaderMap::new(),
233 };
234 assert_eq!(host_config.effective_concurrency(&global_config), 5);
235 assert_eq!(
236 host_config.effective_request_interval(&global_config),
237 Duration::from_millis(500)
238 );
239 }
240
241 #[test]
242 fn test_config_serialization() {
243 let config = RateLimitConfig {
244 concurrency: 15,
245 request_interval: Duration::from_millis(200),
246 };
247
248 let toml = toml::to_string(&config).unwrap();
249 let deserialized: RateLimitConfig = toml::from_str(&toml).unwrap();
250
251 assert_eq!(config.concurrency, deserialized.concurrency);
252 assert_eq!(config.request_interval, deserialized.request_interval);
253 }
254
255 #[test]
256 fn test_headers_serialization() {
257 let mut headers = HeaderMap::new();
258 headers.insert("Authorization", "Bearer token123".parse().unwrap());
259 headers.insert("User-Agent", "test-agent".parse().unwrap());
260
261 let host_config = HostConfig {
262 concurrency: Some(5),
263 request_interval: Some(Duration::from_millis(500)),
264 headers,
265 };
266
267 let toml = toml::to_string(&host_config).unwrap();
268 let deserialized: HostConfig = toml::from_str(&toml).unwrap();
269
270 assert_eq!(deserialized.concurrency, Some(5));
271 assert_eq!(
272 deserialized.request_interval,
273 Some(Duration::from_millis(500))
274 );
275 assert_eq!(deserialized.headers.len(), 2);
276 assert!(deserialized.headers.contains_key("authorization"));
277 assert!(deserialized.headers.contains_key("user-agent"));
278 }
279}