1use crate::audit::{AuditEvent, AuditOutcome, AuditSink};
30use adk_core::{Result, Tool, ToolContext};
31use async_trait::async_trait;
32use serde_json::Value;
33use std::collections::HashSet;
34use std::sync::Arc;
35
36#[async_trait]
41pub trait ScopeResolver: Send + Sync {
42 async fn resolve(&self, ctx: &dyn ToolContext) -> Vec<String>;
44}
45
46pub struct ContextScopeResolver;
52
53#[async_trait]
54impl ScopeResolver for ContextScopeResolver {
55 async fn resolve(&self, ctx: &dyn ToolContext) -> Vec<String> {
56 ctx.user_scopes()
57 }
58}
59
60pub struct StaticScopeResolver {
70 scopes: Vec<String>,
71}
72
73impl StaticScopeResolver {
74 pub fn new(scopes: Vec<impl Into<String>>) -> Self {
76 Self { scopes: scopes.into_iter().map(Into::into).collect() }
77 }
78}
79
80#[async_trait]
81impl ScopeResolver for StaticScopeResolver {
82 async fn resolve(&self, _ctx: &dyn ToolContext) -> Vec<String> {
83 self.scopes.clone()
84 }
85}
86
87pub fn check_scopes(required: &[&str], granted: &[String]) -> std::result::Result<(), ScopeDenied> {
92 if required.is_empty() {
93 return Ok(());
94 }
95
96 let granted_set: HashSet<&str> = granted.iter().map(String::as_str).collect();
97 let missing: Vec<String> =
98 required.iter().filter(|s| !granted_set.contains(**s)).map(|s| s.to_string()).collect();
99
100 if missing.is_empty() {
101 Ok(())
102 } else {
103 Err(ScopeDenied { required: required.iter().map(|s| s.to_string()).collect(), missing })
104 }
105}
106
107#[derive(Debug, Clone)]
109pub struct ScopeDenied {
110 pub required: Vec<String>,
112 pub missing: Vec<String>,
114}
115
116impl std::fmt::Display for ScopeDenied {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(
119 f,
120 "missing required scopes: [{}] (tool requires: [{}])",
121 self.missing.join(", "),
122 self.required.join(", ")
123 )
124 }
125}
126
127impl std::error::Error for ScopeDenied {}
128
129pub struct ScopeGuard {
148 resolver: Arc<dyn ScopeResolver>,
149 audit_sink: Option<Arc<dyn AuditSink>>,
150}
151
152impl ScopeGuard {
153 pub fn new(resolver: impl ScopeResolver + 'static) -> Self {
155 Self { resolver: Arc::new(resolver), audit_sink: None }
156 }
157
158 pub fn with_audit(
160 resolver: impl ScopeResolver + 'static,
161 audit_sink: impl AuditSink + 'static,
162 ) -> Self {
163 Self { resolver: Arc::new(resolver), audit_sink: Some(Arc::new(audit_sink)) }
164 }
165
166 pub fn protect<T: Tool + 'static>(&self, tool: T) -> ScopedTool<T> {
170 ScopedTool {
171 inner: tool,
172 resolver: self.resolver.clone(),
173 audit_sink: self.audit_sink.clone(),
174 }
175 }
176
177 pub fn protect_all(&self, tools: Vec<Arc<dyn Tool>>) -> Vec<Arc<dyn Tool>> {
179 tools
180 .into_iter()
181 .map(|t| {
182 let wrapped = ScopedToolDyn {
183 inner: t,
184 resolver: self.resolver.clone(),
185 audit_sink: self.audit_sink.clone(),
186 };
187 Arc::new(wrapped) as Arc<dyn Tool>
188 })
189 .collect()
190 }
191}
192
193pub struct ScopedTool<T: Tool> {
195 inner: T,
196 resolver: Arc<dyn ScopeResolver>,
197 audit_sink: Option<Arc<dyn AuditSink>>,
198}
199
200#[async_trait]
201impl<T: Tool + Send + Sync> Tool for ScopedTool<T> {
202 fn name(&self) -> &str {
203 self.inner.name()
204 }
205
206 fn description(&self) -> &str {
207 self.inner.description()
208 }
209
210 fn enhanced_description(&self) -> String {
211 self.inner.enhanced_description()
212 }
213
214 fn is_long_running(&self) -> bool {
215 self.inner.is_long_running()
216 }
217
218 fn parameters_schema(&self) -> Option<Value> {
219 self.inner.parameters_schema()
220 }
221
222 fn response_schema(&self) -> Option<Value> {
223 self.inner.response_schema()
224 }
225
226 fn required_scopes(&self) -> &[&str] {
227 self.inner.required_scopes()
228 }
229
230 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
231 let required = self.inner.required_scopes();
232 if !required.is_empty() {
233 let granted = self.resolver.resolve(ctx.as_ref()).await;
234 let result = check_scopes(required, &granted);
235
236 if let Some(sink) = &self.audit_sink {
237 let outcome =
238 if result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
239 let event = AuditEvent::tool_access(ctx.user_id(), self.name(), outcome)
240 .with_session(ctx.session_id());
241 let _ = sink.log(event).await;
242 }
243
244 if let Err(denied) = result {
245 tracing::warn!(
246 tool.name = %self.name(),
247 user.id = %ctx.user_id(),
248 missing_scopes = ?denied.missing,
249 "scope check failed"
250 );
251 return Err(adk_core::AdkError::Tool(denied.to_string()));
252 }
253 }
254
255 self.inner.execute(ctx, args).await
256 }
257}
258
259pub struct ScopedToolDyn {
261 inner: Arc<dyn Tool>,
262 resolver: Arc<dyn ScopeResolver>,
263 audit_sink: Option<Arc<dyn AuditSink>>,
264}
265
266#[async_trait]
267impl Tool for ScopedToolDyn {
268 fn name(&self) -> &str {
269 self.inner.name()
270 }
271
272 fn description(&self) -> &str {
273 self.inner.description()
274 }
275
276 fn enhanced_description(&self) -> String {
277 self.inner.enhanced_description()
278 }
279
280 fn is_long_running(&self) -> bool {
281 self.inner.is_long_running()
282 }
283
284 fn parameters_schema(&self) -> Option<Value> {
285 self.inner.parameters_schema()
286 }
287
288 fn response_schema(&self) -> Option<Value> {
289 self.inner.response_schema()
290 }
291
292 fn required_scopes(&self) -> &[&str] {
293 self.inner.required_scopes()
294 }
295
296 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
297 let required = self.inner.required_scopes();
298 if !required.is_empty() {
299 let granted = self.resolver.resolve(ctx.as_ref()).await;
300 let result = check_scopes(required, &granted);
301
302 if let Some(sink) = &self.audit_sink {
303 let outcome =
304 if result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
305 let event = AuditEvent::tool_access(ctx.user_id(), self.name(), outcome)
306 .with_session(ctx.session_id());
307 let _ = sink.log(event).await;
308 }
309
310 if let Err(denied) = result {
311 tracing::warn!(
312 tool.name = %self.name(),
313 user.id = %ctx.user_id(),
314 missing_scopes = ?denied.missing,
315 "scope check failed"
316 );
317 return Err(adk_core::AdkError::Tool(denied.to_string()));
318 }
319 }
320
321 self.inner.execute(ctx, args).await
322 }
323}
324
325pub trait ScopeToolExt: Tool + Sized {
327 fn with_scope_guard(self, resolver: impl ScopeResolver + 'static) -> ScopedTool<Self> {
329 ScopedTool { inner: self, resolver: Arc::new(resolver), audit_sink: None }
330 }
331}
332
333impl<T: Tool> ScopeToolExt for T {}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 #[test]
340 fn test_check_scopes_empty_required() {
341 assert!(check_scopes(&[], &[]).is_ok());
342 assert!(check_scopes(&[], &["admin".to_string()]).is_ok());
343 }
344
345 #[test]
346 fn test_check_scopes_all_granted() {
347 let granted = vec!["finance:read".to_string(), "finance:write".to_string()];
348 assert!(check_scopes(&["finance:read", "finance:write"], &granted).is_ok());
349 }
350
351 #[test]
352 fn test_check_scopes_subset_granted() {
353 let granted =
354 vec!["finance:read".to_string(), "finance:write".to_string(), "admin".to_string()];
355 assert!(check_scopes(&["finance:write"], &granted).is_ok());
356 }
357
358 #[test]
359 fn test_check_scopes_missing() {
360 let granted = vec!["finance:read".to_string()];
361 let err = check_scopes(&["finance:read", "finance:write"], &granted).unwrap_err();
362 assert_eq!(err.missing, vec!["finance:write"]);
363 }
364
365 #[test]
366 fn test_check_scopes_none_granted() {
367 let err = check_scopes(&["admin"], &[]).unwrap_err();
368 assert_eq!(err.missing, vec!["admin"]);
369 }
370
371 #[test]
372 fn test_scope_denied_display() {
373 let denied =
374 ScopeDenied { required: vec!["a".into(), "b".into()], missing: vec!["b".into()] };
375 let msg = denied.to_string();
376 assert!(msg.contains("missing required scopes"));
377 assert!(msg.contains("b"));
378 }
379
380 #[test]
381 fn test_static_scope_resolver() {
382 let resolver = StaticScopeResolver::new(vec!["admin", "finance:write"]);
383 assert_eq!(resolver.scopes, vec!["admin", "finance:write"]);
384 }
385}