aspect_std/
authorization.rs1use aspect_core::{Aspect, JoinPoint};
4use std::collections::HashSet;
5use std::sync::Arc;
6
7#[derive(Clone)]
30pub struct AuthorizationAspect {
31 required_roles: Arc<HashSet<String>>,
32 role_provider: Arc<dyn Fn() -> HashSet<String> + Send + Sync>,
33 mode: AuthMode,
34}
35
36#[derive(Clone, Copy, Debug, PartialEq)]
38pub enum AuthMode {
39 RequireAll,
41 RequireAny,
43}
44
45impl AuthorizationAspect {
46 pub fn require_role<F>(role: &str, role_provider: F) -> Self
59 where
60 F: Fn() -> HashSet<String> + Send + Sync + 'static,
61 {
62 let mut roles = HashSet::new();
63 roles.insert(role.to_string());
64
65 Self {
66 required_roles: Arc::new(roles),
67 role_provider: Arc::new(role_provider),
68 mode: AuthMode::RequireAll,
69 }
70 }
71
72 pub fn require_roles<F>(roles: &[&str], role_provider: F, mode: AuthMode) -> Self
88 where
89 F: Fn() -> HashSet<String> + Send + Sync + 'static,
90 {
91 let role_set: HashSet<String> = roles.iter().map(|r| r.to_string()).collect();
92
93 Self {
94 required_roles: Arc::new(role_set),
95 role_provider: Arc::new(role_provider),
96 mode,
97 }
98 }
99
100 fn check_authorization(&self) -> Result<(), String> {
102 let current_roles = (self.role_provider)();
103
104 let authorized = match self.mode {
105 AuthMode::RequireAll => {
106 self.required_roles.iter().all(|r| current_roles.contains(r))
108 }
109 AuthMode::RequireAny => {
110 self.required_roles.iter().any(|r| current_roles.contains(r))
112 }
113 };
114
115 if authorized {
116 Ok(())
117 } else {
118 let required: Vec<_> = self.required_roles.iter().cloned().collect();
119 let mode_str = match self.mode {
120 AuthMode::RequireAll => "all",
121 AuthMode::RequireAny => "any",
122 };
123 Err(format!(
124 "Access denied: requires {} of roles {:?}",
125 mode_str, required
126 ))
127 }
128 }
129}
130
131impl Aspect for AuthorizationAspect {
132 fn before(&self, ctx: &JoinPoint) {
133 if let Err(msg) = self.check_authorization() {
134 panic!("Authorization failed for {}: {}", ctx.function_name, msg);
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 fn mock_roles(roles: Vec<&str>) -> HashSet<String> {
144 roles.into_iter().map(|s| s.to_string()).collect()
145 }
146
147 #[test]
148 fn test_require_role_success() {
149 let auth = AuthorizationAspect::require_role("admin", || mock_roles(vec!["admin"]));
150
151 assert!(auth.check_authorization().is_ok());
152 }
153
154 #[test]
155 fn test_require_role_failure() {
156 let auth = AuthorizationAspect::require_role("admin", || mock_roles(vec!["user"]));
157
158 assert!(auth.check_authorization().is_err());
159 }
160
161 #[test]
162 fn test_require_all_success() {
163 let auth = AuthorizationAspect::require_roles(
164 &["admin", "moderator"],
165 || mock_roles(vec!["admin", "moderator", "user"]),
166 AuthMode::RequireAll,
167 );
168
169 assert!(auth.check_authorization().is_ok());
170 }
171
172 #[test]
173 fn test_require_all_failure() {
174 let auth = AuthorizationAspect::require_roles(
175 &["admin", "moderator"],
176 || mock_roles(vec!["admin"]),
177 AuthMode::RequireAll,
178 );
179
180 assert!(auth.check_authorization().is_err());
181 }
182
183 #[test]
184 fn test_require_any_success() {
185 let auth = AuthorizationAspect::require_roles(
186 &["admin", "moderator"],
187 || mock_roles(vec!["moderator"]),
188 AuthMode::RequireAny,
189 );
190
191 assert!(auth.check_authorization().is_ok());
192 }
193
194 #[test]
195 fn test_require_any_failure() {
196 let auth = AuthorizationAspect::require_roles(
197 &["admin", "moderator"],
198 || mock_roles(vec!["user"]),
199 AuthMode::RequireAny,
200 );
201
202 assert!(auth.check_authorization().is_err());
203 }
204
205 #[test]
206 fn test_empty_roles() {
207 let auth = AuthorizationAspect::require_role("admin", || mock_roles(vec![]));
208
209 assert!(auth.check_authorization().is_err());
210 }
211
212 #[test]
213 fn test_multiple_roles_user() {
214 let auth = AuthorizationAspect::require_role("admin", || {
215 mock_roles(vec!["user", "moderator", "admin"])
216 });
217
218 assert!(auth.check_authorization().is_ok());
219 }
220}