1use crate::security::{Claims, RbacConfig, Result, SecurityError};
2use axum::{
3 extract::{Request, State},
4 http::StatusCode,
5 middleware::Next,
6 response::Response,
7};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::sync::Arc;
11use tracing::{debug, warn};
12
13pub struct RbacManager {
15 config: RbacConfig,
16 role_permissions: HashMap<String, HashSet<String>>,
17 user_roles: HashMap<String, String>,
18 admin_users: HashSet<String>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct PermissionCheck {
24 pub allowed: bool,
25 pub user_role: String,
26 pub required_permission: String,
27 pub user_permissions: Vec<String>,
28 pub reason: String,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct ResourcePermission {
34 pub resource: String,
35 pub action: String,
36 pub required_permissions: Vec<String>,
37 pub allow_admin_override: bool,
38}
39
40impl RbacManager {
41 pub fn new(config: RbacConfig) -> Self {
42 let mut manager = Self {
43 role_permissions: HashMap::new(),
44 user_roles: HashMap::new(),
45 admin_users: HashSet::new(),
46 config: config.clone(),
47 };
48
49 if config.enabled {
50 manager.initialize_roles();
51 }
52
53 manager
54 }
55
56 fn initialize_roles(&mut self) {
57 for (role, permissions) in &self.config.roles {
59 let perm_set: HashSet<String> = permissions.iter().cloned().collect();
60 self.role_permissions.insert(role.clone(), perm_set);
61 }
62
63 for admin_user in &self.config.admin_users {
65 self.admin_users.insert(admin_user.clone());
66 }
67
68 debug!(
69 "Initialized RBAC with {} roles and {} admin users",
70 self.role_permissions.len(),
71 self.admin_users.len()
72 );
73 }
74
75 pub fn check_permission(&self, user_id: &str, permission: &str) -> PermissionCheck {
77 if !self.config.enabled {
78 return PermissionCheck {
79 allowed: true,
80 user_role: "none".to_string(),
81 required_permission: permission.to_string(),
82 user_permissions: vec!["all".to_string()],
83 reason: "RBAC disabled".to_string(),
84 };
85 }
86
87 if self.admin_users.contains(user_id) {
89 return PermissionCheck {
90 allowed: true,
91 user_role: "admin".to_string(),
92 required_permission: permission.to_string(),
93 user_permissions: vec!["admin".to_string()],
94 reason: "Admin override".to_string(),
95 };
96 }
97
98 let user_role = self.get_user_role(user_id);
100
101 let user_permissions = self.get_user_permissions(user_id);
103
104 let allowed = user_permissions.contains(&permission.to_string())
106 || user_permissions.contains(&"*".to_string());
107
108 let reason = if allowed {
109 "Permission granted".to_string()
110 } else {
111 format!("Missing required permission: {permission}")
112 };
113
114 PermissionCheck {
115 allowed,
116 user_role,
117 required_permission: permission.to_string(),
118 user_permissions,
119 reason,
120 }
121 }
122
123 pub fn check_permissions(&self, user_id: &str, permissions: &[&str]) -> PermissionCheck {
125 if permissions.is_empty() {
126 return PermissionCheck {
127 allowed: true,
128 user_role: self.get_user_role(user_id),
129 required_permission: "none".to_string(),
130 user_permissions: self.get_user_permissions(user_id),
131 reason: "No permissions required".to_string(),
132 };
133 }
134
135 for permission in permissions {
136 let check = self.check_permission(user_id, permission);
137 if !check.allowed {
138 return check;
139 }
140 }
141
142 PermissionCheck {
143 allowed: true,
144 user_role: self.get_user_role(user_id),
145 required_permission: permissions.join(", "),
146 user_permissions: self.get_user_permissions(user_id),
147 reason: "All permissions granted".to_string(),
148 }
149 }
150
151 pub fn check_resource_access(
153 &self,
154 user_id: &str,
155 resource: &str,
156 action: &str,
157 ) -> PermissionCheck {
158 let permission = format!("{resource}:{action}");
159 let check = self.check_permission(user_id, &permission);
160
161 if !check.allowed {
162 let resource_wildcard = format!("{resource}:*");
164 let wildcard_check = self.check_permission(user_id, &resource_wildcard);
165
166 if wildcard_check.allowed {
167 return PermissionCheck {
168 allowed: true,
169 user_role: wildcard_check.user_role,
170 required_permission: permission,
171 user_permissions: wildcard_check.user_permissions,
172 reason: "Wildcard permission granted".to_string(),
173 };
174 }
175 }
176
177 check
178 }
179
180 pub fn assign_role(&mut self, user_id: &str, role: &str) -> Result<()> {
182 if !self.config.enabled {
183 return Err(SecurityError::AuthorizationFailed {
184 message: "RBAC is disabled".to_string(),
185 });
186 }
187
188 if !self.role_permissions.contains_key(role) {
189 return Err(SecurityError::AuthorizationFailed {
190 message: format!("Role '{role}' does not exist"),
191 });
192 }
193
194 self.user_roles
195 .insert(user_id.to_string(), role.to_string());
196 debug!("Assigned role '{}' to user '{}'", role, user_id);
197 Ok(())
198 }
199
200 pub fn remove_user_role(&mut self, user_id: &str) -> Result<()> {
202 if !self.config.enabled {
203 return Err(SecurityError::AuthorizationFailed {
204 message: "RBAC is disabled".to_string(),
205 });
206 }
207
208 if self.user_roles.remove(user_id).is_some() {
209 debug!("Removed role from user '{}'", user_id);
210 Ok(())
211 } else {
212 Err(SecurityError::AuthorizationFailed {
213 message: format!("User '{user_id}' has no role assigned"),
214 })
215 }
216 }
217
218 pub fn add_permission_to_role(&mut self, role: &str, permission: &str) -> Result<()> {
220 if !self.config.enabled {
221 return Err(SecurityError::AuthorizationFailed {
222 message: "RBAC is disabled".to_string(),
223 });
224 }
225
226 let permissions = self.role_permissions.entry(role.to_string()).or_default();
227 permissions.insert(permission.to_string());
228 debug!("Added permission '{}' to role '{}'", permission, role);
229 Ok(())
230 }
231
232 pub fn remove_permission_from_role(&mut self, role: &str, permission: &str) -> Result<()> {
234 if !self.config.enabled {
235 return Err(SecurityError::AuthorizationFailed {
236 message: "RBAC is disabled".to_string(),
237 });
238 }
239
240 if let Some(permissions) = self.role_permissions.get_mut(role) {
241 if permissions.remove(permission) {
242 debug!("Removed permission '{}' from role '{}'", permission, role);
243 Ok(())
244 } else {
245 Err(SecurityError::AuthorizationFailed {
246 message: format!("Role '{role}' does not have permission '{permission}'"),
247 })
248 }
249 } else {
250 Err(SecurityError::AuthorizationFailed {
251 message: format!("Role '{role}' does not exist"),
252 })
253 }
254 }
255
256 pub fn create_role(&mut self, role: &str, permissions: Vec<String>) -> Result<()> {
258 if !self.config.enabled {
259 return Err(SecurityError::AuthorizationFailed {
260 message: "RBAC is disabled".to_string(),
261 });
262 }
263
264 if self.role_permissions.contains_key(role) {
265 return Err(SecurityError::AuthorizationFailed {
266 message: format!("Role '{role}' already exists"),
267 });
268 }
269
270 let perm_set: HashSet<String> = permissions.into_iter().collect();
271 self.role_permissions.insert(role.to_string(), perm_set);
272 debug!("Created new role '{}'", role);
273 Ok(())
274 }
275
276 pub fn delete_role(&mut self, role: &str) -> Result<()> {
278 if !self.config.enabled {
279 return Err(SecurityError::AuthorizationFailed {
280 message: "RBAC is disabled".to_string(),
281 });
282 }
283
284 if role == "admin" || role == "user" {
286 return Err(SecurityError::AuthorizationFailed {
287 message: "Cannot delete built-in roles".to_string(),
288 });
289 }
290
291 self.user_roles.retain(|_, user_role| user_role != role);
293
294 if self.role_permissions.remove(role).is_some() {
296 debug!("Deleted role '{}'", role);
297 Ok(())
298 } else {
299 Err(SecurityError::AuthorizationFailed {
300 message: format!("Role '{role}' does not exist"),
301 })
302 }
303 }
304
305 pub fn get_user_role(&self, user_id: &str) -> String {
307 if self.admin_users.contains(user_id) {
308 "admin".to_string()
309 } else {
310 self.user_roles
311 .get(user_id)
312 .cloned()
313 .unwrap_or_else(|| self.config.default_role.clone())
314 }
315 }
316
317 pub fn get_user_permissions(&self, user_id: &str) -> Vec<String> {
319 if !self.config.enabled {
320 return vec!["*".to_string()];
321 }
322
323 let user_role = self.get_user_role(user_id);
324
325 if let Some(permissions) = self.role_permissions.get(&user_role) {
326 permissions.iter().cloned().collect()
327 } else {
328 Vec::new()
329 }
330 }
331
332 pub fn get_roles(&self) -> HashMap<String, Vec<String>> {
334 self.role_permissions
335 .iter()
336 .map(|(role, permissions)| (role.clone(), permissions.iter().cloned().collect()))
337 .collect()
338 }
339
340 pub fn get_user_roles(&self) -> HashMap<String, String> {
342 let mut all_users = self.user_roles.clone();
343
344 for admin_user in &self.admin_users {
346 all_users.insert(admin_user.clone(), "admin".to_string());
347 }
348
349 all_users
350 }
351
352 pub fn is_admin(&self, user_id: &str) -> bool {
354 self.admin_users.contains(user_id)
355 }
356
357 pub fn add_admin(&mut self, user_id: &str) -> Result<()> {
359 if !self.config.enabled {
360 return Err(SecurityError::AuthorizationFailed {
361 message: "RBAC is disabled".to_string(),
362 });
363 }
364
365 self.admin_users.insert(user_id.to_string());
366 debug!("Added admin user: {}", user_id);
367 Ok(())
368 }
369
370 pub fn remove_admin(&mut self, user_id: &str) -> Result<()> {
372 if !self.config.enabled {
373 return Err(SecurityError::AuthorizationFailed {
374 message: "RBAC is disabled".to_string(),
375 });
376 }
377
378 if self.admin_users.remove(user_id) {
379 debug!("Removed admin user: {}", user_id);
380 Ok(())
381 } else {
382 Err(SecurityError::AuthorizationFailed {
383 message: format!("User '{user_id}' is not an admin"),
384 })
385 }
386 }
387
388 pub fn is_enabled(&self) -> bool {
389 self.config.enabled
390 }
391}
392
393pub fn require_permission(
395 permission: &'static str,
396) -> impl Fn(
397 Request,
398 Next,
399) -> std::pin::Pin<
400 Box<dyn std::future::Future<Output = std::result::Result<Response, StatusCode>> + Send>,
401> + Clone {
402 move |request: Request, next: Next| {
403 let required_permission = permission;
404 Box::pin(async move {
405 if let Some(claims) = request.extensions().get::<Claims>() {
407 if claims
409 .permissions
410 .contains(&required_permission.to_string())
411 || claims.permissions.contains(&"*".to_string())
412 || claims.role == "admin"
413 {
414 return Ok(next.run(request).await);
415 } else {
416 warn!(
417 "Access denied for user '{}': missing permission '{}'",
418 claims.sub, required_permission
419 );
420 return Err(StatusCode::FORBIDDEN);
421 }
422 }
423
424 warn!(
426 "Access denied: no authentication found for permission '{}'",
427 required_permission
428 );
429 Err(StatusCode::UNAUTHORIZED)
430 })
431 }
432}
433
434pub fn require_resource_access(
436 resource: &'static str,
437 action: &'static str,
438) -> impl Fn(
439 Request,
440 Next,
441) -> std::pin::Pin<
442 Box<dyn std::future::Future<Output = std::result::Result<Response, StatusCode>> + Send>,
443> + Clone {
444 move |request: Request, next: Next| {
445 let required_resource = resource;
446 let required_action = action;
447 Box::pin(async move {
448 if let Some(claims) = request.extensions().get::<Claims>() {
449 let permission = format!("{required_resource}:{required_action}");
450 let wildcard = format!("{required_resource}:*");
451
452 if claims.permissions.contains(&permission)
453 || claims.permissions.contains(&wildcard)
454 || claims.permissions.contains(&"*".to_string())
455 || claims.role == "admin"
456 {
457 return Ok(next.run(request).await);
458 } else {
459 warn!(
460 "Access denied for user '{}': cannot {} on {}",
461 claims.sub, required_action, required_resource
462 );
463 return Err(StatusCode::FORBIDDEN);
464 }
465 }
466
467 warn!(
468 "Access denied: no authentication found for {}:{}",
469 required_resource, required_action
470 );
471 Err(StatusCode::UNAUTHORIZED)
472 })
473 }
474}
475
476pub async fn rbac_middleware(
478 State(rbac): State<Arc<RbacManager>>,
479 request: Request,
480 next: Next,
481) -> std::result::Result<Response, StatusCode> {
482 if !rbac.is_enabled() {
483 return Ok(next.run(request).await);
484 }
485
486 if request.extensions().get::<Claims>().is_some() {
489 Ok(next.run(request).await)
490 } else {
491 warn!("Access denied: no authentication found");
492 Err(StatusCode::UNAUTHORIZED)
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499
500 fn create_test_rbac() -> RbacManager {
501 let mut roles = HashMap::new();
502 roles.insert("user".to_string(), vec!["memory:read".to_string()]);
503 roles.insert(
504 "admin".to_string(),
505 vec![
506 "memory:read".to_string(),
507 "memory:write".to_string(),
508 "memory:delete".to_string(),
509 ],
510 );
511
512 let config = RbacConfig {
513 enabled: true,
514 default_role: "user".to_string(),
515 roles,
516 admin_users: vec!["admin@example.com".to_string()],
517 };
518
519 RbacManager::new(config)
520 }
521
522 #[test]
523 fn test_rbac_manager_creation() {
524 let rbac = create_test_rbac();
525 assert!(rbac.is_enabled());
526 assert_eq!(rbac.role_permissions.len(), 2);
527 assert_eq!(rbac.admin_users.len(), 1);
528 }
529
530 #[test]
531 fn test_permission_check_admin() {
532 let rbac = create_test_rbac();
533
534 let check = rbac.check_permission("admin@example.com", "memory:delete");
535 assert!(check.allowed);
536 assert_eq!(check.user_role, "admin");
537 assert!(check.reason.contains("Admin override"));
538 }
539
540 #[test]
541 fn test_permission_check_user() {
542 let mut rbac = create_test_rbac();
543 rbac.assign_role("user@example.com", "user").unwrap();
544
545 let read_check = rbac.check_permission("user@example.com", "memory:read");
547 assert!(read_check.allowed);
548 assert_eq!(read_check.user_role, "user");
549
550 let write_check = rbac.check_permission("user@example.com", "memory:write");
552 assert!(!write_check.allowed);
553 assert!(write_check.reason.contains("Missing required permission"));
554 }
555
556 #[test]
557 fn test_resource_access_check() {
558 let mut rbac = create_test_rbac();
559 rbac.assign_role("user@example.com", "user").unwrap();
560
561 let check = rbac.check_resource_access("user@example.com", "memory", "read");
562 assert!(check.allowed);
563
564 let check = rbac.check_resource_access("user@example.com", "memory", "write");
565 assert!(!check.allowed);
566 }
567
568 #[test]
569 fn test_multiple_permissions_check() {
570 let mut rbac = create_test_rbac();
571 rbac.assign_role("admin@example.com", "admin").unwrap();
572
573 let check = rbac.check_permissions("admin@example.com", &["memory:read", "memory:write"]);
575 assert!(check.allowed);
576
577 rbac.assign_role("user@example.com", "user").unwrap();
578
579 let check = rbac.check_permissions("user@example.com", &["memory:read", "memory:write"]);
581 assert!(!check.allowed);
582 }
583
584 #[test]
585 fn test_role_assignment() {
586 let mut rbac = create_test_rbac();
587
588 let result = rbac.assign_role("test@example.com", "user");
590 assert!(result.is_ok());
591
592 assert_eq!(rbac.get_user_role("test@example.com"), "user");
594
595 let result = rbac.assign_role("test@example.com", "nonexistent");
597 assert!(result.is_err());
598 }
599
600 #[test]
601 fn test_role_creation_and_deletion() {
602 let mut rbac = create_test_rbac();
603
604 let result = rbac.create_role(
606 "moderator",
607 vec!["memory:read".to_string(), "memory:moderate".to_string()],
608 );
609 assert!(result.is_ok());
610
611 assert!(rbac.role_permissions.contains_key("moderator"));
613
614 let result = rbac.delete_role("moderator");
616 assert!(result.is_ok());
617
618 assert!(!rbac.role_permissions.contains_key("moderator"));
620
621 let result = rbac.delete_role("admin");
623 assert!(result.is_err());
624 }
625
626 #[test]
627 fn test_permission_management() {
628 let mut rbac = create_test_rbac();
629
630 let result = rbac.add_permission_to_role("user", "memory:moderate");
632 assert!(result.is_ok());
633
634 let permissions = rbac.role_permissions.get("user").unwrap();
636 assert!(permissions.contains("memory:moderate"));
637
638 let result = rbac.remove_permission_from_role("user", "memory:moderate");
640 assert!(result.is_ok());
641
642 let permissions = rbac.role_permissions.get("user").unwrap();
644 assert!(!permissions.contains("memory:moderate"));
645 }
646
647 #[test]
648 fn test_admin_management() {
649 let mut rbac = create_test_rbac();
650
651 assert!(rbac.is_admin("admin@example.com"));
653
654 let result = rbac.add_admin("newadmin@example.com");
656 assert!(result.is_ok());
657 assert!(rbac.is_admin("newadmin@example.com"));
658
659 let result = rbac.remove_admin("newadmin@example.com");
661 assert!(result.is_ok());
662 assert!(!rbac.is_admin("newadmin@example.com"));
663
664 let result = rbac.remove_admin("notadmin@example.com");
666 assert!(result.is_err());
667 }
668
669 #[test]
670 fn test_default_role() {
671 let rbac = create_test_rbac();
672
673 let role = rbac.get_user_role("unknown@example.com");
675 assert_eq!(role, "user");
676 }
677
678 #[test]
679 fn test_disabled_rbac() {
680 let config = RbacConfig {
681 enabled: false,
682 default_role: "user".to_string(),
683 roles: HashMap::new(),
684 admin_users: Vec::new(),
685 };
686
687 let rbac = RbacManager::new(config);
688 assert!(!rbac.is_enabled());
689
690 let check = rbac.check_permission("anyone", "anything");
692 assert!(check.allowed);
693 assert!(check.reason.contains("RBAC disabled"));
694 }
695
696 #[test]
697 fn test_wildcard_permissions() {
698 let mut rbac = create_test_rbac();
699
700 rbac.add_permission_to_role("user", "memory:*").unwrap();
702 rbac.assign_role("user@example.com", "user").unwrap();
703
704 let check = rbac.check_resource_access("user@example.com", "memory", "write");
706 assert!(check.allowed);
707 assert!(check.reason.contains("Wildcard permission"));
708
709 let check = rbac.check_resource_access("user@example.com", "memory", "delete");
710 assert!(check.allowed);
711 }
712
713 #[test]
714 fn test_get_roles_and_users() {
715 let mut rbac = create_test_rbac();
716 rbac.assign_role("user1@example.com", "user").unwrap();
717 rbac.assign_role("user2@example.com", "admin").unwrap();
718
719 let roles = rbac.get_roles();
720 assert_eq!(roles.len(), 2);
721 assert!(roles.contains_key("user"));
722 assert!(roles.contains_key("admin"));
723
724 let user_roles = rbac.get_user_roles();
725 assert!(user_roles.len() >= 3); assert_eq!(
727 user_roles.get("admin@example.com"),
728 Some(&"admin".to_string())
729 );
730 }
731}