1use std::sync::Arc;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value as JsonValue;
10
11use super::ai::{AiCallbackHandler, AiCallbackRequest, AiCallbackError};
12use super::tool::{ToolCallbackHandler, ToolCallbackRequest, ToolCallbackError};
13use super::context::{ContextCallbackHandler, ContextCallbackRequest, ContextCallbackError};
14use super::security::{SecurityValidator, ValidationResult};
15use crate::matrixrpc::{
16 ErrorCode, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse, ServiceId,
17 ToolRouter, NodeRouter,
18};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22#[serde(rename_all = "lowercase")]
23pub enum CallbackType {
24 Ai,
26 Tool,
28 Context,
30}
31
32impl CallbackType {
33 pub fn method_name(&self) -> &'static str {
35 match self {
36 CallbackType::Ai => "callback.ai",
37 CallbackType::Tool => "callback.tool",
38 CallbackType::Context => "callback.context",
39 }
40 }
41
42 pub fn from_method(method: &str) -> Option<Self> {
44 match method {
45 "callback.ai" => Some(CallbackType::Ai),
46 "callback.tool" => Some(CallbackType::Tool),
47 "callback.context" => Some(CallbackType::Context),
48 _ => None,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub enum CallbackResult {
56 Ai(super::ai::AiCallbackResult),
58 Tool(super::tool::ToolCallbackResult),
60 Context(super::context::ContextCallbackResult),
62}
63
64impl CallbackResult {
65 pub fn to_json(&self) -> JsonValue {
67 match self {
68 CallbackResult::Ai(result) => serde_json::to_value(result).unwrap_or(JsonValue::Null),
69 CallbackResult::Tool(result) => serde_json::to_value(result).unwrap_or(JsonValue::Null),
70 CallbackResult::Context(result) => serde_json::to_value(result).unwrap_or(JsonValue::Null),
71 }
72 }
73
74 pub fn callback_type(&self) -> CallbackType {
76 match self {
77 CallbackResult::Ai(_) => CallbackType::Ai,
78 CallbackResult::Tool(_) => CallbackType::Tool,
79 CallbackResult::Context(_) => CallbackType::Context,
80 }
81 }
82}
83
84#[derive(Debug, thiserror::Error)]
86pub enum CallbackError {
87 #[error("Security validation failed: {0}")]
89 SecurityFailed(String),
90
91 #[error("Invalid callback type: {0}")]
93 InvalidType(String),
94
95 #[error("AI callback error: {0}")]
97 Ai(#[from] AiCallbackError),
98
99 #[error("Tool callback error: {0}")]
101 Tool(#[from] ToolCallbackError),
102
103 #[error("Context callback error: {0}")]
105 Context(#[from] ContextCallbackError),
106
107 #[error("Missing required field: {0}")]
109 MissingField(String),
110
111 #[error("Invalid JSON-RPC request: {0}")]
113 InvalidRequest(String),
114
115 #[error("Internal error: {0}")]
117 Internal(String),
118}
119
120#[derive(Debug, Clone)]
122pub struct CallbackConfig {
123 pub enable_ai: bool,
125
126 pub enable_tool: bool,
128
129 pub enable_context: bool,
131
132 pub default_timeout_ms: u64,
134
135 pub max_concurrent: u32,
137
138 pub detailed_logging: bool,
140}
141
142impl Default for CallbackConfig {
143 fn default() -> Self {
144 Self {
145 enable_ai: true,
146 enable_tool: true,
147 enable_context: true,
148 default_timeout_ms: 60_000,
149 max_concurrent: 10,
150 detailed_logging: false,
151 }
152 }
153}
154
155pub struct CallbackHandler {
160 security: Arc<SecurityValidator>,
162
163 ai_handler: Arc<AiCallbackHandler>,
165
166 tool_handler: Arc<ToolCallbackHandler>,
168
169 context_handler: Arc<ContextCallbackHandler>,
171
172 config: CallbackConfig,
174}
175
176impl CallbackHandler {
177 pub fn new(
179 security: Arc<SecurityValidator>,
180 tool_router: Arc<ToolRouter>,
181 node_router: Arc<NodeRouter>,
182 ) -> Self {
183 Self::with_config(
184 security,
185 tool_router,
186 node_router,
187 CallbackConfig::default(),
188 )
189 }
190
191 pub fn with_config(
193 security: Arc<SecurityValidator>,
194 tool_router: Arc<ToolRouter>,
195 _node_router: Arc<NodeRouter>,
196 config: CallbackConfig,
197 ) -> Self {
198 let ai_handler = Arc::new(AiCallbackHandler::new(security.clone()));
199 let tool_handler = Arc::new(ToolCallbackHandler::new(security.clone(), tool_router));
200 let context_handler = Arc::new(ContextCallbackHandler::new(security.clone()));
201
202 Self {
203 security,
204 ai_handler,
205 tool_handler,
206 context_handler,
207 config,
208 }
209 }
210
211 pub async fn handle_request(&self, request: JsonRpcRequest) -> Result<CallbackResult, CallbackError> {
213 let callback_type = CallbackType::from_method(&request.method)
215 .ok_or_else(|| CallbackError::InvalidType(request.method.clone()))?;
216
217 self.check_callback_enabled(callback_type)?;
219
220 let request_id = request.id.clone().unwrap_or_default();
222
223 match callback_type {
225 CallbackType::Ai => {
226 self.handle_ai_callback(request.params, request_id).await
227 }
228 CallbackType::Tool => {
229 self.handle_tool_callback(request.params, request_id).await
230 }
231 CallbackType::Context => {
232 self.handle_context_callback(request.params, request_id).await
233 }
234 }
235 }
236
237 fn check_callback_enabled(&self, callback_type: CallbackType) -> Result<(), CallbackError> {
239 let enabled = match callback_type {
240 CallbackType::Ai => self.config.enable_ai,
241 CallbackType::Tool => self.config.enable_tool,
242 CallbackType::Context => self.config.enable_context,
243 };
244
245 if enabled {
246 Ok(())
247 } else {
248 Err(CallbackError::InvalidType(format!(
249 "{} callbacks are disabled",
250 callback_type.method_name()
251 )))
252 }
253 }
254
255 async fn handle_ai_callback(
257 &self,
258 params: Option<JsonValue>,
259 _request_id: JsonRpcId,
260 ) -> Result<CallbackResult, CallbackError> {
261 let params = params.ok_or_else(|| CallbackError::MissingField("params".to_string()))?;
262
263 let request_id = params
265 .get("request_id")
266 .and_then(|v| v.as_str())
267 .ok_or_else(|| CallbackError::MissingField("request_id".to_string()))?;
268
269 let service_id = params
270 .get("service_id")
271 .and_then(|v| v.as_str())
272 .ok_or_else(|| CallbackError::MissingField("service_id".to_string()))?;
273
274 let token = params
275 .get("token")
276 .and_then(|v| v.as_str())
277 .ok_or_else(|| CallbackError::MissingField("token".to_string()))?;
278
279 let prompt = params
280 .get("prompt")
281 .and_then(|v| v.as_str())
282 .ok_or_else(|| CallbackError::MissingField("prompt".to_string()))?;
283
284 let ai_request = AiCallbackRequest {
286 request_id: request_id.to_string(),
287 service_id: ServiceId::new(service_id),
288 token: token.to_string(),
289 prompt: prompt.to_string(),
290 context: params.get("context").cloned().unwrap_or(JsonValue::Null),
291 model_config: params
292 .get("model_config")
293 .and_then(|v| serde_json::from_value(v.clone()).ok())
294 .unwrap_or_default(),
295 timeout_ms: params
296 .get("timeout_ms")
297 .and_then(|v| v.as_u64())
298 .unwrap_or(self.config.default_timeout_ms),
299 };
300
301 let result = self.ai_handler.handle(ai_request).await?;
302 Ok(CallbackResult::Ai(result))
303 }
304
305 async fn handle_tool_callback(
307 &self,
308 params: Option<JsonValue>,
309 _request_id: JsonRpcId,
310 ) -> Result<CallbackResult, CallbackError> {
311 let params = params.ok_or_else(|| CallbackError::MissingField("params".to_string()))?;
312
313 let request_id = params
315 .get("request_id")
316 .and_then(|v| v.as_str())
317 .ok_or_else(|| CallbackError::MissingField("request_id".to_string()))?;
318
319 let service_id = params
320 .get("service_id")
321 .and_then(|v| v.as_str())
322 .ok_or_else(|| CallbackError::MissingField("service_id".to_string()))?;
323
324 let token = params
325 .get("token")
326 .and_then(|v| v.as_str())
327 .ok_or_else(|| CallbackError::MissingField("token".to_string()))?;
328
329 let tool_name = params
330 .get("tool_name")
331 .and_then(|v| v.as_str())
332 .ok_or_else(|| CallbackError::MissingField("tool_name".to_string()))?;
333
334 let tool_request = ToolCallbackRequest {
336 request_id: request_id.to_string(),
337 service_id: ServiceId::new(service_id),
338 token: token.to_string(),
339 tool_name: tool_name.to_string(),
340 params: params.get("params").cloned().unwrap_or(JsonValue::Null),
341 timeout_ms: params
342 .get("timeout_ms")
343 .and_then(|v| v.as_u64())
344 .unwrap_or(self.config.default_timeout_ms),
345 require_approval: params
346 .get("require_approval")
347 .and_then(|v| v.as_bool())
348 .unwrap_or(false),
349 };
350
351 let result = self.tool_handler.handle(tool_request).await?;
352 Ok(CallbackResult::Tool(result))
353 }
354
355 async fn handle_context_callback(
357 &self,
358 params: Option<JsonValue>,
359 _request_id: JsonRpcId,
360 ) -> Result<CallbackResult, CallbackError> {
361 let params = params.ok_or_else(|| CallbackError::MissingField("params".to_string()))?;
362
363 let request_id = params
365 .get("request_id")
366 .and_then(|v| v.as_str())
367 .ok_or_else(|| CallbackError::MissingField("request_id".to_string()))?;
368
369 let service_id = params
370 .get("service_id")
371 .and_then(|v| v.as_str())
372 .ok_or_else(|| CallbackError::MissingField("service_id".to_string()))?;
373
374 let token = params
375 .get("token")
376 .and_then(|v| v.as_str())
377 .ok_or_else(|| CallbackError::MissingField("token".to_string()))?;
378
379 let context_request = ContextCallbackRequest {
381 request_id: request_id.to_string(),
382 service_id: ServiceId::new(service_id),
383 token: token.to_string(),
384 operation: params
385 .get("operation")
386 .and_then(|v| serde_json::from_value(v.clone()).ok())
387 .unwrap_or_default(),
388 key: params.get("key").and_then(|v| v.as_str()).map(|s| s.to_string()),
389 value: params.get("value").cloned(),
390 namespace: params.get("namespace").and_then(|v| v.as_str()).map(|s| s.to_string()),
391 };
392
393 let result = self.context_handler.handle(context_request).await?;
394 Ok(CallbackResult::Context(result))
395 }
396
397 pub fn create_success_response(&self, id: JsonRpcId, result: CallbackResult) -> JsonRpcResponse {
399 JsonRpcResponse::success(id, result.to_json())
400 }
401
402 pub fn create_error_response(&self, error: CallbackError, id: JsonRpcId) -> JsonRpcResponse {
404 let (code, message, data) = match error {
405 CallbackError::SecurityFailed(msg) => (
406 ErrorCode::PERMISSION_DENIED,
407 "Security validation failed".to_string(),
408 Some(serde_json::json!({ "reason": msg })),
409 ),
410 CallbackError::InvalidType(t) => (
411 ErrorCode::METHOD_NOT_FOUND,
412 format!("Invalid callback type: {}", t),
413 None,
414 ),
415 CallbackError::Ai(ai_error) => {
416 return self.ai_handler.create_error_response(ai_error, id);
417 }
418 CallbackError::Tool(tool_error) => {
419 return self.tool_handler.create_error_response(tool_error, id);
420 }
421 CallbackError::Context(context_error) => {
422 return self.context_handler.create_error_response(context_error, id);
423 }
424 CallbackError::MissingField(field) => (
425 ErrorCode::INVALID_PARAMS,
426 format!("Missing required field: {}", field),
427 None,
428 ),
429 CallbackError::InvalidRequest(msg) => (
430 ErrorCode::INVALID_REQUEST,
431 msg,
432 None,
433 ),
434 CallbackError::Internal(msg) => (
435 ErrorCode::INTERNAL_ERROR,
436 msg,
437 None,
438 ),
439 };
440
441 JsonRpcResponse::error(
442 id,
443 JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)),
444 )
445 }
446
447 pub async fn generate_token(
449 &self,
450 service_id: ServiceId,
451 request_id: String,
452 allowed_types: Vec<String>,
453 ) -> Result<String, CallbackError> {
454 self.security
455 .generate_token(service_id, request_id, allowed_types)
456 .await
457 .map_err(|e| CallbackError::SecurityFailed(e.to_string()))
458 }
459
460 pub async fn validate(
462 &self,
463 token: &str,
464 service_id: &ServiceId,
465 request_id: &str,
466 callback_type: &str,
467 ) -> ValidationResult {
468 self.security.validate(token, service_id, request_id, callback_type).await
469 }
470
471 pub fn security(&self) -> Arc<SecurityValidator> {
473 self.security.clone()
474 }
475
476 pub fn ai_handler(&self) -> Arc<AiCallbackHandler> {
478 self.ai_handler.clone()
479 }
480
481 pub fn tool_handler(&self) -> Arc<ToolCallbackHandler> {
483 self.tool_handler.clone()
484 }
485
486 pub fn context_handler(&self) -> Arc<ContextCallbackHandler> {
488 self.context_handler.clone()
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use crate::matrixrpc::{RegistryService, ToolRouter, NodeRouter};
496 use super::super::ai::AiCallbackResult;
497
498 fn create_test_handlers() -> (
499 Arc<SecurityValidator>,
500 Arc<ToolRouter>,
501 Arc<NodeRouter>,
502 Arc<CallbackHandler>,
503 ) {
504 let security = Arc::new(SecurityValidator::new());
505 let registry = Arc::new(RegistryService::new());
506 let tool_router = Arc::new(ToolRouter::new(registry.clone()));
507 let node_router = Arc::new(NodeRouter::new(registry));
508 let callback = Arc::new(CallbackHandler::new(
509 security.clone(),
510 tool_router.clone(),
511 node_router.clone(),
512 ));
513
514 (security, tool_router, node_router, callback)
515 }
516
517 #[tokio::test]
518 async fn test_callback_handler_creation() {
519 let (_, _, _, handler) = create_test_handlers();
520 assert!(handler.config.enable_ai);
521 assert!(handler.config.enable_tool);
522 assert!(handler.config.enable_context);
523 }
524
525 #[tokio::test]
526 async fn test_callback_type_detection() {
527 assert_eq!(
528 CallbackType::from_method("callback.ai"),
529 Some(CallbackType::Ai)
530 );
531 assert_eq!(
532 CallbackType::from_method("callback.tool"),
533 Some(CallbackType::Tool)
534 );
535 assert_eq!(
536 CallbackType::from_method("callback.context"),
537 Some(CallbackType::Context)
538 );
539 assert_eq!(CallbackType::from_method("unknown"), None);
540 }
541
542 #[tokio::test]
543 async fn test_invalid_callback_type() {
544 let (_, _, _, handler) = create_test_handlers();
545
546 let request = JsonRpcRequest::new("unknown.method");
547 let result = handler.handle_request(request).await;
548
549 assert!(matches!(result, Err(CallbackError::InvalidType(_))));
550 }
551
552 #[tokio::test]
553 async fn test_missing_params() {
554 let (_, _, _, handler) = create_test_handlers();
555
556 let request = JsonRpcRequest::new("callback.ai"); let result = handler.handle_request(request).await;
558
559 assert!(matches!(result, Err(CallbackError::MissingField(_))));
560 }
561
562 #[tokio::test]
563 async fn test_generate_token() {
564 let (_, _, _, handler) = create_test_handlers();
565
566 let service_id = ServiceId::new("test-service");
567 let token = handler
568 .generate_token(service_id, "req-001".to_string(), vec!["ai".to_string(), "tool".to_string()])
569 .await
570 .unwrap();
571
572 assert!(token.starts_with("cb_"));
573 }
574
575 #[tokio::test]
576 async fn test_ai_callback_with_valid_token() {
577 let (security, _, _, handler) = create_test_handlers();
578
579 let service_id = ServiceId::new("test-service");
580 let request_id = "req-001".to_string();
581 let token = security
582 .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string()])
583 .await
584 .unwrap();
585
586 let request = JsonRpcRequest::new("callback.ai")
587 .params(serde_json::json!({
588 "request_id": request_id,
589 "service_id": service_id.to_string(),
590 "token": token,
591 "prompt": "Test prompt"
592 }));
593
594 let result = handler.handle_request(request).await;
595 assert!(result.is_ok());
596 }
597
598 #[tokio::test]
599 async fn test_tool_callback_with_invalid_token() {
600 let (_, _, _, handler) = create_test_handlers();
601
602 let request = JsonRpcRequest::new("callback.tool")
603 .params(serde_json::json!({
604 "request_id": "req-001",
605 "service_id": "test-service",
606 "token": "invalid_token",
607 "tool_name": "read"
608 }));
609
610 let result = handler.handle_request(request).await;
611 assert!(matches!(result, Err(CallbackError::Tool(_))));
612 }
613
614 #[test]
615 fn test_callback_result_to_json() {
616 let ai_result = AiCallbackResult {
617 content: "Test response".to_string(),
618 model: "claude-sonnet-4".to_string(),
619 input_tokens: 100,
620 output_tokens: 50,
621 duration_ms: 500,
622 metadata: serde_json::json!({}),
623 };
624
625 let result = CallbackResult::Ai(ai_result);
626 let json = result.to_json();
627
628 assert!(json.get("content").is_some());
629 }
630}