Skip to main content

matrixcode_core/matrixrpc/callback/
handler.rs

1//! Callback Handler
2//!
3//! Central handler for all callback types (AI, tool, context).
4//! Coordinates callback requests from external extension services.
5
6use std::sync::Arc;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value as JsonValue;
10
11use super::ai::{AiCallbackHandler, AiCallbackRequest, AiCallbackError};
12use super::tool::{ToolCallbackHandler, ToolCallbackRequest, ToolCallbackError};
13use super::context::{ContextCallbackHandler, ContextCallbackRequest, ContextCallbackError};
14use super::security::{SecurityValidator, ValidationResult};
15use crate::matrixrpc::{
16    ErrorCode, JsonRpcError, JsonRpcId, JsonRpcRequest, JsonRpcResponse, ServiceId,
17    ToolRouter, NodeRouter,
18};
19
20/// Callback type enumeration
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22#[serde(rename_all = "lowercase")]
23pub enum CallbackType {
24    /// AI callback - call AI provider
25    Ai,
26    /// Tool callback - execute a tool
27    Tool,
28    /// Context callback - access workflow context
29    Context,
30}
31
32impl CallbackType {
33    /// Get the method name for this callback type
34    pub fn method_name(&self) -> &'static str {
35        match self {
36            CallbackType::Ai => "callback.ai",
37            CallbackType::Tool => "callback.tool",
38            CallbackType::Context => "callback.context",
39        }
40    }
41
42    /// Parse from method name
43    pub fn from_method(method: &str) -> Option<Self> {
44        match method {
45            "callback.ai" => Some(CallbackType::Ai),
46            "callback.tool" => Some(CallbackType::Tool),
47            "callback.context" => Some(CallbackType::Context),
48            _ => None,
49        }
50    }
51}
52
53/// Unified callback result
54#[derive(Debug, Clone)]
55pub enum CallbackResult {
56    /// AI callback result
57    Ai(super::ai::AiCallbackResult),
58    /// Tool callback result
59    Tool(super::tool::ToolCallbackResult),
60    /// Context callback result
61    Context(super::context::ContextCallbackResult),
62}
63
64impl CallbackResult {
65    /// Convert to JSON value
66    pub fn to_json(&self) -> JsonValue {
67        match self {
68            CallbackResult::Ai(result) => serde_json::to_value(result).unwrap_or(JsonValue::Null),
69            CallbackResult::Tool(result) => serde_json::to_value(result).unwrap_or(JsonValue::Null),
70            CallbackResult::Context(result) => serde_json::to_value(result).unwrap_or(JsonValue::Null),
71        }
72    }
73
74    /// Get callback type
75    pub fn callback_type(&self) -> CallbackType {
76        match self {
77            CallbackResult::Ai(_) => CallbackType::Ai,
78            CallbackResult::Tool(_) => CallbackType::Tool,
79            CallbackResult::Context(_) => CallbackType::Context,
80        }
81    }
82}
83
84/// Callback error
85#[derive(Debug, thiserror::Error)]
86pub enum CallbackError {
87    /// Security validation failed
88    #[error("Security validation failed: {0}")]
89    SecurityFailed(String),
90
91    /// Invalid callback type
92    #[error("Invalid callback type: {0}")]
93    InvalidType(String),
94
95    /// AI callback error
96    #[error("AI callback error: {0}")]
97    Ai(#[from] AiCallbackError),
98
99    /// Tool callback error
100    #[error("Tool callback error: {0}")]
101    Tool(#[from] ToolCallbackError),
102
103    /// Context callback error
104    #[error("Context callback error: {0}")]
105    Context(#[from] ContextCallbackError),
106
107    /// Missing required field
108    #[error("Missing required field: {0}")]
109    MissingField(String),
110
111    /// Invalid JSON-RPC request
112    #[error("Invalid JSON-RPC request: {0}")]
113    InvalidRequest(String),
114
115    /// Internal error
116    #[error("Internal error: {0}")]
117    Internal(String),
118}
119
120/// Callback configuration
121#[derive(Debug, Clone)]
122pub struct CallbackConfig {
123    /// Enable AI callbacks
124    pub enable_ai: bool,
125
126    /// Enable tool callbacks
127    pub enable_tool: bool,
128
129    /// Enable context callbacks
130    pub enable_context: bool,
131
132    /// Default timeout for all callbacks (ms)
133    pub default_timeout_ms: u64,
134
135    /// Maximum concurrent callbacks
136    pub max_concurrent: u32,
137
138    /// Enable detailed logging
139    pub detailed_logging: bool,
140}
141
142impl Default for CallbackConfig {
143    fn default() -> Self {
144        Self {
145            enable_ai: true,
146            enable_tool: true,
147            enable_context: true,
148            default_timeout_ms: 60_000,
149            max_concurrent: 10,
150            detailed_logging: false,
151        }
152    }
153}
154
155/// Callback Handler
156///
157/// Central handler for all callback requests from external extension services.
158/// Coordinates AI, tool, and context callback handlers.
159pub struct CallbackHandler {
160    /// Security validator
161    security: Arc<SecurityValidator>,
162
163    /// AI callback handler
164    ai_handler: Arc<AiCallbackHandler>,
165
166    /// Tool callback handler
167    tool_handler: Arc<ToolCallbackHandler>,
168
169    /// Context callback handler
170    context_handler: Arc<ContextCallbackHandler>,
171
172    /// Configuration
173    config: CallbackConfig,
174}
175
176impl CallbackHandler {
177    /// Create a new callback handler
178    pub fn new(
179        security: Arc<SecurityValidator>,
180        tool_router: Arc<ToolRouter>,
181        node_router: Arc<NodeRouter>,
182    ) -> Self {
183        Self::with_config(
184            security,
185            tool_router,
186            node_router,
187            CallbackConfig::default(),
188        )
189    }
190
191    /// Create a new callback handler with configuration
192    pub fn with_config(
193        security: Arc<SecurityValidator>,
194        tool_router: Arc<ToolRouter>,
195        _node_router: Arc<NodeRouter>,
196        config: CallbackConfig,
197    ) -> Self {
198        let ai_handler = Arc::new(AiCallbackHandler::new(security.clone()));
199        let tool_handler = Arc::new(ToolCallbackHandler::new(security.clone(), tool_router));
200        let context_handler = Arc::new(ContextCallbackHandler::new(security.clone()));
201
202        Self {
203            security,
204            ai_handler,
205            tool_handler,
206            context_handler,
207            config,
208        }
209    }
210
211    /// Handle a JSON-RPC callback request
212    pub async fn handle_request(&self, request: JsonRpcRequest) -> Result<CallbackResult, CallbackError> {
213        // Determine callback type from method
214        let callback_type = CallbackType::from_method(&request.method)
215            .ok_or_else(|| CallbackError::InvalidType(request.method.clone()))?;
216
217        // Check if callback type is enabled
218        self.check_callback_enabled(callback_type)?;
219
220        // Get request ID
221        let request_id = request.id.clone().unwrap_or_default();
222
223        // Handle based on type
224        match callback_type {
225            CallbackType::Ai => {
226                self.handle_ai_callback(request.params, request_id).await
227            }
228            CallbackType::Tool => {
229                self.handle_tool_callback(request.params, request_id).await
230            }
231            CallbackType::Context => {
232                self.handle_context_callback(request.params, request_id).await
233            }
234        }
235    }
236
237    /// Check if a callback type is enabled
238    fn check_callback_enabled(&self, callback_type: CallbackType) -> Result<(), CallbackError> {
239        let enabled = match callback_type {
240            CallbackType::Ai => self.config.enable_ai,
241            CallbackType::Tool => self.config.enable_tool,
242            CallbackType::Context => self.config.enable_context,
243        };
244
245        if enabled {
246            Ok(())
247        } else {
248            Err(CallbackError::InvalidType(format!(
249                "{} callbacks are disabled",
250                callback_type.method_name()
251            )))
252        }
253    }
254
255    /// Handle AI callback request
256    async fn handle_ai_callback(
257        &self,
258        params: Option<JsonValue>,
259        _request_id: JsonRpcId,
260    ) -> Result<CallbackResult, CallbackError> {
261        let params = params.ok_or_else(|| CallbackError::MissingField("params".to_string()))?;
262
263        // Extract required fields
264        let request_id = params
265            .get("request_id")
266            .and_then(|v| v.as_str())
267            .ok_or_else(|| CallbackError::MissingField("request_id".to_string()))?;
268
269        let service_id = params
270            .get("service_id")
271            .and_then(|v| v.as_str())
272            .ok_or_else(|| CallbackError::MissingField("service_id".to_string()))?;
273
274        let token = params
275            .get("token")
276            .and_then(|v| v.as_str())
277            .ok_or_else(|| CallbackError::MissingField("token".to_string()))?;
278
279        let prompt = params
280            .get("prompt")
281            .and_then(|v| v.as_str())
282            .ok_or_else(|| CallbackError::MissingField("prompt".to_string()))?;
283
284        // Build AI callback request
285        let ai_request = AiCallbackRequest {
286            request_id: request_id.to_string(),
287            service_id: ServiceId::new(service_id),
288            token: token.to_string(),
289            prompt: prompt.to_string(),
290            context: params.get("context").cloned().unwrap_or(JsonValue::Null),
291            model_config: params
292                .get("model_config")
293                .and_then(|v| serde_json::from_value(v.clone()).ok())
294                .unwrap_or_default(),
295            timeout_ms: params
296                .get("timeout_ms")
297                .and_then(|v| v.as_u64())
298                .unwrap_or(self.config.default_timeout_ms),
299        };
300
301        let result = self.ai_handler.handle(ai_request).await?;
302        Ok(CallbackResult::Ai(result))
303    }
304
305    /// Handle tool callback request
306    async fn handle_tool_callback(
307        &self,
308        params: Option<JsonValue>,
309        _request_id: JsonRpcId,
310    ) -> Result<CallbackResult, CallbackError> {
311        let params = params.ok_or_else(|| CallbackError::MissingField("params".to_string()))?;
312
313        // Extract required fields
314        let request_id = params
315            .get("request_id")
316            .and_then(|v| v.as_str())
317            .ok_or_else(|| CallbackError::MissingField("request_id".to_string()))?;
318
319        let service_id = params
320            .get("service_id")
321            .and_then(|v| v.as_str())
322            .ok_or_else(|| CallbackError::MissingField("service_id".to_string()))?;
323
324        let token = params
325            .get("token")
326            .and_then(|v| v.as_str())
327            .ok_or_else(|| CallbackError::MissingField("token".to_string()))?;
328
329        let tool_name = params
330            .get("tool_name")
331            .and_then(|v| v.as_str())
332            .ok_or_else(|| CallbackError::MissingField("tool_name".to_string()))?;
333
334        // Build tool callback request
335        let tool_request = ToolCallbackRequest {
336            request_id: request_id.to_string(),
337            service_id: ServiceId::new(service_id),
338            token: token.to_string(),
339            tool_name: tool_name.to_string(),
340            params: params.get("params").cloned().unwrap_or(JsonValue::Null),
341            timeout_ms: params
342                .get("timeout_ms")
343                .and_then(|v| v.as_u64())
344                .unwrap_or(self.config.default_timeout_ms),
345            require_approval: params
346                .get("require_approval")
347                .and_then(|v| v.as_bool())
348                .unwrap_or(false),
349        };
350
351        let result = self.tool_handler.handle(tool_request).await?;
352        Ok(CallbackResult::Tool(result))
353    }
354
355    /// Handle context callback request
356    async fn handle_context_callback(
357        &self,
358        params: Option<JsonValue>,
359        _request_id: JsonRpcId,
360    ) -> Result<CallbackResult, CallbackError> {
361        let params = params.ok_or_else(|| CallbackError::MissingField("params".to_string()))?;
362
363        // Extract required fields
364        let request_id = params
365            .get("request_id")
366            .and_then(|v| v.as_str())
367            .ok_or_else(|| CallbackError::MissingField("request_id".to_string()))?;
368
369        let service_id = params
370            .get("service_id")
371            .and_then(|v| v.as_str())
372            .ok_or_else(|| CallbackError::MissingField("service_id".to_string()))?;
373
374        let token = params
375            .get("token")
376            .and_then(|v| v.as_str())
377            .ok_or_else(|| CallbackError::MissingField("token".to_string()))?;
378
379        // Build context callback request
380        let context_request = ContextCallbackRequest {
381            request_id: request_id.to_string(),
382            service_id: ServiceId::new(service_id),
383            token: token.to_string(),
384            operation: params
385                .get("operation")
386                .and_then(|v| serde_json::from_value(v.clone()).ok())
387                .unwrap_or_default(),
388            key: params.get("key").and_then(|v| v.as_str()).map(|s| s.to_string()),
389            value: params.get("value").cloned(),
390            namespace: params.get("namespace").and_then(|v| v.as_str()).map(|s| s.to_string()),
391        };
392
393        let result = self.context_handler.handle(context_request).await?;
394        Ok(CallbackResult::Context(result))
395    }
396
397    /// Create a JSON-RPC response for a callback result
398    pub fn create_success_response(&self, id: JsonRpcId, result: CallbackResult) -> JsonRpcResponse {
399        JsonRpcResponse::success(id, result.to_json())
400    }
401
402    /// Create a JSON-RPC error response for callback failures
403    pub fn create_error_response(&self, error: CallbackError, id: JsonRpcId) -> JsonRpcResponse {
404        let (code, message, data) = match error {
405            CallbackError::SecurityFailed(msg) => (
406                ErrorCode::PERMISSION_DENIED,
407                "Security validation failed".to_string(),
408                Some(serde_json::json!({ "reason": msg })),
409            ),
410            CallbackError::InvalidType(t) => (
411                ErrorCode::METHOD_NOT_FOUND,
412                format!("Invalid callback type: {}", t),
413                None,
414            ),
415            CallbackError::Ai(ai_error) => {
416                return self.ai_handler.create_error_response(ai_error, id);
417            }
418            CallbackError::Tool(tool_error) => {
419                return self.tool_handler.create_error_response(tool_error, id);
420            }
421            CallbackError::Context(context_error) => {
422                return self.context_handler.create_error_response(context_error, id);
423            }
424            CallbackError::MissingField(field) => (
425                ErrorCode::INVALID_PARAMS,
426                format!("Missing required field: {}", field),
427                None,
428            ),
429            CallbackError::InvalidRequest(msg) => (
430                ErrorCode::INVALID_REQUEST,
431                msg,
432                None,
433            ),
434            CallbackError::Internal(msg) => (
435                ErrorCode::INTERNAL_ERROR,
436                msg,
437                None,
438            ),
439        };
440
441        JsonRpcResponse::error(
442            id,
443            JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)),
444        )
445    }
446
447    /// Generate a callback token for a service
448    pub async fn generate_token(
449        &self,
450        service_id: ServiceId,
451        request_id: String,
452        allowed_types: Vec<String>,
453    ) -> Result<String, CallbackError> {
454        self.security
455            .generate_token(service_id, request_id, allowed_types)
456            .await
457            .map_err(|e| CallbackError::SecurityFailed(e.to_string()))
458    }
459
460    /// Validate a callback request (pre-validation)
461    pub async fn validate(
462        &self,
463        token: &str,
464        service_id: &ServiceId,
465        request_id: &str,
466        callback_type: &str,
467    ) -> ValidationResult {
468        self.security.validate(token, service_id, request_id, callback_type).await
469    }
470
471    /// Get the security validator
472    pub fn security(&self) -> Arc<SecurityValidator> {
473        self.security.clone()
474    }
475
476    /// Get the AI handler
477    pub fn ai_handler(&self) -> Arc<AiCallbackHandler> {
478        self.ai_handler.clone()
479    }
480
481    /// Get the tool handler
482    pub fn tool_handler(&self) -> Arc<ToolCallbackHandler> {
483        self.tool_handler.clone()
484    }
485
486    /// Get the context handler
487    pub fn context_handler(&self) -> Arc<ContextCallbackHandler> {
488        self.context_handler.clone()
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use crate::matrixrpc::{RegistryService, ToolRouter, NodeRouter};
496    use super::super::ai::AiCallbackResult;
497
498    fn create_test_handlers() -> (
499        Arc<SecurityValidator>,
500        Arc<ToolRouter>,
501        Arc<NodeRouter>,
502        Arc<CallbackHandler>,
503    ) {
504        let security = Arc::new(SecurityValidator::new());
505        let registry = Arc::new(RegistryService::new());
506        let tool_router = Arc::new(ToolRouter::new(registry.clone()));
507        let node_router = Arc::new(NodeRouter::new(registry));
508        let callback = Arc::new(CallbackHandler::new(
509            security.clone(),
510            tool_router.clone(),
511            node_router.clone(),
512        ));
513
514        (security, tool_router, node_router, callback)
515    }
516
517    #[tokio::test]
518    async fn test_callback_handler_creation() {
519        let (_, _, _, handler) = create_test_handlers();
520        assert!(handler.config.enable_ai);
521        assert!(handler.config.enable_tool);
522        assert!(handler.config.enable_context);
523    }
524
525    #[tokio::test]
526    async fn test_callback_type_detection() {
527        assert_eq!(
528            CallbackType::from_method("callback.ai"),
529            Some(CallbackType::Ai)
530        );
531        assert_eq!(
532            CallbackType::from_method("callback.tool"),
533            Some(CallbackType::Tool)
534        );
535        assert_eq!(
536            CallbackType::from_method("callback.context"),
537            Some(CallbackType::Context)
538        );
539        assert_eq!(CallbackType::from_method("unknown"), None);
540    }
541
542    #[tokio::test]
543    async fn test_invalid_callback_type() {
544        let (_, _, _, handler) = create_test_handlers();
545
546        let request = JsonRpcRequest::new("unknown.method");
547        let result = handler.handle_request(request).await;
548
549        assert!(matches!(result, Err(CallbackError::InvalidType(_))));
550    }
551
552    #[tokio::test]
553    async fn test_missing_params() {
554        let (_, _, _, handler) = create_test_handlers();
555
556        let request = JsonRpcRequest::new("callback.ai"); // No params
557        let result = handler.handle_request(request).await;
558
559        assert!(matches!(result, Err(CallbackError::MissingField(_))));
560    }
561
562    #[tokio::test]
563    async fn test_generate_token() {
564        let (_, _, _, handler) = create_test_handlers();
565
566        let service_id = ServiceId::new("test-service");
567        let token = handler
568            .generate_token(service_id, "req-001".to_string(), vec!["ai".to_string(), "tool".to_string()])
569            .await
570            .unwrap();
571
572        assert!(token.starts_with("cb_"));
573    }
574
575    #[tokio::test]
576    async fn test_ai_callback_with_valid_token() {
577        let (security, _, _, handler) = create_test_handlers();
578
579        let service_id = ServiceId::new("test-service");
580        let request_id = "req-001".to_string();
581        let token = security
582            .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string()])
583            .await
584            .unwrap();
585
586        let request = JsonRpcRequest::new("callback.ai")
587            .params(serde_json::json!({
588                "request_id": request_id,
589                "service_id": service_id.to_string(),
590                "token": token,
591                "prompt": "Test prompt"
592            }));
593
594        let result = handler.handle_request(request).await;
595        assert!(result.is_ok());
596    }
597
598    #[tokio::test]
599    async fn test_tool_callback_with_invalid_token() {
600        let (_, _, _, handler) = create_test_handlers();
601
602        let request = JsonRpcRequest::new("callback.tool")
603            .params(serde_json::json!({
604                "request_id": "req-001",
605                "service_id": "test-service",
606                "token": "invalid_token",
607                "tool_name": "read"
608            }));
609
610        let result = handler.handle_request(request).await;
611        assert!(matches!(result, Err(CallbackError::Tool(_))));
612    }
613
614    #[test]
615    fn test_callback_result_to_json() {
616        let ai_result = AiCallbackResult {
617            content: "Test response".to_string(),
618            model: "claude-sonnet-4".to_string(),
619            input_tokens: 100,
620            output_tokens: 50,
621            duration_ms: 500,
622            metadata: serde_json::json!({}),
623        };
624
625        let result = CallbackResult::Ai(ai_result);
626        let json = result.to_json();
627
628        assert!(json.get("content").is_some());
629    }
630}