astrid_capabilities/
validator.rs1use astrid_core::Permission;
6use astrid_crypto::PublicKey;
7
8use crate::error::{CapabilityError, CapabilityResult};
9use crate::store::CapabilityStore;
10use crate::token::CapabilityToken;
11
12#[derive(Debug, Clone)]
14pub enum AuthorizationResult {
15 Authorized {
17 token: Box<CapabilityToken>,
19 },
20 RequiresApproval {
22 resource: String,
24 permission: Permission,
26 },
27}
28
29impl AuthorizationResult {
30 #[must_use]
32 pub fn is_authorized(&self) -> bool {
33 matches!(self, Self::Authorized { .. })
34 }
35
36 #[must_use]
38 pub fn token(&self) -> Option<&CapabilityToken> {
39 match self {
40 Self::Authorized { token } => Some(token),
41 Self::RequiresApproval { .. } => None,
42 }
43 }
44}
45
46pub struct CapabilityValidator<'a> {
48 store: &'a CapabilityStore,
49 trusted_issuers: Vec<PublicKey>,
50}
51
52impl<'a> CapabilityValidator<'a> {
53 #[must_use]
55 pub fn new(store: &'a CapabilityStore) -> Self {
56 Self {
57 store,
58 trusted_issuers: Vec::new(),
59 }
60 }
61
62 #[must_use]
64 pub fn trust_issuer(mut self, issuer: PublicKey) -> Self {
65 self.trusted_issuers.push(issuer);
66 self
67 }
68
69 #[must_use]
71 pub fn check(&self, resource: &str, permission: Permission) -> AuthorizationResult {
72 if let Some(token) = self.store.find_capability(resource, permission) {
73 if self.validate_token(&token).is_ok() {
75 return AuthorizationResult::Authorized {
76 token: Box::new(token),
77 };
78 }
79 }
80
81 AuthorizationResult::RequiresApproval {
82 resource: resource.to_string(),
83 permission,
84 }
85 }
86
87 pub fn validate_token(&self, token: &CapabilityToken) -> CapabilityResult<()> {
94 if token.is_expired() {
96 return Err(CapabilityError::TokenExpired {
97 token_id: token.id.to_string(),
98 });
99 }
100
101 token.verify_signature()?;
103
104 if !self.trusted_issuers.is_empty() && !self.trusted_issuers.contains(&token.issuer) {
106 return Err(CapabilityError::InvalidSignature);
107 }
108
109 Ok(())
110 }
111
112 pub fn validate_by_id(&self, token_id: &astrid_core::TokenId) -> CapabilityResult<()> {
118 let token = self
119 .store
120 .get(token_id)?
121 .ok_or_else(|| CapabilityError::TokenNotFound {
122 token_id: token_id.to_string(),
123 })?;
124
125 self.validate_token(&token)
126 }
127}
128
129#[cfg(test)]
131pub(crate) struct MultiPermissionCheck {
132 checks: Vec<(String, Permission)>,
133}
134
135#[cfg(test)]
136impl MultiPermissionCheck {
137 #[must_use]
139 pub(crate) fn new() -> Self {
140 Self { checks: Vec::new() }
141 }
142
143 #[must_use]
145 pub(crate) fn add(mut self, resource: impl Into<String>, permission: Permission) -> Self {
146 self.checks.push((resource.into(), permission));
147 self
148 }
149
150 #[must_use]
152 pub(crate) fn check_all(
153 &self,
154 validator: &CapabilityValidator<'_>,
155 ) -> Vec<(String, Permission, AuthorizationResult)> {
156 self.checks
157 .iter()
158 .map(|(resource, permission)| {
159 let result = validator.check(resource, *permission);
160 (resource.clone(), *permission, result)
161 })
162 .collect()
163 }
164
165 #[must_use]
167 pub(crate) fn all_authorized(&self, validator: &CapabilityValidator<'_>) -> bool {
168 self.checks
169 .iter()
170 .all(|(resource, permission)| validator.check(resource, *permission).is_authorized())
171 }
172
173 #[must_use]
175 pub(crate) fn needs_approval(
176 &self,
177 validator: &CapabilityValidator<'_>,
178 ) -> Vec<(String, Permission)> {
179 self.checks
180 .iter()
181 .filter(|(resource, permission)| {
182 !validator.check(resource, *permission).is_authorized()
183 })
184 .cloned()
185 .collect()
186 }
187}
188
189#[cfg(test)]
190impl Default for MultiPermissionCheck {
191 fn default() -> Self {
192 Self::new()
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::pattern::ResourcePattern;
200 use crate::token::{AuditEntryId, TokenScope};
201 use astrid_crypto::KeyPair;
202
203 fn test_keypair() -> KeyPair {
204 KeyPair::generate()
205 }
206
207 #[test]
208 fn test_authorization_check() {
209 let store = CapabilityStore::in_memory();
210 let keypair = test_keypair();
211
212 let token = CapabilityToken::create(
213 ResourcePattern::exact("mcp://test:tool").unwrap(),
214 vec![Permission::Invoke],
215 TokenScope::Session,
216 keypair.key_id(),
217 AuditEntryId::new(),
218 &keypair,
219 None,
220 );
221
222 store.add(token).unwrap();
223
224 let validator = CapabilityValidator::new(&store);
225
226 let result = validator.check("mcp://test:tool", Permission::Invoke);
227 assert!(result.is_authorized());
228
229 let result = validator.check("mcp://test:other", Permission::Invoke);
230 assert!(!result.is_authorized());
231 }
232
233 #[test]
234 fn test_trusted_issuer() {
235 let store = CapabilityStore::in_memory();
236 let keypair = test_keypair();
237 let other_keypair = test_keypair();
238
239 let token = CapabilityToken::create(
240 ResourcePattern::exact("mcp://test:tool").unwrap(),
241 vec![Permission::Invoke],
242 TokenScope::Session,
243 keypair.key_id(),
244 AuditEntryId::new(),
245 &keypair,
246 None,
247 );
248
249 store.add(token.clone()).unwrap();
250
251 let validator =
253 CapabilityValidator::new(&store).trust_issuer(other_keypair.export_public_key());
254
255 assert!(validator.validate_token(&token).is_err());
257
258 let validator2 = CapabilityValidator::new(&store).trust_issuer(keypair.export_public_key());
260
261 assert!(validator2.validate_token(&token).is_ok());
262 }
263
264 #[test]
265 fn test_multi_permission_check() {
266 let store = CapabilityStore::in_memory();
267 let keypair = test_keypair();
268
269 let token = CapabilityToken::create(
270 ResourcePattern::exact("mcp://test:read").unwrap(),
271 vec![Permission::Invoke],
272 TokenScope::Session,
273 keypair.key_id(),
274 AuditEntryId::new(),
275 &keypair,
276 None,
277 );
278
279 store.add(token).unwrap();
280
281 let validator = CapabilityValidator::new(&store);
282
283 let check = MultiPermissionCheck::new()
284 .add("mcp://test:read", Permission::Invoke)
285 .add("mcp://test:write", Permission::Invoke);
286
287 assert!(!check.all_authorized(&validator));
288
289 let needs = check.needs_approval(&validator);
290 assert_eq!(needs.len(), 1);
291 assert_eq!(needs[0].0, "mcp://test:write");
292 }
293}