1use super::error::AuthorizationError;
6use super::response::AuthResponse;
7use crate::auth::Authenticatable;
8use std::any::{Any, TypeId};
9use std::collections::HashMap;
10use std::sync::{Arc, RwLock};
11
12type AbilityCallback =
14 Box<dyn Fn(&dyn Authenticatable, Option<&dyn Any>) -> AuthResponse + Send + Sync>;
15
16type BeforeCallback = Box<dyn Fn(&dyn Authenticatable, &str) -> Option<bool> + Send + Sync>;
18
19static GATE_REGISTRY: RwLock<Option<GateRegistry>> = RwLock::new(None);
21
22struct GateRegistry {
24 abilities: HashMap<String, AbilityCallback>,
26 before_hooks: Vec<BeforeCallback>,
28 policies: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
30}
31
32impl GateRegistry {
33 fn new() -> Self {
34 Self {
35 abilities: HashMap::new(),
36 before_hooks: Vec::new(),
37 policies: HashMap::new(),
38 }
39 }
40}
41
42pub struct Gate;
63
64impl Gate {
65 pub fn init() {
69 let mut registry = GATE_REGISTRY.write().unwrap();
70 if registry.is_none() {
71 *registry = Some(GateRegistry::new());
72 }
73 }
74
75 pub fn define<F>(ability: &str, callback: F)
87 where
88 F: Fn(&dyn Authenticatable, Option<&dyn Any>) -> AuthResponse + Send + Sync + 'static,
89 {
90 Self::init();
91 let mut registry = GATE_REGISTRY.write().unwrap();
92 if let Some(ref mut reg) = *registry {
93 reg.abilities
94 .insert(ability.to_string(), Box::new(callback));
95 }
96 }
97
98 pub fn before<F>(callback: F)
117 where
118 F: Fn(&dyn Authenticatable, &str) -> Option<bool> + Send + Sync + 'static,
119 {
120 Self::init();
121 let mut registry = GATE_REGISTRY.write().unwrap();
122 if let Some(ref mut reg) = *registry {
123 reg.before_hooks.push(Box::new(callback));
124 }
125 }
126
127 pub fn allows(ability: &str, resource: Option<&dyn Any>) -> bool {
131 crate::auth::Auth::id().is_some() && Self::allows_for_user_id(ability, resource)
132 }
133
134 pub fn denies(ability: &str, resource: Option<&dyn Any>) -> bool {
136 !Self::allows(ability, resource)
137 }
138
139 pub fn authorize(ability: &str, resource: Option<&dyn Any>) -> Result<(), AuthorizationError> {
152 if crate::auth::Auth::id().is_none() {
153 return Err(AuthorizationError::new(ability).with_status(401));
154 }
155
156 if Self::allows_for_user_id(ability, resource) {
157 Ok(())
158 } else {
159 Err(AuthorizationError::new(ability))
160 }
161 }
162
163 pub fn allows_for<U: Authenticatable>(
165 user: &U,
166 ability: &str,
167 resource: Option<&dyn Any>,
168 ) -> bool {
169 Self::inspect(user, ability, resource).allowed()
170 }
171
172 pub fn authorize_for<U: Authenticatable>(
174 user: &U,
175 ability: &str,
176 resource: Option<&dyn Any>,
177 ) -> Result<(), AuthorizationError> {
178 let response = Self::inspect(user, ability, resource);
179 if response.allowed() {
180 Ok(())
181 } else {
182 let mut error = AuthorizationError::new(ability);
183 if let Some(msg) = response.message() {
184 error.message = Some(msg.to_string());
185 }
186 error.status = response.status();
187 Err(error)
188 }
189 }
190
191 pub fn check_for<U: Authenticatable>(
193 user: &U,
194 ability: &str,
195 resource: Option<&dyn Any>,
196 ) -> AuthResponse {
197 Self::inspect(user, ability, resource)
198 }
199
200 pub fn inspect(
204 user: &dyn Authenticatable,
205 ability: &str,
206 resource: Option<&dyn Any>,
207 ) -> AuthResponse {
208 let registry = GATE_REGISTRY.read().unwrap();
209 let reg = match &*registry {
210 Some(r) => r,
211 None => return AuthResponse::deny_silent(),
212 };
213
214 for hook in ®.before_hooks {
216 if let Some(result) = hook(user, ability) {
217 return result.into();
218 }
219 }
220
221 if let Some(callback) = reg.abilities.get(ability) {
223 return callback(user, resource);
224 }
225
226 AuthResponse::deny_silent()
228 }
229
230 fn allows_for_user_id(ability: &str, _resource: Option<&dyn Any>) -> bool {
232 let registry = GATE_REGISTRY.read().unwrap();
235 let reg = match &*registry {
236 Some(r) => r,
237 None => return false,
238 };
239
240 if !reg.abilities.contains_key(ability) && reg.before_hooks.is_empty() {
242 return false;
243 }
244
245 false
248 }
249
250 pub fn has_policy_for<M: 'static>() -> bool {
252 let registry = GATE_REGISTRY.read().unwrap();
253 registry
254 .as_ref()
255 .map(|r| r.policies.contains_key(&TypeId::of::<M>()))
256 .unwrap_or(false)
257 }
258
259 #[cfg(test)]
261 pub fn flush() {
262 let mut registry = GATE_REGISTRY.write().unwrap();
263 *registry = Some(GateRegistry::new());
264 }
265
266 #[cfg(test)]
268 pub fn test_lock() -> std::sync::MutexGuard<'static, ()> {
269 use std::sync::Mutex;
270 static TEST_LOCK: Mutex<()> = Mutex::new(());
271 TEST_LOCK.lock().unwrap()
272 }
273}
274
275impl Gate {
279 pub async fn user_allows(ability: &str, resource: Option<&dyn Any>) -> bool {
283 match Self::resolve_user_and_check(ability, resource).await {
284 Ok(response) => response.allowed(),
285 Err(_) => false,
286 }
287 }
288
289 pub async fn user_authorize(
293 ability: &str,
294 resource: Option<&dyn Any>,
295 ) -> Result<(), AuthorizationError> {
296 let response = Self::resolve_user_and_check(ability, resource).await?;
297 if response.allowed() {
298 Ok(())
299 } else {
300 let mut error = AuthorizationError::new(ability);
301 if let Some(msg) = response.message() {
302 error.message = Some(msg.to_string());
303 }
304 error.status = response.status();
305 Err(error)
306 }
307 }
308
309 async fn resolve_user_and_check(
311 ability: &str,
312 resource: Option<&dyn Any>,
313 ) -> Result<AuthResponse, AuthorizationError> {
314 let user = crate::auth::Auth::user()
315 .await
316 .map_err(|_| AuthorizationError::new(ability).with_status(401))?
317 .ok_or_else(|| AuthorizationError::new(ability).with_status(401))?;
318
319 Ok(Self::inspect(user.as_ref(), ability, resource))
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use std::any::Any;
327
328 #[derive(Debug, Clone)]
329 struct TestUser {
330 id: i64,
331 is_admin: bool,
332 }
333
334 impl Authenticatable for TestUser {
335 fn auth_identifier(&self) -> i64 {
336 self.id
337 }
338
339 fn as_any(&self) -> &dyn Any {
340 self
341 }
342 }
343
344 #[test]
345 fn test_define_and_check() {
346 let _guard = Gate::test_lock();
347 Gate::flush();
348
349 Gate::define("test-ability", |user, _| {
350 user.as_any()
351 .downcast_ref::<TestUser>()
352 .map(|u| u.is_admin.into())
353 .unwrap_or_else(AuthResponse::deny_silent)
354 });
355
356 let admin = TestUser {
357 id: 1,
358 is_admin: true,
359 };
360 let regular = TestUser {
361 id: 2,
362 is_admin: false,
363 };
364
365 assert!(Gate::allows_for(&admin, "test-ability", None));
366 assert!(!Gate::allows_for(®ular, "test-ability", None));
367 }
368
369 #[test]
370 fn test_before_hook() {
371 let _guard = Gate::test_lock();
372 Gate::flush();
373
374 Gate::before(|user, _| {
375 if let Some(u) = user.as_any().downcast_ref::<TestUser>() {
376 if u.is_admin {
377 return Some(true);
378 }
379 }
380 None
381 });
382
383 Gate::define("restricted", |_, _| AuthResponse::deny("Always denied"));
385
386 let admin = TestUser {
387 id: 1,
388 is_admin: true,
389 };
390 let regular = TestUser {
391 id: 2,
392 is_admin: false,
393 };
394
395 assert!(Gate::allows_for(&admin, "restricted", None));
397 assert!(!Gate::allows_for(®ular, "restricted", None));
399 }
400
401 #[test]
402 fn test_authorize_for() {
403 let _guard = Gate::test_lock();
404 Gate::flush();
405
406 Gate::define("view-posts", |_, _| AuthResponse::allow());
407 Gate::define("admin-only", |user, _| {
408 user.as_any()
409 .downcast_ref::<TestUser>()
410 .map(|u| {
411 if u.is_admin {
412 AuthResponse::allow()
413 } else {
414 AuthResponse::deny("Admin access required")
415 }
416 })
417 .unwrap_or_else(AuthResponse::deny_silent)
418 });
419
420 let admin = TestUser {
421 id: 1,
422 is_admin: true,
423 };
424 let regular = TestUser {
425 id: 2,
426 is_admin: false,
427 };
428
429 assert!(Gate::authorize_for(&admin, "view-posts", None).is_ok());
431 assert!(Gate::authorize_for(®ular, "view-posts", None).is_ok());
432 assert!(Gate::authorize_for(&admin, "admin-only", None).is_ok());
433
434 let err = Gate::authorize_for(®ular, "admin-only", None).unwrap_err();
436 assert_eq!(err.message, Some("Admin access required".to_string()));
437 }
438
439 #[test]
440 fn test_undefined_ability() {
441 let _guard = Gate::test_lock();
442 Gate::flush();
443
444 let user = TestUser {
445 id: 1,
446 is_admin: false,
447 };
448
449 assert!(!Gate::allows_for(&user, "undefined-ability", None));
451 }
452}