llm_config_api/
middleware.rs

1//! Security middleware for API requests
2//!
3//! This module provides security middleware layers including:
4//! - Input validation
5//! - Rate limiting
6//! - Policy enforcement
7//! - Request/response sanitization
8
9use axum::{
10    extract::{ConnectInfo, Request, State},
11    http::{HeaderMap, StatusCode},
12    middleware::Next,
13    response::{IntoResponse, Response},
14    Json,
15};
16use llm_config_security::{
17    InputValidator, PolicyEnforcer, RateLimiter, SecurityContext, SecurityError,
18};
19use serde_json::json;
20use std::net::SocketAddr;
21use std::sync::Arc;
22
23/// Security middleware state
24#[derive(Clone)]
25pub struct SecurityState {
26    pub rate_limiter: Arc<RateLimiter>,
27    pub input_validator: Arc<InputValidator>,
28    pub policy_enforcer: Arc<PolicyEnforcer>,
29}
30
31impl SecurityState {
32    /// Create a new security state with default configuration
33    pub fn new() -> Self {
34        Self {
35            rate_limiter: Arc::new(RateLimiter::new(Default::default())),
36            input_validator: Arc::new(InputValidator::default()),
37            policy_enforcer: Arc::new(PolicyEnforcer::default()),
38        }
39    }
40
41    /// Create a new security state with custom components
42    pub fn with_components(
43        rate_limiter: RateLimiter,
44        input_validator: InputValidator,
45        policy_enforcer: PolicyEnforcer,
46    ) -> Self {
47        Self {
48            rate_limiter: Arc::new(rate_limiter),
49            input_validator: Arc::new(input_validator),
50            policy_enforcer: Arc::new(policy_enforcer),
51        }
52    }
53}
54
55impl Default for SecurityState {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61/// Rate limiting middleware
62///
63/// Checks requests against rate limits and automatically bans abusive IPs
64pub async fn rate_limit_middleware(
65    State(security): State<SecurityState>,
66    ConnectInfo(addr): ConnectInfo<SocketAddr>,
67    headers: HeaderMap,
68    request: Request,
69    next: Next,
70) -> Result<Response, SecurityResponse> {
71    let ip = addr.ip();
72
73    // Check if request has authentication (simplified - in production use proper auth)
74    let is_authenticated = headers.get("authorization").is_some();
75
76    // Check rate limit
77    security
78        .rate_limiter
79        .check_request(ip, is_authenticated)
80        .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::TOO_MANY_REQUESTS))?;
81
82    Ok(next.run(request).await)
83}
84
85/// Input validation middleware
86///
87/// Validates and sanitizes request paths and query parameters
88pub async fn input_validation_middleware(
89    State(security): State<SecurityState>,
90    request: Request,
91    next: Next,
92) -> Result<Response, SecurityResponse> {
93    let uri = request.uri();
94    let path = uri.path();
95    let query = uri.query().unwrap_or("");
96
97    // Validate path
98    security
99        .input_validator
100        .validate(path)
101        .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::BAD_REQUEST))?;
102
103    // Validate query parameters
104    if !query.is_empty() {
105        security
106            .input_validator
107            .validate(query)
108            .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::BAD_REQUEST))?;
109    }
110
111    Ok(next.run(request).await)
112}
113
114/// Policy enforcement middleware
115///
116/// Enforces security policies including IP blocking and TLS requirements
117pub async fn policy_enforcement_middleware(
118    State(security): State<SecurityState>,
119    ConnectInfo(addr): ConnectInfo<SocketAddr>,
120    headers: HeaderMap,
121    request: Request,
122    next: Next,
123) -> Result<Response, SecurityResponse> {
124    let ip = addr.ip();
125
126    // Check if IP is blocked
127    security
128        .policy_enforcer
129        .check_ip(&ip.to_string())
130        .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::FORBIDDEN))?;
131
132    // Check TLS (in production, check X-Forwarded-Proto or similar)
133    let is_tls = headers
134        .get("x-forwarded-proto")
135        .and_then(|v| v.to_str().ok())
136        .map(|v| v == "https")
137        .unwrap_or(false);
138
139    security
140        .policy_enforcer
141        .check_tls(is_tls, "1.2")
142        .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::UPGRADE_REQUIRED))?;
143
144    // Check CORS origin
145    if let Some(origin) = headers.get("origin").and_then(|v| v.to_str().ok()) {
146        security
147            .policy_enforcer
148            .check_origin(origin)
149            .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::FORBIDDEN))?;
150    }
151
152    // Check request size
153    if let Some(content_length) = headers
154        .get("content-length")
155        .and_then(|v| v.to_str().ok())
156        .and_then(|v| v.parse::<usize>().ok())
157    {
158        security
159            .policy_enforcer
160            .check_request_size(content_length)
161            .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::PAYLOAD_TOO_LARGE))?;
162    }
163
164    // Check endpoint access
165    let endpoint = request.uri().path();
166    security
167        .policy_enforcer
168        .check_endpoint(endpoint)
169        .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::FORBIDDEN))?;
170
171    Ok(next.run(request).await)
172}
173
174/// Security context middleware
175///
176/// Creates a security context for audit logging and tracking
177pub async fn security_context_middleware(
178    ConnectInfo(addr): ConnectInfo<SocketAddr>,
179    headers: HeaderMap,
180    mut request: Request,
181    next: Next,
182) -> Response {
183    let ip = addr.ip();
184
185    // Extract user ID from headers (simplified - in production use proper auth)
186    let user_id = headers
187        .get("x-user-id")
188        .and_then(|v| v.to_str().ok())
189        .unwrap_or("anonymous")
190        .to_string();
191
192    // Create security context
193    let context = SecurityContext::new(user_id, ip.to_string());
194
195    // Store context in request extensions for use in handlers
196    request.extensions_mut().insert(context);
197
198    next.run(request).await
199}
200
201/// Comprehensive security middleware
202///
203/// Combines all security checks in a single middleware
204pub async fn comprehensive_security_middleware(
205    State(security): State<SecurityState>,
206    ConnectInfo(addr): ConnectInfo<SocketAddr>,
207    headers: HeaderMap,
208    mut request: Request,
209    next: Next,
210) -> Result<Response, SecurityResponse> {
211    let ip = addr.ip();
212    let is_authenticated = headers.get("authorization").is_some();
213
214    // 1. Rate limiting
215    security
216        .rate_limiter
217        .check_request(ip, is_authenticated)
218        .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::TOO_MANY_REQUESTS))?;
219
220    // 2. Policy enforcement - IP check
221    security
222        .policy_enforcer
223        .check_ip(&ip.to_string())
224        .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::FORBIDDEN))?;
225
226    // 3. Policy enforcement - TLS check
227    let is_tls = headers
228        .get("x-forwarded-proto")
229        .and_then(|v| v.to_str().ok())
230        .map(|v| v == "https")
231        .unwrap_or(false);
232
233    security
234        .policy_enforcer
235        .check_tls(is_tls, "1.2")
236        .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::UPGRADE_REQUIRED))?;
237
238    // 4. Policy enforcement - endpoint check
239    let endpoint = request.uri().path();
240    security
241        .policy_enforcer
242        .check_endpoint(endpoint)
243        .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::FORBIDDEN))?;
244
245    // 5. Input validation
246    let uri = request.uri();
247    security
248        .input_validator
249        .validate(uri.path())
250        .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::BAD_REQUEST))?;
251
252    if let Some(query) = uri.query() {
253        security
254            .input_validator
255            .validate(query)
256            .map_err(|e| SecurityResponse::from_security_error(e, StatusCode::BAD_REQUEST))?;
257    }
258
259    // 6. Create security context
260    let user_id = headers
261        .get("x-user-id")
262        .and_then(|v| v.to_str().ok())
263        .unwrap_or("anonymous")
264        .to_string();
265
266    let context = SecurityContext::new(user_id, ip.to_string());
267    request.extensions_mut().insert(context);
268
269    Ok(next.run(request).await)
270}
271
272/// Security error response
273pub struct SecurityResponse {
274    status: StatusCode,
275    message: String,
276}
277
278impl SecurityResponse {
279    pub fn new(status: StatusCode, message: String) -> Self {
280        Self { status, message }
281    }
282
283    pub fn from_security_error(error: SecurityError, status: StatusCode) -> Self {
284        Self {
285            status,
286            message: error.public_message(),
287        }
288    }
289}
290
291impl IntoResponse for SecurityResponse {
292    fn into_response(self) -> Response {
293        let body = json!({
294            "error": self.status.canonical_reason().unwrap_or("Security Error"),
295            "message": self.message,
296        });
297
298        (self.status, Json(body)).into_response()
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use llm_config_security::{RateLimitConfig, SecurityPolicy};
306    use std::net::{IpAddr, Ipv4Addr};
307
308    fn create_test_security_state() -> SecurityState {
309        let rate_limiter = RateLimiter::new(RateLimitConfig {
310            authenticated_rps: 100,
311            unauthenticated_rps: 10,
312            burst_size: 50,
313            window_seconds: 60,
314            ban_duration_seconds: 3600,
315            ban_threshold: 10,
316        });
317
318        let input_validator = InputValidator::default();
319        let policy_enforcer = PolicyEnforcer::new(SecurityPolicy::default());
320
321        SecurityState::with_components(rate_limiter, input_validator, policy_enforcer)
322    }
323
324    #[test]
325    fn test_security_state_creation() {
326        let state = SecurityState::new();
327        assert!(Arc::strong_count(&state.rate_limiter) == 1);
328        assert!(Arc::strong_count(&state.input_validator) == 1);
329        assert!(Arc::strong_count(&state.policy_enforcer) == 1);
330    }
331
332    #[test]
333    fn test_security_response() {
334        let response = SecurityResponse::new(
335            StatusCode::FORBIDDEN,
336            "Access denied".to_string(),
337        );
338        assert_eq!(response.status, StatusCode::FORBIDDEN);
339        assert_eq!(response.message, "Access denied");
340    }
341}