Skip to main content

adk_auth/
middleware.rs

1//! Middleware for integrating access control with adk-core.
2//!
3//! This module provides a `ProtectedTool` wrapper that enforces permissions
4//! before tool execution and optionally logs audit events.
5
6use crate::audit::{AuditEvent, AuditOutcome, AuditSink};
7use crate::{AccessControl, Permission};
8use adk_core::{Result, Tool, ToolContext};
9use async_trait::async_trait;
10use serde_json::Value;
11use std::sync::Arc;
12
13macro_rules! impl_protected_tool {
14    ($wrapper:ident<$generic:ident>, $self_ident:ident => $inner:expr) => {
15        #[async_trait]
16        impl<$generic: Tool + Send + Sync> Tool for $wrapper<$generic> {
17            fn name(&self) -> &str {
18                let $self_ident = self;
19                ($inner).name()
20            }
21
22            fn description(&self) -> &str {
23                let $self_ident = self;
24                ($inner).description()
25            }
26
27            fn enhanced_description(&self) -> String {
28                let $self_ident = self;
29                ($inner).enhanced_description()
30            }
31
32            fn is_long_running(&self) -> bool {
33                let $self_ident = self;
34                ($inner).is_long_running()
35            }
36
37            fn parameters_schema(&self) -> Option<Value> {
38                let $self_ident = self;
39                ($inner).parameters_schema()
40            }
41
42            fn response_schema(&self) -> Option<Value> {
43                let $self_ident = self;
44                ($inner).response_schema()
45            }
46
47            fn required_scopes(&self) -> &[&str] {
48                let $self_ident = self;
49                ($inner).required_scopes()
50            }
51
52            async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
53                let $self_ident = self;
54                execute_protected_tool(
55                    ($inner),
56                    self.access_control.as_ref(),
57                    self.audit_sink.as_ref(),
58                    ctx,
59                    args,
60                )
61                .await
62            }
63        }
64    };
65    ($wrapper:ty, $self_ident:ident => $inner:expr) => {
66        #[async_trait]
67        impl Tool for $wrapper {
68            fn name(&self) -> &str {
69                let $self_ident = self;
70                ($inner).name()
71            }
72
73            fn description(&self) -> &str {
74                let $self_ident = self;
75                ($inner).description()
76            }
77
78            fn enhanced_description(&self) -> String {
79                let $self_ident = self;
80                ($inner).enhanced_description()
81            }
82
83            fn is_long_running(&self) -> bool {
84                let $self_ident = self;
85                ($inner).is_long_running()
86            }
87
88            fn parameters_schema(&self) -> Option<Value> {
89                let $self_ident = self;
90                ($inner).parameters_schema()
91            }
92
93            fn response_schema(&self) -> Option<Value> {
94                let $self_ident = self;
95                ($inner).response_schema()
96            }
97
98            fn required_scopes(&self) -> &[&str] {
99                let $self_ident = self;
100                ($inner).required_scopes()
101            }
102
103            async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
104                let $self_ident = self;
105                execute_protected_tool(
106                    ($inner),
107                    self.access_control.as_ref(),
108                    self.audit_sink.as_ref(),
109                    ctx,
110                    args,
111                )
112                .await
113            }
114        }
115    };
116}
117
118/// A tool wrapper that enforces access control and optionally logs audit events.
119///
120/// Wraps any tool and checks permissions before execution.
121///
122/// # Example
123///
124/// ```rust,ignore
125/// use adk_auth::{AccessControl, ProtectedTool, Permission, Role};
126/// use std::sync::Arc;
127///
128/// let ac = AccessControl::builder()
129///     .role(Role::new("user").allow(Permission::Tool("search".into())))
130///     .build()?;
131///
132/// let protected_search = ProtectedTool::new(search_tool, Arc::new(ac));
133/// ```
134pub struct ProtectedTool<T: Tool> {
135    inner: T,
136    access_control: Arc<AccessControl>,
137    audit_sink: Option<Arc<dyn AuditSink>>,
138}
139
140async fn authorize_tool_access(
141    tool_name: &str,
142    access_control: &AccessControl,
143    audit_sink: Option<&Arc<dyn AuditSink>>,
144    ctx: &Arc<dyn ToolContext>,
145) -> Result<()> {
146    let permission = Permission::Tool(tool_name.to_string());
147    let check_result = access_control.check(ctx.user_id(), &permission);
148
149    if let Some(sink) = audit_sink {
150        let outcome =
151            if check_result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
152        let event = AuditEvent::tool_access(ctx.user_id(), tool_name, outcome)
153            .with_session(ctx.session_id());
154        let _ = sink.log(event).await;
155    }
156
157    check_result.map_err(|err| adk_core::AdkError::Tool(err.to_string()))
158}
159
160async fn execute_protected_tool(
161    inner: &dyn Tool,
162    access_control: &AccessControl,
163    audit_sink: Option<&Arc<dyn AuditSink>>,
164    ctx: Arc<dyn ToolContext>,
165    args: Value,
166) -> Result<Value> {
167    authorize_tool_access(inner.name(), access_control, audit_sink, &ctx).await?;
168    inner.execute(ctx, args).await
169}
170
171impl<T: Tool> ProtectedTool<T> {
172    /// Create a new protected tool.
173    pub fn new(tool: T, access_control: Arc<AccessControl>) -> Self {
174        Self { inner: tool, access_control, audit_sink: None }
175    }
176
177    /// Create a new protected tool with audit logging.
178    pub fn with_audit(
179        tool: T,
180        access_control: Arc<AccessControl>,
181        audit_sink: Arc<dyn AuditSink>,
182    ) -> Self {
183        Self { inner: tool, access_control, audit_sink: Some(audit_sink) }
184    }
185}
186
187impl_protected_tool!(ProtectedTool<T>, wrapper => &wrapper.inner);
188
189/// Extension trait for easily wrapping tools with access control.
190pub trait ToolExt: Tool + Sized {
191    /// Wrap this tool with access control.
192    fn with_access_control(self, ac: Arc<AccessControl>) -> ProtectedTool<Self> {
193        ProtectedTool::new(self, ac)
194    }
195
196    /// Wrap this tool with access control and audit logging.
197    fn with_access_control_and_audit(
198        self,
199        ac: Arc<AccessControl>,
200        audit: Arc<dyn AuditSink>,
201    ) -> ProtectedTool<Self> {
202        ProtectedTool::with_audit(self, ac, audit)
203    }
204}
205
206impl<T: Tool> ToolExt for T {}
207
208/// A collection of auth utilities for integrating with ADK.
209pub struct AuthMiddleware {
210    access_control: Arc<AccessControl>,
211    audit_sink: Option<Arc<dyn AuditSink>>,
212}
213
214impl AuthMiddleware {
215    /// Create a new auth middleware.
216    pub fn new(access_control: AccessControl) -> Self {
217        Self { access_control: Arc::new(access_control), audit_sink: None }
218    }
219
220    /// Create a new auth middleware with audit logging.
221    pub fn with_audit(access_control: AccessControl, audit_sink: impl AuditSink + 'static) -> Self {
222        Self { access_control: Arc::new(access_control), audit_sink: Some(Arc::new(audit_sink)) }
223    }
224
225    /// Get a reference to the access control.
226    pub fn access_control(&self) -> &AccessControl {
227        &self.access_control
228    }
229
230    /// Wrap a tool with access control.
231    pub fn protect<T: Tool>(&self, tool: T) -> ProtectedTool<T> {
232        match &self.audit_sink {
233            Some(sink) => {
234                ProtectedTool::with_audit(tool, self.access_control.clone(), sink.clone())
235            }
236            None => ProtectedTool::new(tool, self.access_control.clone()),
237        }
238    }
239
240    /// Wrap multiple tools with access control.
241    pub fn protect_all(&self, tools: Vec<Arc<dyn Tool>>) -> Vec<Arc<dyn Tool>> {
242        tools
243            .into_iter()
244            .map(|t| {
245                let protected = match &self.audit_sink {
246                    Some(sink) => {
247                        ProtectedToolDyn::with_audit(t, self.access_control.clone(), sink.clone())
248                    }
249                    None => ProtectedToolDyn::new(t, self.access_control.clone()),
250                };
251                Arc::new(protected) as Arc<dyn Tool>
252            })
253            .collect()
254    }
255}
256
257/// Dynamic version of [`ProtectedTool`] for `Arc<dyn Tool>`.
258pub struct ProtectedToolDyn {
259    inner: Arc<dyn Tool>,
260    access_control: Arc<AccessControl>,
261    audit_sink: Option<Arc<dyn AuditSink>>,
262}
263
264impl ProtectedToolDyn {
265    /// Create a new protected dynamic tool.
266    pub fn new(tool: Arc<dyn Tool>, access_control: Arc<AccessControl>) -> Self {
267        Self { inner: tool, access_control, audit_sink: None }
268    }
269
270    /// Create a new protected dynamic tool with audit logging.
271    pub fn with_audit(
272        tool: Arc<dyn Tool>,
273        access_control: Arc<AccessControl>,
274        audit_sink: Arc<dyn AuditSink>,
275    ) -> Self {
276        Self { inner: tool, access_control, audit_sink: Some(audit_sink) }
277    }
278}
279
280impl_protected_tool!(ProtectedToolDyn, wrapper => wrapper.inner.as_ref());
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::Role;
286
287    // Mock tool for testing
288    struct MockTool {
289        name: String,
290    }
291
292    impl MockTool {
293        fn new(name: &str) -> Self {
294            Self { name: name.to_string() }
295        }
296    }
297
298    #[async_trait]
299    impl Tool for MockTool {
300        fn name(&self) -> &str {
301            &self.name
302        }
303
304        fn description(&self) -> &str {
305            "Mock tool"
306        }
307
308        async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
309            Ok(serde_json::json!({"result": "success"}))
310        }
311    }
312
313    #[test]
314    fn test_tool_ext() {
315        let ac = AccessControl::builder()
316            .role(Role::new("user").allow(Permission::Tool("mock".into())))
317            .build()
318            .unwrap();
319
320        let tool = MockTool::new("mock");
321        let protected = tool.with_access_control(Arc::new(ac));
322
323        assert_eq!(protected.name(), "mock");
324        assert_eq!(protected.description(), "Mock tool");
325    }
326}