llm_registry_api/
middleware.rs1use axum::http::{HeaderValue, Method, Request};
7use tower_http::{
8 cors::{Any, CorsLayer},
9 request_id::{MakeRequestId, RequestId},
10 trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer},
11 LatencyUnit,
12};
13use tracing::Level;
14use uuid::Uuid;
15
16#[derive(Clone, Default)]
18pub struct UuidRequestIdGenerator;
19
20impl MakeRequestId for UuidRequestIdGenerator {
21 fn make_request_id<B>(&mut self, _request: &Request<B>) -> Option<RequestId> {
22 let request_id = Uuid::new_v4().to_string();
23 Some(RequestId::new(
24 HeaderValue::from_str(&request_id).unwrap(),
25 ))
26 }
27}
28
29pub fn trace_layer() -> TraceLayer<tower_http::classify::SharedClassifier<tower_http::classify::ServerErrorsAsFailures>> {
31 TraceLayer::new_for_http()
32 .make_span_with(
33 DefaultMakeSpan::new()
34 .include_headers(true)
35 .level(Level::INFO),
36 )
37 .on_response(
38 DefaultOnResponse::new()
39 .include_headers(true)
40 .latency_unit(LatencyUnit::Millis)
41 .level(Level::INFO),
42 )
43}
44
45pub fn cors_layer() -> CorsLayer {
47 CorsLayer::new()
48 .allow_origin(Any)
51 .allow_methods([
53 Method::GET,
54 Method::POST,
55 Method::PUT,
56 Method::PATCH,
57 Method::DELETE,
58 Method::OPTIONS,
59 ])
60 .allow_headers(Any)
62 .expose_headers([
64 axum::http::header::CONTENT_TYPE,
65 axum::http::header::HeaderName::from_static("x-request-id"),
66 ])
67 .allow_credentials(false)
69}
70
71#[derive(Debug, Clone)]
73pub struct CorsConfig {
74 pub allowed_origins: Vec<String>,
76
77 pub allow_credentials: bool,
79
80 pub max_age_seconds: Option<u64>,
82}
83
84impl Default for CorsConfig {
85 fn default() -> Self {
86 Self {
87 allowed_origins: vec![],
88 allow_credentials: false,
89 max_age_seconds: Some(3600),
90 }
91 }
92}
93
94impl CorsConfig {
95 pub fn into_layer(self) -> CorsLayer {
97 let mut layer = CorsLayer::new()
98 .allow_methods([
99 Method::GET,
100 Method::POST,
101 Method::PUT,
102 Method::PATCH,
103 Method::DELETE,
104 Method::OPTIONS,
105 ])
106 .allow_headers(Any)
107 .expose_headers([
108 axum::http::header::CONTENT_TYPE,
109 axum::http::header::HeaderName::from_static("x-request-id"),
110 ])
111 .allow_credentials(self.allow_credentials);
112
113 if self.allowed_origins.is_empty() {
115 layer = layer.allow_origin(Any);
116 } else {
117 let origins: Vec<HeaderValue> = self
119 .allowed_origins
120 .iter()
121 .filter_map(|o| o.parse().ok())
122 .collect();
123 layer = layer.allow_origin(origins);
124 }
125
126 if let Some(max_age) = self.max_age_seconds {
128 layer = layer.max_age(std::time::Duration::from_secs(max_age));
129 }
130
131 layer
132 }
133}
134
135#[derive(Debug, Clone)]
137pub struct MiddlewareConfig {
138 pub cors: CorsConfig,
140
141 pub enable_compression: bool,
143
144 pub enable_tracing: bool,
146
147 pub request_timeout_seconds: Option<u64>,
149}
150
151impl Default for MiddlewareConfig {
152 fn default() -> Self {
153 Self {
154 cors: CorsConfig::default(),
155 enable_compression: true,
156 enable_tracing: true,
157 request_timeout_seconds: Some(30),
158 }
159 }
160}
161
162impl MiddlewareConfig {
163 pub fn new() -> Self {
165 Self::default()
166 }
167
168 pub fn with_cors(mut self, cors: CorsConfig) -> Self {
170 self.cors = cors;
171 self
172 }
173
174 pub fn with_compression(mut self, enable: bool) -> Self {
176 self.enable_compression = enable;
177 self
178 }
179
180 pub fn with_tracing(mut self, enable: bool) -> Self {
182 self.enable_tracing = enable;
183 self
184 }
185
186 pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
188 self.request_timeout_seconds = Some(timeout_seconds);
189 self
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn test_uuid_request_id_generator() {
199 let mut generator = UuidRequestIdGenerator::default();
200 let request = Request::new(());
201
202 let request_id = generator.make_request_id(&request);
203 assert!(request_id.is_some());
204
205 }
207
208 #[test]
209 fn test_cors_config_default() {
210 let config = CorsConfig::default();
211 assert!(config.allowed_origins.is_empty());
212 assert!(!config.allow_credentials);
213 assert_eq!(config.max_age_seconds, Some(3600));
214 }
215
216 #[test]
217 fn test_middleware_config_default() {
218 let config = MiddlewareConfig::default();
219 assert!(config.enable_compression);
220 assert!(config.enable_tracing);
221 assert_eq!(config.request_timeout_seconds, Some(30));
222 }
223
224 #[test]
225 fn test_middleware_config_builder() {
226 let config = MiddlewareConfig::new()
227 .with_compression(false)
228 .with_tracing(false)
229 .with_timeout(60);
230
231 assert!(!config.enable_compression);
232 assert!(!config.enable_tracing);
233 assert_eq!(config.request_timeout_seconds, Some(60));
234 }
235}