1use std::time::Duration;
7
8use serde_json::Value;
9
10use rdap_types::error::{RdapError, Result};
11use rdap_security::SsrfGuard;
12
13#[derive(Debug, Clone)]
15pub struct FetcherConfig {
16 pub timeout: Duration,
18 pub user_agent: String,
20 pub max_attempts: u32,
22 pub initial_backoff: Duration,
24 pub max_backoff: Duration,
26 pub reuse_connections: bool,
28 pub max_connections_per_host: usize,
30}
31
32impl Default for FetcherConfig {
33 fn default() -> Self {
34 Self {
35 timeout: Duration::from_secs(10),
36 user_agent: format!(
37 "rdapify/{} (https://rdapify.com)",
38 env!("CARGO_PKG_VERSION")
39 ),
40 max_attempts: 3,
41 initial_backoff: Duration::from_millis(500),
42 max_backoff: Duration::from_secs(8),
43 reuse_connections: true,
44 max_connections_per_host: 10,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct Fetcher {
52 client: reqwest::Client,
53 ssrf: SsrfGuard,
54 config: FetcherConfig,
55}
56
57impl Fetcher {
58 pub fn new(ssrf: SsrfGuard) -> Result<Self> {
60 Self::with_config(ssrf, FetcherConfig::default())
61 }
62
63 pub fn with_config(ssrf: SsrfGuard, config: FetcherConfig) -> Result<Self> {
65 let tcp_keepalive = if config.reuse_connections {
66 Some(Duration::from_secs(60))
67 } else {
68 None
69 };
70
71 let client = reqwest::Client::builder()
72 .timeout(config.timeout)
73 .user_agent(&config.user_agent)
74 .use_rustls_tls()
75 .gzip(true)
76 .tcp_keepalive(tcp_keepalive)
77 .pool_max_idle_per_host(config.max_connections_per_host)
78 .build()
79 .map_err(RdapError::Network)?;
80
81 Ok(Self {
82 client,
83 ssrf,
84 config,
85 })
86 }
87
88 pub async fn fetch(&self, url: &str) -> Result<Value> {
93 self.ssrf.validate(url)?;
94
95 let mut attempt = 0u32;
96 loop {
97 attempt += 1;
98 match self.do_fetch(url).await {
99 Ok(value) => return Ok(value),
100 Err(err) if attempt < self.config.max_attempts && is_retryable(&err) => {
101 let delay = backoff(
102 attempt,
103 self.config.initial_backoff,
104 self.config.max_backoff,
105 );
106 tokio::time::sleep(delay).await;
107 }
108 Err(err) => return Err(err),
109 }
110 }
111 }
112
113 async fn do_fetch(&self, url: &str) -> Result<Value> {
114 let response = self
115 .client
116 .get(url)
117 .header("Accept", "application/rdap+json, application/json")
118 .send()
119 .await
120 .map_err(|e| {
121 if e.is_timeout() {
122 RdapError::Timeout {
123 millis: self.config.timeout.as_millis() as u64,
124 url: url.to_string(),
125 }
126 } else {
127 RdapError::Network(e)
128 }
129 })?;
130
131 let status = response.status();
132
133 if !status.is_success() {
134 return Err(RdapError::HttpStatus {
135 status: status.as_u16(),
136 url: url.to_string(),
137 });
138 }
139
140 response
141 .json::<Value>()
142 .await
143 .map_err(|e| RdapError::ParseError {
144 reason: e.to_string(),
145 })
146 }
147
148 pub fn reqwest_client(&self) -> reqwest::Client {
150 self.client.clone()
151 }
152}
153
154fn is_retryable(err: &RdapError) -> bool {
157 match err {
158 RdapError::Network(_) | RdapError::Timeout { .. } => true,
159 RdapError::HttpStatus { status, .. } => {
160 matches!(status, 429 | 500 | 502 | 503 | 504)
161 }
162 _ => false,
163 }
164}
165
166fn backoff(attempt: u32, initial: Duration, max: Duration) -> Duration {
167 let millis = initial.as_millis() as u64 * 2u64.saturating_pow(attempt - 1);
168 Duration::from_millis(millis).min(max)
169}
170
171#[cfg(test)]
172mod tests {
173 use super::{backoff, is_retryable, Fetcher, FetcherConfig};
174 use rdap_types::error::RdapError;
175 use rdap_security::{SsrfConfig, SsrfGuard};
176 use std::time::Duration;
177
178 #[test]
179 fn backoff_grows_exponentially() {
180 let base = Duration::from_millis(500);
181 let cap = Duration::from_secs(8);
182 assert_eq!(backoff(1, base, cap), Duration::from_millis(500));
183 assert_eq!(backoff(2, base, cap), Duration::from_millis(1000));
184 assert_eq!(backoff(3, base, cap), Duration::from_millis(2000));
185 assert_eq!(backoff(4, base, cap), Duration::from_millis(4000));
186 assert_eq!(backoff(5, base, cap), Duration::from_millis(8000));
187 assert_eq!(backoff(6, base, cap), Duration::from_secs(8));
188 }
189
190 #[test]
191 fn backoff_saturates_on_very_large_attempt() {
192 let base = Duration::from_millis(1);
193 let cap = Duration::from_secs(30);
194 let result = backoff(64, base, cap);
195 assert_eq!(result, cap);
196 }
197
198 #[test]
199 fn retryable_http_statuses() {
200 for status in [429u16, 500, 502, 503, 504] {
201 let err = RdapError::HttpStatus {
202 status,
203 url: "https://example.com/".to_string(),
204 };
205 assert!(is_retryable(&err));
206 }
207 }
208
209 #[test]
210 fn non_retryable_http_statuses() {
211 for status in [400u16, 401, 403, 404, 422] {
212 let err = RdapError::HttpStatus {
213 status,
214 url: "https://example.com/".to_string(),
215 };
216 assert!(!is_retryable(&err));
217 }
218 }
219
220 #[test]
221 fn default_config_values() {
222 let cfg = FetcherConfig::default();
223 assert_eq!(cfg.timeout, Duration::from_secs(10));
224 assert_eq!(cfg.max_attempts, 3);
225 assert!(cfg.user_agent.starts_with("rdapify/"));
226 }
227
228 #[tokio::test]
229 async fn fetch_rejects_ssrf_before_network() {
230 let ssrf = SsrfGuard::new();
231 let fetcher = Fetcher::new(ssrf).unwrap();
232 let err = fetcher.fetch("https://192.168.1.1/rdap").await.unwrap_err();
233 assert!(matches!(err, RdapError::SsrfBlocked { .. }));
234 }
235
236 #[tokio::test]
237 async fn fetch_rejects_http_scheme() {
238 let ssrf = SsrfGuard::new();
239 let fetcher = Fetcher::new(ssrf).unwrap();
240 let err = fetcher.fetch("http://example.com/rdap").await.unwrap_err();
241 assert!(matches!(err, RdapError::InsecureScheme { .. }));
242 }
243
244 fn disabled_ssrf_fetcher() -> Fetcher {
245 let ssrf = SsrfGuard::with_config(SsrfConfig {
246 enabled: false,
247 ..Default::default()
248 });
249 Fetcher::with_config(
250 ssrf,
251 FetcherConfig {
252 max_attempts: 1,
253 ..Default::default()
254 },
255 )
256 .unwrap()
257 }
258
259 #[tokio::test]
260 async fn fetch_returns_parsed_json_on_200() {
261 let mut server = mockito::Server::new_async().await;
262 let mock = server
263 .mock("GET", "/rdap/domain")
264 .with_status(200)
265 .with_header("content-type", "application/rdap+json")
266 .with_body(r#"{"objectClassName":"domain","ldhName":"EXAMPLE.COM"}"#)
267 .create_async()
268 .await;
269
270 let url = format!("{}/rdap/domain", server.url());
271 let result = disabled_ssrf_fetcher().fetch(&url).await.unwrap();
272 assert_eq!(result["ldhName"], "EXAMPLE.COM");
273 mock.assert_async().await;
274 }
275
276 #[tokio::test]
277 async fn fetch_returns_http_status_error_on_404() {
278 let mut server = mockito::Server::new_async().await;
279 let mock = server
280 .mock("GET", "/rdap/missing")
281 .with_status(404)
282 .with_body("{}")
283 .create_async()
284 .await;
285
286 let url = format!("{}/rdap/missing", server.url());
287 let err = disabled_ssrf_fetcher().fetch(&url).await.unwrap_err();
288 assert!(matches!(err, RdapError::HttpStatus { status: 404, .. }));
289 mock.assert_async().await;
290 }
291}