lychee_lib/ratelimit/
config.rs1use http::{HeaderMap, HeaderName, HeaderValue};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::time::Duration;
5
6use crate::ratelimit::HostKey;
7
8const DEFAULT_CONCURRENCY: usize = 10;
10
11const DEFAULT_REQUEST_INTERVAL: Duration = Duration::from_millis(50);
13
14#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
16pub struct RateLimitConfig {
17 #[serde(default = "default_concurrency")]
19 pub concurrency: usize,
20
21 #[serde(default = "default_request_interval", with = "humantime_serde")]
23 pub request_interval: Duration,
24}
25
26impl Default for RateLimitConfig {
27 fn default() -> Self {
28 Self {
29 concurrency: default_concurrency(),
30 request_interval: default_request_interval(),
31 }
32 }
33}
34
35const fn default_concurrency() -> usize {
37 DEFAULT_CONCURRENCY
38}
39
40const fn default_request_interval() -> Duration {
42 DEFAULT_REQUEST_INTERVAL
43}
44
45impl RateLimitConfig {
46 #[must_use]
48 pub fn from_options(concurrency: Option<usize>, request_interval: Option<Duration>) -> Self {
49 Self {
50 concurrency: concurrency.unwrap_or(DEFAULT_CONCURRENCY),
51 request_interval: request_interval.unwrap_or(DEFAULT_REQUEST_INTERVAL),
52 }
53 }
54}
55
56pub type HostConfigs = HashMap<HostKey, HostConfig>;
58
59#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
61#[serde(deny_unknown_fields)]
62pub struct HostConfig {
63 pub concurrency: Option<usize>,
65
66 #[serde(default, with = "humantime_serde")]
68 pub request_interval: Option<Duration>,
69
70 #[serde(default)]
72 #[serde(deserialize_with = "deserialize_headers")]
73 #[serde(serialize_with = "serialize_headers")]
74 pub headers: HeaderMap,
75}
76
77impl Default for HostConfig {
78 fn default() -> Self {
79 Self {
80 concurrency: None,
81 request_interval: None,
82 headers: HeaderMap::new(),
83 }
84 }
85}
86
87impl HostConfig {
88 #[must_use]
90 pub fn effective_concurrency(&self, global_config: &RateLimitConfig) -> usize {
91 self.concurrency.unwrap_or(global_config.concurrency)
92 }
93
94 #[must_use]
96 pub fn effective_request_interval(&self, global_config: &RateLimitConfig) -> Duration {
97 self.request_interval
98 .unwrap_or(global_config.request_interval)
99 }
100}
101
102fn deserialize_headers<'de, D>(deserializer: D) -> Result<HeaderMap, D::Error>
104where
105 D: serde::Deserializer<'de>,
106{
107 let map = HashMap::<String, String>::deserialize(deserializer)?;
108 let mut header_map = HeaderMap::new();
109
110 for (name, value) in map {
111 let header_name = HeaderName::from_bytes(name.as_bytes())
112 .map_err(|e| serde::de::Error::custom(format!("Invalid header name '{name}': {e}")))?;
113 let header_value = HeaderValue::from_str(&value).map_err(|e| {
114 serde::de::Error::custom(format!("Invalid header value '{value}': {e}"))
115 })?;
116 header_map.insert(header_name, header_value);
117 }
118
119 Ok(header_map)
120}
121
122fn serialize_headers<S>(headers: &HeaderMap, serializer: S) -> Result<S::Ok, S::Error>
124where
125 S: serde::Serializer,
126{
127 let map: HashMap<String, String> = headers
128 .iter()
129 .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
130 .collect();
131 map.serialize(serializer)
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn test_default_rate_limit_config() {
140 let config = RateLimitConfig::default();
141 assert_eq!(config.concurrency, 10);
142 assert_eq!(config.request_interval, Duration::from_millis(50));
143 }
144
145 #[test]
146 fn test_host_config_effective_values() {
147 let global_config = RateLimitConfig::default();
148
149 let host_config = HostConfig::default();
151 assert_eq!(host_config.effective_concurrency(&global_config), 10);
152 assert_eq!(
153 host_config.effective_request_interval(&global_config),
154 Duration::from_millis(50)
155 );
156
157 let host_config = HostConfig {
159 concurrency: Some(5),
160 request_interval: Some(Duration::from_millis(500)),
161 headers: HeaderMap::new(),
162 };
163 assert_eq!(host_config.effective_concurrency(&global_config), 5);
164 assert_eq!(
165 host_config.effective_request_interval(&global_config),
166 Duration::from_millis(500)
167 );
168 }
169
170 #[test]
171 fn test_config_serialization() {
172 let config = RateLimitConfig {
173 concurrency: 15,
174 request_interval: Duration::from_millis(200),
175 };
176
177 let toml = toml::to_string(&config).unwrap();
178 let deserialized: RateLimitConfig = toml::from_str(&toml).unwrap();
179
180 assert_eq!(config.concurrency, deserialized.concurrency);
181 assert_eq!(config.request_interval, deserialized.request_interval);
182 }
183
184 #[test]
185 fn test_headers_serialization() {
186 let mut headers = HeaderMap::new();
187 headers.insert("Authorization", "Bearer token123".parse().unwrap());
188 headers.insert("User-Agent", "test-agent".parse().unwrap());
189
190 let host_config = HostConfig {
191 concurrency: Some(5),
192 request_interval: Some(Duration::from_millis(500)),
193 headers,
194 };
195
196 let toml = toml::to_string(&host_config).unwrap();
197 let deserialized: HostConfig = toml::from_str(&toml).unwrap();
198
199 assert_eq!(deserialized.concurrency, Some(5));
200 assert_eq!(
201 deserialized.request_interval,
202 Some(Duration::from_millis(500))
203 );
204 assert_eq!(deserialized.headers.len(), 2);
205 assert!(deserialized.headers.contains_key("authorization"));
206 assert!(deserialized.headers.contains_key("user-agent"));
207 }
208}