1use async_trait::async_trait;
7use mockforge_core::{ai_response::AiResponseConfig, openapi::response::AiGenerator, Result};
8use mockforge_data::rag::{LlmProvider, RagConfig, RagEngine};
9use serde_json::Value;
10use std::sync::Arc;
11use tracing::{debug, warn};
12
13pub struct RagAiGenerator {
15 engine: Arc<tokio::sync::RwLock<RagEngine>>,
17}
18
19impl RagAiGenerator {
20 pub fn new(rag_config: RagConfig) -> Result<Self> {
28 debug!("Creating RAG AI generator with provider: {:?}", rag_config.provider);
29
30 let engine = RagEngine::new(rag_config);
32
33 Ok(Self {
34 engine: Arc::new(tokio::sync::RwLock::new(engine)),
35 })
36 }
37
38 pub fn from_env() -> Result<Self> {
48 let provider =
49 std::env::var("MOCKFORGE_AI_PROVIDER").unwrap_or_else(|_| "openai".to_string());
50
51 let provider = match provider.to_lowercase().as_str() {
52 "openai" => LlmProvider::OpenAI,
53 "anthropic" => LlmProvider::Anthropic,
54 "ollama" => LlmProvider::Ollama,
55 "openai-compatible" => LlmProvider::OpenAICompatible,
56 _ => {
57 warn!("Unknown AI provider '{}', defaulting to OpenAI", provider);
58 LlmProvider::OpenAI
59 }
60 };
61
62 let api_key = std::env::var("MOCKFORGE_AI_API_KEY").ok();
63
64 let model = std::env::var("MOCKFORGE_AI_MODEL").unwrap_or_else(|_| match provider {
65 LlmProvider::OpenAI => "gpt-3.5-turbo".to_string(),
66 LlmProvider::Anthropic => "claude-3-haiku-20240307".to_string(),
67 LlmProvider::Ollama => "llama2".to_string(),
68 LlmProvider::OpenAICompatible => "gpt-3.5-turbo".to_string(),
69 });
70
71 let api_endpoint =
72 std::env::var("MOCKFORGE_AI_ENDPOINT").unwrap_or_else(|_| match provider {
73 LlmProvider::OpenAI => "https://api.openai.com/v1/chat/completions".to_string(),
74 LlmProvider::Anthropic => "https://api.anthropic.com/v1/messages".to_string(),
75 LlmProvider::Ollama => "http://localhost:11434/api/generate".to_string(),
76 LlmProvider::OpenAICompatible => {
77 "http://localhost:8080/v1/chat/completions".to_string()
78 }
79 });
80
81 let temperature = std::env::var("MOCKFORGE_AI_TEMPERATURE")
82 .ok()
83 .and_then(|s| s.parse::<f64>().ok())
84 .unwrap_or(0.7);
85
86 let max_tokens = std::env::var("MOCKFORGE_AI_MAX_TOKENS")
87 .ok()
88 .and_then(|s| s.parse::<usize>().ok())
89 .unwrap_or(1024);
90
91 let config = RagConfig {
92 provider,
93 api_key,
94 model,
95 api_endpoint,
96 temperature,
97 max_tokens,
98 ..Default::default()
99 };
100
101 debug!("Creating RAG AI generator from environment variables");
102 Self::new(config)
103 }
104}
105
106#[async_trait]
107impl AiGenerator for RagAiGenerator {
108 async fn generate(&self, prompt: &str, config: &AiResponseConfig) -> Result<Value> {
109 debug!("Generating AI response with RAG engine");
110
111 let mut engine = self.engine.write().await;
113
114 let mut engine_config = engine.config().clone();
116 engine_config.temperature = config.temperature as f64;
117 engine_config.max_tokens = config.max_tokens;
118
119 engine.update_config(engine_config);
121
122 match engine.generate_text(prompt).await {
124 Ok(response_text) => {
125 debug!("RAG engine generated response ({} chars)", response_text.len());
126
127 match serde_json::from_str::<Value>(&response_text) {
129 Ok(json_value) => Ok(json_value),
130 Err(_) => {
131 if let Some(start) = response_text.find('{') {
133 if let Some(end) = response_text.rfind('}') {
134 let json_str = &response_text[start..=end];
135 match serde_json::from_str::<Value>(json_str) {
136 Ok(json_value) => Ok(json_value),
137 Err(_) => {
138 Ok(serde_json::json!({
140 "response": response_text,
141 "note": "Response was not valid JSON, wrapped in object"
142 }))
143 }
144 }
145 } else {
146 Ok(serde_json::json!({
147 "response": response_text,
148 "note": "Response was not valid JSON, wrapped in object"
149 }))
150 }
151 } else {
152 Ok(serde_json::json!({
153 "response": response_text,
154 "note": "Response was not valid JSON, wrapped in object"
155 }))
156 }
157 }
158 }
159 }
160 Err(e) => {
161 warn!("RAG engine generation failed: {}", e);
162 Err(mockforge_core::Error::Config {
163 message: format!("RAG engine generation failed: {}", e),
164 })
165 }
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
177 fn test_rag_generator_creation() {
178 let config = RagConfig {
179 provider: LlmProvider::Ollama,
180 api_key: None,
181 model: "llama2".to_string(),
182 api_endpoint: "http://localhost:11434/api/generate".to_string(),
183 ..Default::default()
184 };
185
186 let result = RagAiGenerator::new(config);
187 assert!(result.is_ok());
188 }
189
190 #[test]
191 fn test_rag_generator_creation_openai() {
192 let config = RagConfig {
193 provider: LlmProvider::OpenAI,
194 api_key: Some("test-api-key".to_string()),
195 model: "gpt-4".to_string(),
196 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
197 ..Default::default()
198 };
199
200 let result = RagAiGenerator::new(config);
201 assert!(result.is_ok());
202 }
203
204 #[test]
205 fn test_rag_generator_creation_anthropic() {
206 let config = RagConfig {
207 provider: LlmProvider::Anthropic,
208 api_key: Some("test-api-key".to_string()),
209 model: "claude-3-opus".to_string(),
210 api_endpoint: "https://api.anthropic.com/v1/messages".to_string(),
211 ..Default::default()
212 };
213
214 let result = RagAiGenerator::new(config);
215 assert!(result.is_ok());
216 }
217
218 #[test]
219 fn test_rag_generator_creation_openai_compatible() {
220 let config = RagConfig {
221 provider: LlmProvider::OpenAICompatible,
222 api_key: None,
223 model: "local-model".to_string(),
224 api_endpoint: "http://localhost:8080/v1/chat/completions".to_string(),
225 ..Default::default()
226 };
227
228 let result = RagAiGenerator::new(config);
229 assert!(result.is_ok());
230 }
231
232 #[test]
233 fn test_rag_generator_creation_with_custom_settings() {
234 let config = RagConfig {
235 provider: LlmProvider::Ollama,
236 api_key: None,
237 model: "codellama".to_string(),
238 api_endpoint: "http://localhost:11434/api/generate".to_string(),
239 temperature: 0.5,
240 max_tokens: 2048,
241 ..Default::default()
242 };
243
244 let result = RagAiGenerator::new(config);
245 assert!(result.is_ok());
246 }
247
248 #[test]
249 fn test_rag_generator_creation_with_low_temperature() {
250 let config = RagConfig {
251 provider: LlmProvider::OpenAI,
252 api_key: Some("test-key".to_string()),
253 model: "gpt-3.5-turbo".to_string(),
254 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
255 temperature: 0.0,
256 max_tokens: 512,
257 ..Default::default()
258 };
259
260 let result = RagAiGenerator::new(config);
261 assert!(result.is_ok());
262 }
263
264 #[test]
265 fn test_rag_generator_creation_with_high_temperature() {
266 let config = RagConfig {
267 provider: LlmProvider::OpenAI,
268 api_key: Some("test-key".to_string()),
269 model: "gpt-4".to_string(),
270 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
271 temperature: 1.0,
272 max_tokens: 4096,
273 ..Default::default()
274 };
275
276 let result = RagAiGenerator::new(config);
277 assert!(result.is_ok());
278 }
279
280 #[test]
283 fn test_rag_config_default() {
284 let config = RagConfig::default();
285 assert!(config.temperature >= 0.0);
287 assert!(config.max_tokens > 0);
288 }
289
290 #[test]
291 fn test_rag_config_clone() {
292 let config = RagConfig {
293 provider: LlmProvider::Ollama,
294 api_key: Some("secret".to_string()),
295 model: "llama2".to_string(),
296 api_endpoint: "http://localhost:11434/api/generate".to_string(),
297 temperature: 0.7,
298 max_tokens: 1024,
299 ..Default::default()
300 };
301
302 let cloned = config.clone();
303 assert_eq!(cloned.model, config.model);
304 assert_eq!(cloned.api_key, config.api_key);
305 }
306
307 #[test]
310 fn test_llm_provider_openai() {
311 let provider = LlmProvider::OpenAI;
312 let config = RagConfig {
313 provider,
314 ..Default::default()
315 };
316 assert!(matches!(config.provider, LlmProvider::OpenAI));
317 }
318
319 #[test]
320 fn test_llm_provider_anthropic() {
321 let provider = LlmProvider::Anthropic;
322 let config = RagConfig {
323 provider,
324 ..Default::default()
325 };
326 assert!(matches!(config.provider, LlmProvider::Anthropic));
327 }
328
329 #[test]
330 fn test_llm_provider_ollama() {
331 let provider = LlmProvider::Ollama;
332 let config = RagConfig {
333 provider,
334 ..Default::default()
335 };
336 assert!(matches!(config.provider, LlmProvider::Ollama));
337 }
338
339 #[test]
340 fn test_llm_provider_openai_compatible() {
341 let provider = LlmProvider::OpenAICompatible;
342 let config = RagConfig {
343 provider,
344 ..Default::default()
345 };
346 assert!(matches!(config.provider, LlmProvider::OpenAICompatible));
347 }
348
349 #[tokio::test]
352 async fn test_generate_fallback_to_json() {
353 let config = RagConfig {
357 provider: LlmProvider::Ollama,
358 api_key: None,
359 model: "test-model".to_string(),
360 api_endpoint: "http://localhost:11434/api/generate".to_string(),
361 ..Default::default()
362 };
363
364 let generator = RagAiGenerator::new(config);
367 assert!(generator.is_ok());
368 }
369
370 #[tokio::test]
371 async fn test_generator_engine_access() {
372 let config = RagConfig {
373 provider: LlmProvider::Ollama,
374 api_key: None,
375 model: "llama2".to_string(),
376 api_endpoint: "http://localhost:11434/api/generate".to_string(),
377 temperature: 0.8,
378 max_tokens: 512,
379 ..Default::default()
380 };
381
382 let generator = RagAiGenerator::new(config).unwrap();
383 let engine = generator.engine.read().await;
385 let engine_config = engine.config();
386 assert_eq!(engine_config.model, "llama2");
387 }
388
389 #[tokio::test]
390 async fn test_generator_can_be_cloned_via_arc() {
391 let config = RagConfig {
392 provider: LlmProvider::OpenAI,
393 api_key: Some("test".to_string()),
394 model: "gpt-3.5-turbo".to_string(),
395 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
396 ..Default::default()
397 };
398
399 let generator = RagAiGenerator::new(config).unwrap();
400 let engine_clone = generator.engine.clone();
402 assert!(Arc::strong_count(&engine_clone) >= 2);
403 }
404
405 #[test]
408 fn test_ai_response_config_with_generator() {
409 let ai_config = AiResponseConfig {
411 temperature: 0.7,
412 max_tokens: 1024,
413 ..Default::default()
414 };
415
416 assert!((ai_config.temperature - 0.7).abs() < 0.001);
417 assert_eq!(ai_config.max_tokens, 1024);
418 }
419
420 #[test]
421 fn test_ai_response_config_low_temp() {
422 let ai_config = AiResponseConfig {
423 temperature: 0.0,
424 max_tokens: 256,
425 ..Default::default()
426 };
427
428 assert!((ai_config.temperature - 0.0).abs() < 0.001);
429 }
430
431 #[test]
432 fn test_ai_response_config_high_tokens() {
433 let ai_config = AiResponseConfig {
434 temperature: 0.5,
435 max_tokens: 8192,
436 ..Default::default()
437 };
438
439 assert_eq!(ai_config.max_tokens, 8192);
440 }
441
442 #[test]
445 fn test_generator_with_empty_model_name() {
446 let config = RagConfig {
447 provider: LlmProvider::Ollama,
448 api_key: None,
449 model: String::new(), api_endpoint: "http://localhost:11434/api/generate".to_string(),
451 ..Default::default()
452 };
453
454 let result = RagAiGenerator::new(config);
456 assert!(result.is_ok());
457 }
458
459 #[test]
460 fn test_generator_with_empty_endpoint() {
461 let config = RagConfig {
462 provider: LlmProvider::Ollama,
463 api_key: None,
464 model: "llama2".to_string(),
465 api_endpoint: String::new(), ..Default::default()
467 };
468
469 let result = RagAiGenerator::new(config);
471 assert!(result.is_ok());
472 }
473
474 #[test]
475 fn test_generator_with_no_api_key_openai() {
476 let config = RagConfig {
477 provider: LlmProvider::OpenAI,
478 api_key: None, model: "gpt-4".to_string(),
480 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
481 ..Default::default()
482 };
483
484 let result = RagAiGenerator::new(config);
486 assert!(result.is_ok());
487 }
488
489 #[tokio::test]
492 async fn test_multiple_generators_different_providers() {
493 let openai_config = RagConfig {
494 provider: LlmProvider::OpenAI,
495 api_key: Some("test-key".to_string()),
496 model: "gpt-4".to_string(),
497 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
498 ..Default::default()
499 };
500
501 let ollama_config = RagConfig {
502 provider: LlmProvider::Ollama,
503 api_key: None,
504 model: "llama2".to_string(),
505 api_endpoint: "http://localhost:11434/api/generate".to_string(),
506 ..Default::default()
507 };
508
509 let anthropic_config = RagConfig {
510 provider: LlmProvider::Anthropic,
511 api_key: Some("test-key".to_string()),
512 model: "claude-3-haiku-20240307".to_string(),
513 api_endpoint: "https://api.anthropic.com/v1/messages".to_string(),
514 ..Default::default()
515 };
516
517 assert!(RagAiGenerator::new(openai_config).is_ok());
519 assert!(RagAiGenerator::new(ollama_config).is_ok());
520 assert!(RagAiGenerator::new(anthropic_config).is_ok());
521 }
522
523 #[tokio::test]
524 async fn test_generator_engine_update() {
525 let config = RagConfig {
526 provider: LlmProvider::Ollama,
527 api_key: None,
528 model: "llama2".to_string(),
529 api_endpoint: "http://localhost:11434/api/generate".to_string(),
530 temperature: 0.7,
531 max_tokens: 1024,
532 ..Default::default()
533 };
534
535 let generator = RagAiGenerator::new(config).unwrap();
536
537 {
539 let engine = generator.engine.read().await;
540 let engine_config = engine.config();
541 assert!((engine_config.temperature - 0.7).abs() < 0.001);
542 assert_eq!(engine_config.max_tokens, 1024);
543 }
544
545 {
547 let mut engine = generator.engine.write().await;
548 let mut new_config = engine.config().clone();
549 new_config.temperature = 0.5;
550 new_config.max_tokens = 2048;
551 engine.update_config(new_config);
552 }
553
554 {
556 let engine = generator.engine.read().await;
557 let engine_config = engine.config();
558 assert!((engine_config.temperature - 0.5).abs() < 0.001);
559 assert_eq!(engine_config.max_tokens, 2048);
560 }
561 }
562}