Skip to main content

adk_auth/
middleware.rs

1//! Middleware for integrating access control with adk-core.
2//!
3//! This module provides a `ProtectedTool` wrapper that enforces permissions
4//! before tool execution and optionally logs audit events.
5
6use crate::audit::{AuditEvent, AuditOutcome, AuditSink};
7use crate::{AccessControl, Permission};
8use adk_core::{Result, Tool, ToolContext};
9use async_trait::async_trait;
10use serde_json::Value;
11use std::sync::Arc;
12
13/// A tool wrapper that enforces access control and optionally logs audit events.
14///
15/// Wraps any tool and checks permissions before execution.
16///
17/// # Example
18///
19/// ```rust,ignore
20/// use adk_auth::{AccessControl, ProtectedTool, Permission, Role};
21/// use std::sync::Arc;
22///
23/// let ac = AccessControl::builder()
24///     .role(Role::new("user").allow(Permission::Tool("search".into())))
25///     .build()?;
26///
27/// let protected_search = ProtectedTool::new(search_tool, Arc::new(ac));
28/// ```
29pub struct ProtectedTool<T: Tool> {
30    inner: T,
31    access_control: Arc<AccessControl>,
32    audit_sink: Option<Arc<dyn AuditSink>>,
33}
34
35impl<T: Tool> ProtectedTool<T> {
36    /// Create a new protected tool.
37    pub fn new(tool: T, access_control: Arc<AccessControl>) -> Self {
38        Self { inner: tool, access_control, audit_sink: None }
39    }
40
41    /// Create a new protected tool with audit logging.
42    pub fn with_audit(
43        tool: T,
44        access_control: Arc<AccessControl>,
45        audit_sink: Arc<dyn AuditSink>,
46    ) -> Self {
47        Self { inner: tool, access_control, audit_sink: Some(audit_sink) }
48    }
49}
50
51#[async_trait]
52impl<T: Tool + Send + Sync> Tool for ProtectedTool<T> {
53    fn name(&self) -> &str {
54        self.inner.name()
55    }
56
57    fn description(&self) -> &str {
58        self.inner.description()
59    }
60
61    fn enhanced_description(&self) -> String {
62        self.inner.enhanced_description()
63    }
64
65    fn is_long_running(&self) -> bool {
66        self.inner.is_long_running()
67    }
68
69    fn parameters_schema(&self) -> Option<Value> {
70        self.inner.parameters_schema()
71    }
72
73    fn response_schema(&self) -> Option<Value> {
74        self.inner.response_schema()
75    }
76
77    fn required_scopes(&self) -> &[&str] {
78        self.inner.required_scopes()
79    }
80
81    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
82        let user_id = ctx.user_id();
83        let tool_name = self.name();
84        let permission = Permission::Tool(tool_name.to_string());
85
86        // Check permission
87        let check_result = self.access_control.check(user_id, &permission);
88
89        // Log audit event if sink is configured
90        if let Some(sink) = &self.audit_sink {
91            let outcome =
92                if check_result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
93            let event =
94                AuditEvent::tool_access(user_id, tool_name, outcome).with_session(ctx.session_id());
95
96            // Log asynchronously (don't block on audit failure)
97            let _ = sink.log(event).await;
98        }
99
100        // Return error if access denied
101        check_result.map_err(|e| adk_core::AdkError::Tool(e.to_string()))?;
102
103        // Execute the tool
104        self.inner.execute(ctx, args).await
105    }
106}
107
108/// Extension trait for easily wrapping tools with access control.
109pub trait ToolExt: Tool + Sized {
110    /// Wrap this tool with access control.
111    fn with_access_control(self, ac: Arc<AccessControl>) -> ProtectedTool<Self> {
112        ProtectedTool::new(self, ac)
113    }
114
115    /// Wrap this tool with access control and audit logging.
116    fn with_access_control_and_audit(
117        self,
118        ac: Arc<AccessControl>,
119        audit: Arc<dyn AuditSink>,
120    ) -> ProtectedTool<Self> {
121        ProtectedTool::with_audit(self, ac, audit)
122    }
123}
124
125impl<T: Tool> ToolExt for T {}
126
127/// A collection of auth utilities for integrating with ADK.
128pub struct AuthMiddleware {
129    access_control: Arc<AccessControl>,
130    audit_sink: Option<Arc<dyn AuditSink>>,
131}
132
133impl AuthMiddleware {
134    /// Create a new auth middleware.
135    pub fn new(access_control: AccessControl) -> Self {
136        Self { access_control: Arc::new(access_control), audit_sink: None }
137    }
138
139    /// Create a new auth middleware with audit logging.
140    pub fn with_audit(access_control: AccessControl, audit_sink: impl AuditSink + 'static) -> Self {
141        Self { access_control: Arc::new(access_control), audit_sink: Some(Arc::new(audit_sink)) }
142    }
143
144    /// Get a reference to the access control.
145    pub fn access_control(&self) -> &AccessControl {
146        &self.access_control
147    }
148
149    /// Wrap a tool with access control.
150    pub fn protect<T: Tool>(&self, tool: T) -> ProtectedTool<T> {
151        match &self.audit_sink {
152            Some(sink) => {
153                ProtectedTool::with_audit(tool, self.access_control.clone(), sink.clone())
154            }
155            None => ProtectedTool::new(tool, self.access_control.clone()),
156        }
157    }
158
159    /// Wrap multiple tools with access control.
160    pub fn protect_all(&self, tools: Vec<Arc<dyn Tool>>) -> Vec<Arc<dyn Tool>> {
161        tools
162            .into_iter()
163            .map(|t| {
164                let protected = match &self.audit_sink {
165                    Some(sink) => {
166                        ProtectedToolDyn::with_audit(t, self.access_control.clone(), sink.clone())
167                    }
168                    None => ProtectedToolDyn::new(t, self.access_control.clone()),
169                };
170                Arc::new(protected) as Arc<dyn Tool>
171            })
172            .collect()
173    }
174}
175
176/// Dynamic version of [`ProtectedTool`] for `Arc<dyn Tool>`.
177pub struct ProtectedToolDyn {
178    inner: Arc<dyn Tool>,
179    access_control: Arc<AccessControl>,
180    audit_sink: Option<Arc<dyn AuditSink>>,
181}
182
183impl ProtectedToolDyn {
184    /// Create a new protected dynamic tool.
185    pub fn new(tool: Arc<dyn Tool>, access_control: Arc<AccessControl>) -> Self {
186        Self { inner: tool, access_control, audit_sink: None }
187    }
188
189    /// Create a new protected dynamic tool with audit logging.
190    pub fn with_audit(
191        tool: Arc<dyn Tool>,
192        access_control: Arc<AccessControl>,
193        audit_sink: Arc<dyn AuditSink>,
194    ) -> Self {
195        Self { inner: tool, access_control, audit_sink: Some(audit_sink) }
196    }
197}
198
199#[async_trait]
200impl Tool for ProtectedToolDyn {
201    fn name(&self) -> &str {
202        self.inner.name()
203    }
204
205    fn description(&self) -> &str {
206        self.inner.description()
207    }
208
209    fn enhanced_description(&self) -> String {
210        self.inner.enhanced_description()
211    }
212
213    fn is_long_running(&self) -> bool {
214        self.inner.is_long_running()
215    }
216
217    fn parameters_schema(&self) -> Option<Value> {
218        self.inner.parameters_schema()
219    }
220
221    fn response_schema(&self) -> Option<Value> {
222        self.inner.response_schema()
223    }
224
225    fn required_scopes(&self) -> &[&str] {
226        self.inner.required_scopes()
227    }
228
229    async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
230        let user_id = ctx.user_id();
231        let tool_name = self.name();
232        let permission = Permission::Tool(tool_name.to_string());
233
234        // Check permission
235        let check_result = self.access_control.check(user_id, &permission);
236
237        // Log audit event if sink is configured
238        if let Some(sink) = &self.audit_sink {
239            let outcome =
240                if check_result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
241            let event =
242                AuditEvent::tool_access(user_id, tool_name, outcome).with_session(ctx.session_id());
243
244            // Log asynchronously (don't block on audit failure)
245            let _ = sink.log(event).await;
246        }
247
248        // Return error if access denied
249        check_result.map_err(|e| adk_core::AdkError::Tool(e.to_string()))?;
250
251        // Execute the tool
252        self.inner.execute(ctx, args).await
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::Role;
260
261    // Mock tool for testing
262    struct MockTool {
263        name: String,
264    }
265
266    impl MockTool {
267        fn new(name: &str) -> Self {
268            Self { name: name.to_string() }
269        }
270    }
271
272    #[async_trait]
273    impl Tool for MockTool {
274        fn name(&self) -> &str {
275            &self.name
276        }
277
278        fn description(&self) -> &str {
279            "Mock tool"
280        }
281
282        async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
283            Ok(serde_json::json!({"result": "success"}))
284        }
285    }
286
287    #[test]
288    fn test_tool_ext() {
289        let ac = AccessControl::builder()
290            .role(Role::new("user").allow(Permission::Tool("mock".into())))
291            .build()
292            .unwrap();
293
294        let tool = MockTool::new("mock");
295        let protected = tool.with_access_control(Arc::new(ac));
296
297        assert_eq!(protected.name(), "mock");
298        assert_eq!(protected.description(), "Mock tool");
299    }
300}