llm_registry_api/
middleware.rs

1//! API middleware
2//!
3//! This module provides middleware layers for request processing including
4//! logging, CORS, compression, and request ID generation.
5
6use axum::http::{HeaderValue, Method, Request};
7use tower_http::{
8    cors::{Any, CorsLayer},
9    request_id::{MakeRequestId, RequestId},
10    trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer},
11    LatencyUnit,
12};
13use tracing::Level;
14use uuid::Uuid;
15
16/// Request ID generator using UUIDs
17#[derive(Clone, Default)]
18pub struct UuidRequestIdGenerator;
19
20impl MakeRequestId for UuidRequestIdGenerator {
21    fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
22        let request_id = Uuid::new_v4().to_string();
23        Some(RequestId::new(
24            HeaderValue::from_str(&request_id).unwrap(),
25        ))
26    }
27}
28
29/// Build trace layer
30pub fn trace_layer() -> TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>> {
31    TraceLayer::new_for_http()
32        .make_span_with(
33            DefaultMakeSpan::new()
34                .include_headers(true)
35                .level(Level::INFO),
36        )
37        .on_response(
38            DefaultOnResponse::new()
39                .include_headers(true)
40                .latency_unit(LatencyUnit::Millis)
41                .level(Level::INFO),
42        )
43}
44
45/// Build CORS layer
46pub fn cors_layer() -> CorsLayer {
47    CorsLayer::new()
48        // Allow requests from any origin
49        // In production, configure this based on environment
50        .allow_origin(Any)
51        // Allow common HTTP methods
52        .allow_methods([
53            Method::GET,
54            Method::POST,
55            Method::PUT,
56            Method::PATCH,
57            Method::DELETE,
58            Method::OPTIONS,
59        ])
60        // Allow common headers
61        .allow_headers(Any)
62        // Expose request ID header
63        .expose_headers([
64            axum::http::header::CONTENT_TYPE,
65            axum::http::header::HeaderName::from_static("x-request-id"),
66        ])
67        // Allow credentials
68        .allow_credentials(false)
69}
70
71/// CORS configuration options
72#[derive(Debug, Clone)]
73pub struct CorsConfig {
74    /// Allowed origins (empty means any)
75    pub allowed_origins: Vec<String>,
76
77    /// Whether to allow credentials
78    pub allow_credentials: bool,
79
80    /// Max age for preflight cache
81    pub max_age_seconds: Option<u64>,
82}
83
84impl Default for CorsConfig {
85    fn default() -> Self {
86        Self {
87            allowed_origins: vec![],
88            allow_credentials: false,
89            max_age_seconds: Some(3600),
90        }
91    }
92}
93
94impl CorsConfig {
95    /// Build CORS layer from config
96    pub fn into_layer(self) -> CorsLayer {
97        let mut layer = CorsLayer::new()
98            .allow_methods([
99                Method::GET,
100                Method::POST,
101                Method::PUT,
102                Method::PATCH,
103                Method::DELETE,
104                Method::OPTIONS,
105            ])
106            .allow_headers(Any)
107            .expose_headers([
108                axum::http::header::CONTENT_TYPE,
109                axum::http::header::HeaderName::from_static("x-request-id"),
110            ])
111            .allow_credentials(self.allow_credentials);
112
113        // Configure origins
114        if self.allowed_origins.is_empty() {
115            layer = layer.allow_origin(Any);
116        } else {
117            // Parse origins
118            let origins: Vec<HeaderValue> = self
119                .allowed_origins
120                .iter()
121                .filter_map(|o| o.parse().ok())
122                .collect();
123            layer = layer.allow_origin(origins);
124        }
125
126        // Configure max age
127        if let Some(max_age) = self.max_age_seconds {
128            layer = layer.max_age(std::time::Duration::from_secs(max_age));
129        }
130
131        layer
132    }
133}
134
135/// Middleware configuration
136#[derive(Debug, Clone)]
137pub struct MiddlewareConfig {
138    /// CORS configuration
139    pub cors: CorsConfig,
140
141    /// Enable compression
142    pub enable_compression: bool,
143
144    /// Enable request tracing
145    pub enable_tracing: bool,
146
147    /// Request timeout in seconds
148    pub request_timeout_seconds: Option<u64>,
149}
150
151impl Default for MiddlewareConfig {
152    fn default() -> Self {
153        Self {
154            cors: CorsConfig::default(),
155            enable_compression: true,
156            enable_tracing: true,
157            request_timeout_seconds: Some(30),
158        }
159    }
160}
161
162impl MiddlewareConfig {
163    /// Create a new middleware config
164    pub fn new() -> Self {
165        Self::default()
166    }
167
168    /// Set CORS config
169    pub fn with_cors(mut self, cors: CorsConfig) -> Self {
170        self.cors = cors;
171        self
172    }
173
174    /// Enable/disable compression
175    pub fn with_compression(mut self, enable: bool) -> Self {
176        self.enable_compression = enable;
177        self
178    }
179
180    /// Enable/disable tracing
181    pub fn with_tracing(mut self, enable: bool) -> Self {
182        self.enable_tracing = enable;
183        self
184    }
185
186    /// Set request timeout
187    pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
188        self.request_timeout_seconds = Some(timeout_seconds);
189        self
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_uuid_request_id_generator() {
199        let mut generator = UuidRequestIdGenerator::default();
200        let request = Request::new(());
201
202        let request_id = generator.make_request_id(&request);
203        assert!(request_id.is_some());
204
205        // RequestId is generated successfully (internal format verification not possible)
206    }
207
208    #[test]
209    fn test_cors_config_default() {
210        let config = CorsConfig::default();
211        assert!(config.allowed_origins.is_empty());
212        assert!(!config.allow_credentials);
213        assert_eq!(config.max_age_seconds, Some(3600));
214    }
215
216    #[test]
217    fn test_middleware_config_default() {
218        let config = MiddlewareConfig::default();
219        assert!(config.enable_compression);
220        assert!(config.enable_tracing);
221        assert_eq!(config.request_timeout_seconds, Some(30));
222    }
223
224    #[test]
225    fn test_middleware_config_builder() {
226        let config = MiddlewareConfig::new()
227            .with_compression(false)
228            .with_tracing(false)
229            .with_timeout(60);
230
231        assert!(!config.enable_compression);
232        assert!(!config.enable_tracing);
233        assert_eq!(config.request_timeout_seconds, Some(60));
234    }
235}