use crate::audit::{AuditEvent, AuditOutcome, AuditSink};
use adk_core::{Result, Tool, ToolContext};
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashSet;
use std::sync::Arc;
macro_rules! impl_scoped_tool {
($wrapper:ident<$generic:ident>, $self_ident:ident => $inner:expr) => {
#[async_trait]
impl<$generic: Tool + Send + Sync> Tool for $wrapper<$generic> {
fn name(&self) -> &str {
let $self_ident = self;
($inner).name()
}
fn description(&self) -> &str {
let $self_ident = self;
($inner).description()
}
fn enhanced_description(&self) -> String {
let $self_ident = self;
($inner).enhanced_description()
}
fn is_long_running(&self) -> bool {
let $self_ident = self;
($inner).is_long_running()
}
fn parameters_schema(&self) -> Option<Value> {
let $self_ident = self;
($inner).parameters_schema()
}
fn response_schema(&self) -> Option<Value> {
let $self_ident = self;
($inner).response_schema()
}
fn required_scopes(&self) -> &[&str] {
let $self_ident = self;
($inner).required_scopes()
}
async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
let $self_ident = self;
execute_scoped_tool(
($inner),
self.resolver.as_ref(),
self.audit_sink.as_ref(),
ctx,
args,
)
.await
}
}
};
($wrapper:ty, $self_ident:ident => $inner:expr) => {
#[async_trait]
impl Tool for $wrapper {
fn name(&self) -> &str {
let $self_ident = self;
($inner).name()
}
fn description(&self) -> &str {
let $self_ident = self;
($inner).description()
}
fn enhanced_description(&self) -> String {
let $self_ident = self;
($inner).enhanced_description()
}
fn is_long_running(&self) -> bool {
let $self_ident = self;
($inner).is_long_running()
}
fn parameters_schema(&self) -> Option<Value> {
let $self_ident = self;
($inner).parameters_schema()
}
fn response_schema(&self) -> Option<Value> {
let $self_ident = self;
($inner).response_schema()
}
fn required_scopes(&self) -> &[&str] {
let $self_ident = self;
($inner).required_scopes()
}
async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
let $self_ident = self;
execute_scoped_tool(
($inner),
self.resolver.as_ref(),
self.audit_sink.as_ref(),
ctx,
args,
)
.await
}
}
};
}
#[async_trait]
pub trait ScopeResolver: Send + Sync {
async fn resolve(&self, ctx: &dyn ToolContext) -> Vec<String>;
}
pub struct ContextScopeResolver;
#[async_trait]
impl ScopeResolver for ContextScopeResolver {
async fn resolve(&self, ctx: &dyn ToolContext) -> Vec<String> {
ctx.user_scopes()
}
}
pub struct StaticScopeResolver {
scopes: Vec<String>,
}
impl StaticScopeResolver {
pub fn new(scopes: Vec<impl Into<String>>) -> Self {
Self { scopes: scopes.into_iter().map(Into::into).collect() }
}
}
#[async_trait]
impl ScopeResolver for StaticScopeResolver {
async fn resolve(&self, _ctx: &dyn ToolContext) -> Vec<String> {
self.scopes.clone()
}
}
pub fn check_scopes(required: &[&str], granted: &[String]) -> std::result::Result<(), ScopeDenied> {
if required.is_empty() {
return Ok(());
}
let granted_set: HashSet<&str> = granted.iter().map(String::as_str).collect();
let missing: Vec<String> =
required.iter().filter(|s| !granted_set.contains(**s)).map(|s| s.to_string()).collect();
if missing.is_empty() {
Ok(())
} else {
Err(ScopeDenied { required: required.iter().map(|s| s.to_string()).collect(), missing })
}
}
#[derive(Debug, Clone)]
pub struct ScopeDenied {
pub required: Vec<String>,
pub missing: Vec<String>,
}
impl std::fmt::Display for ScopeDenied {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"missing required scopes: [{}] (tool requires: [{}])",
self.missing.join(", "),
self.required.join(", ")
)
}
}
impl std::error::Error for ScopeDenied {}
pub struct ScopeGuard {
resolver: Arc<dyn ScopeResolver>,
audit_sink: Option<Arc<dyn AuditSink>>,
}
impl ScopeGuard {
pub fn new(resolver: impl ScopeResolver + 'static) -> Self {
Self { resolver: Arc::new(resolver), audit_sink: None }
}
pub fn with_audit(
resolver: impl ScopeResolver + 'static,
audit_sink: impl AuditSink + 'static,
) -> Self {
Self { resolver: Arc::new(resolver), audit_sink: Some(Arc::new(audit_sink)) }
}
pub fn protect<T: Tool + 'static>(&self, tool: T) -> ScopedTool<T> {
ScopedTool {
inner: tool,
resolver: self.resolver.clone(),
audit_sink: self.audit_sink.clone(),
}
}
pub fn protect_all(&self, tools: Vec<Arc<dyn Tool>>) -> Vec<Arc<dyn Tool>> {
tools
.into_iter()
.map(|t| {
let wrapped = ScopedToolDyn {
inner: t,
resolver: self.resolver.clone(),
audit_sink: self.audit_sink.clone(),
};
Arc::new(wrapped) as Arc<dyn Tool>
})
.collect()
}
}
pub struct ScopedTool<T: Tool> {
inner: T,
resolver: Arc<dyn ScopeResolver>,
audit_sink: Option<Arc<dyn AuditSink>>,
}
async fn authorize_tool_scopes(
tool: &dyn Tool,
resolver: &dyn ScopeResolver,
audit_sink: Option<&Arc<dyn AuditSink>>,
ctx: &Arc<dyn ToolContext>,
) -> Result<()> {
let required = tool.required_scopes();
if required.is_empty() {
return Ok(());
}
let granted = resolver.resolve(ctx.as_ref()).await;
let result = check_scopes(required, &granted);
if let Some(sink) = audit_sink {
let outcome = if result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
let event = AuditEvent::tool_access(ctx.user_id(), tool.name(), outcome)
.with_session(ctx.session_id());
let _ = sink.log(event).await;
}
if let Err(denied) = result {
tracing::warn!(
tool.name = %tool.name(),
user.id = %ctx.user_id(),
missing_scopes = ?denied.missing,
"scope check failed"
);
return Err(adk_core::AdkError::tool(denied.to_string()));
}
Ok(())
}
async fn execute_scoped_tool(
inner: &dyn Tool,
resolver: &dyn ScopeResolver,
audit_sink: Option<&Arc<dyn AuditSink>>,
ctx: Arc<dyn ToolContext>,
args: Value,
) -> Result<Value> {
authorize_tool_scopes(inner, resolver, audit_sink, &ctx).await?;
inner.execute(ctx, args).await
}
impl_scoped_tool!(ScopedTool<T>, wrapper => &wrapper.inner);
pub struct ScopedToolDyn {
inner: Arc<dyn Tool>,
resolver: Arc<dyn ScopeResolver>,
audit_sink: Option<Arc<dyn AuditSink>>,
}
impl_scoped_tool!(ScopedToolDyn, wrapper => wrapper.inner.as_ref());
pub trait ScopeToolExt: Tool + Sized {
fn with_scope_guard(self, resolver: impl ScopeResolver + 'static) -> ScopedTool<Self> {
ScopedTool { inner: self, resolver: Arc::new(resolver), audit_sink: None }
}
}
impl<T: Tool> ScopeToolExt for T {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_check_scopes_empty_required() {
assert!(check_scopes(&[], &[]).is_ok());
assert!(check_scopes(&[], &["admin".to_string()]).is_ok());
}
#[test]
fn test_check_scopes_all_granted() {
let granted = vec!["finance:read".to_string(), "finance:write".to_string()];
assert!(check_scopes(&["finance:read", "finance:write"], &granted).is_ok());
}
#[test]
fn test_check_scopes_subset_granted() {
let granted =
vec!["finance:read".to_string(), "finance:write".to_string(), "admin".to_string()];
assert!(check_scopes(&["finance:write"], &granted).is_ok());
}
#[test]
fn test_check_scopes_missing() {
let granted = vec!["finance:read".to_string()];
let err = check_scopes(&["finance:read", "finance:write"], &granted).unwrap_err();
assert_eq!(err.missing, vec!["finance:write"]);
}
#[test]
fn test_check_scopes_none_granted() {
let err = check_scopes(&["admin"], &[]).unwrap_err();
assert_eq!(err.missing, vec!["admin"]);
}
#[test]
fn test_scope_denied_display() {
let denied =
ScopeDenied { required: vec!["a".into(), "b".into()], missing: vec!["b".into()] };
let msg = denied.to_string();
assert!(msg.contains("missing required scopes"));
assert!(msg.contains("b"));
}
#[test]
fn test_static_scope_resolver() {
let resolver = StaticScopeResolver::new(vec!["admin", "finance:write"]);
assert_eq!(resolver.scopes, vec!["admin", "finance:write"]);
}
}