matrixcode_core/matrixrpc/callback/
ai.rs1use std::sync::Arc;
7
8use serde::{Deserialize, Serialize};
9use serde_json::Value as JsonValue;
10
11use super::security::SecurityValidator;
12use crate::matrixrpc::{ErrorCode, JsonRpcError, JsonRpcId, JsonRpcResponse, ServiceId};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct AiCallbackRequest {
17 pub request_id: String,
19
20 pub service_id: ServiceId,
22
23 pub token: String,
25
26 pub prompt: String,
28
29 #[serde(default)]
31 pub context: JsonValue,
32
33 #[serde(default)]
35 pub model_config: AiModelConfig,
36
37 #[serde(default = "default_ai_timeout")]
39 pub timeout_ms: u64,
40}
41
42fn default_ai_timeout() -> u64 {
43 60_000 }
45
46#[derive(Debug, Clone, Default, Serialize, Deserialize)]
48pub struct AiModelConfig {
49 #[serde(default)]
51 pub model: Option<String>,
52
53 #[serde(default)]
55 pub temperature: Option<f32>,
56
57 #[serde(default)]
59 pub max_tokens: Option<u32>,
60
61 #[serde(default)]
63 pub system_prompt: Option<String>,
64
65 #[serde(default)]
67 pub stop_sequences: Option<Vec<String>>,
68
69 #[serde(default)]
71 pub stream: bool,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct AiCallbackResult {
77 pub content: String,
79
80 pub model: String,
82
83 pub input_tokens: u32,
85
86 pub output_tokens: u32,
88
89 pub duration_ms: u64,
91
92 #[serde(default)]
94 pub metadata: JsonValue,
95}
96
97#[derive(Debug, thiserror::Error)]
99pub enum AiCallbackError {
100 #[error("Security validation failed: {0}")]
102 SecurityFailed(String),
103
104 #[error("AI provider not available")]
106 ProviderNotAvailable,
107
108 #[error("Invalid prompt: {0}")]
110 InvalidPrompt(String),
111
112 #[error("Model '{0}' not found")]
114 ModelNotFound(String),
115
116 #[error("AI request timed out after {0}ms")]
118 Timeout(u64),
119
120 #[error("Provider error: {0}")]
122 ProviderError(String),
123
124 #[error("AI rate limit exceeded")]
126 RateLimitExceeded,
127
128 #[error("Internal error: {0}")]
130 Internal(String),
131}
132
133pub struct AiCallbackHandler {
137 security: Arc<SecurityValidator>,
139
140 default_model: String,
142
143 default_timeout_ms: u64,
145
146 max_tokens_limit: u32,
148}
149
150impl AiCallbackHandler {
151 pub fn new(security: Arc<SecurityValidator>) -> Self {
153 Self {
154 security,
155 default_model: "claude-sonnet-4".to_string(),
156 default_timeout_ms: 60_000,
157 max_tokens_limit: 4096,
158 }
159 }
160
161 pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
163 self.default_model = model.into();
164 self
165 }
166
167 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
169 self.default_timeout_ms = timeout_ms;
170 self
171 }
172
173 pub fn with_max_tokens_limit(mut self, limit: u32) -> Self {
175 self.max_tokens_limit = limit;
176 self
177 }
178
179 pub async fn handle(&self, request: AiCallbackRequest) -> Result<AiCallbackResult, AiCallbackError> {
181 let validation = self
183 .security
184 .validate(&request.token, &request.service_id, &request.request_id, "ai")
185 .await;
186
187 if !validation.is_valid {
188 return Err(AiCallbackError::SecurityFailed(
189 validation.error.unwrap_or_else(|| "Unknown security error".to_string()),
190 ));
191 }
192
193 if request.prompt.is_empty() {
195 return Err(AiCallbackError::InvalidPrompt("Prompt cannot be empty".to_string()));
196 }
197
198 let model = request
200 .model_config
201 .model
202 .clone()
203 .unwrap_or_else(|| self.default_model.clone());
204
205 let _max_tokens = request
207 .model_config
208 .max_tokens
209 .unwrap_or(1024)
210 .min(self.max_tokens_limit);
211
212 let _timeout = request.timeout_ms.max(self.default_timeout_ms);
214
215 let result = AiCallbackResult {
219 content: format!("AI response to: {}", &request.prompt[..100.min(request.prompt.len())]),
220 model,
221 input_tokens: request.prompt.len() as u32 / 4,
222 output_tokens: 100,
223 duration_ms: 500,
224 metadata: serde_json::json!({
225 "request_id": request.request_id,
226 "service_id": request.service_id.to_string(),
227 "temperature": request.model_config.temperature.unwrap_or(0.7),
228 }),
229 };
230
231 Ok(result)
232 }
233
234 pub fn create_error_response(&self, error: AiCallbackError, id: JsonRpcId) -> JsonRpcResponse {
236 let (code, message, data) = match error {
237 AiCallbackError::SecurityFailed(msg) => (
238 ErrorCode::PERMISSION_DENIED,
239 "Security validation failed".to_string(),
240 Some(serde_json::json!({ "reason": msg })),
241 ),
242 AiCallbackError::ProviderNotAvailable => (
243 ErrorCode::RESOURCE_NOT_FOUND,
244 "AI provider not available".to_string(),
245 None,
246 ),
247 AiCallbackError::InvalidPrompt(msg) => (
248 ErrorCode::INVALID_PARAMS,
249 "Invalid prompt".to_string(),
250 Some(serde_json::json!({ "reason": msg })),
251 ),
252 AiCallbackError::ModelNotFound(model) => (
253 ErrorCode::RESOURCE_NOT_FOUND,
254 format!("Model '{}' not found", model),
255 None,
256 ),
257 AiCallbackError::Timeout(ms) => (
258 ErrorCode::TIMEOUT_ERROR,
259 "AI request timed out".to_string(),
260 Some(serde_json::json!({ "timeout_ms": ms })),
261 ),
262 AiCallbackError::ProviderError(msg) => (
263 ErrorCode::INTERNAL_ERROR,
264 "Provider error".to_string(),
265 Some(serde_json::json!({ "error": msg })),
266 ),
267 AiCallbackError::RateLimitExceeded => (
268 ErrorCode::PERMISSION_DENIED,
269 "AI rate limit exceeded".to_string(),
270 None,
271 ),
272 AiCallbackError::Internal(msg) => (
273 ErrorCode::INTERNAL_ERROR,
274 msg,
275 None,
276 ),
277 };
278
279 JsonRpcResponse::error(
280 id,
281 JsonRpcError::with_data(code, message, data.unwrap_or(JsonValue::Null)),
282 )
283 }
284
285 pub fn is_model_available(&self, model: &str) -> bool {
287 matches!(
289 model,
290 "claude-opus-4" | "claude-sonnet-4" | "claude-haiku-4" | "claude-3-opus" | "claude-3-sonnet" | "claude-3-haiku"
291 )
292 }
293
294 pub fn get_available_models(&self) -> Vec<String> {
296 vec![
297 "claude-opus-4".to_string(),
298 "claude-sonnet-4".to_string(),
299 "claude-haiku-4".to_string(),
300 ]
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[tokio::test]
309 async fn test_ai_callback_handler_creation() {
310 let security = Arc::new(SecurityValidator::new());
311 let handler = AiCallbackHandler::new(security);
312
313 assert_eq!(handler.default_model, "claude-sonnet-4");
314 assert_eq!(handler.default_timeout_ms, 60_000);
315 }
316
317 #[tokio::test]
318 async fn test_ai_callback_with_custom_config() {
319 let security = Arc::new(SecurityValidator::new());
320 let handler = AiCallbackHandler::new(security)
321 .with_default_model("claude-opus-4")
322 .with_timeout(30_000)
323 .with_max_tokens_limit(2048);
324
325 assert_eq!(handler.default_model, "claude-opus-4");
326 assert_eq!(handler.default_timeout_ms, 30_000);
327 assert_eq!(handler.max_tokens_limit, 2048);
328 }
329
330 #[test]
331 fn test_ai_model_config_default() {
332 let config = AiModelConfig::default();
333 assert!(config.model.is_none());
334 assert!(config.temperature.is_none());
335 assert!(config.max_tokens.is_none());
336 }
337
338 #[test]
339 fn test_ai_callback_result_serialization() {
340 let result = AiCallbackResult {
341 content: "Test response".to_string(),
342 model: "claude-sonnet-4".to_string(),
343 input_tokens: 100,
344 output_tokens: 50,
345 duration_ms: 500,
346 metadata: serde_json::json!({}),
347 };
348
349 let json = serde_json::to_string(&result).unwrap();
350 assert!(json.contains("Test response"));
351 }
352
353 #[test]
354 fn test_is_model_available() {
355 let security = Arc::new(SecurityValidator::new());
356 let handler = AiCallbackHandler::new(security);
357
358 assert!(handler.is_model_available("claude-sonnet-4"));
359 assert!(handler.is_model_available("claude-opus-4"));
360 assert!(!handler.is_model_available("unknown-model"));
361 }
362
363 #[test]
364 fn test_get_available_models() {
365 let security = Arc::new(SecurityValidator::new());
366 let handler = AiCallbackHandler::new(security);
367
368 let models = handler.get_available_models();
369 assert!(models.contains(&"claude-sonnet-4".to_string()));
370 assert!(models.contains(&"claude-opus-4".to_string()));
371 }
372
373 #[tokio::test]
374 async fn test_ai_callback_security_validation() {
375 let security = Arc::new(SecurityValidator::new());
376 let handler = AiCallbackHandler::new(security.clone());
377
378 let service_id = ServiceId::new("test-service");
380 let request_id = "req-001".to_string();
381 let token = security
382 .generate_token(service_id.clone(), request_id.clone(), vec!["ai".to_string()])
383 .await
384 .unwrap();
385
386 let request = AiCallbackRequest {
387 request_id,
388 service_id,
389 token,
390 prompt: "Test prompt".to_string(),
391 context: serde_json::json!({}),
392 model_config: AiModelConfig::default(),
393 timeout_ms: 60_000,
394 };
395
396 let result = handler.handle(request).await;
397 assert!(result.is_ok());
398
399 let response = result.unwrap();
400 assert!(!response.content.is_empty());
401 }
402
403 #[tokio::test]
404 async fn test_ai_callback_invalid_token() {
405 let security = Arc::new(SecurityValidator::new());
406 let handler = AiCallbackHandler::new(security);
407
408 let request = AiCallbackRequest {
409 request_id: "req-001".to_string(),
410 service_id: ServiceId::new("test-service"),
411 token: "invalid_token".to_string(),
412 prompt: "Test prompt".to_string(),
413 context: serde_json::json!({}),
414 model_config: AiModelConfig::default(),
415 timeout_ms: 60_000,
416 };
417
418 let result = handler.handle(request).await;
419 assert!(matches!(result, Err(AiCallbackError::SecurityFailed(_))));
420 }
421}