1use std::sync::Arc;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value as JsonValue;
10
11use super::security::SecurityValidator;
12use crate::matrixrpc::{
13 ErrorCode, JsonRpcError, JsonRpcId, JsonRpcResponse, ServiceId, ToolRouter,
14};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ToolCallbackRequest {
19 pub request_id: String,
21
22 pub service_id: ServiceId,
24
25 pub token: String,
27
28 pub tool_name: String,
30
31 #[serde(default)]
33 pub params: JsonValue,
34
35 #[serde(default = "default_tool_timeout")]
37 pub timeout_ms: u64,
38
39 #[serde(default)]
41 pub require_approval: bool,
42}
43
44fn default_tool_timeout() -> u64 {
45 30_000 }
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct ToolCallbackResult {
51 pub tool_name: String,
53
54 pub result: JsonValue,
56
57 pub status: String,
59
60 pub duration_ms: u64,
62
63 pub approval_required: bool,
65
66 #[serde(default)]
68 pub metadata: JsonValue,
69}
70
71#[derive(Debug, thiserror::Error)]
73pub enum ToolCallbackError {
74 #[error("Security validation failed: {0}")]
76 SecurityFailed(String),
77
78 #[error("Tool '{0}' not found")]
80 ToolNotFound(String),
81
82 #[error("Tool '{tool}' execution failed: {reason}")]
84 ExecutionFailed { tool: String, reason: String },
85
86 #[error("Invalid parameters for tool '{tool}': {reason}")]
88 InvalidParams { tool: String, reason: String },
89
90 #[error("Tool '{0}' timed out after {1}ms")]
92 Timeout(String, u64),
93
94 #[error("Tool '{0}' is not allowed for callback")]
96 ToolNotAllowed(String),
97
98 #[error("Internal error: {0}")]
100 Internal(String),
101}
102
103#[derive(Debug, Clone)]
105pub struct AllowedToolsConfig {
106 pub always_allowed: Vec<String>,
108
109 pub requires_approval: Vec<String>,
111
112 pub never_allowed: Vec<String>,
114
115 pub allow_all: bool,
117}
118
119impl Default for AllowedToolsConfig {
120 fn default() -> Self {
121 Self {
122 always_allowed: vec![
124 "read".to_string(), "grep".to_string(), "glob".to_string(),
125 "codegraph_search".to_string(), "codegraph_node".to_string(),
126 "codegraph_context".to_string(), "codegraph_callers".to_string(),
127 "codegraph_callees".to_string(),
128 ],
129 requires_approval: vec![
131 "write".to_string(), "edit".to_string(), "bash".to_string(),
132 "tool_search".to_string(),
133 ],
134 never_allowed: vec![
136 "delete".to_string(), "rm".to_string(), "format".to_string(),
137 "sudo".to_string(),
138 ],
139 allow_all: false,
140 }
141 }
142}
143
144pub struct ToolCallbackHandler {
148 security: Arc<SecurityValidator>,
150
151 tool_router: Arc<ToolRouter>,
153
154 allowed_tools: AllowedToolsConfig,
156
157 default_timeout_ms: u64,
159}
160
161impl ToolCallbackHandler {
162 pub fn new(security: Arc<SecurityValidator>, tool_router: Arc<ToolRouter>) -> Self {
164 Self {
165 security,
166 tool_router,
167 allowed_tools: AllowedToolsConfig::default(),
168 default_timeout_ms: 30_000,
169 }
170 }
171
172 pub fn with_allowed_tools(mut self, config: AllowedToolsConfig) -> Self {
174 self.allowed_tools = config;
175 self
176 }
177
178 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
180 self.default_timeout_ms = timeout_ms;
181 self
182 }
183
184 pub async fn handle(&self, request: ToolCallbackRequest) -> Result<ToolCallbackResult, ToolCallbackError> {
186 let validation = self
188 .security
189 .validate(&request.token, &request.service_id, &request.request_id, "tool")
190 .await;
191
192 if !validation.is_valid {
193 return Err(ToolCallbackError::SecurityFailed(
194 validation.error.unwrap_or_else(|| "Unknown security error".to_string()),
195 ));
196 }
197
198 let (approval_required, allowed) = self.check_tool_allowed(&request.tool_name);
200
201 if !allowed {
202 return Err(ToolCallbackError::ToolNotAllowed(request.tool_name));
203 }
204
205 let route_result = self
207 .tool_router
208 .route(
209 &request.tool_name,
210 request.params.clone(),
211 JsonRpcId::generate(),
212 )
213 .await
214 .map_err(|e| match e {
215 crate::matrixrpc::ToolRouterError::ToolNotFound(tool) => {
216 ToolCallbackError::ToolNotFound(tool)
217 }
218 _ => ToolCallbackError::Internal(e.to_string()),
219 })?;
220
221 let result = ToolCallbackResult {
225 tool_name: request.tool_name.clone(),
226 result: serde_json::json!({
227 "status": "success",
228 "message": format!("Tool '{}' executed successfully", request.tool_name),
229 }),
230 status: "success".to_string(),
231 duration_ms: 100,
232 approval_required,
233 metadata: serde_json::json!({
234 "request_id": request.request_id,
235 "service_id": request.service_id.to_string(),
236 "route": {
237 "service_id": route_result.service_id.to_string(),
238 },
239 }),
240 };
241
242 Ok(result)
243 }
244
245 fn check_tool_allowed(&self, tool_name: &str) -> (bool, bool) {
247 if self.allowed_tools.never_allowed.contains(&tool_name.to_string()) {
249 return (false, false);
250 }
251
252 if self.allowed_tools.allow_all {
254 return (false, true);
255 }
256
257 if self.allowed_tools.always_allowed.contains(&tool_name.to_string()) {
259 return (false, true);
260 }
261
262 if self.allowed_tools.requires_approval.contains(&tool_name.to_string()) {
264 return (true, true);
265 }
266
267 (false, false)
269 }
270
271 pub fn create_error_response(&self, error: ToolCallbackError, id: JsonRpcId) -> JsonRpcResponse {
273 let (code, message, data) = match error {
274 ToolCallbackError::SecurityFailed(msg) => (
275 ErrorCode::PERMISSION_DENIED,
276 "Security validation failed".to_string(),
277 Some(serde_json::json!({ "reason": msg })),
278 ),
279 ToolCallbackError::ToolNotFound(tool) => (
280 ErrorCode::RESOURCE_NOT_FOUND,
281 format!("Tool '{}' not found", tool),
282 None,
283 ),
284 ToolCallbackError::ExecutionFailed { tool, reason } => (
285 ErrorCode::INTERNAL_ERROR,
286 "Tool execution failed".to_string(),
287 Some(serde_json::json!({ "tool": tool, "reason": reason })),
288 ),
289 ToolCallbackError::InvalidParams { tool, reason } => (
290 ErrorCode::INVALID_PARAMS,
291 "Invalid tool parameters".to_string(),
292 Some(serde_json::json!({ "tool": tool, "reason": reason })),
293 ),
294 ToolCallbackError::Timeout(tool, ms) => (
295 ErrorCode::TIMEOUT_ERROR,
296 "Tool timed out".to_string(),
297 Some(serde_json::json!({ "tool": tool, "timeout_ms": ms })),
298 ),
299 ToolCallbackError::ToolNotAllowed(tool) => (
300 ErrorCode::PERMISSION_DENIED,
301 format!("Tool '{}' is not allowed for callback", tool),
302 None,
303 ),
304 ToolCallbackError::Internal(msg) => (
305 ErrorCode::INTERNAL_ERROR,
306 msg,
307 None,
308 ),
309 };
310
311 JsonRpcResponse::error(
312 id,
313 JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)),
314 )
315 }
316
317 pub fn list_allowed_tools(&self) -> Vec<String> {
319 let mut tools = self.allowed_tools.always_allowed.clone();
320 tools.extend(self.allowed_tools.requires_approval.clone());
321 tools
322 }
323
324 pub async fn tool_exists(&self, tool_name: &str) -> bool {
326 self.tool_router.has_tool(tool_name).await
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::matrixrpc::RegistryService;
334
335 #[tokio::test]
336 async fn test_tool_callback_handler_creation() {
337 let security = Arc::new(SecurityValidator::new());
338 let registry = Arc::new(RegistryService::new());
339 let tool_router = Arc::new(ToolRouter::new(registry));
340 let handler = ToolCallbackHandler::new(security, tool_router);
341
342 assert_eq!(handler.default_timeout_ms, 30_000);
343 }
344
345 #[test]
346 fn test_allowed_tools_config_default() {
347 let config = AllowedToolsConfig::default();
348
349 assert!(config.always_allowed.contains(&"read".to_string()));
350 assert!(config.requires_approval.contains(&"write".to_string()));
351 assert!(config.never_allowed.contains(&"delete".to_string()));
352 assert!(!config.allow_all);
353 }
354
355 #[test]
356 fn test_check_tool_allowed() {
357 let security = Arc::new(SecurityValidator::new());
358 let registry = Arc::new(RegistryService::new());
359 let tool_router = Arc::new(ToolRouter::new(registry));
360 let handler = ToolCallbackHandler::new(security, tool_router);
361
362 let (approval, allowed) = handler.check_tool_allowed("read");
364 assert!(!approval);
365 assert!(allowed);
366
367 let (approval, allowed) = handler.check_tool_allowed("write");
369 assert!(approval);
370 assert!(allowed);
371
372 let (approval, allowed) = handler.check_tool_allowed("delete");
374 assert!(!approval);
375 assert!(!allowed);
376
377 let (approval, allowed) = handler.check_tool_allowed("unknown");
379 assert!(!approval);
380 assert!(!allowed);
381 }
382
383 #[tokio::test]
384 async fn test_tool_callback_security_validation() {
385 let security = Arc::new(SecurityValidator::new());
386 let registry = Arc::new(RegistryService::new());
387 let tool_router = Arc::new(ToolRouter::new(registry));
388
389 tool_router
391 .register_tool(
392 ServiceId::new("test-service"),
393 crate::matrixrpc::ToolDefinition {
394 name: "read".to_string(),
395 service_id: ServiceId::new("test-service"),
396 description: None,
397 risk_level: None,
398 timeout_ms: None,
399 },
400 )
401 .await;
402
403 let handler = ToolCallbackHandler::new(security.clone(), tool_router);
404
405 let service_id = ServiceId::new("callback-service");
407 let request_id = "req-001".to_string();
408 let token = security
409 .generate_token(service_id.clone(), request_id.clone(), vec!["tool".to_string()])
410 .await
411 .unwrap();
412
413 let request = ToolCallbackRequest {
414 request_id,
415 service_id,
416 token,
417 tool_name: "read".to_string(),
418 params: serde_json::json!({}),
419 timeout_ms: 30_000,
420 require_approval: false,
421 };
422
423 let result = handler.handle(request).await;
424 assert!(result.is_ok() || matches!(result, Err(ToolCallbackError::ToolNotFound(_))));
426 }
427
428 #[tokio::test]
429 async fn test_tool_callback_invalid_token() {
430 let security = Arc::new(SecurityValidator::new());
431 let registry = Arc::new(RegistryService::new());
432 let tool_router = Arc::new(ToolRouter::new(registry));
433 let handler = ToolCallbackHandler::new(security, tool_router);
434
435 let request = ToolCallbackRequest {
436 request_id: "req-001".to_string(),
437 service_id: ServiceId::new("test-service"),
438 token: "invalid_token".to_string(),
439 tool_name: "read".to_string(),
440 params: serde_json::json!({}),
441 timeout_ms: 30_000,
442 require_approval: false,
443 };
444
445 let result = handler.handle(request).await;
446 assert!(matches!(result, Err(ToolCallbackError::SecurityFailed(_))));
447 }
448
449 #[test]
450 fn test_list_allowed_tools() {
451 let security = Arc::new(SecurityValidator::new());
452 let registry = Arc::new(RegistryService::new());
453 let tool_router = Arc::new(ToolRouter::new(registry));
454 let handler = ToolCallbackHandler::new(security, tool_router);
455
456 let tools = handler.list_allowed_tools();
457 assert!(tools.contains(&"read".to_string()));
458 assert!(tools.contains(&"write".to_string()));
459 }
460}