1use crate::audit::{AuditEvent, AuditOutcome, AuditSink};
30use adk_core::{Result, Tool, ToolContext};
31use async_trait::async_trait;
32use serde_json::Value;
33use std::collections::HashSet;
34use std::sync::Arc;
35
36macro_rules! impl_scoped_tool {
37 ($wrapper:ident<$generic:ident>, $self_ident:ident => $inner:expr) => {
38 #[async_trait]
39 impl<$generic: Tool + Send + Sync> Tool for $wrapper<$generic> {
40 fn name(&self) -> &str {
41 let $self_ident = self;
42 ($inner).name()
43 }
44
45 fn description(&self) -> &str {
46 let $self_ident = self;
47 ($inner).description()
48 }
49
50 fn enhanced_description(&self) -> String {
51 let $self_ident = self;
52 ($inner).enhanced_description()
53 }
54
55 fn is_long_running(&self) -> bool {
56 let $self_ident = self;
57 ($inner).is_long_running()
58 }
59
60 fn parameters_schema(&self) -> Option<Value> {
61 let $self_ident = self;
62 ($inner).parameters_schema()
63 }
64
65 fn response_schema(&self) -> Option<Value> {
66 let $self_ident = self;
67 ($inner).response_schema()
68 }
69
70 fn required_scopes(&self) -> &[&str] {
71 let $self_ident = self;
72 ($inner).required_scopes()
73 }
74
75 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
76 let $self_ident = self;
77 execute_scoped_tool(
78 ($inner),
79 self.resolver.as_ref(),
80 self.audit_sink.as_ref(),
81 ctx,
82 args,
83 )
84 .await
85 }
86 }
87 };
88 ($wrapper:ty, $self_ident:ident => $inner:expr) => {
89 #[async_trait]
90 impl Tool for $wrapper {
91 fn name(&self) -> &str {
92 let $self_ident = self;
93 ($inner).name()
94 }
95
96 fn description(&self) -> &str {
97 let $self_ident = self;
98 ($inner).description()
99 }
100
101 fn enhanced_description(&self) -> String {
102 let $self_ident = self;
103 ($inner).enhanced_description()
104 }
105
106 fn is_long_running(&self) -> bool {
107 let $self_ident = self;
108 ($inner).is_long_running()
109 }
110
111 fn parameters_schema(&self) -> Option<Value> {
112 let $self_ident = self;
113 ($inner).parameters_schema()
114 }
115
116 fn response_schema(&self) -> Option<Value> {
117 let $self_ident = self;
118 ($inner).response_schema()
119 }
120
121 fn required_scopes(&self) -> &[&str] {
122 let $self_ident = self;
123 ($inner).required_scopes()
124 }
125
126 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
127 let $self_ident = self;
128 execute_scoped_tool(
129 ($inner),
130 self.resolver.as_ref(),
131 self.audit_sink.as_ref(),
132 ctx,
133 args,
134 )
135 .await
136 }
137 }
138 };
139}
140
141#[async_trait]
146pub trait ScopeResolver: Send + Sync {
147 async fn resolve(&self, ctx: &dyn ToolContext) -> Vec<String>;
149}
150
151pub struct ContextScopeResolver;
157
158#[async_trait]
159impl ScopeResolver for ContextScopeResolver {
160 async fn resolve(&self, ctx: &dyn ToolContext) -> Vec<String> {
161 ctx.user_scopes()
162 }
163}
164
165pub struct StaticScopeResolver {
175 scopes: Vec<String>,
176}
177
178impl StaticScopeResolver {
179 pub fn new(scopes: Vec<impl Into<String>>) -> Self {
181 Self { scopes: scopes.into_iter().map(Into::into).collect() }
182 }
183}
184
185#[async_trait]
186impl ScopeResolver for StaticScopeResolver {
187 async fn resolve(&self, _ctx: &dyn ToolContext) -> Vec<String> {
188 self.scopes.clone()
189 }
190}
191
192pub fn check_scopes(required: &[&str], granted: &[String]) -> std::result::Result<(), ScopeDenied> {
197 if required.is_empty() {
198 return Ok(());
199 }
200
201 let granted_set: HashSet<&str> = granted.iter().map(String::as_str).collect();
202 let missing: Vec<String> =
203 required.iter().filter(|s| !granted_set.contains(**s)).map(|s| s.to_string()).collect();
204
205 if missing.is_empty() {
206 Ok(())
207 } else {
208 Err(ScopeDenied { required: required.iter().map(|s| s.to_string()).collect(), missing })
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct ScopeDenied {
215 pub required: Vec<String>,
217 pub missing: Vec<String>,
219}
220
221impl std::fmt::Display for ScopeDenied {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 write!(
224 f,
225 "missing required scopes: [{}] (tool requires: [{}])",
226 self.missing.join(", "),
227 self.required.join(", ")
228 )
229 }
230}
231
232impl std::error::Error for ScopeDenied {}
233
234pub struct ScopeGuard {
253 resolver: Arc<dyn ScopeResolver>,
254 audit_sink: Option<Arc<dyn AuditSink>>,
255}
256
257impl ScopeGuard {
258 pub fn new(resolver: impl ScopeResolver + 'static) -> Self {
260 Self { resolver: Arc::new(resolver), audit_sink: None }
261 }
262
263 pub fn with_audit(
265 resolver: impl ScopeResolver + 'static,
266 audit_sink: impl AuditSink + 'static,
267 ) -> Self {
268 Self { resolver: Arc::new(resolver), audit_sink: Some(Arc::new(audit_sink)) }
269 }
270
271 pub fn protect<T: Tool + 'static>(&self, tool: T) -> ScopedTool<T> {
275 ScopedTool {
276 inner: tool,
277 resolver: self.resolver.clone(),
278 audit_sink: self.audit_sink.clone(),
279 }
280 }
281
282 pub fn protect_all(&self, tools: Vec<Arc<dyn Tool>>) -> Vec<Arc<dyn Tool>> {
284 tools
285 .into_iter()
286 .map(|t| {
287 let wrapped = ScopedToolDyn {
288 inner: t,
289 resolver: self.resolver.clone(),
290 audit_sink: self.audit_sink.clone(),
291 };
292 Arc::new(wrapped) as Arc<dyn Tool>
293 })
294 .collect()
295 }
296}
297
298pub struct ScopedTool<T: Tool> {
300 inner: T,
301 resolver: Arc<dyn ScopeResolver>,
302 audit_sink: Option<Arc<dyn AuditSink>>,
303}
304
305async fn authorize_tool_scopes(
306 tool: &dyn Tool,
307 resolver: &dyn ScopeResolver,
308 audit_sink: Option<&Arc<dyn AuditSink>>,
309 ctx: &Arc<dyn ToolContext>,
310) -> Result<()> {
311 let required = tool.required_scopes();
312 if required.is_empty() {
313 return Ok(());
314 }
315
316 let granted = resolver.resolve(ctx.as_ref()).await;
317 let result = check_scopes(required, &granted);
318
319 if let Some(sink) = audit_sink {
320 let outcome = if result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
321 let event = AuditEvent::tool_access(ctx.user_id(), tool.name(), outcome)
322 .with_session(ctx.session_id());
323 let _ = sink.log(event).await;
324 }
325
326 if let Err(denied) = result {
327 tracing::warn!(
328 tool.name = %tool.name(),
329 user.id = %ctx.user_id(),
330 missing_scopes = ?denied.missing,
331 "scope check failed"
332 );
333 return Err(adk_core::AdkError::Tool(denied.to_string()));
334 }
335
336 Ok(())
337}
338
339async fn execute_scoped_tool(
340 inner: &dyn Tool,
341 resolver: &dyn ScopeResolver,
342 audit_sink: Option<&Arc<dyn AuditSink>>,
343 ctx: Arc<dyn ToolContext>,
344 args: Value,
345) -> Result<Value> {
346 authorize_tool_scopes(inner, resolver, audit_sink, &ctx).await?;
347 inner.execute(ctx, args).await
348}
349
350impl_scoped_tool!(ScopedTool<T>, wrapper => &wrapper.inner);
351
352pub struct ScopedToolDyn {
354 inner: Arc<dyn Tool>,
355 resolver: Arc<dyn ScopeResolver>,
356 audit_sink: Option<Arc<dyn AuditSink>>,
357}
358
359impl_scoped_tool!(ScopedToolDyn, wrapper => wrapper.inner.as_ref());
360
361pub trait ScopeToolExt: Tool + Sized {
363 fn with_scope_guard(self, resolver: impl ScopeResolver + 'static) -> ScopedTool<Self> {
365 ScopedTool { inner: self, resolver: Arc::new(resolver), audit_sink: None }
366 }
367}
368
369impl<T: Tool> ScopeToolExt for T {}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_check_scopes_empty_required() {
377 assert!(check_scopes(&[], &[]).is_ok());
378 assert!(check_scopes(&[], &["admin".to_string()]).is_ok());
379 }
380
381 #[test]
382 fn test_check_scopes_all_granted() {
383 let granted = vec!["finance:read".to_string(), "finance:write".to_string()];
384 assert!(check_scopes(&["finance:read", "finance:write"], &granted).is_ok());
385 }
386
387 #[test]
388 fn test_check_scopes_subset_granted() {
389 let granted =
390 vec!["finance:read".to_string(), "finance:write".to_string(), "admin".to_string()];
391 assert!(check_scopes(&["finance:write"], &granted).is_ok());
392 }
393
394 #[test]
395 fn test_check_scopes_missing() {
396 let granted = vec!["finance:read".to_string()];
397 let err = check_scopes(&["finance:read", "finance:write"], &granted).unwrap_err();
398 assert_eq!(err.missing, vec!["finance:write"]);
399 }
400
401 #[test]
402 fn test_check_scopes_none_granted() {
403 let err = check_scopes(&["admin"], &[]).unwrap_err();
404 assert_eq!(err.missing, vec!["admin"]);
405 }
406
407 #[test]
408 fn test_scope_denied_display() {
409 let denied =
410 ScopeDenied { required: vec!["a".into(), "b".into()], missing: vec!["b".into()] };
411 let msg = denied.to_string();
412 assert!(msg.contains("missing required scopes"));
413 assert!(msg.contains("b"));
414 }
415
416 #[test]
417 fn test_static_scope_resolver() {
418 let resolver = StaticScopeResolver::new(vec!["admin", "finance:write"]);
419 assert_eq!(resolver.scopes, vec!["admin", "finance:write"]);
420 }
421}