1use async_trait::async_trait;
14use mockforge_data::rag::{LlmProvider, RagConfig, RagEngine};
15use mockforge_foundation::ai_response::AiResponseConfig;
16use mockforge_foundation::{Error, Result};
17use mockforge_openapi::response::AiGenerator;
18use serde_json::Value;
19use std::sync::Arc;
20use tracing::{debug, warn};
21
22pub struct RagAiGenerator {
24 engine: Arc<tokio::sync::RwLock<RagEngine>>,
26}
27
28impl RagAiGenerator {
29 pub fn new(rag_config: RagConfig) -> Result<Self> {
37 debug!("Creating RAG AI generator with provider: {:?}", rag_config.provider);
38
39 let engine = RagEngine::new(rag_config);
41
42 Ok(Self {
43 engine: Arc::new(tokio::sync::RwLock::new(engine)),
44 })
45 }
46
47 pub fn from_env() -> Result<Self> {
57 let provider =
58 std::env::var("MOCKFORGE_AI_PROVIDER").unwrap_or_else(|_| "openai".to_string());
59
60 let provider = match provider.to_lowercase().as_str() {
61 "openai" => LlmProvider::OpenAI,
62 "anthropic" => LlmProvider::Anthropic,
63 "ollama" => LlmProvider::Ollama,
64 "openai-compatible" => LlmProvider::OpenAICompatible,
65 _ => {
66 warn!("Unknown AI provider '{}', defaulting to OpenAI", provider);
67 LlmProvider::OpenAI
68 }
69 };
70
71 let api_key = std::env::var("MOCKFORGE_AI_API_KEY").ok();
72
73 let model = std::env::var("MOCKFORGE_AI_MODEL").unwrap_or_else(|_| match provider {
74 LlmProvider::OpenAI => "gpt-3.5-turbo".to_string(),
75 LlmProvider::Anthropic => "claude-3-haiku-20240307".to_string(),
76 LlmProvider::Ollama => "llama2".to_string(),
77 LlmProvider::OpenAICompatible => "gpt-3.5-turbo".to_string(),
78 });
79
80 let api_endpoint =
81 std::env::var("MOCKFORGE_AI_ENDPOINT").unwrap_or_else(|_| match provider {
82 LlmProvider::OpenAI => "https://api.openai.com/v1/chat/completions".to_string(),
83 LlmProvider::Anthropic => "https://api.anthropic.com/v1/messages".to_string(),
84 LlmProvider::Ollama => "http://localhost:11434/api/generate".to_string(),
85 LlmProvider::OpenAICompatible => {
86 "http://localhost:8080/v1/chat/completions".to_string()
87 }
88 });
89
90 let temperature = std::env::var("MOCKFORGE_AI_TEMPERATURE")
91 .ok()
92 .and_then(|s| s.parse::<f64>().ok())
93 .unwrap_or(0.7);
94
95 let max_tokens = std::env::var("MOCKFORGE_AI_MAX_TOKENS")
96 .ok()
97 .and_then(|s| s.parse::<usize>().ok())
98 .unwrap_or(1024);
99
100 let config = RagConfig {
101 provider,
102 api_key,
103 model,
104 api_endpoint,
105 temperature,
106 max_tokens,
107 ..Default::default()
108 };
109
110 debug!("Creating RAG AI generator from environment variables");
111 Self::new(config)
112 }
113}
114
115#[async_trait]
116impl AiGenerator for RagAiGenerator {
117 async fn generate(&self, prompt: &str, config: &AiResponseConfig) -> Result<Value> {
118 debug!("Generating AI response with RAG engine");
119
120 let mut engine = self.engine.write().await;
122
123 let mut engine_config = engine.config().clone();
125 engine_config.temperature = config.temperature as f64;
126 engine_config.max_tokens = config.max_tokens;
127
128 engine.update_config(engine_config);
130
131 match engine.generate_text(prompt).await {
133 Ok(response_text) => {
134 debug!("RAG engine generated response ({} chars)", response_text.len());
135
136 match serde_json::from_str::<Value>(&response_text) {
138 Ok(json_value) => Ok(json_value),
139 Err(_) => {
140 if let Some(start) = response_text.find('{') {
142 if let Some(end) = response_text.rfind('}') {
143 let json_str = &response_text[start..=end];
144 match serde_json::from_str::<Value>(json_str) {
145 Ok(json_value) => Ok(json_value),
146 Err(_) => {
147 Ok(serde_json::json!({
149 "response": response_text,
150 "note": "Response was not valid JSON, wrapped in object"
151 }))
152 }
153 }
154 } else {
155 Ok(serde_json::json!({
156 "response": response_text,
157 "note": "Response was not valid JSON, wrapped in object"
158 }))
159 }
160 } else {
161 Ok(serde_json::json!({
162 "response": response_text,
163 "note": "Response was not valid JSON, wrapped in object"
164 }))
165 }
166 }
167 }
168 }
169 Err(e) => {
170 warn!("RAG engine generation failed: {}", e);
171 Err(Error::Config {
172 message: format!("RAG engine generation failed: {}", e),
173 })
174 }
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
186 fn test_rag_generator_creation() {
187 let config = RagConfig {
188 provider: LlmProvider::Ollama,
189 api_key: None,
190 model: "llama2".to_string(),
191 api_endpoint: "http://localhost:11434/api/generate".to_string(),
192 ..Default::default()
193 };
194
195 let result = RagAiGenerator::new(config);
196 assert!(result.is_ok());
197 }
198
199 #[test]
200 fn test_rag_generator_creation_openai() {
201 let config = RagConfig {
202 provider: LlmProvider::OpenAI,
203 api_key: Some("test-api-key".to_string()),
204 model: "gpt-4".to_string(),
205 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
206 ..Default::default()
207 };
208
209 let result = RagAiGenerator::new(config);
210 assert!(result.is_ok());
211 }
212
213 #[test]
214 fn test_rag_generator_creation_anthropic() {
215 let config = RagConfig {
216 provider: LlmProvider::Anthropic,
217 api_key: Some("test-api-key".to_string()),
218 model: "claude-3-opus".to_string(),
219 api_endpoint: "https://api.anthropic.com/v1/messages".to_string(),
220 ..Default::default()
221 };
222
223 let result = RagAiGenerator::new(config);
224 assert!(result.is_ok());
225 }
226
227 #[test]
228 fn test_rag_generator_creation_openai_compatible() {
229 let config = RagConfig {
230 provider: LlmProvider::OpenAICompatible,
231 api_key: None,
232 model: "local-model".to_string(),
233 api_endpoint: "http://localhost:8080/v1/chat/completions".to_string(),
234 ..Default::default()
235 };
236
237 let result = RagAiGenerator::new(config);
238 assert!(result.is_ok());
239 }
240
241 #[test]
242 fn test_rag_generator_creation_with_custom_settings() {
243 let config = RagConfig {
244 provider: LlmProvider::Ollama,
245 api_key: None,
246 model: "codellama".to_string(),
247 api_endpoint: "http://localhost:11434/api/generate".to_string(),
248 temperature: 0.5,
249 max_tokens: 2048,
250 ..Default::default()
251 };
252
253 let result = RagAiGenerator::new(config);
254 assert!(result.is_ok());
255 }
256
257 #[test]
258 fn test_rag_generator_creation_with_low_temperature() {
259 let config = RagConfig {
260 provider: LlmProvider::OpenAI,
261 api_key: Some("test-key".to_string()),
262 model: "gpt-3.5-turbo".to_string(),
263 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
264 temperature: 0.0,
265 max_tokens: 512,
266 ..Default::default()
267 };
268
269 let result = RagAiGenerator::new(config);
270 assert!(result.is_ok());
271 }
272
273 #[test]
274 fn test_rag_generator_creation_with_high_temperature() {
275 let config = RagConfig {
276 provider: LlmProvider::OpenAI,
277 api_key: Some("test-key".to_string()),
278 model: "gpt-4".to_string(),
279 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
280 temperature: 1.0,
281 max_tokens: 4096,
282 ..Default::default()
283 };
284
285 let result = RagAiGenerator::new(config);
286 assert!(result.is_ok());
287 }
288
289 #[test]
292 fn test_rag_config_default() {
293 let config = RagConfig::default();
294 assert!(config.temperature >= 0.0);
296 assert!(config.max_tokens > 0);
297 }
298
299 #[test]
300 fn test_rag_config_clone() {
301 let config = RagConfig {
302 provider: LlmProvider::Ollama,
303 api_key: Some("secret".to_string()),
304 model: "llama2".to_string(),
305 api_endpoint: "http://localhost:11434/api/generate".to_string(),
306 temperature: 0.7,
307 max_tokens: 1024,
308 ..Default::default()
309 };
310
311 let cloned = config.clone();
312 assert_eq!(cloned.model, config.model);
313 assert_eq!(cloned.api_key, config.api_key);
314 }
315
316 #[test]
319 fn test_llm_provider_openai() {
320 let provider = LlmProvider::OpenAI;
321 let config = RagConfig {
322 provider,
323 ..Default::default()
324 };
325 assert!(matches!(config.provider, LlmProvider::OpenAI));
326 }
327
328 #[test]
329 fn test_llm_provider_anthropic() {
330 let provider = LlmProvider::Anthropic;
331 let config = RagConfig {
332 provider,
333 ..Default::default()
334 };
335 assert!(matches!(config.provider, LlmProvider::Anthropic));
336 }
337
338 #[test]
339 fn test_llm_provider_ollama() {
340 let provider = LlmProvider::Ollama;
341 let config = RagConfig {
342 provider,
343 ..Default::default()
344 };
345 assert!(matches!(config.provider, LlmProvider::Ollama));
346 }
347
348 #[test]
349 fn test_llm_provider_openai_compatible() {
350 let provider = LlmProvider::OpenAICompatible;
351 let config = RagConfig {
352 provider,
353 ..Default::default()
354 };
355 assert!(matches!(config.provider, LlmProvider::OpenAICompatible));
356 }
357
358 #[tokio::test]
361 async fn test_generate_fallback_to_json() {
362 let config = RagConfig {
366 provider: LlmProvider::Ollama,
367 api_key: None,
368 model: "test-model".to_string(),
369 api_endpoint: "http://localhost:11434/api/generate".to_string(),
370 ..Default::default()
371 };
372
373 let generator = RagAiGenerator::new(config);
376 assert!(generator.is_ok());
377 }
378
379 #[tokio::test]
380 async fn test_generator_engine_access() {
381 let config = RagConfig {
382 provider: LlmProvider::Ollama,
383 api_key: None,
384 model: "llama2".to_string(),
385 api_endpoint: "http://localhost:11434/api/generate".to_string(),
386 temperature: 0.8,
387 max_tokens: 512,
388 ..Default::default()
389 };
390
391 let generator = RagAiGenerator::new(config).unwrap();
392 let engine = generator.engine.read().await;
394 let engine_config = engine.config();
395 assert_eq!(engine_config.model, "llama2");
396 }
397
398 #[tokio::test]
399 async fn test_generator_can_be_cloned_via_arc() {
400 let config = RagConfig {
401 provider: LlmProvider::OpenAI,
402 api_key: Some("test".to_string()),
403 model: "gpt-3.5-turbo".to_string(),
404 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
405 ..Default::default()
406 };
407
408 let generator = RagAiGenerator::new(config).unwrap();
409 let engine_clone = generator.engine.clone();
411 assert!(Arc::strong_count(&engine_clone) >= 2);
412 }
413
414 #[test]
417 fn test_ai_response_config_with_generator() {
418 let ai_config = AiResponseConfig {
420 temperature: 0.7,
421 max_tokens: 1024,
422 ..Default::default()
423 };
424
425 assert!((ai_config.temperature - 0.7).abs() < 0.001);
426 assert_eq!(ai_config.max_tokens, 1024);
427 }
428
429 #[test]
430 fn test_ai_response_config_low_temp() {
431 let ai_config = AiResponseConfig {
432 temperature: 0.0,
433 max_tokens: 256,
434 ..Default::default()
435 };
436
437 assert!((ai_config.temperature - 0.0).abs() < 0.001);
438 }
439
440 #[test]
441 fn test_ai_response_config_high_tokens() {
442 let ai_config = AiResponseConfig {
443 temperature: 0.5,
444 max_tokens: 8192,
445 ..Default::default()
446 };
447
448 assert_eq!(ai_config.max_tokens, 8192);
449 }
450
451 #[test]
454 fn test_generator_with_empty_model_name() {
455 let config = RagConfig {
456 provider: LlmProvider::Ollama,
457 api_key: None,
458 model: String::new(), api_endpoint: "http://localhost:11434/api/generate".to_string(),
460 ..Default::default()
461 };
462
463 let result = RagAiGenerator::new(config);
465 assert!(result.is_ok());
466 }
467
468 #[test]
469 fn test_generator_with_empty_endpoint() {
470 let config = RagConfig {
471 provider: LlmProvider::Ollama,
472 api_key: None,
473 model: "llama2".to_string(),
474 api_endpoint: String::new(), ..Default::default()
476 };
477
478 let result = RagAiGenerator::new(config);
480 assert!(result.is_ok());
481 }
482
483 #[test]
484 fn test_generator_with_no_api_key_openai() {
485 let config = RagConfig {
486 provider: LlmProvider::OpenAI,
487 api_key: None, model: "gpt-4".to_string(),
489 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
490 ..Default::default()
491 };
492
493 let result = RagAiGenerator::new(config);
495 assert!(result.is_ok());
496 }
497
498 #[tokio::test]
501 async fn test_multiple_generators_different_providers() {
502 let openai_config = RagConfig {
503 provider: LlmProvider::OpenAI,
504 api_key: Some("test-key".to_string()),
505 model: "gpt-4".to_string(),
506 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
507 ..Default::default()
508 };
509
510 let ollama_config = RagConfig {
511 provider: LlmProvider::Ollama,
512 api_key: None,
513 model: "llama2".to_string(),
514 api_endpoint: "http://localhost:11434/api/generate".to_string(),
515 ..Default::default()
516 };
517
518 let anthropic_config = RagConfig {
519 provider: LlmProvider::Anthropic,
520 api_key: Some("test-key".to_string()),
521 model: "claude-3-haiku-20240307".to_string(),
522 api_endpoint: "https://api.anthropic.com/v1/messages".to_string(),
523 ..Default::default()
524 };
525
526 assert!(RagAiGenerator::new(openai_config).is_ok());
528 assert!(RagAiGenerator::new(ollama_config).is_ok());
529 assert!(RagAiGenerator::new(anthropic_config).is_ok());
530 }
531
532 #[tokio::test]
533 async fn test_generator_engine_update() {
534 let config = RagConfig {
535 provider: LlmProvider::Ollama,
536 api_key: None,
537 model: "llama2".to_string(),
538 api_endpoint: "http://localhost:11434/api/generate".to_string(),
539 temperature: 0.7,
540 max_tokens: 1024,
541 ..Default::default()
542 };
543
544 let generator = RagAiGenerator::new(config).unwrap();
545
546 {
548 let engine = generator.engine.read().await;
549 let engine_config = engine.config();
550 assert!((engine_config.temperature - 0.7).abs() < 0.001);
551 assert_eq!(engine_config.max_tokens, 1024);
552 }
553
554 {
556 let mut engine = generator.engine.write().await;
557 let mut new_config = engine.config().clone();
558 new_config.temperature = 0.5;
559 new_config.max_tokens = 2048;
560 engine.update_config(new_config);
561 }
562
563 {
565 let engine = generator.engine.read().await;
566 let engine_config = engine.config();
567 assert!((engine_config.temperature - 0.5).abs() < 0.001);
568 assert_eq!(engine_config.max_tokens, 2048);
569 }
570 }
571}