1use crate::audit::{AuditEvent, AuditOutcome, AuditSink};
7use crate::{AccessControl, Permission};
8use adk_core::{Result, Tool, ToolContext};
9use async_trait::async_trait;
10use serde_json::Value;
11use std::sync::Arc;
12
13macro_rules! impl_protected_tool {
14 ($wrapper:ident<$generic:ident>, $self_ident:ident => $inner:expr) => {
15 #[async_trait]
16 impl<$generic: Tool + Send + Sync> Tool for $wrapper<$generic> {
17 fn name(&self) -> &str {
18 let $self_ident = self;
19 ($inner).name()
20 }
21
22 fn description(&self) -> &str {
23 let $self_ident = self;
24 ($inner).description()
25 }
26
27 fn enhanced_description(&self) -> String {
28 let $self_ident = self;
29 ($inner).enhanced_description()
30 }
31
32 fn is_long_running(&self) -> bool {
33 let $self_ident = self;
34 ($inner).is_long_running()
35 }
36
37 fn parameters_schema(&self) -> Option<Value> {
38 let $self_ident = self;
39 ($inner).parameters_schema()
40 }
41
42 fn response_schema(&self) -> Option<Value> {
43 let $self_ident = self;
44 ($inner).response_schema()
45 }
46
47 fn required_scopes(&self) -> &[&str] {
48 let $self_ident = self;
49 ($inner).required_scopes()
50 }
51
52 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
53 let $self_ident = self;
54 execute_protected_tool(
55 ($inner),
56 self.access_control.as_ref(),
57 self.audit_sink.as_ref(),
58 ctx,
59 args,
60 )
61 .await
62 }
63 }
64 };
65 ($wrapper:ty, $self_ident:ident => $inner:expr) => {
66 #[async_trait]
67 impl Tool for $wrapper {
68 fn name(&self) -> &str {
69 let $self_ident = self;
70 ($inner).name()
71 }
72
73 fn description(&self) -> &str {
74 let $self_ident = self;
75 ($inner).description()
76 }
77
78 fn enhanced_description(&self) -> String {
79 let $self_ident = self;
80 ($inner).enhanced_description()
81 }
82
83 fn is_long_running(&self) -> bool {
84 let $self_ident = self;
85 ($inner).is_long_running()
86 }
87
88 fn parameters_schema(&self) -> Option<Value> {
89 let $self_ident = self;
90 ($inner).parameters_schema()
91 }
92
93 fn response_schema(&self) -> Option<Value> {
94 let $self_ident = self;
95 ($inner).response_schema()
96 }
97
98 fn required_scopes(&self) -> &[&str] {
99 let $self_ident = self;
100 ($inner).required_scopes()
101 }
102
103 async fn execute(&self, ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
104 let $self_ident = self;
105 execute_protected_tool(
106 ($inner),
107 self.access_control.as_ref(),
108 self.audit_sink.as_ref(),
109 ctx,
110 args,
111 )
112 .await
113 }
114 }
115 };
116}
117
118pub struct ProtectedTool<T: Tool> {
135 inner: T,
136 access_control: Arc<AccessControl>,
137 audit_sink: Option<Arc<dyn AuditSink>>,
138}
139
140async fn authorize_tool_access(
141 tool_name: &str,
142 access_control: &AccessControl,
143 audit_sink: Option<&Arc<dyn AuditSink>>,
144 ctx: &Arc<dyn ToolContext>,
145) -> Result<()> {
146 let permission = Permission::Tool(tool_name.to_string());
147 let check_result = access_control.check(ctx.user_id(), &permission);
148
149 if let Some(sink) = audit_sink {
150 let outcome =
151 if check_result.is_ok() { AuditOutcome::Allowed } else { AuditOutcome::Denied };
152 let event = AuditEvent::tool_access(ctx.user_id(), tool_name, outcome)
153 .with_session(ctx.session_id());
154 let _ = sink.log(event).await;
155 }
156
157 check_result.map_err(|err| adk_core::AdkError::Tool(err.to_string()))
158}
159
160async fn execute_protected_tool(
161 inner: &dyn Tool,
162 access_control: &AccessControl,
163 audit_sink: Option<&Arc<dyn AuditSink>>,
164 ctx: Arc<dyn ToolContext>,
165 args: Value,
166) -> Result<Value> {
167 authorize_tool_access(inner.name(), access_control, audit_sink, &ctx).await?;
168 inner.execute(ctx, args).await
169}
170
171impl<T: Tool> ProtectedTool<T> {
172 pub fn new(tool: T, access_control: Arc<AccessControl>) -> Self {
174 Self { inner: tool, access_control, audit_sink: None }
175 }
176
177 pub fn with_audit(
179 tool: T,
180 access_control: Arc<AccessControl>,
181 audit_sink: Arc<dyn AuditSink>,
182 ) -> Self {
183 Self { inner: tool, access_control, audit_sink: Some(audit_sink) }
184 }
185}
186
187impl_protected_tool!(ProtectedTool<T>, wrapper => &wrapper.inner);
188
189pub trait ToolExt: Tool + Sized {
191 fn with_access_control(self, ac: Arc<AccessControl>) -> ProtectedTool<Self> {
193 ProtectedTool::new(self, ac)
194 }
195
196 fn with_access_control_and_audit(
198 self,
199 ac: Arc<AccessControl>,
200 audit: Arc<dyn AuditSink>,
201 ) -> ProtectedTool<Self> {
202 ProtectedTool::with_audit(self, ac, audit)
203 }
204}
205
206impl<T: Tool> ToolExt for T {}
207
208pub struct AuthMiddleware {
210 access_control: Arc<AccessControl>,
211 audit_sink: Option<Arc<dyn AuditSink>>,
212}
213
214impl AuthMiddleware {
215 pub fn new(access_control: AccessControl) -> Self {
217 Self { access_control: Arc::new(access_control), audit_sink: None }
218 }
219
220 pub fn with_audit(access_control: AccessControl, audit_sink: impl AuditSink + 'static) -> Self {
222 Self { access_control: Arc::new(access_control), audit_sink: Some(Arc::new(audit_sink)) }
223 }
224
225 pub fn access_control(&self) -> &AccessControl {
227 &self.access_control
228 }
229
230 pub fn protect<T: Tool>(&self, tool: T) -> ProtectedTool<T> {
232 match &self.audit_sink {
233 Some(sink) => {
234 ProtectedTool::with_audit(tool, self.access_control.clone(), sink.clone())
235 }
236 None => ProtectedTool::new(tool, self.access_control.clone()),
237 }
238 }
239
240 pub fn protect_all(&self, tools: Vec<Arc<dyn Tool>>) -> Vec<Arc<dyn Tool>> {
242 tools
243 .into_iter()
244 .map(|t| {
245 let protected = match &self.audit_sink {
246 Some(sink) => {
247 ProtectedToolDyn::with_audit(t, self.access_control.clone(), sink.clone())
248 }
249 None => ProtectedToolDyn::new(t, self.access_control.clone()),
250 };
251 Arc::new(protected) as Arc<dyn Tool>
252 })
253 .collect()
254 }
255}
256
257pub struct ProtectedToolDyn {
259 inner: Arc<dyn Tool>,
260 access_control: Arc<AccessControl>,
261 audit_sink: Option<Arc<dyn AuditSink>>,
262}
263
264impl ProtectedToolDyn {
265 pub fn new(tool: Arc<dyn Tool>, access_control: Arc<AccessControl>) -> Self {
267 Self { inner: tool, access_control, audit_sink: None }
268 }
269
270 pub fn with_audit(
272 tool: Arc<dyn Tool>,
273 access_control: Arc<AccessControl>,
274 audit_sink: Arc<dyn AuditSink>,
275 ) -> Self {
276 Self { inner: tool, access_control, audit_sink: Some(audit_sink) }
277 }
278}
279
280impl_protected_tool!(ProtectedToolDyn, wrapper => wrapper.inner.as_ref());
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::Role;
286
287 struct MockTool {
289 name: String,
290 }
291
292 impl MockTool {
293 fn new(name: &str) -> Self {
294 Self { name: name.to_string() }
295 }
296 }
297
298 #[async_trait]
299 impl Tool for MockTool {
300 fn name(&self) -> &str {
301 &self.name
302 }
303
304 fn description(&self) -> &str {
305 "Mock tool"
306 }
307
308 async fn execute(&self, _ctx: Arc<dyn ToolContext>, _args: Value) -> Result<Value> {
309 Ok(serde_json::json!({"result": "success"}))
310 }
311 }
312
313 #[test]
314 fn test_tool_ext() {
315 let ac = AccessControl::builder()
316 .role(Role::new("user").allow(Permission::Tool("mock".into())))
317 .build()
318 .unwrap();
319
320 let tool = MockTool::new("mock");
321 let protected = tool.with_access_control(Arc::new(ac));
322
323 assert_eq!(protected.name(), "mock");
324 assert_eq!(protected.description(), "Mock tool");
325 }
326}