pulseengine_mcp_auth/
validation.rs

1//! Authentication validation utilities
2//!
3//! This module provides helper functions for extracting authentication
4//! information from requests, validating permissions, and handling
5//! session management.
6
7use crate::models::{AuthContext, Role};
8use std::collections::HashMap;
9
10/// Permission constants for common operations
11pub mod permissions {
12    pub const ADMIN_CREATE_KEY: &str = "admin.create_key";
13    pub const ADMIN_DELETE_KEY: &str = "admin.delete_key";
14    pub const ADMIN_LIST_KEYS: &str = "admin.list_keys";
15    pub const ADMIN_VIEW_AUDIT: &str = "admin.view_audit";
16
17    pub const DEVICE_READ: &str = "device.read";
18    pub const DEVICE_CONTROL: &str = "device.control";
19
20    pub const SYSTEM_STATUS: &str = "system.status";
21    pub const SYSTEM_HEALTH: &str = "system.health";
22
23    pub const MCP_TOOLS_LIST: &str = "mcp.tools.list";
24    pub const MCP_TOOLS_EXECUTE: &str = "mcp.tools.execute";
25    pub const MCP_RESOURCES_LIST: &str = "mcp.resources.list";
26    pub const MCP_RESOURCES_READ: &str = "mcp.resources.read";
27}
28
29/// Helper function to extract client IP from various sources
30/// This works with axum HTTP headers
31pub fn extract_client_ip(headers: &HashMap<String, String>) -> String {
32    // Try various headers in order of preference
33    for header_name in ["x-forwarded-for", "x-real-ip", "x-client-ip"] {
34        if let Some(ip_str) = headers.get(header_name) {
35            // Take the first IP if there are multiple (comma-separated)
36            let ip = ip_str.split(',').next().unwrap_or(ip_str).trim();
37            if !ip.is_empty() {
38                return ip.to_string();
39            }
40        }
41    }
42
43    "unknown".to_string()
44}
45
46/// Helper function to extract API key from request headers or query parameters
47pub fn extract_api_key(headers: &HashMap<String, String>, query: Option<&str>) -> Option<String> {
48    // Try Authorization header with Bearer token
49    if let Some(auth_header) = headers.get("authorization") {
50        if let Some(token) = auth_header.strip_prefix("Bearer ") {
51            return Some(token.to_string());
52        }
53    }
54
55    // Try X-API-Key header
56    if let Some(api_key_header) = headers.get("x-api-key") {
57        return Some(api_key_header.clone());
58    }
59
60    // Try query parameter
61    if let Some(query_string) = query {
62        for param in query_string.split('&') {
63            if let Some((key, value)) = param.split_once('=') {
64                if key == "api_key" {
65                    return Some(urlencoding::decode(value).unwrap_or_default().to_string());
66                }
67            }
68        }
69    }
70
71    None
72}
73
74/// Check if a session has the required permission
75pub fn check_permission(
76    context: &AuthContext,
77    permission: &str,
78    session_timeout_minutes: u64,
79) -> bool {
80    // Check if session is still valid
81    if !is_session_valid(context, session_timeout_minutes) {
82        return false;
83    }
84
85    // Check role-based permission
86    context.has_permission(permission)
87}
88
89/// Check if a session is still valid based on timeout
90pub fn is_session_valid(_context: &AuthContext, _session_timeout_minutes: u64) -> bool {
91    // For now, sessions don't have explicit timestamps in AuthContext
92    // This can be enhanced when we add session tracking
93    true
94}
95
96/// Validate that a string is a valid UUID
97pub fn is_valid_uuid(uuid_str: &str) -> bool {
98    uuid::Uuid::parse_str(uuid_str).is_ok()
99}
100
101/// Validate that a string is a valid IP address
102pub fn is_valid_ip_address(ip_str: &str) -> bool {
103    ip_str.parse::<std::net::IpAddr>().is_ok()
104}
105
106/// Validate that a role has permission for a specific device
107pub fn validate_device_permission(role: &Role, device_id: &str) -> bool {
108    match role {
109        Role::Admin => true,    // Admin has access to all devices
110        Role::Operator => true, // Operator has access to all devices
111        Role::Monitor => true,  // Monitor can read all devices
112        Role::Device { allowed_devices } => allowed_devices.contains(&device_id.to_string()),
113        Role::Custom { permissions } => {
114            // Check if custom role has device-specific permission
115            permissions
116                .iter()
117                .any(|perm| perm == "device.*" || perm == &format!("device.{}", device_id))
118        }
119    }
120}
121
122/// Generate a secure random key for API keys
123pub fn generate_secure_key(prefix: &str) -> String {
124    let random_part = uuid::Uuid::new_v4().to_string().replace('-', "");
125    format!("{}_{}", prefix, random_part)
126}
127
128/// Sanitize input to prevent injection attacks
129pub fn sanitize_input(input: &str) -> String {
130    // Remove potentially dangerous characters
131    input
132        .chars()
133        .filter(|c| c.is_alphanumeric() || "-_.".contains(*c))
134        .collect()
135}
136
137/// Validate input length and format
138pub fn validate_input_format(
139    input: &str,
140    max_length: usize,
141    allow_special: bool,
142) -> Result<(), String> {
143    if input.is_empty() {
144        return Err("Input cannot be empty".to_string());
145    }
146
147    if input.len() > max_length {
148        return Err(format!("Input too long (max: {})", max_length));
149    }
150
151    if !allow_special {
152        for ch in input.chars() {
153            if !ch.is_alphanumeric() && !"-_.".contains(ch) {
154                return Err(format!("Invalid character: '{}'", ch));
155            }
156        }
157    }
158
159    Ok(())
160}
161
162/// Extract and validate rate limiting headers
163pub fn extract_rate_limit_info(headers: &HashMap<String, String>) -> Option<(u32, u32)> {
164    let limit = headers.get("x-ratelimit-limit")?.parse().ok()?;
165    let remaining = headers.get("x-ratelimit-remaining")?.parse().ok()?;
166    Some((limit, remaining))
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use crate::models::Role;
173
174    #[test]
175    fn test_extract_client_ip() {
176        let mut headers = HashMap::new();
177        headers.insert(
178            "x-forwarded-for".to_string(),
179            "192.168.1.1, 10.0.0.1".to_string(),
180        );
181
182        let ip = extract_client_ip(&headers);
183        assert_eq!(ip, "192.168.1.1");
184    }
185
186    #[test]
187    fn test_extract_api_key_from_bearer() {
188        let mut headers = HashMap::new();
189        headers.insert(
190            "authorization".to_string(),
191            "Bearer test_key_123".to_string(),
192        );
193
194        let key = extract_api_key(&headers, None);
195        assert_eq!(key, Some("test_key_123".to_string()));
196    }
197
198    #[test]
199    fn test_extract_api_key_from_header() {
200        let mut headers = HashMap::new();
201        headers.insert("x-api-key".to_string(), "test_key_123".to_string());
202
203        let key = extract_api_key(&headers, None);
204        assert_eq!(key, Some("test_key_123".to_string()));
205    }
206
207    #[test]
208    fn test_extract_api_key_from_query() {
209        let headers = HashMap::new();
210        let query = "param1=value1&api_key=test_key_123&param2=value2";
211
212        let key = extract_api_key(&headers, Some(query));
213        assert_eq!(key, Some("test_key_123".to_string()));
214    }
215
216    #[test]
217    fn test_validate_device_permission() {
218        let admin_role = Role::Admin;
219        let device_role = Role::Device {
220            allowed_devices: vec!["device1".to_string(), "device2".to_string()],
221        };
222
223        assert!(validate_device_permission(&admin_role, "any_device"));
224        assert!(validate_device_permission(&device_role, "device1"));
225        assert!(!validate_device_permission(&device_role, "device3"));
226    }
227
228    #[test]
229    fn test_is_valid_uuid() {
230        assert!(is_valid_uuid("550e8400-e29b-41d4-a716-446655440000"));
231        assert!(!is_valid_uuid("invalid-uuid"));
232        assert!(!is_valid_uuid(""));
233    }
234
235    #[test]
236    fn test_is_valid_ip_address() {
237        assert!(is_valid_ip_address("192.168.1.1"));
238        assert!(is_valid_ip_address("::1"));
239        assert!(!is_valid_ip_address("invalid-ip"));
240        assert!(!is_valid_ip_address("999.999.999.999"));
241    }
242
243    #[test]
244    fn test_sanitize_input() {
245        assert_eq!(sanitize_input("hello_world-123.txt"), "hello_world-123.txt");
246        assert_eq!(sanitize_input("hello<script>"), "helloscript");
247        assert_eq!(sanitize_input("test;DROP TABLE"), "testDROPTABLE");
248    }
249
250    #[test]
251    fn test_validate_input_format() {
252        assert!(validate_input_format("valid_input", 20, false).is_ok());
253        assert!(validate_input_format("", 20, false).is_err());
254        assert!(validate_input_format("very_long_input_exceeding_limit", 10, false).is_err());
255        assert!(validate_input_format("invalid@char", 20, false).is_err());
256        assert!(validate_input_format("invalid@char", 20, true).is_ok());
257    }
258}