1use crate::audit::{AuditEvent, AuditOutcome, AuditSink};
7use crate::{AccessControl, Permission};
8use adk_core::{Result, Tool, ToolContext};
9use async_trait::async_trait;
10use serde_json::Value;
11use std::sync::Arc;
12
13pub struct ProtectedTool<T: Tool> {
30 inner: T,
31 access_control: Arc<AccessControl>,
32 audit_sink: Option<Arc<dyn AuditSink>>,
33}
34
35impl<T: Tool> ProtectedTool<T> {
36 pub fn new(tool: T, access_control: Arc<AccessControl>) -> Self {
38 Self { inner: tool, access_control, audit_sink: None }
39 }
40
41 pub fn with_audit(
43 tool: T,
44 access_control: Arc<AccessControl>,
45 audit_sink: Arc<dyn AuditSink>,
46 ) -> Self {
47 Self { inner: tool, access_control, audit_sink: Some(audit_sink) }
48 }
49}
50
51#[async_trait]
52impl<T: Tool + Send + Sync> Tool for ProtectedTool<T> {
53 fn name(&self) -> &str {
54 self.inner.name()
55 }
56
57 fn description(&self) -> &str {
58 self.inner.description()
59 }
60
61 fn enhanced_description(&self) -> String {
62 self.inner.enhanced_description()
63 }
64
65 fn is_long_running(&self) -> bool {
66 self.inner.is_long_running()
67 }
68
69 fn parameters_schema(&self) -> Option<Value> {
70 self.inner.parameters_schema()
71 }
72
73 fn response_schema(&self) -> Option<Value> {
74 self.inner.response_schema()
75 }
76
77 fn required_scopes(&self) -> &[&str] {
78 self.inner.required_scopes()
79 }
80
81 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
82 let user_id = ctx.user_id();
83 let tool_name = self.name();
84 let permission = Permission::Tool(tool_name.to_string());
85
86 let check_result = self.access_control.check(user_id, &permission);
88
89 if let Some(sink) = &self.audit_sink {
91 let outcome =
92 if check_result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
93 let event =
94 AuditEvent::tool_access(user_id, tool_name, outcome).with_session(ctx.session_id());
95
96 let _ = sink.log(event).await;
98 }
99
100 check_result.map_err(|e| adk_core::AdkError::Tool(e.to_string()))?;
102
103 self.inner.execute(ctx, args).await
105 }
106}
107
108pub trait ToolExt: Tool + Sized {
110 fn with_access_control(self, ac: Arc<AccessControl>) -> ProtectedTool<Self> {
112 ProtectedTool::new(self, ac)
113 }
114
115 fn with_access_control_and_audit(
117 self,
118 ac: Arc<AccessControl>,
119 audit: Arc<dyn AuditSink>,
120 ) -> ProtectedTool<Self> {
121 ProtectedTool::with_audit(self, ac, audit)
122 }
123}
124
125impl<T: Tool> ToolExt for T {}
126
127pub struct AuthMiddleware {
129 access_control: Arc<AccessControl>,
130 audit_sink: Option<Arc<dyn AuditSink>>,
131}
132
133impl AuthMiddleware {
134 pub fn new(access_control: AccessControl) -> Self {
136 Self { access_control: Arc::new(access_control), audit_sink: None }
137 }
138
139 pub fn with_audit(access_control: AccessControl, audit_sink: impl AuditSink + 'static) -> Self {
141 Self { access_control: Arc::new(access_control), audit_sink: Some(Arc::new(audit_sink)) }
142 }
143
144 pub fn access_control(&self) -> &AccessControl {
146 &self.access_control
147 }
148
149 pub fn protect<T: Tool>(&self, tool: T) -> ProtectedTool<T> {
151 match &self.audit_sink {
152 Some(sink) => {
153 ProtectedTool::with_audit(tool, self.access_control.clone(), sink.clone())
154 }
155 None => ProtectedTool::new(tool, self.access_control.clone()),
156 }
157 }
158
159 pub fn protect_all(&self, tools: Vec<Arc<dyn Tool>>) -> Vec<Arc<dyn Tool>> {
161 tools
162 .into_iter()
163 .map(|t| {
164 let protected = match &self.audit_sink {
165 Some(sink) => {
166 ProtectedToolDyn::with_audit(t, self.access_control.clone(), sink.clone())
167 }
168 None => ProtectedToolDyn::new(t, self.access_control.clone()),
169 };
170 Arc::new(protected) as Arc<dyn Tool>
171 })
172 .collect()
173 }
174}
175
176pub struct ProtectedToolDyn {
178 inner: Arc<dyn Tool>,
179 access_control: Arc<AccessControl>,
180 audit_sink: Option<Arc<dyn AuditSink>>,
181}
182
183impl ProtectedToolDyn {
184 pub fn new(tool: Arc<dyn Tool>, access_control: Arc<AccessControl>) -> Self {
186 Self { inner: tool, access_control, audit_sink: None }
187 }
188
189 pub fn with_audit(
191 tool: Arc<dyn Tool>,
192 access_control: Arc<AccessControl>,
193 audit_sink: Arc<dyn AuditSink>,
194 ) -> Self {
195 Self { inner: tool, access_control, audit_sink: Some(audit_sink) }
196 }
197}
198
199#[async_trait]
200impl Tool for ProtectedToolDyn {
201 fn name(&self) -> &str {
202 self.inner.name()
203 }
204
205 fn description(&self) -> &str {
206 self.inner.description()
207 }
208
209 fn enhanced_description(&self) -> String {
210 self.inner.enhanced_description()
211 }
212
213 fn is_long_running(&self) -> bool {
214 self.inner.is_long_running()
215 }
216
217 fn parameters_schema(&self) -> Option<Value> {
218 self.inner.parameters_schema()
219 }
220
221 fn response_schema(&self) -> Option<Value> {
222 self.inner.response_schema()
223 }
224
225 fn required_scopes(&self) -> &[&str] {
226 self.inner.required_scopes()
227 }
228
229 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
230 let user_id = ctx.user_id();
231 let tool_name = self.name();
232 let permission = Permission::Tool(tool_name.to_string());
233
234 let check_result = self.access_control.check(user_id, &permission);
236
237 if let Some(sink) = &self.audit_sink {
239 let outcome =
240 if check_result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
241 let event =
242 AuditEvent::tool_access(user_id, tool_name, outcome).with_session(ctx.session_id());
243
244 let _ = sink.log(event).await;
246 }
247
248 check_result.map_err(|e| adk_core::AdkError::Tool(e.to_string()))?;
250
251 self.inner.execute(ctx, args).await
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::Role;
260
261 struct MockTool {
263 name: String,
264 }
265
266 impl MockTool {
267 fn new(name: &str) -> Self {
268 Self { name: name.to_string() }
269 }
270 }
271
272 #[async_trait]
273 impl Tool for MockTool {
274 fn name(&self) -> &str {
275 &self.name
276 }
277
278 fn description(&self) -> &str {
279 "Mock tool"
280 }
281
282 async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
283 Ok(serde_json::json!({"result": "success"}))
284 }
285 }
286
287 #[test]
288 fn test_tool_ext() {
289 let ac = AccessControl::builder()
290 .role(Role::new("user").allow(Permission::Tool("mock".into())))
291 .build()
292 .unwrap();
293
294 let tool = MockTool::new("mock");
295 let protected = tool.with_access_control(Arc::new(ac));
296
297 assert_eq!(protected.name(), "mock");
298 assert_eq!(protected.description(), "Mock tool");
299 }
300}