Skip to main content

auth_framework/server/core/
common_http.rs

1//! Common HTTP Client Utilities
2//!
3//! This module provides shared HTTP client functionality to eliminate
4//! duplication across server modules.
5
6use crate::errors::{AuthError, Result};
7use crate::server::core::common_config::{EndpointConfig, RetryConfig};
8use reqwest::{Client, Method, RequestBuilder, Response};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::Duration;
12use tokio::time::{sleep, timeout};
13
14/// HTTP client wrapper with common functionality
15#[derive(Clone, Debug)]
16pub struct HttpClient {
17    client: Client,
18    config: EndpointConfig,
19    retry_config: RetryConfig,
20}
21
22impl HttpClient {
23    /// Create new HTTP client
24    pub fn new(config: EndpointConfig) -> Result<Self> {
25        let mut client_builder = Client::builder()
26            .timeout(Duration::from_secs(
27                config.timeout.connect_timeout.as_secs(),
28            ))
29            .connect_timeout(config.timeout.connect_timeout)
30            .danger_accept_invalid_certs(config.security.accept_invalid_certs);
31
32        // Add default headers
33        let mut headers = reqwest::header::HeaderMap::new();
34        for (key, value) in &config.headers {
35            let header_name = reqwest::header::HeaderName::from_bytes(key.as_bytes())
36                .map_err(|e| AuthError::config(format!("Invalid header name: {}", e)))?;
37            let header_value = reqwest::header::HeaderValue::from_str(value)
38                .map_err(|e| AuthError::config(format!("Invalid header value: {}", e)))?;
39            headers.insert(header_name, header_value);
40        }
41
42        if !headers.contains_key("user-agent") {
43            headers.insert(
44                reqwest::header::USER_AGENT,
45                reqwest::header::HeaderValue::from_static("auth-framework/0.3.0"),
46            );
47        }
48
49        client_builder = client_builder.default_headers(headers);
50
51        let client = client_builder
52            .build()
53            .map_err(|e| AuthError::config(format!("Failed to create HTTP client: {}", e)))?;
54
55        Ok(Self {
56            client,
57            config,
58            retry_config: RetryConfig::default(),
59        })
60    }
61
62    /// Set retry configuration
63    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
64        self.retry_config = retry_config;
65        self
66    }
67
68    /// Execute GET request with retries
69    pub async fn get(&self, path: &str) -> Result<Response> {
70        let url = self.build_url(path)?;
71        self.execute_with_retry(Method::GET, &url, None::<&()>)
72            .await
73    }
74
75    /// Create POST request builder (reqwest-compatible)
76    pub fn post(&self, url: &str) -> RequestBuilder {
77        self.client.post(url)
78    }
79
80    /// Execute POST request with JSON body
81    pub async fn post_json<T>(&self, path: &str, body: &T) -> Result<Response>
82    where
83        T: Serialize,
84    {
85        let url = self.build_url(path)?;
86        self.execute_with_retry(Method::POST, &url, Some(body))
87            .await
88    }
89
90    /// Execute PUT request with JSON body
91    pub async fn put_json<T>(&self, path: &str, body: &T) -> Result<Response>
92    where
93        T: Serialize,
94    {
95        let url = self.build_url(path)?;
96        self.execute_with_retry(Method::PUT, &url, Some(body)).await
97    }
98
99    /// Execute DELETE request
100    pub async fn delete(&self, path: &str) -> Result<Response> {
101        let url = self.build_url(path)?;
102        self.execute_with_retry(Method::DELETE, &url, None::<&()>)
103            .await
104    }
105
106    /// Execute form-encoded POST request
107    pub async fn post_form(
108        &self,
109        path: &str,
110        form_data: &HashMap<String, String>,
111    ) -> Result<Response> {
112        let url = self.build_url(path)?;
113
114        let mut request = self.client.request(Method::POST, &url);
115        request = request.form(form_data);
116
117        self.execute_request_with_retry(request).await
118    }
119
120    /// Execute request with custom headers
121    pub async fn request_with_headers<T>(
122        &self,
123        method: Method,
124        path: &str,
125        headers: HashMap<String, String>,
126        body: Option<&T>,
127    ) -> Result<Response>
128    where
129        T: Serialize,
130    {
131        let url = self.build_url(path)?;
132        let mut request = self.client.request(method, &url);
133
134        // Add custom headers
135        for (key, value) in headers {
136            request = request.header(key, value);
137        }
138
139        // Add body if provided
140        if let Some(body) = body {
141            request = request.json(body);
142        }
143
144        self.execute_request_with_retry(request).await
145    }
146
147    /// Build full URL from base and path
148    fn build_url(&self, path: &str) -> Result<String> {
149        let mut url = self.config.base_url.clone();
150
151        // Add API version if configured
152        if let Some(ref version) = self.config.api_version {
153            if !url.ends_with('/') {
154                url.push('/');
155            }
156            url.push_str(version);
157        }
158
159        // Add path
160        if !url.ends_with('/') && !path.starts_with('/') {
161            url.push('/');
162        }
163        url.push_str(path);
164
165        Ok(url)
166    }
167
168    /// Execute request with retry logic
169    async fn execute_with_retry<T>(
170        &self,
171        method: Method,
172        url: &str,
173        body: Option<&T>,
174    ) -> Result<Response>
175    where
176        T: Serialize,
177    {
178        let mut request = self.client.request(method, url);
179
180        if let Some(body) = body {
181            request = request.json(body);
182        }
183
184        self.execute_request_with_retry(request).await
185    }
186
187    /// Execute request with retry logic
188    async fn execute_request_with_retry(
189        &self,
190        request_builder: RequestBuilder,
191    ) -> Result<Response> {
192        let mut last_error = None;
193
194        for attempt in 0..=self.retry_config.max_attempts {
195            let request = request_builder
196                .try_clone()
197                .ok_or_else(|| AuthError::validation("Cannot clone request for retry"))?;
198
199            match timeout(self.config.timeout.read_timeout, request.send()).await {
200                Ok(Ok(response)) => {
201                    if response.status().is_success() || !self.is_retryable_error(&response) {
202                        return Ok(response);
203                    }
204                    last_error = Some(AuthError::validation(format!("HTTP {}", response.status())));
205                }
206                Ok(Err(e)) => {
207                    last_error = Some(AuthError::validation(format!("Request failed: {}", e)));
208                }
209                Err(_) => {
210                    last_error = Some(AuthError::validation("Request timeout"));
211                }
212            }
213
214            // Don't sleep after the last attempt
215            if attempt < self.retry_config.max_attempts {
216                let delay = self.calculate_retry_delay(attempt);
217                sleep(delay).await;
218            }
219        }
220
221        Err(last_error.unwrap_or_else(|| AuthError::validation("All retry attempts failed")))
222    }
223
224    /// Check if error is retryable
225    fn is_retryable_error(&self, response: &Response) -> bool {
226        match response.status().as_u16() {
227            // Retry on server errors and some client errors
228            500..=599 => true, // Server errors
229            429 => true,       // Rate limiting
230            408 => true,       // Request timeout
231            _ => false,
232        }
233    }
234
235    /// Calculate retry delay with exponential backoff and jitter
236    fn calculate_retry_delay(&self, attempt: u32) -> Duration {
237        let base_delay = self.retry_config.initial_delay.as_millis() as f64;
238        let backoff = self.retry_config.backoff_multiplier.powi(attempt as i32);
239        let delay_ms = (base_delay * backoff).min(self.retry_config.max_delay.as_millis() as f64);
240
241        // Add jitter
242        let jitter = delay_ms * self.retry_config.jitter_factor * (rand::random::<f64>() - 0.5);
243        let final_delay = (delay_ms + jitter).max(0.0) as u64;
244
245        Duration::from_millis(final_delay)
246    }
247}
248
249/// Common HTTP response handling utilities
250pub mod response {
251    use super::*;
252
253    /// Parse JSON response with error handling
254    pub async fn parse_json<T>(response: Response) -> Result<T>
255    where
256        T: for<'de> Deserialize<'de>,
257    {
258        if !response.status().is_success() {
259            let status = response.status();
260            let body = response
261                .text()
262                .await
263                .unwrap_or_else(|_| "Failed to read error response body".to_string());
264
265            return Err(AuthError::validation(format!("HTTP {} - {}", status, body)));
266        }
267
268        response
269            .json::<T>()
270            .await
271            .map_err(|e| AuthError::validation(format!("Failed to parse JSON response: {}", e)))
272    }
273
274    /// Extract response body as text
275    pub async fn extract_text(response: Response) -> Result<String> {
276        if !response.status().is_success() {
277            let status = response.status();
278            let body = response
279                .text()
280                .await
281                .unwrap_or_else(|_| "Failed to read error response body".to_string());
282
283            return Err(AuthError::validation(format!("HTTP {} - {}", status, body)));
284        }
285
286        response
287            .text()
288            .await
289            .map_err(|e| AuthError::validation(format!("Failed to read response body: {}", e)))
290    }
291
292    /// Check if response indicates success
293    pub fn is_success_status(status_code: u16) -> bool {
294        (200..300).contains(&status_code)
295    }
296
297    /// Extract error details from response
298    pub async fn extract_error_details(response: Response) -> (u16, String) {
299        let status = response.status().as_u16();
300        let body = response
301            .text()
302            .await
303            .unwrap_or_else(|_| "Unable to read response body".to_string());
304        (status, body)
305    }
306}
307
308/// OAuth-specific HTTP client utilities
309pub mod oauth {
310    use super::*;
311
312    /// Execute OAuth token exchange request
313    pub async fn token_exchange(
314        client: &HttpClient,
315        token_endpoint: &str,
316        params: &HashMap<String, String>,
317    ) -> Result<serde_json::Value> {
318        // Use relative path from base_url or full URL
319        let path = if token_endpoint.starts_with("http") {
320            // Override base_url for this request
321            return execute_absolute_url_form_post(client, token_endpoint, params).await;
322        } else {
323            token_endpoint
324        };
325
326        let response = client.post_form(path, params).await?;
327        response::parse_json(response).await
328    }
329
330    /// Execute introspection request
331    pub async fn introspect_token(
332        client: &HttpClient,
333        introspect_endpoint: &str,
334        token: &str,
335        client_id: Option<&str>,
336    ) -> Result<serde_json::Value> {
337        let mut params = HashMap::new();
338        params.insert("token".to_string(), token.to_string());
339
340        if let Some(client_id) = client_id {
341            params.insert("client_id".to_string(), client_id.to_string());
342        }
343
344        let response = client.post_form(introspect_endpoint, &params).await?;
345        response::parse_json(response).await
346    }
347
348    /// Execute JWKS fetch
349    pub async fn fetch_jwks(client: &HttpClient, jwks_uri: &str) -> Result<serde_json::Value> {
350        let response = client.get(jwks_uri).await?;
351        response::parse_json(response).await
352    }
353
354    /// Execute OAuth discovery request
355    pub async fn discover_configuration(
356        _client: &HttpClient,
357        issuer: &str,
358    ) -> Result<serde_json::Value> {
359        let discovery_url = format!(
360            "{}/.well-known/openid_configuration",
361            issuer.trim_end_matches('/')
362        );
363
364        // Create temporary client for absolute URL
365        let temp_config = EndpointConfig::new(&discovery_url);
366        let temp_client = HttpClient::new(temp_config)?;
367
368        let response = temp_client.get("").await?;
369        response::parse_json(response).await
370    }
371
372    /// Execute form POST to absolute URL
373    async fn execute_absolute_url_form_post(
374        _client: &HttpClient,
375        url: &str,
376        params: &HashMap<String, String>,
377    ) -> Result<serde_json::Value> {
378        // Create client for specific URL
379        let temp_config = EndpointConfig::new(url);
380        let temp_client = HttpClient::new(temp_config)?;
381
382        let response = temp_client.post_form("", params).await?;
383        response::parse_json(response).await
384    }
385}
386
387/// Webhook and callback utilities
388pub mod webhooks {
389    use super::*;
390
391    /// Send webhook notification
392    pub async fn send_webhook<T>(
393        client: &HttpClient,
394        webhook_url: &str,
395        payload: &T,
396        signature_key: Option<&str>,
397    ) -> Result<()>
398    where
399        T: Serialize,
400    {
401        let mut headers = HashMap::new();
402        headers.insert("Content-Type".to_string(), "application/json".to_string());
403
404        // Add signature if key provided
405        if let Some(key) = signature_key {
406            let payload_json = serde_json::to_string(payload).map_err(|e| {
407                AuthError::validation(format!("Failed to serialize payload: {}", e))
408            })?;
409            let signature = calculate_webhook_signature(&payload_json, key)?;
410            headers.insert("X-Webhook-Signature".to_string(), signature);
411        }
412
413        let response = client
414            .request_with_headers(Method::POST, webhook_url, headers, Some(payload))
415            .await?;
416
417        if !response.status().is_success() {
418            return Err(AuthError::validation(format!(
419                "Webhook failed: {}",
420                response.status()
421            )));
422        }
423
424        Ok(())
425    }
426
427    /// Calculate HMAC-SHA256 signature for webhook payload.
428    ///
429    /// Returns a signature in the form `sha256=<hex>` suitable for the
430    /// `X-Webhook-Signature` header, compatible with GitHub-style webhook
431    /// verification.
432    fn calculate_webhook_signature(payload: &str, key: &str) -> Result<String> {
433        use hmac::{Hmac, Mac};
434        use sha2::Sha256;
435
436        let mut mac = Hmac::<Sha256>::new_from_slice(key.as_bytes())
437            .map_err(|e| AuthError::crypto(format!("Invalid HMAC key: {}", e)))?;
438        mac.update(payload.as_bytes());
439        let result = mac.finalize().into_bytes();
440
441        Ok(format!("sha256={}", hex::encode(result)))
442    }
443}