use crate::audit::{AuditEvent, AuditOutcome, AuditSink};
use crate::{AccessControl, Permission};
use adk_core::{Result, Tool, ToolContext};
use async_trait::async_trait;
use serde_json::Value;
use std::sync::Arc;
macro_rules! impl_protected_tool {
($wrapper:ident<$generic:ident>, $self_ident:ident => $inner:expr) => {
#[async_trait]
impl<$generic: Tool + Send + Sync> Tool for $wrapper<$generic> {
fn name(&self) -> &str {
let $self_ident = self;
($inner).name()
}
fn description(&self) -> &str {
let $self_ident = self;
($inner).description()
}
fn enhanced_description(&self) -> String {
let $self_ident = self;
($inner).enhanced_description()
}
fn is_long_running(&self) -> bool {
let $self_ident = self;
($inner).is_long_running()
}
fn parameters_schema(&self) -> Option<Value> {
let $self_ident = self;
($inner).parameters_schema()
}
fn response_schema(&self) -> Option<Value> {
let $self_ident = self;
($inner).response_schema()
}
fn required_scopes(&self) -> &[&str] {
let $self_ident = self;
($inner).required_scopes()
}
async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
let $self_ident = self;
execute_protected_tool(
($inner),
self.access_control.as_ref(),
self.audit_sink.as_ref(),
ctx,
args,
)
.await
}
}
};
($wrapper:ty, $self_ident:ident => $inner:expr) => {
#[async_trait]
impl Tool for $wrapper {
fn name(&self) -> &str {
let $self_ident = self;
($inner).name()
}
fn description(&self) -> &str {
let $self_ident = self;
($inner).description()
}
fn enhanced_description(&self) -> String {
let $self_ident = self;
($inner).enhanced_description()
}
fn is_long_running(&self) -> bool {
let $self_ident = self;
($inner).is_long_running()
}
fn parameters_schema(&self) -> Option<Value> {
let $self_ident = self;
($inner).parameters_schema()
}
fn response_schema(&self) -> Option<Value> {
let $self_ident = self;
($inner).response_schema()
}
fn required_scopes(&self) -> &[&str] {
let $self_ident = self;
($inner).required_scopes()
}
async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
let $self_ident = self;
execute_protected_tool(
($inner),
self.access_control.as_ref(),
self.audit_sink.as_ref(),
ctx,
args,
)
.await
}
}
};
}
pub struct ProtectedTool<T: Tool> {
inner: T,
access_control: Arc<AccessControl>,
audit_sink: Option<Arc<dyn AuditSink>>,
}
async fn authorize_tool_access(
tool_name: &str,
access_control: &AccessControl,
audit_sink: Option<&Arc<dyn AuditSink>>,
ctx: &Arc<dyn ToolContext>,
) -> Result<()> {
let permission = Permission::Tool(tool_name.to_string());
let check_result = access_control.check(ctx.user_id(), &permission);
if let Some(sink) = audit_sink {
let outcome =
if check_result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
let event = AuditEvent::tool_access(ctx.user_id(), tool_name, outcome)
.with_session(ctx.session_id());
let _ = sink.log(event).await;
}
check_result.map_err(|err| adk_core::AdkError::tool(err.to_string()))
}
async fn execute_protected_tool(
inner: &dyn Tool,
access_control: &AccessControl,
audit_sink: Option<&Arc<dyn AuditSink>>,
ctx: Arc<dyn ToolContext>,
args: Value,
) -> Result<Value> {
authorize_tool_access(inner.name(), access_control, audit_sink, &ctx).await?;
inner.execute(ctx, args).await
}
impl<T: Tool> ProtectedTool<T> {
pub fn new(tool: T, access_control: Arc<AccessControl>) -> Self {
Self { inner: tool, access_control, audit_sink: None }
}
pub fn with_audit(
tool: T,
access_control: Arc<AccessControl>,
audit_sink: Arc<dyn AuditSink>,
) -> Self {
Self { inner: tool, access_control, audit_sink: Some(audit_sink) }
}
}
impl_protected_tool!(ProtectedTool<T>, wrapper => &wrapper.inner);
pub trait ToolExt: Tool + Sized {
fn with_access_control(self, ac: Arc<AccessControl>) -> ProtectedTool<Self> {
ProtectedTool::new(self, ac)
}
fn with_access_control_and_audit(
self,
ac: Arc<AccessControl>,
audit: Arc<dyn AuditSink>,
) -> ProtectedTool<Self> {
ProtectedTool::with_audit(self, ac, audit)
}
}
impl<T: Tool> ToolExt for T {}
pub struct AuthMiddleware {
access_control: Arc<AccessControl>,
audit_sink: Option<Arc<dyn AuditSink>>,
}
impl AuthMiddleware {
pub fn new(access_control: AccessControl) -> Self {
Self { access_control: Arc::new(access_control), audit_sink: None }
}
pub fn with_audit(access_control: AccessControl, audit_sink: impl AuditSink + 'static) -> Self {
Self { access_control: Arc::new(access_control), audit_sink: Some(Arc::new(audit_sink)) }
}
pub fn access_control(&self) -> &AccessControl {
&self.access_control
}
pub fn protect<T: Tool>(&self, tool: T) -> ProtectedTool<T> {
match &self.audit_sink {
Some(sink) => {
ProtectedTool::with_audit(tool, self.access_control.clone(), sink.clone())
}
None => ProtectedTool::new(tool, self.access_control.clone()),
}
}
pub fn protect_all(&self, tools: Vec<Arc<dyn Tool>>) -> Vec<Arc<dyn Tool>> {
tools
.into_iter()
.map(|t| {
let protected = match &self.audit_sink {
Some(sink) => {
ProtectedToolDyn::with_audit(t, self.access_control.clone(), sink.clone())
}
None => ProtectedToolDyn::new(t, self.access_control.clone()),
};
Arc::new(protected) as Arc<dyn Tool>
})
.collect()
}
}
pub struct ProtectedToolDyn {
inner: Arc<dyn Tool>,
access_control: Arc<AccessControl>,
audit_sink: Option<Arc<dyn AuditSink>>,
}
impl ProtectedToolDyn {
pub fn new(tool: Arc<dyn Tool>, access_control: Arc<AccessControl>) -> Self {
Self { inner: tool, access_control, audit_sink: None }
}
pub fn with_audit(
tool: Arc<dyn Tool>,
access_control: Arc<AccessControl>,
audit_sink: Arc<dyn AuditSink>,
) -> Self {
Self { inner: tool, access_control, audit_sink: Some(audit_sink) }
}
}
impl_protected_tool!(ProtectedToolDyn, wrapper => wrapper.inner.as_ref());
#[cfg(test)]
mod tests {
use super::*;
use crate::Role;
struct MockTool {
name: String,
}
impl MockTool {
fn new(name: &str) -> Self {
Self { name: name.to_string() }
}
}
#[async_trait]
impl Tool for MockTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"Mock tool"
}
async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
Ok(serde_json::json!({"result": "success"}))
}
}
#[test]
fn test_tool_ext() {
let ac = AccessControl::builder()
.role(Role::new("user").allow(Permission::Tool("mock".into())))
.build()
.unwrap();
let tool = MockTool::new("mock");
let protected = tool.with_access_control(Arc::new(ac));
assert_eq!(protected.name(), "mock");
assert_eq!(protected.description(), "Mock tool");
}
}