1use axum::{
10 extract::{ConnectInfo, Request, State},
11 http::{HeaderMap, StatusCode},
12 middleware::Next,
13 response::{IntoResponse, Response},
14 Json,
15};
16use llm_config_security::{
17 InputValidator, PolicyEnforcer, RateLimiter, SecurityContext, SecurityError,
18};
19use serde_json::json;
20use std::net::SocketAddr;
21use std::sync::Arc;
22
23#[derive(Clone)]
25pub struct SecurityState {
26 pub rate_limiter: Arc<RateLimiter>,
27 pub input_validator: Arc<InputValidator>,
28 pub policy_enforcer: Arc<PolicyEnforcer>,
29}
30
31impl SecurityState {
32 pub fn new() -> Self {
34 Self {
35 rate_limiter: Arc::new(RateLimiter::new(Default::default())),
36 input_validator: Arc::new(InputValidator::default()),
37 policy_enforcer: Arc::new(PolicyEnforcer::default()),
38 }
39 }
40
41 pub fn with_components(
43 rate_limiter: RateLimiter,
44 input_validator: InputValidator,
45 policy_enforcer: PolicyEnforcer,
46 ) -> Self {
47 Self {
48 rate_limiter: Arc::new(rate_limiter),
49 input_validator: Arc::new(input_validator),
50 policy_enforcer: Arc::new(policy_enforcer),
51 }
52 }
53}
54
55impl Default for SecurityState {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61pub async fn rate_limit_middleware(
65 State(security): State<SecurityState>,
66 ConnectInfo(addr): ConnectInfo<SocketAddr>,
67 headers: HeaderMap,
68 request: Request,
69 next: Next,
70) -> Result<Response, SecurityResponse> {
71 let ip = addr.ip();
72
73 let is_authenticated = headers.get("authorization").is_some();
75
76 security
78 .rate_limiter
79 .check_request(ip, is_authenticated)
80 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::TOO_MANY_REQUESTS))?;
81
82 Ok(next.run(request).await)
83}
84
85pub async fn input_validation_middleware(
89 State(security): State<SecurityState>,
90 request: Request,
91 next: Next,
92) -> Result<Response, SecurityResponse> {
93 let uri = request.uri();
94 let path = uri.path();
95 let query = uri.query().unwrap_or("");
96
97 security
99 .input_validator
100 .validate(path)
101 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::BAD_REQUEST))?;
102
103 if !query.is_empty() {
105 security
106 .input_validator
107 .validate(query)
108 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::BAD_REQUEST))?;
109 }
110
111 Ok(next.run(request).await)
112}
113
114pub async fn policy_enforcement_middleware(
118 State(security): State<SecurityState>,
119 ConnectInfo(addr): ConnectInfo<SocketAddr>,
120 headers: HeaderMap,
121 request: Request,
122 next: Next,
123) -> Result<Response, SecurityResponse> {
124 let ip = addr.ip();
125
126 security
128 .policy_enforcer
129 .check_ip(&ip.to_string())
130 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::FORBIDDEN))?;
131
132 let is_tls = headers
134 .get("x-forwarded-proto")
135 .and_then(|v| v.to_str().ok())
136 .map(|v| v == "https")
137 .unwrap_or(false);
138
139 security
140 .policy_enforcer
141 .check_tls(is_tls, "1.2")
142 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::UPGRADE_REQUIRED))?;
143
144 if let Some(origin) = headers.get("origin").and_then(|v| v.to_str().ok()) {
146 security
147 .policy_enforcer
148 .check_origin(origin)
149 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::FORBIDDEN))?;
150 }
151
152 if let Some(content_length) = headers
154 .get("content-length")
155 .and_then(|v| v.to_str().ok())
156 .and_then(|v| v.parse::<usize>().ok())
157 {
158 security
159 .policy_enforcer
160 .check_request_size(content_length)
161 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::PAYLOAD_TOO_LARGE))?;
162 }
163
164 let endpoint = request.uri().path();
166 security
167 .policy_enforcer
168 .check_endpoint(endpoint)
169 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::FORBIDDEN))?;
170
171 Ok(next.run(request).await)
172}
173
174pub async fn security_context_middleware(
178 ConnectInfo(addr): ConnectInfo<SocketAddr>,
179 headers: HeaderMap,
180 mut request: Request,
181 next: Next,
182) -> Response {
183 let ip = addr.ip();
184
185 let user_id = headers
187 .get("x-user-id")
188 .and_then(|v| v.to_str().ok())
189 .unwrap_or("anonymous")
190 .to_string();
191
192 let context = SecurityContext::new(user_id, ip.to_string());
194
195 request.extensions_mut().insert(context);
197
198 next.run(request).await
199}
200
201pub async fn comprehensive_security_middleware(
205 State(security): State<SecurityState>,
206 ConnectInfo(addr): ConnectInfo<SocketAddr>,
207 headers: HeaderMap,
208 mut request: Request,
209 next: Next,
210) -> Result<Response, SecurityResponse> {
211 let ip = addr.ip();
212 let is_authenticated = headers.get("authorization").is_some();
213
214 security
216 .rate_limiter
217 .check_request(ip, is_authenticated)
218 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::TOO_MANY_REQUESTS))?;
219
220 security
222 .policy_enforcer
223 .check_ip(&ip.to_string())
224 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::FORBIDDEN))?;
225
226 let is_tls = headers
228 .get("x-forwarded-proto")
229 .and_then(|v| v.to_str().ok())
230 .map(|v| v == "https")
231 .unwrap_or(false);
232
233 security
234 .policy_enforcer
235 .check_tls(is_tls, "1.2")
236 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::UPGRADE_REQUIRED))?;
237
238 let endpoint = request.uri().path();
240 security
241 .policy_enforcer
242 .check_endpoint(endpoint)
243 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::FORBIDDEN))?;
244
245 let uri = request.uri();
247 security
248 .input_validator
249 .validate(uri.path())
250 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::BAD_REQUEST))?;
251
252 if let Some(query) = uri.query() {
253 security
254 .input_validator
255 .validate(query)
256 .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::BAD_REQUEST))?;
257 }
258
259 let user_id = headers
261 .get("x-user-id")
262 .and_then(|v| v.to_str().ok())
263 .unwrap_or("anonymous")
264 .to_string();
265
266 let context = SecurityContext::new(user_id, ip.to_string());
267 request.extensions_mut().insert(context);
268
269 Ok(next.run(request).await)
270}
271
272pub struct SecurityResponse {
274 status: StatusCode,
275 message: String,
276}
277
278impl SecurityResponse {
279 pub fn new(status: StatusCode, message: String) -> Self {
280 Self { status, message }
281 }
282
283 pub fn from_security_error(error: SecurityError, status: StatusCode) -> Self {
284 Self {
285 status,
286 message: error.public_message(),
287 }
288 }
289}
290
291impl IntoResponse for SecurityResponse {
292 fn into_response(self) -> Response {
293 let body = json!({
294 "error": self.status.canonical_reason().unwrap_or("Security Error"),
295 "message": self.message,
296 });
297
298 (self.status, Json(body)).into_response()
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use llm_config_security::{RateLimitConfig, SecurityPolicy};
306 use std::net::{IpAddr, Ipv4Addr};
307
308 fn create_test_security_state() -> SecurityState {
309 let rate_limiter = RateLimiter::new(RateLimitConfig {
310 authenticated_rps: 100,
311 unauthenticated_rps: 10,
312 burst_size: 50,
313 window_seconds: 60,
314 ban_duration_seconds: 3600,
315 ban_threshold: 10,
316 });
317
318 let input_validator = InputValidator::default();
319 let policy_enforcer = PolicyEnforcer::new(SecurityPolicy::default());
320
321 SecurityState::with_components(rate_limiter, input_validator, policy_enforcer)
322 }
323
324 #[test]
325 fn test_security_state_creation() {
326 let state = SecurityState::new();
327 assert!(Arc::strong_count(&state.rate_limiter) == 1);
328 assert!(Arc::strong_count(&state.input_validator) == 1);
329 assert!(Arc::strong_count(&state.policy_enforcer) == 1);
330 }
331
332 #[test]
333 fn test_security_response() {
334 let response = SecurityResponse::new(
335 StatusCode::FORBIDDEN,
336 "Access denied".to_string(),
337 );
338 assert_eq!(response.status, StatusCode::FORBIDDEN);
339 assert_eq!(response.message, "Access denied");
340 }
341}