Skip to main content

matrixcode_core/matrixrpc/callback/
ai.rs

1//! AI Callback Handler
2//!
3//! Handles AI callback requests from external services.
4//! Enables external nodes to call MatrixCode's AI providers.
5
6use std::sync::Arc;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value as JsonValue;
10
11use super::security::SecurityValidator;
12use crate::matrixrpc::{ErrorCode, JsonRpcError, JsonRpcId, JsonRpcResponse, ServiceId};
13
14/// AI callback request
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AiCallbackRequest {
17    /// Request ID from original node execution
18    pub request_id: String,
19
20    /// Service ID making the callback
21    pub service_id: ServiceId,
22
23    /// Security token
24    pub token: String,
25
26    /// Prompt to send to AI
27    pub prompt: String,
28
29    /// Context data
30    #[serde(default)]
31    pub context: JsonValue,
32
33    /// Model configuration
34    #[serde(default)]
35    pub model_config: AiModelConfig,
36
37    /// Timeout in milliseconds
38    #[serde(default = "default_ai_timeout")]
39    pub timeout_ms: u64,
40}
41
42fn default_ai_timeout() -> u64 {
43    60_000 // 60 seconds
44}
45
46/// AI model configuration
47#[derive(Debug, Clone, Default, Serialize, Deserialize)]
48pub struct AiModelConfig {
49    /// Model name (e.g., "claude-sonnet-4")
50    #[serde(default)]
51    pub model: Option<String>,
52
53    /// Temperature (0.0 - 1.0)
54    #[serde(default)]
55    pub temperature: Option<f32>,
56
57    /// Max tokens to generate
58    #[serde(default)]
59    pub max_tokens: Option<u32>,
60
61    /// System prompt override
62    #[serde(default)]
63    pub system_prompt: Option<String>,
64
65    /// Stop sequences
66    #[serde(default)]
67    pub stop_sequences: Option<Vec<String>>,
68
69    /// Enable streaming
70    #[serde(default)]
71    pub stream: bool,
72}
73
74/// AI callback result
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct AiCallbackResult {
77    /// Response content
78    pub content: String,
79
80    /// Model used
81    pub model: String,
82
83    /// Input tokens
84    pub input_tokens: u32,
85
86    /// Output tokens
87    pub output_tokens: u32,
88
89    /// Duration in milliseconds
90    pub duration_ms: u64,
91
92    /// Additional metadata
93    #[serde(default)]
94    pub metadata: JsonValue,
95}
96
97/// AI callback error
98#[derive(Debug, thiserror::Error)]
99pub enum AiCallbackError {
100    /// Security validation failed
101    #[error("Security validation failed: {0}")]
102    SecurityFailed(String),
103
104    /// Provider not available
105    #[error("AI provider not available")]
106    ProviderNotAvailable,
107
108    /// Invalid prompt
109    #[error("Invalid prompt: {0}")]
110    InvalidPrompt(String),
111
112    /// Model not found
113    #[error("Model '{0}' not found")]
114    ModelNotFound(String),
115
116    /// Timeout
117    #[error("AI request timed out after {0}ms")]
118    Timeout(u64),
119
120    /// Provider error
121    #[error("Provider error: {0}")]
122    ProviderError(String),
123
124    /// Rate limit exceeded
125    #[error("AI rate limit exceeded")]
126    RateLimitExceeded,
127
128    /// Internal error
129    #[error("Internal error: {0}")]
130    Internal(String),
131}
132
133/// AI Callback Handler
134///
135/// Handles AI callback requests from external extension services.
136pub struct AiCallbackHandler {
137    /// Security validator
138    security: Arc<SecurityValidator>,
139
140    /// Default model
141    default_model: String,
142
143    /// Default timeout
144    default_timeout_ms: u64,
145
146    /// Maximum tokens allowed
147    max_tokens_limit: u32,
148}
149
150impl AiCallbackHandler {
151    /// Create a new AI callback handler
152    pub fn new(security: Arc<SecurityValidator>) -> Self {
153        Self {
154            security,
155            default_model: "claude-sonnet-4".to_string(),
156            default_timeout_ms: 60_000,
157            max_tokens_limit: 4096,
158        }
159    }
160
161    /// Set default model
162    pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
163        self.default_model = model.into();
164        self
165    }
166
167    /// Set default timeout
168    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
169        self.default_timeout_ms = timeout_ms;
170        self
171    }
172
173    /// Set max tokens limit
174    pub fn with_max_tokens_limit(mut self, limit: u32) -> Self {
175        self.max_tokens_limit = limit;
176        self
177    }
178
179    /// Handle an AI callback request
180    pub async fn handle(&self, request: AiCallbackRequest) -> Result<AiCallbackResult, AiCallbackError> {
181        // Validate security
182        let validation = self
183            .security
184            .validate(&request.token, &request.service_id, &request.request_id, "ai")
185            .await;
186
187        if !validation.is_valid {
188            return Err(AiCallbackError::SecurityFailed(
189                validation.error.unwrap_or_else(|| "Unknown security error".to_string()),
190            ));
191        }
192
193        // Validate prompt
194        if request.prompt.is_empty() {
195            return Err(AiCallbackError::InvalidPrompt("Prompt cannot be empty".to_string()));
196        }
197
198        // Get model to use
199        let model = request
200            .model_config
201            .model
202            .clone()
203            .unwrap_or_else(|| self.default_model.clone());
204
205        // Get max tokens with limit
206        let _max_tokens = request
207            .model_config
208            .max_tokens
209            .unwrap_or(1024)
210            .min(self.max_tokens_limit);
211
212        // Get timeout
213        let _timeout = request.timeout_ms.max(self.default_timeout_ms);
214
215        // Build response
216        // Note: In real implementation, this would call the actual AI provider
217        // Here we provide a mock implementation for testing
218        let result = AiCallbackResult {
219            content: format!("AI response to: {}", &request.prompt[..100.min(request.prompt.len())]),
220            model,
221            input_tokens: request.prompt.len() as u32 / 4,
222            output_tokens: 100,
223            duration_ms: 500,
224            metadata: serde_json::json!({
225                "request_id": request.request_id,
226                "service_id": request.service_id.to_string(),
227                "temperature": request.model_config.temperature.unwrap_or(0.7),
228            }),
229        };
230
231        Ok(result)
232    }
233
234    /// Create a JSON-RPC error response for AI callback failures
235    pub fn create_error_response(&self, error: AiCallbackError, id: JsonRpcId) -> JsonRpcResponse {
236        let (code, message, data) = match error {
237            AiCallbackError::SecurityFailed(msg) => (
238                ErrorCode::PERMISSION_DENIED,
239                "Security validation failed".to_string(),
240                Some(serde_json::json!({ "reason": msg })),
241            ),
242            AiCallbackError::ProviderNotAvailable => (
243                ErrorCode::RESOURCE_NOT_FOUND,
244                "AI provider not available".to_string(),
245                None,
246            ),
247            AiCallbackError::InvalidPrompt(msg) => (
248                ErrorCode::INVALID_PARAMS,
249                "Invalid prompt".to_string(),
250                Some(serde_json::json!({ "reason": msg })),
251            ),
252            AiCallbackError::ModelNotFound(model) => (
253                ErrorCode::RESOURCE_NOT_FOUND,
254                format!("Model '{}' not found", model),
255                None,
256            ),
257            AiCallbackError::Timeout(ms) => (
258                ErrorCode::TIMEOUT_ERROR,
259                "AI request timed out".to_string(),
260                Some(serde_json::json!({ "timeout_ms": ms })),
261            ),
262            AiCallbackError::ProviderError(msg) => (
263                ErrorCode::INTERNAL_ERROR,
264                "Provider error".to_string(),
265                Some(serde_json::json!({ "error": msg })),
266            ),
267            AiCallbackError::RateLimitExceeded => (
268                ErrorCode::PERMISSION_DENIED,
269                "AI rate limit exceeded".to_string(),
270                None,
271            ),
272            AiCallbackError::Internal(msg) => (
273                ErrorCode::INTERNAL_ERROR,
274                msg,
275                None,
276            ),
277        };
278
279        JsonRpcResponse::error(
280            id,
281            JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)),
282        )
283    }
284
285    /// Check if a model is available
286    pub fn is_model_available(&self, model: &str) -> bool {
287        // Check against known models
288        matches!(
289            model,
290            "claude-opus-4" | "claude-sonnet-4" | "claude-haiku-4" | "claude-3-opus" | "claude-3-sonnet" | "claude-3-haiku"
291        )
292    }
293
294    /// Get available models
295    pub fn get_available_models(&self) -> Vec<String> {
296        vec![
297            "claude-opus-4".to_string(),
298            "claude-sonnet-4".to_string(),
299            "claude-haiku-4".to_string(),
300        ]
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[tokio::test]
309    async fn test_ai_callback_handler_creation() {
310        let security = Arc::new(SecurityValidator::new());
311        let handler = AiCallbackHandler::new(security);
312
313        assert_eq!(handler.default_model, "claude-sonnet-4");
314        assert_eq!(handler.default_timeout_ms, 60_000);
315    }
316
317    #[tokio::test]
318    async fn test_ai_callback_with_custom_config() {
319        let security = Arc::new(SecurityValidator::new());
320        let handler = AiCallbackHandler::new(security)
321            .with_default_model("claude-opus-4")
322            .with_timeout(30_000)
323            .with_max_tokens_limit(2048);
324
325        assert_eq!(handler.default_model, "claude-opus-4");
326        assert_eq!(handler.default_timeout_ms, 30_000);
327        assert_eq!(handler.max_tokens_limit, 2048);
328    }
329
330    #[test]
331    fn test_ai_model_config_default() {
332        let config = AiModelConfig::default();
333        assert!(config.model.is_none());
334        assert!(config.temperature.is_none());
335        assert!(config.max_tokens.is_none());
336    }
337
338    #[test]
339    fn test_ai_callback_result_serialization() {
340        let result = AiCallbackResult {
341            content: "Test response".to_string(),
342            model: "claude-sonnet-4".to_string(),
343            input_tokens: 100,
344            output_tokens: 50,
345            duration_ms: 500,
346            metadata: serde_json::json!({}),
347        };
348
349        let json = serde_json::to_string(&result).unwrap();
350        assert!(json.contains("Test response"));
351    }
352
353    #[test]
354    fn test_is_model_available() {
355        let security = Arc::new(SecurityValidator::new());
356        let handler = AiCallbackHandler::new(security);
357
358        assert!(handler.is_model_available("claude-sonnet-4"));
359        assert!(handler.is_model_available("claude-opus-4"));
360        assert!(!handler.is_model_available("unknown-model"));
361    }
362
363    #[test]
364    fn test_get_available_models() {
365        let security = Arc::new(SecurityValidator::new());
366        let handler = AiCallbackHandler::new(security);
367
368        let models = handler.get_available_models();
369        assert!(models.contains(&"claude-sonnet-4".to_string()));
370        assert!(models.contains(&"claude-opus-4".to_string()));
371    }
372
373    #[tokio::test]
374    async fn test_ai_callback_security_validation() {
375        let security = Arc::new(SecurityValidator::new());
376        let handler = AiCallbackHandler::new(security.clone());
377
378        // Generate token
379        let service_id = ServiceId::new("test-service");
380        let request_id = "req-001".to_string();
381        let token = security
382            .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string()])
383            .await
384            .unwrap();
385
386        let request = AiCallbackRequest {
387            request_id,
388            service_id,
389            token,
390            prompt: "Test prompt".to_string(),
391            context: serde_json::json!({}),
392            model_config: AiModelConfig::default(),
393            timeout_ms: 60_000,
394        };
395
396        let result = handler.handle(request).await;
397        assert!(result.is_ok());
398
399        let response = result.unwrap();
400        assert!(!response.content.is_empty());
401    }
402
403    #[tokio::test]
404    async fn test_ai_callback_invalid_token() {
405        let security = Arc::new(SecurityValidator::new());
406        let handler = AiCallbackHandler::new(security);
407
408        let request = AiCallbackRequest {
409            request_id: "req-001".to_string(),
410            service_id: ServiceId::new("test-service"),
411            token: "invalid_token".to_string(),
412            prompt: "Test prompt".to_string(),
413            context: serde_json::json!({}),
414            model_config: AiModelConfig::default(),
415            timeout_ms: 60_000,
416        };
417
418        let result = handler.handle(request).await;
419        assert!(matches!(result, Err(AiCallbackError::SecurityFailed(_))));
420    }
421}