1use parking_lot::RwLock;
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
11pub struct Role {
12 pub name: String,
13 pub permissions: HashSet<Permission>,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
18pub struct Permission {
19 pub resource: String,
20 pub action: String,
21}
22
23impl Permission {
24 pub fn new(resource: impl Into<String>, action: impl Into<String>) -> Self {
25 Self {
26 resource: resource.into(),
27 action: action.into(),
28 }
29 }
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct User {
35 pub id: String,
36 pub roles: HashSet<String>,
37}
38
39pub struct AccessController {
41 roles: RwLock<HashMap<String, Role>>,
43 user_roles: RwLock<HashMap<String, HashSet<String>>>,
45}
46
47impl AccessController {
48 pub fn new() -> Self {
49 Self {
50 roles: RwLock::new(Self::default_roles()),
51 user_roles: RwLock::new(HashMap::new()),
52 }
53 }
54
55 fn default_roles() -> HashMap<String, Role> {
57 let mut roles = HashMap::new();
58
59 let admin_perms: HashSet<Permission> = [Permission::new("*", "*")].into_iter().collect();
61 roles.insert(
62 "admin".to_string(),
63 Role {
64 name: "admin".to_string(),
65 permissions: admin_perms,
66 },
67 );
68
69 let user_perms: HashSet<Permission> = [
71 Permission::new("session", "read"),
72 Permission::new("session", "write"),
73 Permission::new("tool", "execute"),
74 Permission::new("agent", "run"),
75 ]
76 .into_iter()
77 .collect();
78 roles.insert(
79 "user".to_string(),
80 Role {
81 name: "user".to_string(),
82 permissions: user_perms,
83 },
84 );
85
86 let guest_perms: HashSet<Permission> =
88 [Permission::new("session", "read")].into_iter().collect();
89 roles.insert(
90 "guest".to_string(),
91 Role {
92 name: "guest".to_string(),
93 permissions: guest_perms,
94 },
95 );
96
97 roles
98 }
99
100 pub fn check(&self, user_id: &str, resource: &str, action: &str) -> bool {
102 let user_roles = self.user_roles.read();
103
104 let roles = user_roles.get(user_id).cloned().unwrap_or_else(|| {
106 HashSet::from(["guest".to_string()])
108 });
109
110 let role_map = self.roles.read();
111
112 for role_name in roles {
114 if let Some(role) = role_map.get(&role_name) {
115 for perm in &role.permissions {
116 if (perm.resource == "*" || perm.resource == resource)
118 && (perm.action == "*" || perm.action == action)
119 {
120 return true;
121 }
122 }
123 }
124 }
125
126 false
127 }
128
129 pub fn add_role(&self, user_id: &str, role_name: &str) {
131 let mut user_roles = self.user_roles.write();
132 user_roles
133 .entry(user_id.to_string())
134 .or_default()
135 .insert(role_name.to_string());
136 }
137
138 pub fn remove_role(&self, user_id: &str, role_name: &str) {
140 let mut user_roles = self.user_roles.write();
141 if let Some(roles) = user_roles.get_mut(user_id) {
142 roles.remove(role_name);
143 if roles.is_empty() {
145 user_roles.remove(user_id);
146 }
147 }
148 }
149
150 pub fn create_role(&self, role: Role) {
152 let mut roles = self.roles.write();
153 roles.insert(role.name.clone(), role);
154 }
155
156 pub fn get_permissions(&self, user_id: &str) -> HashSet<Permission> {
158 let user_roles = self.user_roles.read();
159 let roles = user_roles.get(user_id).cloned().unwrap_or_default();
160 let role_map = self.roles.read();
161
162 let mut permissions = HashSet::new();
163 for role_name in roles {
164 if let Some(role) = role_map.get(&role_name) {
165 permissions.extend(role.permissions.clone());
166 }
167 }
168 permissions
169 }
170}
171
172impl Default for AccessController {
173 fn default() -> Self {
174 Self::new()
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn test_default_roles() {
184 let controller = AccessController::new();
185
186 controller.add_role("admin_user", "admin");
188 assert!(controller.check("admin_user", "any_resource", "any_action"));
189
190 controller.add_role("normal_user", "user");
192 assert!(controller.check("normal_user", "tool", "execute"));
193 assert!(!controller.check("normal_user", "admin", "create"));
194 }
195
196 #[test]
197 fn test_guest_default() {
198 let controller = AccessController::new();
199
200 assert!(controller.check("unknown_user", "session", "read"));
202 assert!(!controller.check("unknown_user", "session", "write"));
203 }
204
205 #[test]
206 fn test_remove_role() {
207 let controller = AccessController::new();
208
209 controller.add_role("test_user", "admin");
210 assert!(controller.check("test_user", "any", "any"));
211
212 controller.remove_role("test_user", "admin");
213 assert!(!controller.check("test_user", "any", "any"));
215 assert!(controller.check("test_user", "session", "read"));
216 }
217
218 #[test]
219 fn test_create_custom_role() {
220 let controller = AccessController::new();
221
222 let custom_role = Role {
223 name: "custom".to_string(),
224 permissions: HashSet::from([
225 Permission::new("custom_resource", "read"),
226 Permission::new("custom_resource", "write"),
227 ]),
228 };
229 controller.create_role(custom_role);
230
231 controller.add_role("custom_user", "custom");
232 assert!(controller.check("custom_user", "custom_resource", "read"));
233 assert!(controller.check("custom_user", "custom_resource", "write"));
234 assert!(!controller.check("custom_user", "other_resource", "read"));
235 }
236
237 #[test]
238 fn test_get_permissions() {
239 let controller = AccessController::new();
240
241 controller.add_role("multi_user", "user");
242 controller.add_role("multi_user", "guest");
243
244 let permissions = controller.get_permissions("multi_user");
245 assert!(permissions.contains(&Permission::new("session", "read")));
247 assert!(permissions.contains(&Permission::new("session", "write")));
248 assert!(permissions.contains(&Permission::new("tool", "execute")));
249 }
250
251 #[test]
252 fn test_permission_new() {
253 let perm = Permission::new("resource", "action");
254 assert_eq!(perm.resource, "resource");
255 assert_eq!(perm.action, "action");
256 }
257
258 #[test]
259 fn test_multiple_roles_same_user() {
260 let controller = AccessController::new();
261
262 controller.add_role("power_user", "user");
263 controller.add_role("power_user", "admin");
264
265 assert!(controller.check("power_user", "super_secret", "delete"));
267 }
268
269 #[test]
270 fn test_role_serialization() {
271 let role = Role {
272 name: "test".to_string(),
273 permissions: HashSet::from([Permission::new("r", "a")]),
274 };
275 let json = serde_json::to_string(&role).unwrap();
276 assert!(json.contains("test"));
277 }
278
279 #[test]
280 fn test_permission_hash_equality() {
281 let p1 = Permission::new("resource", "action");
282 let p2 = Permission::new("resource", "action");
283 let set: HashSet<Permission> = HashSet::from([p1, p2]);
284 assert_eq!(set.len(), 1); }
286
287 #[test]
288 fn test_concurrent_access() {
289 use std::sync::Arc;
290 use std::thread;
291
292 let controller = Arc::new(AccessController::new());
293 controller.add_role("user1", "admin");
294
295 let mut handles = vec![];
296 for i in 0..10 {
297 let c = Arc::clone(&controller);
298 handles.push(thread::spawn(move || {
299 let user = format!("user{}", i);
300 c.check(&user, "session", "read")
301 }));
302 }
303
304 for handle in handles {
305 handle.join().unwrap();
306 }
307 }
308}