1use crate::audit::{AuditEvent, AuditOutcome, AuditSink};
7use crate::{AccessControl, Permission};
8use adk_core::{Result, Tool, ToolContext};
9use async_trait::async_trait;
10use serde_json::Value;
11use std::sync::Arc;
12
13pub struct ProtectedTool<T: Tool> {
30 inner: T,
31 access_control: Arc<AccessControl>,
32 audit_sink: Option<Arc<dyn AuditSink>>,
33}
34
35impl<T: Tool> ProtectedTool<T> {
36 pub fn new(tool: T, access_control: Arc<AccessControl>) -> Self {
38 Self { inner: tool, access_control, audit_sink: None }
39 }
40
41 pub fn with_audit(
43 tool: T,
44 access_control: Arc<AccessControl>,
45 audit_sink: Arc<dyn AuditSink>,
46 ) -> Self {
47 Self { inner: tool, access_control, audit_sink: Some(audit_sink) }
48 }
49}
50
51#[async_trait]
52impl<T: Tool + Send + Sync> Tool for ProtectedTool<T> {
53 fn name(&self) -> &str {
54 self.inner.name()
55 }
56
57 fn description(&self) -> &str {
58 self.inner.description()
59 }
60
61 fn enhanced_description(&self) -> String {
62 self.inner.enhanced_description()
63 }
64
65 fn is_long_running(&self) -> bool {
66 self.inner.is_long_running()
67 }
68
69 fn parameters_schema(&self) -> Option<Value> {
70 self.inner.parameters_schema()
71 }
72
73 fn response_schema(&self) -> Option<Value> {
74 self.inner.response_schema()
75 }
76
77 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
78 let user_id = ctx.user_id();
79 let tool_name = self.name();
80 let permission = Permission::Tool(tool_name.to_string());
81
82 let check_result = self.access_control.check(user_id, &permission);
84
85 if let Some(sink) = &self.audit_sink {
87 let outcome =
88 if check_result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
89 let event =
90 AuditEvent::tool_access(user_id, tool_name, outcome).with_session(ctx.session_id());
91
92 let _ = sink.log(event).await;
94 }
95
96 check_result.map_err(|e| adk_core::AdkError::Tool(e.to_string()))?;
98
99 self.inner.execute(ctx, args).await
101 }
102}
103
104pub trait ToolExt: Tool + Sized {
106 fn with_access_control(self, ac: Arc<AccessControl>) -> ProtectedTool<Self> {
108 ProtectedTool::new(self, ac)
109 }
110
111 fn with_access_control_and_audit(
113 self,
114 ac: Arc<AccessControl>,
115 audit: Arc<dyn AuditSink>,
116 ) -> ProtectedTool<Self> {
117 ProtectedTool::with_audit(self, ac, audit)
118 }
119}
120
121impl<T: Tool> ToolExt for T {}
122
123pub struct AuthMiddleware {
125 access_control: Arc<AccessControl>,
126 audit_sink: Option<Arc<dyn AuditSink>>,
127}
128
129impl AuthMiddleware {
130 pub fn new(access_control: AccessControl) -> Self {
132 Self { access_control: Arc::new(access_control), audit_sink: None }
133 }
134
135 pub fn with_audit(access_control: AccessControl, audit_sink: impl AuditSink + 'static) -> Self {
137 Self { access_control: Arc::new(access_control), audit_sink: Some(Arc::new(audit_sink)) }
138 }
139
140 pub fn access_control(&self) -> &AccessControl {
142 &self.access_control
143 }
144
145 pub fn protect<T: Tool>(&self, tool: T) -> ProtectedTool<T> {
147 match &self.audit_sink {
148 Some(sink) => {
149 ProtectedTool::with_audit(tool, self.access_control.clone(), sink.clone())
150 }
151 None => ProtectedTool::new(tool, self.access_control.clone()),
152 }
153 }
154
155 pub fn protect_all(&self, tools: Vec<Arc<dyn Tool>>) -> Vec<Arc<dyn Tool>> {
157 tools
158 .into_iter()
159 .map(|t| {
160 let protected = match &self.audit_sink {
161 Some(sink) => {
162 ProtectedToolDyn::with_audit(t, self.access_control.clone(), sink.clone())
163 }
164 None => ProtectedToolDyn::new(t, self.access_control.clone()),
165 };
166 Arc::new(protected) as Arc<dyn Tool>
167 })
168 .collect()
169 }
170}
171
172pub struct ProtectedToolDyn {
174 inner: Arc<dyn Tool>,
175 access_control: Arc<AccessControl>,
176 audit_sink: Option<Arc<dyn AuditSink>>,
177}
178
179impl ProtectedToolDyn {
180 pub fn new(tool: Arc<dyn Tool>, access_control: Arc<AccessControl>) -> Self {
182 Self { inner: tool, access_control, audit_sink: None }
183 }
184
185 pub fn with_audit(
187 tool: Arc<dyn Tool>,
188 access_control: Arc<AccessControl>,
189 audit_sink: Arc<dyn AuditSink>,
190 ) -> Self {
191 Self { inner: tool, access_control, audit_sink: Some(audit_sink) }
192 }
193}
194
195#[async_trait]
196impl Tool for ProtectedToolDyn {
197 fn name(&self) -> &str {
198 self.inner.name()
199 }
200
201 fn description(&self) -> &str {
202 self.inner.description()
203 }
204
205 fn enhanced_description(&self) -> String {
206 self.inner.enhanced_description()
207 }
208
209 fn is_long_running(&self) -> bool {
210 self.inner.is_long_running()
211 }
212
213 fn parameters_schema(&self) -> Option<Value> {
214 self.inner.parameters_schema()
215 }
216
217 fn response_schema(&self) -> Option<Value> {
218 self.inner.response_schema()
219 }
220
221 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
222 let user_id = ctx.user_id();
223 let tool_name = self.name();
224 let permission = Permission::Tool(tool_name.to_string());
225
226 let check_result = self.access_control.check(user_id, &permission);
228
229 if let Some(sink) = &self.audit_sink {
231 let outcome =
232 if check_result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
233 let event =
234 AuditEvent::tool_access(user_id, tool_name, outcome).with_session(ctx.session_id());
235
236 let _ = sink.log(event).await;
238 }
239
240 check_result.map_err(|e| adk_core::AdkError::Tool(e.to_string()))?;
242
243 self.inner.execute(ctx, args).await
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use crate::Role;
252
253 struct MockTool {
255 name: String,
256 }
257
258 impl MockTool {
259 fn new(name: &str) -> Self {
260 Self { name: name.to_string() }
261 }
262 }
263
264 #[async_trait]
265 impl Tool for MockTool {
266 fn name(&self) -> &str {
267 &self.name
268 }
269
270 fn description(&self) -> &str {
271 "Mock tool"
272 }
273
274 async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
275 Ok(serde_json::json!({"result": "success"}))
276 }
277 }
278
279 #[test]
280 fn test_tool_ext() {
281 let ac = AccessControl::builder()
282 .role(Role::new("user").allow(Permission::Tool("mock".into())))
283 .build()
284 .unwrap();
285
286 let tool = MockTool::new("mock");
287 let protected = tool.with_access_control(Arc::new(ac));
288
289 assert_eq!(protected.name(), "mock");
290 assert_eq!(protected.description(), "Mock tool");
291 }
292}