Skip to main content

raps_kernel/
http.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2025 Dmytro Yemelianov
3
4//! HTTP client utilities
5//!
6//! Provides retry logic, timeouts, and HTTP client configuration.
7
8use anyhow::{Context, Result};
9use reqwest::Client;
10use std::time::Duration;
11use tokio::time::sleep;
12use url::Url;
13
14/// Allowed domains for custom API calls (APS domains only)
15pub const ALLOWED_DOMAINS: &[&str] = &[
16    "developer.api.autodesk.com",
17    "api.userprofile.autodesk.com",
18    "acc.autodesk.com",
19    "developer.autodesk.com",
20    "b360dm.autodesk.com",
21    "cdn.derivative.autodesk.io",
22];
23
24/// Check if a URL is allowed (belongs to an APS domain)
25///
26/// Returns true if the URL's host matches one of the allowed domains.
27/// Used for custom API calls to prevent credential leakage to external URLs.
28pub fn is_allowed_url(url: &str) -> bool {
29    match Url::parse(url) {
30        Ok(parsed) => {
31            if let Some(host) = parsed.host_str() {
32                // Check if host matches any allowed domain
33                ALLOWED_DOMAINS.iter().any(|domain| {
34                    host == *domain
35                        || (host.len() > domain.len()
36                            && host.ends_with(domain)
37                            && host.as_bytes()[host.len() - domain.len() - 1] == b'.')
38                })
39            } else {
40                false
41            }
42        }
43        Err(_) => false,
44    }
45}
46
47/// HTTP client configuration
48#[derive(Debug, Clone)]
49pub struct HttpClientConfig {
50    /// Maximum number of retries
51    pub max_retries: u32,
52    /// Maximum wait time between retries (seconds)
53    pub max_wait: u64,
54    /// Base delay for exponential backoff (seconds)
55    pub base_delay: u64,
56    /// Request timeout (seconds)
57    pub timeout: u64,
58    /// Connect timeout (seconds)
59    pub connect_timeout: u64,
60}
61
62impl Default for HttpClientConfig {
63    fn default() -> Self {
64        Self {
65            max_retries: 3,
66            max_wait: 60,
67            base_delay: 1,
68            timeout: 120,
69            connect_timeout: 30,
70        }
71    }
72}
73
74impl HttpClientConfig {
75    /// Create HTTP client with configured timeouts
76    pub fn create_client(&self) -> Result<Client> {
77        Client::builder()
78            .timeout(Duration::from_secs(self.timeout))
79            .connect_timeout(Duration::from_secs(self.connect_timeout))
80            .build()
81            .context("Failed to create HTTP client")
82    }
83
84    /// Create HTTP client config from CLI flags and environment variables
85    /// Precedence: CLI flag > environment variable > default
86    pub fn from_cli_and_env(timeout_flag: Option<u64>) -> Self {
87        let timeout = timeout_flag
88            .or_else(|| {
89                std::env::var("RAPS_TIMEOUT")
90                    .ok()
91                    .and_then(|v| v.parse().ok())
92            })
93            .unwrap_or(120);
94
95        Self {
96            timeout,
97            ..Self::default()
98        }
99    }
100}
101
102/// Check if an HTTP status code is retryable (rate limit or server error)
103pub fn is_retryable_status(status: u16) -> bool {
104    matches!(status, 408 | 429 | 500 | 502 | 503 | 504)
105}
106
107/// Calculate retry delay from response headers or exponential backoff
108///
109/// Checks the `Retry-After` header first (seconds value), then falls back
110/// to exponential backoff with jitter.
111pub fn retry_delay_from_response(
112    response: &reqwest::Response,
113    attempt: u32,
114    config: &HttpClientConfig,
115) -> Duration {
116    if let Some(retry_after) = response.headers().get("retry-after")
117        && let Ok(secs) = retry_after.to_str().unwrap_or("").parse::<u64>()
118    {
119        return Duration::from_secs(secs.min(config.max_wait));
120    }
121    calculate_delay(attempt + 1, config.base_delay, config.max_wait)
122}
123
124/// Send HTTP request with automatic retry on 429/5xx and network errors
125///
126/// Inspects the HTTP response status code and retries on retryable status codes
127/// (408, 429, 5xx). Also respects the `Retry-After` header for rate limiting.
128///
129/// The closure should return a `reqwest::RequestBuilder` (not a future),
130/// which will be rebuilt on each retry attempt.
131pub async fn send_with_retry<F>(
132    config: &HttpClientConfig,
133    build_request: F,
134) -> Result<reqwest::Response>
135where
136    F: Fn() -> reqwest::RequestBuilder,
137{
138    let mut attempt = 0;
139    let mut total_network_time = std::time::Duration::ZERO;
140    loop {
141        let start = std::time::Instant::now();
142        match build_request().send().await {
143            Ok(response) => {
144                let elapsed = start.elapsed();
145                total_network_time += elapsed;
146                let status = response.status().as_u16();
147                tracing::debug!(
148                    http.status = status,
149                    url = %response.url(),
150                    elapsed_ms = elapsed.as_millis() as u64,
151                    "HTTP response"
152                );
153                if is_retryable_status(status) && attempt < config.max_retries {
154                    let delay = retry_delay_from_response(&response, attempt, config);
155                    attempt += 1;
156                    crate::profiler::record_http_retry();
157                    tracing::warn!(
158                        http.status = status,
159                        attempt,
160                        max_retries = config.max_retries,
161                        delay_secs = delay.as_secs_f64(),
162                        "Retryable HTTP status, retrying"
163                    );
164                    sleep(delay).await;
165                    continue;
166                }
167                crate::profiler::record_http_request(total_network_time);
168                crate::api_health::record_latency(total_network_time);
169                return Ok(response);
170            }
171            Err(err) => {
172                total_network_time += start.elapsed();
173                let retriable = err.is_timeout() || err.is_connect() || err.is_request();
174                if !retriable || attempt >= config.max_retries {
175                    crate::profiler::record_http_request(total_network_time);
176                    crate::api_health::record_failure();
177                    tracing::error!(error = %err, attempt, "HTTP request failed");
178                    return Err(err).context("HTTP request failed");
179                }
180                attempt += 1;
181                crate::profiler::record_http_retry();
182                let delay = calculate_delay(attempt, config.base_delay, config.max_wait);
183                tracing::warn!(
184                    error = %err,
185                    attempt,
186                    max_retries = config.max_retries,
187                    delay_secs = delay.as_secs_f64(),
188                    "Network error, retrying"
189                );
190                sleep(delay).await;
191            }
192        }
193    }
194}
195
196/// Calculate delay with exponential backoff and jitter
197fn calculate_delay(attempt: u32, base_delay: u64, max_wait: u64) -> Duration {
198    use rand::Rng;
199
200    // Exponential backoff: base_delay * 2^attempt (saturating to avoid overflow)
201    let exponential_delay =
202        base_delay.saturating_mul(1_u64.checked_shl(attempt).unwrap_or(u64::MAX));
203
204    // Cap at max_wait
205    let capped_delay = exponential_delay.min(max_wait);
206
207    // Add jitter (random 0-25% of delay)
208    let mut rng = rand::thread_rng();
209    let jitter = if capped_delay > 0 {
210        rng.gen_range(0..=(capped_delay / 4))
211    } else {
212        0
213    };
214
215    Duration::from_secs(capped_delay.saturating_add(jitter))
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    #[test]
223    fn test_http_config_default() {
224        let config = HttpClientConfig::default();
225        assert_eq!(config.max_retries, 3);
226        assert_eq!(config.max_wait, 60);
227        assert_eq!(config.base_delay, 1);
228        assert_eq!(config.timeout, 120);
229        assert_eq!(config.connect_timeout, 30);
230    }
231
232    #[test]
233    fn test_http_config_create_client() {
234        let config = HttpClientConfig::default();
235        let client = config.create_client();
236        assert!(client.is_ok());
237    }
238
239    #[test]
240    fn test_http_config_from_cli_flag() {
241        let config = HttpClientConfig::from_cli_and_env(Some(60));
242        assert_eq!(config.timeout, 60);
243        // Other values should be default
244        assert_eq!(config.max_retries, 3);
245    }
246
247    #[test]
248    fn test_http_config_from_env() {
249        // SAFETY: Test runs with --test-threads=1 or in isolation
250        unsafe {
251            std::env::set_var("RAPS_TIMEOUT", "90");
252        }
253        let config = HttpClientConfig::from_cli_and_env(None);
254        assert_eq!(config.timeout, 90);
255        unsafe {
256            std::env::remove_var("RAPS_TIMEOUT");
257        }
258    }
259
260    #[test]
261    fn test_http_config_cli_overrides_env() {
262        // SAFETY: Test runs with --test-threads=1 or in isolation
263        unsafe {
264            std::env::set_var("RAPS_TIMEOUT", "90");
265        }
266        let config = HttpClientConfig::from_cli_and_env(Some(45));
267        assert_eq!(config.timeout, 45);
268        unsafe {
269            std::env::remove_var("RAPS_TIMEOUT");
270        }
271    }
272
273    #[test]
274    fn test_http_config_invalid_env() {
275        // SAFETY: Test runs with --test-threads=1 or in isolation
276        unsafe {
277            std::env::set_var("RAPS_TIMEOUT", "not_a_number");
278        }
279        let config = HttpClientConfig::from_cli_and_env(None);
280        assert_eq!(config.timeout, 120); // Falls back to default
281        unsafe {
282            std::env::remove_var("RAPS_TIMEOUT");
283        }
284    }
285
286    #[test]
287    fn test_calculate_delay_exponential() {
288        // First retry: base_delay * 2^1 = 1 * 2 = 2 seconds
289        let delay1 = calculate_delay(1, 1, 60);
290        assert!(delay1.as_secs() >= 2);
291        assert!(delay1.as_secs() <= 3); // 2 + up to 25% jitter
292
293        // Second retry: base_delay * 2^2 = 1 * 4 = 4 seconds
294        let delay2 = calculate_delay(2, 1, 60);
295        assert!(delay2.as_secs() >= 4);
296        assert!(delay2.as_secs() <= 5);
297    }
298
299    #[test]
300    fn test_calculate_delay_max_wait() {
301        // Very high attempt should be capped at max_wait
302        let delay = calculate_delay(10, 1, 60);
303        assert!(delay.as_secs() <= 75); // 60 + up to 25% jitter
304    }
305
306    #[test]
307    fn test_calculate_delay_custom_base() {
308        // With base_delay of 2: 2 * 2^1 = 4 seconds
309        let delay = calculate_delay(1, 2, 60);
310        assert!(delay.as_secs() >= 4);
311        assert!(delay.as_secs() <= 5);
312    }
313
314    #[test]
315    fn test_is_allowed_url_developer_api() {
316        assert!(is_allowed_url(
317            "https://developer.api.autodesk.com/oss/v2/buckets"
318        ));
319    }
320
321    #[test]
322    fn test_is_allowed_url_userprofile() {
323        assert!(is_allowed_url(
324            "https://api.userprofile.autodesk.com/userinfo"
325        ));
326    }
327
328    #[test]
329    fn test_is_allowed_url_acc() {
330        assert!(is_allowed_url("https://acc.autodesk.com/api/projects"));
331    }
332
333    #[test]
334    fn test_is_allowed_url_with_path_and_query() {
335        assert!(is_allowed_url(
336            "https://developer.api.autodesk.com/oss/v2/buckets?limit=10&region=US"
337        ));
338    }
339
340    #[test]
341    fn test_is_allowed_url_external_rejected() {
342        assert!(!is_allowed_url("https://evil.com/steal-token"));
343    }
344
345    #[test]
346    fn test_is_allowed_url_localhost_rejected() {
347        assert!(!is_allowed_url("http://localhost:8080/api"));
348    }
349
350    #[test]
351    fn test_is_allowed_url_internal_ip_rejected() {
352        assert!(!is_allowed_url("http://192.168.1.1/api"));
353    }
354
355    #[test]
356    fn test_is_allowed_url_similar_domain_rejected() {
357        // Should not allow fake domains that look similar
358        assert!(!is_allowed_url(
359            "https://developer.api.autodesk.com.evil.com/api"
360        ));
361    }
362
363    #[test]
364    fn test_is_allowed_url_invalid_url() {
365        assert!(!is_allowed_url("not-a-valid-url"));
366    }
367
368    #[test]
369    fn test_is_allowed_url_empty() {
370        assert!(!is_allowed_url(""));
371    }
372
373    #[test]
374    fn test_is_allowed_url_subdomain() {
375        // Subdomains of allowed domains should be allowed
376        assert!(is_allowed_url("https://us.developer.api.autodesk.com/api"));
377    }
378
379    #[test]
380    fn test_is_retryable_status_429() {
381        assert!(is_retryable_status(429));
382    }
383
384    #[test]
385    fn test_is_retryable_status_408() {
386        assert!(is_retryable_status(408));
387    }
388
389    #[test]
390    fn test_is_retryable_status_5xx() {
391        assert!(is_retryable_status(500));
392        assert!(is_retryable_status(502));
393        assert!(is_retryable_status(503));
394        assert!(is_retryable_status(504));
395    }
396
397    #[test]
398    fn test_is_retryable_status_not_retryable() {
399        assert!(!is_retryable_status(200));
400        assert!(!is_retryable_status(201));
401        assert!(!is_retryable_status(400));
402        assert!(!is_retryable_status(401));
403        assert!(!is_retryable_status(403));
404        assert!(!is_retryable_status(404));
405        assert!(!is_retryable_status(409));
406        assert!(!is_retryable_status(422));
407    }
408
409    /// Helper: bind a TCP listener on a random port and return (addr, listener)
410    fn bind_test_server() -> (std::net::SocketAddr, std::net::TcpListener) {
411        let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
412        let addr = listener.local_addr().unwrap();
413        (addr, listener)
414    }
415
416    fn accept_and_respond(listener: &std::net::TcpListener, raw_response: &str) {
417        use std::io::{Read, Write};
418        let (mut stream, _) = listener.accept().unwrap();
419        // Drain the request
420        let mut buf = [0u8; 4096];
421        let _ = stream.read(&mut buf);
422        stream.write_all(raw_response.as_bytes()).unwrap();
423        stream.flush().unwrap();
424    }
425
426    #[tokio::test]
427    async fn test_retry_delay_from_response_with_retry_after_header() {
428        let (addr, listener) = bind_test_server();
429        let handle = std::thread::spawn(move || {
430            accept_and_respond(
431                &listener,
432                "HTTP/1.1 429 Too Many Requests\r\nRetry-After: 5\r\nContent-Length: 0\r\n\r\n",
433            );
434        });
435
436        let client = reqwest::Client::new();
437        let response = client.get(format!("http://{}", addr)).send().await.unwrap();
438        let config = HttpClientConfig::default();
439        let delay = retry_delay_from_response(&response, 0, &config);
440        assert_eq!(delay, Duration::from_secs(5));
441        handle.join().unwrap();
442    }
443
444    #[tokio::test]
445    async fn test_retry_delay_from_response_retry_after_capped_at_max_wait() {
446        let (addr, listener) = bind_test_server();
447        let handle = std::thread::spawn(move || {
448            accept_and_respond(
449                &listener,
450                "HTTP/1.1 429 Too Many Requests\r\nRetry-After: 300\r\nContent-Length: 0\r\n\r\n",
451            );
452        });
453
454        let client = reqwest::Client::new();
455        let response = client.get(format!("http://{}", addr)).send().await.unwrap();
456        let config = HttpClientConfig {
457            max_wait: 60,
458            ..Default::default()
459        };
460        let delay = retry_delay_from_response(&response, 0, &config);
461        assert_eq!(delay, Duration::from_secs(60));
462        handle.join().unwrap();
463    }
464
465    #[tokio::test]
466    async fn test_retry_delay_from_response_fallback_to_exponential() {
467        let (addr, listener) = bind_test_server();
468        let handle = std::thread::spawn(move || {
469            accept_and_respond(
470                &listener,
471                "HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\n\r\n",
472            );
473        });
474
475        let client = reqwest::Client::new();
476        let response = client.get(format!("http://{}", addr)).send().await.unwrap();
477        let config = HttpClientConfig::default();
478        // attempt=0 -> calculate_delay(1, 1, 60) -> 1*2^1 = 2s + jitter
479        let delay = retry_delay_from_response(&response, 0, &config);
480        assert!(delay.as_secs() >= 2);
481        assert!(delay.as_secs() <= 3);
482        handle.join().unwrap();
483    }
484
485    #[tokio::test]
486    async fn test_send_with_retry_success() {
487        let (addr, listener) = bind_test_server();
488        let handle = std::thread::spawn(move || {
489            accept_and_respond(&listener, "HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK");
490        });
491
492        let config = HttpClientConfig::default();
493        let client = reqwest::Client::new();
494        let url = format!("http://{}", addr);
495
496        let response = send_with_retry(&config, || client.get(&url)).await;
497        assert!(response.is_ok());
498        assert_eq!(response.unwrap().status().as_u16(), 200);
499        handle.join().unwrap();
500    }
501}