pulseengine_mcp_auth/
validation.rs1use crate::models::{AuthContext, Role};
8use std::collections::HashMap;
9
10pub mod permissions {
12 pub const ADMIN_CREATE_KEY: &str = "admin.create_key";
13 pub const ADMIN_DELETE_KEY: &str = "admin.delete_key";
14 pub const ADMIN_LIST_KEYS: &str = "admin.list_keys";
15 pub const ADMIN_VIEW_AUDIT: &str = "admin.view_audit";
16
17 pub const DEVICE_READ: &str = "device.read";
18 pub const DEVICE_CONTROL: &str = "device.control";
19
20 pub const SYSTEM_STATUS: &str = "system.status";
21 pub const SYSTEM_HEALTH: &str = "system.health";
22
23 pub const MCP_TOOLS_LIST: &str = "mcp.tools.list";
24 pub const MCP_TOOLS_EXECUTE: &str = "mcp.tools.execute";
25 pub const MCP_RESOURCES_LIST: &str = "mcp.resources.list";
26 pub const MCP_RESOURCES_READ: &str = "mcp.resources.read";
27}
28
29pub fn extract_client_ip(headers: &HashMap<String, String>) -> String {
32 for header_name in ["x-forwarded-for", "x-real-ip", "x-client-ip"] {
34 if let Some(ip_str) = headers.get(header_name) {
35 let ip = ip_str.split(',').next().unwrap_or(ip_str).trim();
37 if !ip.is_empty() {
38 return ip.to_string();
39 }
40 }
41 }
42
43 "unknown".to_string()
44}
45
46pub fn extract_api_key(headers: &HashMap<String, String>, query: Option<&str>) -> Option<String> {
48 if let Some(auth_header) = headers.get("authorization") {
50 if let Some(token) = auth_header.strip_prefix("Bearer ") {
51 return Some(token.to_string());
52 }
53 }
54
55 if let Some(api_key_header) = headers.get("x-api-key") {
57 return Some(api_key_header.clone());
58 }
59
60 if let Some(query_string) = query {
62 for param in query_string.split('&') {
63 if let Some((key, value)) = param.split_once('=') {
64 if key == "api_key" {
65 return Some(urlencoding::decode(value).unwrap_or_default().to_string());
66 }
67 }
68 }
69 }
70
71 None
72}
73
74pub fn check_permission(
76 context: &AuthContext,
77 permission: &str,
78 session_timeout_minutes: u64,
79) -> bool {
80 if !is_session_valid(context, session_timeout_minutes) {
82 return false;
83 }
84
85 context.has_permission(permission)
87}
88
89pub fn is_session_valid(_context: &AuthContext, _session_timeout_minutes: u64) -> bool {
91 true
94}
95
96pub fn is_valid_uuid(uuid_str: &str) -> bool {
98 uuid::Uuid::parse_str(uuid_str).is_ok()
99}
100
101pub fn is_valid_ip_address(ip_str: &str) -> bool {
103 ip_str.parse::<std::net::IpAddr>().is_ok()
104}
105
106pub fn validate_device_permission(role: &Role, device_id: &str) -> bool {
108 match role {
109 Role::Admin => true, Role::Operator => true, Role::Monitor => true, Role::Device { allowed_devices } => allowed_devices.contains(&device_id.to_string()),
113 Role::Custom { permissions } => {
114 permissions
116 .iter()
117 .any(|perm| perm == "device.*" || perm == &format!("device.{}", device_id))
118 }
119 }
120}
121
122pub fn generate_secure_key(prefix: &str) -> String {
124 let random_part = uuid::Uuid::new_v4().to_string().replace('-', "");
125 format!("{}_{}", prefix, random_part)
126}
127
128pub fn sanitize_input(input: &str) -> String {
130 input
132 .chars()
133 .filter(|c| c.is_alphanumeric() || "-_.".contains(*c))
134 .collect()
135}
136
137pub fn validate_input_format(
139 input: &str,
140 max_length: usize,
141 allow_special: bool,
142) -> Result<(), String> {
143 if input.is_empty() {
144 return Err("Input cannot be empty".to_string());
145 }
146
147 if input.len() > max_length {
148 return Err(format!("Input too long (max: {})", max_length));
149 }
150
151 if !allow_special {
152 for ch in input.chars() {
153 if !ch.is_alphanumeric() && !"-_.".contains(ch) {
154 return Err(format!("Invalid character: '{}'", ch));
155 }
156 }
157 }
158
159 Ok(())
160}
161
162pub fn extract_rate_limit_info(headers: &HashMap<String, String>) -> Option<(u32, u32)> {
164 let limit = headers.get("x-ratelimit-limit")?.parse().ok()?;
165 let remaining = headers.get("x-ratelimit-remaining")?.parse().ok()?;
166 Some((limit, remaining))
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use crate::models::Role;
173
174 #[test]
175 fn test_extract_client_ip() {
176 let mut headers = HashMap::new();
177 headers.insert(
178 "x-forwarded-for".to_string(),
179 "192.168.1.1, 10.0.0.1".to_string(),
180 );
181
182 let ip = extract_client_ip(&headers);
183 assert_eq!(ip, "192.168.1.1");
184 }
185
186 #[test]
187 fn test_extract_api_key_from_bearer() {
188 let mut headers = HashMap::new();
189 headers.insert(
190 "authorization".to_string(),
191 "Bearer test_key_123".to_string(),
192 );
193
194 let key = extract_api_key(&headers, None);
195 assert_eq!(key, Some("test_key_123".to_string()));
196 }
197
198 #[test]
199 fn test_extract_api_key_from_header() {
200 let mut headers = HashMap::new();
201 headers.insert("x-api-key".to_string(), "test_key_123".to_string());
202
203 let key = extract_api_key(&headers, None);
204 assert_eq!(key, Some("test_key_123".to_string()));
205 }
206
207 #[test]
208 fn test_extract_api_key_from_query() {
209 let headers = HashMap::new();
210 let query = "param1=value1&api_key=test_key_123¶m2=value2";
211
212 let key = extract_api_key(&headers, Some(query));
213 assert_eq!(key, Some("test_key_123".to_string()));
214 }
215
216 #[test]
217 fn test_validate_device_permission() {
218 let admin_role = Role::Admin;
219 let device_role = Role::Device {
220 allowed_devices: vec!["device1".to_string(), "device2".to_string()],
221 };
222
223 assert!(validate_device_permission(&admin_role, "any_device"));
224 assert!(validate_device_permission(&device_role, "device1"));
225 assert!(!validate_device_permission(&device_role, "device3"));
226 }
227
228 #[test]
229 fn test_is_valid_uuid() {
230 assert!(is_valid_uuid("550e8400-e29b-41d4-a716-446655440000"));
231 assert!(!is_valid_uuid("invalid-uuid"));
232 assert!(!is_valid_uuid(""));
233 }
234
235 #[test]
236 fn test_is_valid_ip_address() {
237 assert!(is_valid_ip_address("192.168.1.1"));
238 assert!(is_valid_ip_address("::1"));
239 assert!(!is_valid_ip_address("invalid-ip"));
240 assert!(!is_valid_ip_address("999.999.999.999"));
241 }
242
243 #[test]
244 fn test_sanitize_input() {
245 assert_eq!(sanitize_input("hello_world-123.txt"), "hello_world-123.txt");
246 assert_eq!(sanitize_input("hello<script>"), "helloscript");
247 assert_eq!(sanitize_input("test;DROP TABLE"), "testDROPTABLE");
248 }
249
250 #[test]
251 fn test_validate_input_format() {
252 assert!(validate_input_format("valid_input", 20, false).is_ok());
253 assert!(validate_input_format("", 20, false).is_err());
254 assert!(validate_input_format("very_long_input_exceeding_limit", 10, false).is_err());
255 assert!(validate_input_format("invalid@char", 20, false).is_err());
256 assert!(validate_input_format("invalid@char", 20, true).is_ok());
257 }
258}