1#![allow(deprecated)]
8
9use async_trait::async_trait;
10use mockforge_core::{ai_response::AiResponseConfig, openapi::response::AiGenerator, Result};
11use mockforge_data::rag::{LlmProvider, RagConfig, RagEngine};
12use serde_json::Value;
13use std::sync::Arc;
14use tracing::{debug, warn};
15
16pub struct RagAiGenerator {
18 engine: Arc<tokio::sync::RwLock<RagEngine>>,
20}
21
22impl RagAiGenerator {
23 pub fn new(rag_config: RagConfig) -> Result<Self> {
31 debug!("Creating RAG AI generator with provider: {:?}", rag_config.provider);
32
33 let engine = RagEngine::new(rag_config);
35
36 Ok(Self {
37 engine: Arc::new(tokio::sync::RwLock::new(engine)),
38 })
39 }
40
41 pub fn from_env() -> Result<Self> {
51 let provider =
52 std::env::var("MOCKFORGE_AI_PROVIDER").unwrap_or_else(|_| "openai".to_string());
53
54 let provider = match provider.to_lowercase().as_str() {
55 "openai" => LlmProvider::OpenAI,
56 "anthropic" => LlmProvider::Anthropic,
57 "ollama" => LlmProvider::Ollama,
58 "openai-compatible" => LlmProvider::OpenAICompatible,
59 _ => {
60 warn!("Unknown AI provider '{}', defaulting to OpenAI", provider);
61 LlmProvider::OpenAI
62 }
63 };
64
65 let api_key = std::env::var("MOCKFORGE_AI_API_KEY").ok();
66
67 let model = std::env::var("MOCKFORGE_AI_MODEL").unwrap_or_else(|_| match provider {
68 LlmProvider::OpenAI => "gpt-3.5-turbo".to_string(),
69 LlmProvider::Anthropic => "claude-3-haiku-20240307".to_string(),
70 LlmProvider::Ollama => "llama2".to_string(),
71 LlmProvider::OpenAICompatible => "gpt-3.5-turbo".to_string(),
72 });
73
74 let api_endpoint =
75 std::env::var("MOCKFORGE_AI_ENDPOINT").unwrap_or_else(|_| match provider {
76 LlmProvider::OpenAI => "https://api.openai.com/v1/chat/completions".to_string(),
77 LlmProvider::Anthropic => "https://api.anthropic.com/v1/messages".to_string(),
78 LlmProvider::Ollama => "http://localhost:11434/api/generate".to_string(),
79 LlmProvider::OpenAICompatible => {
80 "http://localhost:8080/v1/chat/completions".to_string()
81 }
82 });
83
84 let temperature = std::env::var("MOCKFORGE_AI_TEMPERATURE")
85 .ok()
86 .and_then(|s| s.parse::<f64>().ok())
87 .unwrap_or(0.7);
88
89 let max_tokens = std::env::var("MOCKFORGE_AI_MAX_TOKENS")
90 .ok()
91 .and_then(|s| s.parse::<usize>().ok())
92 .unwrap_or(1024);
93
94 let config = RagConfig {
95 provider,
96 api_key,
97 model,
98 api_endpoint,
99 temperature,
100 max_tokens,
101 ..Default::default()
102 };
103
104 debug!("Creating RAG AI generator from environment variables");
105 Self::new(config)
106 }
107}
108
109#[async_trait]
110impl AiGenerator for RagAiGenerator {
111 async fn generate(&self, prompt: &str, config: &AiResponseConfig) -> Result<Value> {
112 debug!("Generating AI response with RAG engine");
113
114 let mut engine = self.engine.write().await;
116
117 let mut engine_config = engine.config().clone();
119 engine_config.temperature = config.temperature as f64;
120 engine_config.max_tokens = config.max_tokens;
121
122 engine.update_config(engine_config);
124
125 match engine.generate_text(prompt).await {
127 Ok(response_text) => {
128 debug!("RAG engine generated response ({} chars)", response_text.len());
129
130 match serde_json::from_str::<Value>(&response_text) {
132 Ok(json_value) => Ok(json_value),
133 Err(_) => {
134 if let Some(start) = response_text.find('{') {
136 if let Some(end) = response_text.rfind('}') {
137 let json_str = &response_text[start..=end];
138 match serde_json::from_str::<Value>(json_str) {
139 Ok(json_value) => Ok(json_value),
140 Err(_) => {
141 Ok(serde_json::json!({
143 "response": response_text,
144 "note": "Response was not valid JSON, wrapped in object"
145 }))
146 }
147 }
148 } else {
149 Ok(serde_json::json!({
150 "response": response_text,
151 "note": "Response was not valid JSON, wrapped in object"
152 }))
153 }
154 } else {
155 Ok(serde_json::json!({
156 "response": response_text,
157 "note": "Response was not valid JSON, wrapped in object"
158 }))
159 }
160 }
161 }
162 }
163 Err(e) => {
164 warn!("RAG engine generation failed: {}", e);
165 Err(mockforge_core::Error::Config {
166 message: format!("RAG engine generation failed: {}", e),
167 })
168 }
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
180 fn test_rag_generator_creation() {
181 let config = RagConfig {
182 provider: LlmProvider::Ollama,
183 api_key: None,
184 model: "llama2".to_string(),
185 api_endpoint: "http://localhost:11434/api/generate".to_string(),
186 ..Default::default()
187 };
188
189 let result = RagAiGenerator::new(config);
190 assert!(result.is_ok());
191 }
192
193 #[test]
194 fn test_rag_generator_creation_openai() {
195 let config = RagConfig {
196 provider: LlmProvider::OpenAI,
197 api_key: Some("test-api-key".to_string()),
198 model: "gpt-4".to_string(),
199 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
200 ..Default::default()
201 };
202
203 let result = RagAiGenerator::new(config);
204 assert!(result.is_ok());
205 }
206
207 #[test]
208 fn test_rag_generator_creation_anthropic() {
209 let config = RagConfig {
210 provider: LlmProvider::Anthropic,
211 api_key: Some("test-api-key".to_string()),
212 model: "claude-3-opus".to_string(),
213 api_endpoint: "https://api.anthropic.com/v1/messages".to_string(),
214 ..Default::default()
215 };
216
217 let result = RagAiGenerator::new(config);
218 assert!(result.is_ok());
219 }
220
221 #[test]
222 fn test_rag_generator_creation_openai_compatible() {
223 let config = RagConfig {
224 provider: LlmProvider::OpenAICompatible,
225 api_key: None,
226 model: "local-model".to_string(),
227 api_endpoint: "http://localhost:8080/v1/chat/completions".to_string(),
228 ..Default::default()
229 };
230
231 let result = RagAiGenerator::new(config);
232 assert!(result.is_ok());
233 }
234
235 #[test]
236 fn test_rag_generator_creation_with_custom_settings() {
237 let config = RagConfig {
238 provider: LlmProvider::Ollama,
239 api_key: None,
240 model: "codellama".to_string(),
241 api_endpoint: "http://localhost:11434/api/generate".to_string(),
242 temperature: 0.5,
243 max_tokens: 2048,
244 ..Default::default()
245 };
246
247 let result = RagAiGenerator::new(config);
248 assert!(result.is_ok());
249 }
250
251 #[test]
252 fn test_rag_generator_creation_with_low_temperature() {
253 let config = RagConfig {
254 provider: LlmProvider::OpenAI,
255 api_key: Some("test-key".to_string()),
256 model: "gpt-3.5-turbo".to_string(),
257 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
258 temperature: 0.0,
259 max_tokens: 512,
260 ..Default::default()
261 };
262
263 let result = RagAiGenerator::new(config);
264 assert!(result.is_ok());
265 }
266
267 #[test]
268 fn test_rag_generator_creation_with_high_temperature() {
269 let config = RagConfig {
270 provider: LlmProvider::OpenAI,
271 api_key: Some("test-key".to_string()),
272 model: "gpt-4".to_string(),
273 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
274 temperature: 1.0,
275 max_tokens: 4096,
276 ..Default::default()
277 };
278
279 let result = RagAiGenerator::new(config);
280 assert!(result.is_ok());
281 }
282
283 #[test]
286 fn test_rag_config_default() {
287 let config = RagConfig::default();
288 assert!(config.temperature >= 0.0);
290 assert!(config.max_tokens > 0);
291 }
292
293 #[test]
294 fn test_rag_config_clone() {
295 let config = RagConfig {
296 provider: LlmProvider::Ollama,
297 api_key: Some("secret".to_string()),
298 model: "llama2".to_string(),
299 api_endpoint: "http://localhost:11434/api/generate".to_string(),
300 temperature: 0.7,
301 max_tokens: 1024,
302 ..Default::default()
303 };
304
305 let cloned = config.clone();
306 assert_eq!(cloned.model, config.model);
307 assert_eq!(cloned.api_key, config.api_key);
308 }
309
310 #[test]
313 fn test_llm_provider_openai() {
314 let provider = LlmProvider::OpenAI;
315 let config = RagConfig {
316 provider,
317 ..Default::default()
318 };
319 assert!(matches!(config.provider, LlmProvider::OpenAI));
320 }
321
322 #[test]
323 fn test_llm_provider_anthropic() {
324 let provider = LlmProvider::Anthropic;
325 let config = RagConfig {
326 provider,
327 ..Default::default()
328 };
329 assert!(matches!(config.provider, LlmProvider::Anthropic));
330 }
331
332 #[test]
333 fn test_llm_provider_ollama() {
334 let provider = LlmProvider::Ollama;
335 let config = RagConfig {
336 provider,
337 ..Default::default()
338 };
339 assert!(matches!(config.provider, LlmProvider::Ollama));
340 }
341
342 #[test]
343 fn test_llm_provider_openai_compatible() {
344 let provider = LlmProvider::OpenAICompatible;
345 let config = RagConfig {
346 provider,
347 ..Default::default()
348 };
349 assert!(matches!(config.provider, LlmProvider::OpenAICompatible));
350 }
351
352 #[tokio::test]
355 async fn test_generate_fallback_to_json() {
356 let config = RagConfig {
360 provider: LlmProvider::Ollama,
361 api_key: None,
362 model: "test-model".to_string(),
363 api_endpoint: "http://localhost:11434/api/generate".to_string(),
364 ..Default::default()
365 };
366
367 let generator = RagAiGenerator::new(config);
370 assert!(generator.is_ok());
371 }
372
373 #[tokio::test]
374 async fn test_generator_engine_access() {
375 let config = RagConfig {
376 provider: LlmProvider::Ollama,
377 api_key: None,
378 model: "llama2".to_string(),
379 api_endpoint: "http://localhost:11434/api/generate".to_string(),
380 temperature: 0.8,
381 max_tokens: 512,
382 ..Default::default()
383 };
384
385 let generator = RagAiGenerator::new(config).unwrap();
386 let engine = generator.engine.read().await;
388 let engine_config = engine.config();
389 assert_eq!(engine_config.model, "llama2");
390 }
391
392 #[tokio::test]
393 async fn test_generator_can_be_cloned_via_arc() {
394 let config = RagConfig {
395 provider: LlmProvider::OpenAI,
396 api_key: Some("test".to_string()),
397 model: "gpt-3.5-turbo".to_string(),
398 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
399 ..Default::default()
400 };
401
402 let generator = RagAiGenerator::new(config).unwrap();
403 let engine_clone = generator.engine.clone();
405 assert!(Arc::strong_count(&engine_clone) >= 2);
406 }
407
408 #[test]
411 fn test_ai_response_config_with_generator() {
412 let ai_config = AiResponseConfig {
414 temperature: 0.7,
415 max_tokens: 1024,
416 ..Default::default()
417 };
418
419 assert!((ai_config.temperature - 0.7).abs() < 0.001);
420 assert_eq!(ai_config.max_tokens, 1024);
421 }
422
423 #[test]
424 fn test_ai_response_config_low_temp() {
425 let ai_config = AiResponseConfig {
426 temperature: 0.0,
427 max_tokens: 256,
428 ..Default::default()
429 };
430
431 assert!((ai_config.temperature - 0.0).abs() < 0.001);
432 }
433
434 #[test]
435 fn test_ai_response_config_high_tokens() {
436 let ai_config = AiResponseConfig {
437 temperature: 0.5,
438 max_tokens: 8192,
439 ..Default::default()
440 };
441
442 assert_eq!(ai_config.max_tokens, 8192);
443 }
444
445 #[test]
448 fn test_generator_with_empty_model_name() {
449 let config = RagConfig {
450 provider: LlmProvider::Ollama,
451 api_key: None,
452 model: String::new(), api_endpoint: "http://localhost:11434/api/generate".to_string(),
454 ..Default::default()
455 };
456
457 let result = RagAiGenerator::new(config);
459 assert!(result.is_ok());
460 }
461
462 #[test]
463 fn test_generator_with_empty_endpoint() {
464 let config = RagConfig {
465 provider: LlmProvider::Ollama,
466 api_key: None,
467 model: "llama2".to_string(),
468 api_endpoint: String::new(), ..Default::default()
470 };
471
472 let result = RagAiGenerator::new(config);
474 assert!(result.is_ok());
475 }
476
477 #[test]
478 fn test_generator_with_no_api_key_openai() {
479 let config = RagConfig {
480 provider: LlmProvider::OpenAI,
481 api_key: None, model: "gpt-4".to_string(),
483 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
484 ..Default::default()
485 };
486
487 let result = RagAiGenerator::new(config);
489 assert!(result.is_ok());
490 }
491
492 #[tokio::test]
495 async fn test_multiple_generators_different_providers() {
496 let openai_config = RagConfig {
497 provider: LlmProvider::OpenAI,
498 api_key: Some("test-key".to_string()),
499 model: "gpt-4".to_string(),
500 api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
501 ..Default::default()
502 };
503
504 let ollama_config = RagConfig {
505 provider: LlmProvider::Ollama,
506 api_key: None,
507 model: "llama2".to_string(),
508 api_endpoint: "http://localhost:11434/api/generate".to_string(),
509 ..Default::default()
510 };
511
512 let anthropic_config = RagConfig {
513 provider: LlmProvider::Anthropic,
514 api_key: Some("test-key".to_string()),
515 model: "claude-3-haiku-20240307".to_string(),
516 api_endpoint: "https://api.anthropic.com/v1/messages".to_string(),
517 ..Default::default()
518 };
519
520 assert!(RagAiGenerator::new(openai_config).is_ok());
522 assert!(RagAiGenerator::new(ollama_config).is_ok());
523 assert!(RagAiGenerator::new(anthropic_config).is_ok());
524 }
525
526 #[tokio::test]
527 async fn test_generator_engine_update() {
528 let config = RagConfig {
529 provider: LlmProvider::Ollama,
530 api_key: None,
531 model: "llama2".to_string(),
532 api_endpoint: "http://localhost:11434/api/generate".to_string(),
533 temperature: 0.7,
534 max_tokens: 1024,
535 ..Default::default()
536 };
537
538 let generator = RagAiGenerator::new(config).unwrap();
539
540 {
542 let engine = generator.engine.read().await;
543 let engine_config = engine.config();
544 assert!((engine_config.temperature - 0.7).abs() < 0.001);
545 assert_eq!(engine_config.max_tokens, 1024);
546 }
547
548 {
550 let mut engine = generator.engine.write().await;
551 let mut new_config = engine.config().clone();
552 new_config.temperature = 0.5;
553 new_config.max_tokens = 2048;
554 engine.update_config(new_config);
555 }
556
557 {
559 let engine = generator.engine.read().await;
560 let engine_config = engine.config();
561 assert!((engine_config.temperature - 0.5).abs() < 0.001);
562 assert_eq!(engine_config.max_tokens, 2048);
563 }
564 }
565}