Skip to main content

adk_auth/
scope.rs

1//! Scope-based access control for tools.
2//!
3//! Scopes provide a declarative security model where tools declare what scopes
4//! they require, and the framework automatically enforces them before execution.
5//!
6//! # Overview
7//!
8//! Unlike role-based access control (which maps users → roles → permissions),
9//! scope-based access works at the tool level:
10//!
11//! 1. Tools declare required scopes via [`Tool::required_scopes()`]
12//! 2. User scopes are resolved from session state, JWT claims, or a custom provider
13//! 3. The [`ScopeGuard`] checks that the user has **all** required scopes
14//!
15//! # Example
16//!
17//! ```rust,ignore
18//! use adk_auth::{ScopeGuard, ContextScopeResolver};
19//!
20//! // Tools declare their requirements
21//! let transfer = FunctionTool::new("transfer", "Transfer funds", handler)
22//!     .with_scopes(&["finance:write", "verified"]);
23//!
24//! // Guard enforces scopes automatically
25//! let guard = ScopeGuard::new(ContextScopeResolver);
26//! let protected = guard.protect(transfer);
27//! ```
28
29use crate::audit::{AuditEvent, AuditOutcome, AuditSink};
30use adk_core::{Result, Tool, ToolContext};
31use async_trait::async_trait;
32use serde_json::Value;
33use std::collections::HashSet;
34use std::sync::Arc;
35
36macro_rules! impl_scoped_tool {
37    ($wrapper:ident<$generic:ident>, $self_ident:ident => $inner:expr) => {
38        #[async_trait]
39        impl<$generic: Tool + Send + Sync> Tool for $wrapper<$generic> {
40            fn name(&self) -> &str {
41                let $self_ident = self;
42                ($inner).name()
43            }
44
45            fn description(&self) -> &str {
46                let $self_ident = self;
47                ($inner).description()
48            }
49
50            fn enhanced_description(&self) -> String {
51                let $self_ident = self;
52                ($inner).enhanced_description()
53            }
54
55            fn is_long_running(&self) -> bool {
56                let $self_ident = self;
57                ($inner).is_long_running()
58            }
59
60            fn parameters_schema(&self) -> Option<Value> {
61                let $self_ident = self;
62                ($inner).parameters_schema()
63            }
64
65            fn response_schema(&self) -> Option<Value> {
66                let $self_ident = self;
67                ($inner).response_schema()
68            }
69
70            fn required_scopes(&self) -> &[&str] {
71                let $self_ident = self;
72                ($inner).required_scopes()
73            }
74
75            async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
76                let $self_ident = self;
77                execute_scoped_tool(
78                    ($inner),
79                    self.resolver.as_ref(),
80                    self.audit_sink.as_ref(),
81                    ctx,
82                    args,
83                )
84                .await
85            }
86        }
87    };
88    ($wrapper:ty, $self_ident:ident => $inner:expr) => {
89        #[async_trait]
90        impl Tool for $wrapper {
91            fn name(&self) -> &str {
92                let $self_ident = self;
93                ($inner).name()
94            }
95
96            fn description(&self) -> &str {
97                let $self_ident = self;
98                ($inner).description()
99            }
100
101            fn enhanced_description(&self) -> String {
102                let $self_ident = self;
103                ($inner).enhanced_description()
104            }
105
106            fn is_long_running(&self) -> bool {
107                let $self_ident = self;
108                ($inner).is_long_running()
109            }
110
111            fn parameters_schema(&self) -> Option<Value> {
112                let $self_ident = self;
113                ($inner).parameters_schema()
114            }
115
116            fn response_schema(&self) -> Option<Value> {
117                let $self_ident = self;
118                ($inner).response_schema()
119            }
120
121            fn required_scopes(&self) -> &[&str] {
122                let $self_ident = self;
123                ($inner).required_scopes()
124            }
125
126            async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
127                let $self_ident = self;
128                execute_scoped_tool(
129                    ($inner),
130                    self.resolver.as_ref(),
131                    self.audit_sink.as_ref(),
132                    ctx,
133                    args,
134                )
135                .await
136            }
137        }
138    };
139}
140
141/// Resolves the set of scopes granted to the current user.
142///
143/// Implementations can pull scopes from session state, JWT claims,
144/// an external identity provider, or any other source.
145#[async_trait]
146pub trait ScopeResolver: Send + Sync {
147    /// Returns the scopes granted to the user in the given tool context.
148    async fn resolve(&self, ctx: &dyn ToolContext) -> Vec<String>;
149}
150
151/// Resolves user scopes from the `user_scopes()` method on [`ToolContext`].
152///
153/// This is the default resolver — it delegates directly to the context,
154/// which may pull scopes from JWT claims, session state, or any other source
155/// configured at the context level.
156pub struct ContextScopeResolver;
157
158#[async_trait]
159impl ScopeResolver for ContextScopeResolver {
160    async fn resolve(&self, ctx: &dyn ToolContext) -> Vec<String> {
161        ctx.user_scopes()
162    }
163}
164
165/// A static resolver that always returns a fixed set of scopes.
166///
167/// Useful for testing or when scopes are known at configuration time.
168///
169/// # Example
170///
171/// ```rust,ignore
172/// let resolver = StaticScopeResolver::new(vec!["admin", "finance:write"]);
173/// ```
174pub struct StaticScopeResolver {
175    scopes: Vec<String>,
176}
177
178impl StaticScopeResolver {
179    /// Create a resolver with a fixed set of scopes.
180    pub fn new(scopes: Vec<impl Into<String>>) -> Self {
181        Self { scopes: scopes.into_iter().map(Into::into).collect() }
182    }
183}
184
185#[async_trait]
186impl ScopeResolver for StaticScopeResolver {
187    async fn resolve(&self, _ctx: &dyn ToolContext) -> Vec<String> {
188        self.scopes.clone()
189    }
190}
191
192/// Checks whether a user's scopes satisfy a tool's requirements.
193///
194/// Returns `Ok(())` if the user has all required scopes, or an error
195/// listing the missing scopes.
196pub fn check_scopes(required: &[&str], granted: &[String]) -> std::result::Result<(), ScopeDenied> {
197    if required.is_empty() {
198        return Ok(());
199    }
200
201    let granted_set: HashSet<&str> = granted.iter().map(String::as_str).collect();
202    let missing: Vec<String> =
203        required.iter().filter(|s| !granted_set.contains(**s)).map(|s| s.to_string()).collect();
204
205    if missing.is_empty() {
206        Ok(())
207    } else {
208        Err(ScopeDenied { required: required.iter().map(|s| s.to_string()).collect(), missing })
209    }
210}
211
212/// Error returned when a user lacks required scopes.
213#[derive(Debug, Clone)]
214pub struct ScopeDenied {
215    /// All scopes the tool requires.
216    pub required: Vec<String>,
217    /// Scopes the user is missing.
218    pub missing: Vec<String>,
219}
220
221impl std::fmt::Display for ScopeDenied {
222    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223        write!(
224            f,
225            "missing required scopes: [{}] (tool requires: [{}])",
226            self.missing.join(", "),
227            self.required.join(", ")
228        )
229    }
230}
231
232impl std::error::Error for ScopeDenied {}
233
234/// Declarative scope enforcement for tools.
235///
236/// Wraps tools and automatically checks that the user has all scopes
237/// declared by [`Tool::required_scopes()`] before allowing execution.
238///
239/// # Example
240///
241/// ```rust,ignore
242/// use adk_auth::{ScopeGuard, ContextScopeResolver};
243///
244/// let guard = ScopeGuard::new(ContextScopeResolver);
245///
246/// // Wrap a single tool
247/// let protected = guard.protect(my_tool);
248///
249/// // Wrap all tools in a vec
250/// let protected_tools = guard.protect_all(tools);
251/// ```
252pub struct ScopeGuard {
253    resolver: Arc<dyn ScopeResolver>,
254    audit_sink: Option<Arc<dyn AuditSink>>,
255}
256
257impl ScopeGuard {
258    /// Create a scope guard with the given resolver.
259    pub fn new(resolver: impl ScopeResolver + 'static) -> Self {
260        Self { resolver: Arc::new(resolver), audit_sink: None }
261    }
262
263    /// Create a scope guard with audit logging.
264    pub fn with_audit(
265        resolver: impl ScopeResolver + 'static,
266        audit_sink: impl AuditSink + 'static,
267    ) -> Self {
268        Self { resolver: Arc::new(resolver), audit_sink: Some(Arc::new(audit_sink)) }
269    }
270
271    /// Wrap a tool with scope enforcement.
272    ///
273    /// If the tool declares no required scopes, the wrapper is a no-op passthrough.
274    pub fn protect<T: Tool + 'static>(&self, tool: T) -> ScopedTool<T> {
275        ScopedTool {
276            inner: tool,
277            resolver: self.resolver.clone(),
278            audit_sink: self.audit_sink.clone(),
279        }
280    }
281
282    /// Wrap all tools in a vec with scope enforcement.
283    pub fn protect_all(&self, tools: Vec<Arc<dyn Tool>>) -> Vec<Arc<dyn Tool>> {
284        tools
285            .into_iter()
286            .map(|t| {
287                let wrapped = ScopedToolDyn {
288                    inner: t,
289                    resolver: self.resolver.clone(),
290                    audit_sink: self.audit_sink.clone(),
291                };
292                Arc::new(wrapped) as Arc<dyn Tool>
293            })
294            .collect()
295    }
296}
297
298/// A tool wrapper that enforces scope requirements before execution.
299pub struct ScopedTool<T: Tool> {
300    inner: T,
301    resolver: Arc<dyn ScopeResolver>,
302    audit_sink: Option<Arc<dyn AuditSink>>,
303}
304
305async fn authorize_tool_scopes(
306    tool: &dyn Tool,
307    resolver: &dyn ScopeResolver,
308    audit_sink: Option<&Arc<dyn AuditSink>>,
309    ctx: &Arc<dyn ToolContext>,
310) -> Result<()> {
311    let required = tool.required_scopes();
312    if required.is_empty() {
313        return Ok(());
314    }
315
316    let granted = resolver.resolve(ctx.as_ref()).await;
317    let result = check_scopes(required, &granted);
318
319    if let Some(sink) = audit_sink {
320        let outcome = if result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
321        let event = AuditEvent::tool_access(ctx.user_id(), tool.name(), outcome)
322            .with_session(ctx.session_id());
323        let _ = sink.log(event).await;
324    }
325
326    if let Err(denied) = result {
327        tracing::warn!(
328            tool.name = %tool.name(),
329            user.id = %ctx.user_id(),
330            missing_scopes = ?denied.missing,
331            "scope check failed"
332        );
333        return Err(adk_core::AdkError::Tool(denied.to_string()));
334    }
335
336    Ok(())
337}
338
339async fn execute_scoped_tool(
340    inner: &dyn Tool,
341    resolver: &dyn ScopeResolver,
342    audit_sink: Option<&Arc<dyn AuditSink>>,
343    ctx: Arc<dyn ToolContext>,
344    args: Value,
345) -> Result<Value> {
346    authorize_tool_scopes(inner, resolver, audit_sink, &ctx).await?;
347    inner.execute(ctx, args).await
348}
349
350impl_scoped_tool!(ScopedTool<T>, wrapper => &wrapper.inner);
351
352/// Dynamic version of [`ScopedTool`] for `Arc<dyn Tool>`.
353pub struct ScopedToolDyn {
354    inner: Arc<dyn Tool>,
355    resolver: Arc<dyn ScopeResolver>,
356    audit_sink: Option<Arc<dyn AuditSink>>,
357}
358
359impl_scoped_tool!(ScopedToolDyn, wrapper => wrapper.inner.as_ref());
360
361/// Extension trait for easily wrapping tools with scope enforcement.
362pub trait ScopeToolExt: Tool + Sized {
363    /// Wrap this tool with scope enforcement using the given resolver.
364    fn with_scope_guard(self, resolver: impl ScopeResolver + 'static) -> ScopedTool<Self> {
365        ScopedTool { inner: self, resolver: Arc::new(resolver), audit_sink: None }
366    }
367}
368
369impl<T: Tool> ScopeToolExt for T {}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_check_scopes_empty_required() {
377        assert!(check_scopes(&[], &[]).is_ok());
378        assert!(check_scopes(&[], &["admin".to_string()]).is_ok());
379    }
380
381    #[test]
382    fn test_check_scopes_all_granted() {
383        let granted = vec!["finance:read".to_string(), "finance:write".to_string()];
384        assert!(check_scopes(&["finance:read", "finance:write"], &granted).is_ok());
385    }
386
387    #[test]
388    fn test_check_scopes_subset_granted() {
389        let granted =
390            vec!["finance:read".to_string(), "finance:write".to_string(), "admin".to_string()];
391        assert!(check_scopes(&["finance:write"], &granted).is_ok());
392    }
393
394    #[test]
395    fn test_check_scopes_missing() {
396        let granted = vec!["finance:read".to_string()];
397        let err = check_scopes(&["finance:read", "finance:write"], &granted).unwrap_err();
398        assert_eq!(err.missing, vec!["finance:write"]);
399    }
400
401    #[test]
402    fn test_check_scopes_none_granted() {
403        let err = check_scopes(&["admin"], &[]).unwrap_err();
404        assert_eq!(err.missing, vec!["admin"]);
405    }
406
407    #[test]
408    fn test_scope_denied_display() {
409        let denied =
410            ScopeDenied { required: vec!["a".into(), "b".into()], missing: vec!["b".into()] };
411        let msg = denied.to_string();
412        assert!(msg.contains("missing required scopes"));
413        assert!(msg.contains("b"));
414    }
415
416    #[test]
417    fn test_static_scope_resolver() {
418        let resolver = StaticScopeResolver::new(vec!["admin", "finance:write"]);
419        assert_eq!(resolver.scopes, vec!["admin", "finance:write"]);
420    }
421}