Skip to main content

adk_auth/
scope.rs

1//! Scope-based access control for tools.
2//!
3//! Scopes provide a declarative security model where tools declare what scopes
4//! they require, and the framework automatically enforces them before execution.
5//!
6//! # Overview
7//!
8//! Unlike role-based access control (which maps users → roles → permissions),
9//! scope-based access works at the tool level:
10//!
11//! 1. Tools declare required scopes via [`Tool::required_scopes()`]
12//! 2. User scopes are resolved from session state, JWT claims, or a custom provider
13//! 3. The [`ScopeGuard`] checks that the user has **all** required scopes
14//!
15//! # Example
16//!
17//! ```rust,ignore
18//! use adk_auth::{ScopeGuard, ContextScopeResolver};
19//!
20//! // Tools declare their requirements
21//! let transfer = FunctionTool::new("transfer", "Transfer funds", handler)
22//!     .with_scopes(&["finance:write", "verified"]);
23//!
24//! // Guard enforces scopes automatically
25//! let guard = ScopeGuard::new(ContextScopeResolver);
26//! let protected = guard.protect(transfer);
27//! ```
28
29use crate::audit::{AuditEvent, AuditOutcome, AuditSink};
30use adk_core::{Result, Tool, ToolContext};
31use async_trait::async_trait;
32use serde_json::Value;
33use std::collections::HashSet;
34use std::sync::Arc;
35
36/// Resolves the set of scopes granted to the current user.
37///
38/// Implementations can pull scopes from session state, JWT claims,
39/// an external identity provider, or any other source.
40#[async_trait]
41pub trait ScopeResolver: Send + Sync {
42    /// Returns the scopes granted to the user in the given tool context.
43    async fn resolve(&self, ctx: &dyn ToolContext) -> Vec<String>;
44}
45
46/// Resolves user scopes from the `user_scopes()` method on [`ToolContext`].
47///
48/// This is the default resolver — it delegates directly to the context,
49/// which may pull scopes from JWT claims, session state, or any other source
50/// configured at the context level.
51pub struct ContextScopeResolver;
52
53#[async_trait]
54impl ScopeResolver for ContextScopeResolver {
55    async fn resolve(&self, ctx: &dyn ToolContext) -> Vec<String> {
56        ctx.user_scopes()
57    }
58}
59
60/// A static resolver that always returns a fixed set of scopes.
61///
62/// Useful for testing or when scopes are known at configuration time.
63///
64/// # Example
65///
66/// ```rust,ignore
67/// let resolver = StaticScopeResolver::new(vec!["admin", "finance:write"]);
68/// ```
69pub struct StaticScopeResolver {
70    scopes: Vec<String>,
71}
72
73impl StaticScopeResolver {
74    /// Create a resolver with a fixed set of scopes.
75    pub fn new(scopes: Vec<impl Into<String>>) -> Self {
76        Self { scopes: scopes.into_iter().map(Into::into).collect() }
77    }
78}
79
80#[async_trait]
81impl ScopeResolver for StaticScopeResolver {
82    async fn resolve(&self, _ctx: &dyn ToolContext) -> Vec<String> {
83        self.scopes.clone()
84    }
85}
86
87/// Checks whether a user's scopes satisfy a tool's requirements.
88///
89/// Returns `Ok(())` if the user has all required scopes, or an error
90/// listing the missing scopes.
91pub fn check_scopes(required: &[&str], granted: &[String]) -> std::result::Result<(), ScopeDenied> {
92    if required.is_empty() {
93        return Ok(());
94    }
95
96    let granted_set: HashSet<&str> = granted.iter().map(String::as_str).collect();
97    let missing: Vec<String> =
98        required.iter().filter(|s| !granted_set.contains(**s)).map(|s| s.to_string()).collect();
99
100    if missing.is_empty() {
101        Ok(())
102    } else {
103        Err(ScopeDenied { required: required.iter().map(|s| s.to_string()).collect(), missing })
104    }
105}
106
107/// Error returned when a user lacks required scopes.
108#[derive(Debug, Clone)]
109pub struct ScopeDenied {
110    /// All scopes the tool requires.
111    pub required: Vec<String>,
112    /// Scopes the user is missing.
113    pub missing: Vec<String>,
114}
115
116impl std::fmt::Display for ScopeDenied {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(
119            f,
120            "missing required scopes: [{}] (tool requires: [{}])",
121            self.missing.join(", "),
122            self.required.join(", ")
123        )
124    }
125}
126
127impl std::error::Error for ScopeDenied {}
128
129/// Declarative scope enforcement for tools.
130///
131/// Wraps tools and automatically checks that the user has all scopes
132/// declared by [`Tool::required_scopes()`] before allowing execution.
133///
134/// # Example
135///
136/// ```rust,ignore
137/// use adk_auth::{ScopeGuard, ContextScopeResolver};
138///
139/// let guard = ScopeGuard::new(ContextScopeResolver);
140///
141/// // Wrap a single tool
142/// let protected = guard.protect(my_tool);
143///
144/// // Wrap all tools in a vec
145/// let protected_tools = guard.protect_all(tools);
146/// ```
147pub struct ScopeGuard {
148    resolver: Arc<dyn ScopeResolver>,
149    audit_sink: Option<Arc<dyn AuditSink>>,
150}
151
152impl ScopeGuard {
153    /// Create a scope guard with the given resolver.
154    pub fn new(resolver: impl ScopeResolver + 'static) -> Self {
155        Self { resolver: Arc::new(resolver), audit_sink: None }
156    }
157
158    /// Create a scope guard with audit logging.
159    pub fn with_audit(
160        resolver: impl ScopeResolver + 'static,
161        audit_sink: impl AuditSink + 'static,
162    ) -> Self {
163        Self { resolver: Arc::new(resolver), audit_sink: Some(Arc::new(audit_sink)) }
164    }
165
166    /// Wrap a tool with scope enforcement.
167    ///
168    /// If the tool declares no required scopes, the wrapper is a no-op passthrough.
169    pub fn protect<T: Tool + 'static>(&self, tool: T) -> ScopedTool<T> {
170        ScopedTool {
171            inner: tool,
172            resolver: self.resolver.clone(),
173            audit_sink: self.audit_sink.clone(),
174        }
175    }
176
177    /// Wrap all tools in a vec with scope enforcement.
178    pub fn protect_all(&self, tools: Vec<Arc<dyn Tool>>) -> Vec<Arc<dyn Tool>> {
179        tools
180            .into_iter()
181            .map(|t| {
182                let wrapped = ScopedToolDyn {
183                    inner: t,
184                    resolver: self.resolver.clone(),
185                    audit_sink: self.audit_sink.clone(),
186                };
187                Arc::new(wrapped) as Arc<dyn Tool>
188            })
189            .collect()
190    }
191}
192
193/// A tool wrapper that enforces scope requirements before execution.
194pub struct ScopedTool<T: Tool> {
195    inner: T,
196    resolver: Arc<dyn ScopeResolver>,
197    audit_sink: Option<Arc<dyn AuditSink>>,
198}
199
200#[async_trait]
201impl<T: Tool + Send + Sync> Tool for ScopedTool<T> {
202    fn name(&self) -> &str {
203        self.inner.name()
204    }
205
206    fn description(&self) -> &str {
207        self.inner.description()
208    }
209
210    fn enhanced_description(&self) -> String {
211        self.inner.enhanced_description()
212    }
213
214    fn is_long_running(&self) -> bool {
215        self.inner.is_long_running()
216    }
217
218    fn parameters_schema(&self) -> Option<Value> {
219        self.inner.parameters_schema()
220    }
221
222    fn response_schema(&self) -> Option<Value> {
223        self.inner.response_schema()
224    }
225
226    fn required_scopes(&self) -> &[&str] {
227        self.inner.required_scopes()
228    }
229
230    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
231        let required = self.inner.required_scopes();
232        if !required.is_empty() {
233            let granted = self.resolver.resolve(ctx.as_ref()).await;
234            let result = check_scopes(required, &granted);
235
236            if let Some(sink) = &self.audit_sink {
237                let outcome =
238                    if result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
239                let event = AuditEvent::tool_access(ctx.user_id(), self.name(), outcome)
240                    .with_session(ctx.session_id());
241                let _ = sink.log(event).await;
242            }
243
244            if let Err(denied) = result {
245                tracing::warn!(
246                    tool.name = %self.name(),
247                    user.id = %ctx.user_id(),
248                    missing_scopes = ?denied.missing,
249                    "scope check failed"
250                );
251                return Err(adk_core::AdkError::Tool(denied.to_string()));
252            }
253        }
254
255        self.inner.execute(ctx, args).await
256    }
257}
258
259/// Dynamic version of [`ScopedTool`] for `Arc<dyn Tool>`.
260pub struct ScopedToolDyn {
261    inner: Arc<dyn Tool>,
262    resolver: Arc<dyn ScopeResolver>,
263    audit_sink: Option<Arc<dyn AuditSink>>,
264}
265
266#[async_trait]
267impl Tool for ScopedToolDyn {
268    fn name(&self) -> &str {
269        self.inner.name()
270    }
271
272    fn description(&self) -> &str {
273        self.inner.description()
274    }
275
276    fn enhanced_description(&self) -> String {
277        self.inner.enhanced_description()
278    }
279
280    fn is_long_running(&self) -> bool {
281        self.inner.is_long_running()
282    }
283
284    fn parameters_schema(&self) -> Option<Value> {
285        self.inner.parameters_schema()
286    }
287
288    fn response_schema(&self) -> Option<Value> {
289        self.inner.response_schema()
290    }
291
292    fn required_scopes(&self) -> &[&str] {
293        self.inner.required_scopes()
294    }
295
296    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
297        let required = self.inner.required_scopes();
298        if !required.is_empty() {
299            let granted = self.resolver.resolve(ctx.as_ref()).await;
300            let result = check_scopes(required, &granted);
301
302            if let Some(sink) = &self.audit_sink {
303                let outcome =
304                    if result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
305                let event = AuditEvent::tool_access(ctx.user_id(), self.name(), outcome)
306                    .with_session(ctx.session_id());
307                let _ = sink.log(event).await;
308            }
309
310            if let Err(denied) = result {
311                tracing::warn!(
312                    tool.name = %self.name(),
313                    user.id = %ctx.user_id(),
314                    missing_scopes = ?denied.missing,
315                    "scope check failed"
316                );
317                return Err(adk_core::AdkError::Tool(denied.to_string()));
318            }
319        }
320
321        self.inner.execute(ctx, args).await
322    }
323}
324
325/// Extension trait for easily wrapping tools with scope enforcement.
326pub trait ScopeToolExt: Tool + Sized {
327    /// Wrap this tool with scope enforcement using the given resolver.
328    fn with_scope_guard(self, resolver: impl ScopeResolver + 'static) -> ScopedTool<Self> {
329        ScopedTool { inner: self, resolver: Arc::new(resolver), audit_sink: None }
330    }
331}
332
333impl<T: Tool> ScopeToolExt for T {}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_check_scopes_empty_required() {
341        assert!(check_scopes(&[], &[]).is_ok());
342        assert!(check_scopes(&[], &["admin".to_string()]).is_ok());
343    }
344
345    #[test]
346    fn test_check_scopes_all_granted() {
347        let granted = vec!["finance:read".to_string(), "finance:write".to_string()];
348        assert!(check_scopes(&["finance:read", "finance:write"], &granted).is_ok());
349    }
350
351    #[test]
352    fn test_check_scopes_subset_granted() {
353        let granted =
354            vec!["finance:read".to_string(), "finance:write".to_string(), "admin".to_string()];
355        assert!(check_scopes(&["finance:write"], &granted).is_ok());
356    }
357
358    #[test]
359    fn test_check_scopes_missing() {
360        let granted = vec!["finance:read".to_string()];
361        let err = check_scopes(&["finance:read", "finance:write"], &granted).unwrap_err();
362        assert_eq!(err.missing, vec!["finance:write"]);
363    }
364
365    #[test]
366    fn test_check_scopes_none_granted() {
367        let err = check_scopes(&["admin"], &[]).unwrap_err();
368        assert_eq!(err.missing, vec!["admin"]);
369    }
370
371    #[test]
372    fn test_scope_denied_display() {
373        let denied =
374            ScopeDenied { required: vec!["a".into(), "b".into()], missing: vec!["b".into()] };
375        let msg = denied.to_string();
376        assert!(msg.contains("missing required scopes"));
377        assert!(msg.contains("b"));
378    }
379
380    #[test]
381    fn test_static_scope_resolver() {
382        let resolver = StaticScopeResolver::new(vec!["admin", "finance:write"]);
383        assert_eq!(resolver.scopes, vec!["admin", "finance:write"]);
384    }
385}