Skip to main content

mockforge_http/
rag_ai_generator.rs

1//! RAG-based AI generator implementation
2//!
3//! This module provides an implementation of the AiGenerator trait
4//! using the RAG engine from mockforge-data.
5
6// Uses ai_response::AiResponseConfig which stays in core.
7#![allow(deprecated)]
8
9use async_trait::async_trait;
10use mockforge_core::{ai_response::AiResponseConfig, openapi::response::AiGenerator, Result};
11use mockforge_data::rag::{LlmProvider, RagConfig, RagEngine};
12use serde_json::Value;
13use std::sync::Arc;
14use tracing::{debug, warn};
15
16/// RAG-based AI generator that uses the mockforge-data RAG engine
17pub struct RagAiGenerator {
18    /// The RAG engine instance
19    engine: Arc<tokio::sync::RwLock<RagEngine>>,
20}
21
22impl RagAiGenerator {
23    /// Create a new RAG-based AI generator
24    ///
25    /// # Arguments
26    /// * `rag_config` - Configuration for the RAG engine (provider, model, API key, etc.)
27    ///
28    /// # Returns
29    /// A new RagAiGenerator instance
30    pub fn new(rag_config: RagConfig) -> Result<Self> {
31        debug!("Creating RAG AI generator with provider: {:?}", rag_config.provider);
32
33        // Create the RAG engine
34        let engine = RagEngine::new(rag_config);
35
36        Ok(Self {
37            engine: Arc::new(tokio::sync::RwLock::new(engine)),
38        })
39    }
40
41    /// Create a RAG AI generator from environment variables
42    ///
43    /// Reads configuration from:
44    /// - `MOCKFORGE_AI_PROVIDER`: LLM provider (openai, anthropic, ollama, etc.)
45    /// - `MOCKFORGE_AI_API_KEY`: API key for the LLM provider
46    /// - `MOCKFORGE_AI_MODEL`: Model name (e.g., gpt-4, claude-3-opus)
47    /// - `MOCKFORGE_AI_ENDPOINT`: API endpoint (optional, uses provider default)
48    /// - `MOCKFORGE_AI_TEMPERATURE`: Temperature for generation (optional, default: 0.7)
49    /// - `MOCKFORGE_AI_MAX_TOKENS`: Max tokens for generation (optional, default: 1024)
50    pub fn from_env() -> Result<Self> {
51        let provider =
52            std::env::var("MOCKFORGE_AI_PROVIDER").unwrap_or_else(|_| "openai".to_string());
53
54        let provider = match provider.to_lowercase().as_str() {
55            "openai" => LlmProvider::OpenAI,
56            "anthropic" => LlmProvider::Anthropic,
57            "ollama" => LlmProvider::Ollama,
58            "openai-compatible" => LlmProvider::OpenAICompatible,
59            _ => {
60                warn!("Unknown AI provider '{}', defaulting to OpenAI", provider);
61                LlmProvider::OpenAI
62            }
63        };
64
65        let api_key = std::env::var("MOCKFORGE_AI_API_KEY").ok();
66
67        let model = std::env::var("MOCKFORGE_AI_MODEL").unwrap_or_else(|_| match provider {
68            LlmProvider::OpenAI => "gpt-3.5-turbo".to_string(),
69            LlmProvider::Anthropic => "claude-3-haiku-20240307".to_string(),
70            LlmProvider::Ollama => "llama2".to_string(),
71            LlmProvider::OpenAICompatible => "gpt-3.5-turbo".to_string(),
72        });
73
74        let api_endpoint =
75            std::env::var("MOCKFORGE_AI_ENDPOINT").unwrap_or_else(|_| match provider {
76                LlmProvider::OpenAI => "https://api.openai.com/v1/chat/completions".to_string(),
77                LlmProvider::Anthropic => "https://api.anthropic.com/v1/messages".to_string(),
78                LlmProvider::Ollama => "http://localhost:11434/api/generate".to_string(),
79                LlmProvider::OpenAICompatible => {
80                    "http://localhost:8080/v1/chat/completions".to_string()
81                }
82            });
83
84        let temperature = std::env::var("MOCKFORGE_AI_TEMPERATURE")
85            .ok()
86            .and_then(|s| s.parse::<f64>().ok())
87            .unwrap_or(0.7);
88
89        let max_tokens = std::env::var("MOCKFORGE_AI_MAX_TOKENS")
90            .ok()
91            .and_then(|s| s.parse::<usize>().ok())
92            .unwrap_or(1024);
93
94        let config = RagConfig {
95            provider,
96            api_key,
97            model,
98            api_endpoint,
99            temperature,
100            max_tokens,
101            ..Default::default()
102        };
103
104        debug!("Creating RAG AI generator from environment variables");
105        Self::new(config)
106    }
107}
108
109#[async_trait]
110impl AiGenerator for RagAiGenerator {
111    async fn generate(&self, prompt: &str, config: &AiResponseConfig) -> Result<Value> {
112        debug!("Generating AI response with RAG engine");
113
114        // Lock the engine for generation
115        let mut engine = self.engine.write().await;
116
117        // Update engine config with request-specific settings if needed
118        let mut engine_config = engine.config().clone();
119        engine_config.temperature = config.temperature as f64;
120        engine_config.max_tokens = config.max_tokens;
121
122        // Temporarily update the engine config
123        engine.update_config(engine_config);
124
125        // Generate the response using the RAG engine
126        match engine.generate_text(prompt).await {
127            Ok(response_text) => {
128                debug!("RAG engine generated response ({} chars)", response_text.len());
129
130                // Try to parse the response as JSON
131                match serde_json::from_str::<Value>(&response_text) {
132                    Ok(json_value) => Ok(json_value),
133                    Err(_) => {
134                        // If not valid JSON, try to extract JSON from the response
135                        if let Some(start) = response_text.find('{') {
136                            if let Some(end) = response_text.rfind('}') {
137                                let json_str = &response_text[start..=end];
138                                match serde_json::from_str::<Value>(json_str) {
139                                    Ok(json_value) => Ok(json_value),
140                                    Err(_) => {
141                                        // If still not valid JSON, wrap in an object
142                                        Ok(serde_json::json!({
143                                            "response": response_text,
144                                            "note": "Response was not valid JSON, wrapped in object"
145                                        }))
146                                    }
147                                }
148                            } else {
149                                Ok(serde_json::json!({
150                                    "response": response_text,
151                                    "note": "Response was not valid JSON, wrapped in object"
152                                }))
153                            }
154                        } else {
155                            Ok(serde_json::json!({
156                                "response": response_text,
157                                "note": "Response was not valid JSON, wrapped in object"
158                            }))
159                        }
160                    }
161                }
162            }
163            Err(e) => {
164                warn!("RAG engine generation failed: {}", e);
165                Err(mockforge_core::Error::Config {
166                    message: format!("RAG engine generation failed: {}", e),
167                })
168            }
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    // ==================== RagAiGenerator Creation Tests ====================
178
179    #[test]
180    fn test_rag_generator_creation() {
181        let config = RagConfig {
182            provider: LlmProvider::Ollama,
183            api_key: None,
184            model: "llama2".to_string(),
185            api_endpoint: "http://localhost:11434/api/generate".to_string(),
186            ..Default::default()
187        };
188
189        let result = RagAiGenerator::new(config);
190        assert!(result.is_ok());
191    }
192
193    #[test]
194    fn test_rag_generator_creation_openai() {
195        let config = RagConfig {
196            provider: LlmProvider::OpenAI,
197            api_key: Some("test-api-key".to_string()),
198            model: "gpt-4".to_string(),
199            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
200            ..Default::default()
201        };
202
203        let result = RagAiGenerator::new(config);
204        assert!(result.is_ok());
205    }
206
207    #[test]
208    fn test_rag_generator_creation_anthropic() {
209        let config = RagConfig {
210            provider: LlmProvider::Anthropic,
211            api_key: Some("test-api-key".to_string()),
212            model: "claude-3-opus".to_string(),
213            api_endpoint: "https://api.anthropic.com/v1/messages".to_string(),
214            ..Default::default()
215        };
216
217        let result = RagAiGenerator::new(config);
218        assert!(result.is_ok());
219    }
220
221    #[test]
222    fn test_rag_generator_creation_openai_compatible() {
223        let config = RagConfig {
224            provider: LlmProvider::OpenAICompatible,
225            api_key: None,
226            model: "local-model".to_string(),
227            api_endpoint: "http://localhost:8080/v1/chat/completions".to_string(),
228            ..Default::default()
229        };
230
231        let result = RagAiGenerator::new(config);
232        assert!(result.is_ok());
233    }
234
235    #[test]
236    fn test_rag_generator_creation_with_custom_settings() {
237        let config = RagConfig {
238            provider: LlmProvider::Ollama,
239            api_key: None,
240            model: "codellama".to_string(),
241            api_endpoint: "http://localhost:11434/api/generate".to_string(),
242            temperature: 0.5,
243            max_tokens: 2048,
244            ..Default::default()
245        };
246
247        let result = RagAiGenerator::new(config);
248        assert!(result.is_ok());
249    }
250
251    #[test]
252    fn test_rag_generator_creation_with_low_temperature() {
253        let config = RagConfig {
254            provider: LlmProvider::OpenAI,
255            api_key: Some("test-key".to_string()),
256            model: "gpt-3.5-turbo".to_string(),
257            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
258            temperature: 0.0,
259            max_tokens: 512,
260            ..Default::default()
261        };
262
263        let result = RagAiGenerator::new(config);
264        assert!(result.is_ok());
265    }
266
267    #[test]
268    fn test_rag_generator_creation_with_high_temperature() {
269        let config = RagConfig {
270            provider: LlmProvider::OpenAI,
271            api_key: Some("test-key".to_string()),
272            model: "gpt-4".to_string(),
273            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
274            temperature: 1.0,
275            max_tokens: 4096,
276            ..Default::default()
277        };
278
279        let result = RagAiGenerator::new(config);
280        assert!(result.is_ok());
281    }
282
283    // ==================== RagConfig Tests ====================
284
285    #[test]
286    fn test_rag_config_default() {
287        let config = RagConfig::default();
288        // Default config should have reasonable defaults
289        assert!(config.temperature >= 0.0);
290        assert!(config.max_tokens > 0);
291    }
292
293    #[test]
294    fn test_rag_config_clone() {
295        let config = RagConfig {
296            provider: LlmProvider::Ollama,
297            api_key: Some("secret".to_string()),
298            model: "llama2".to_string(),
299            api_endpoint: "http://localhost:11434/api/generate".to_string(),
300            temperature: 0.7,
301            max_tokens: 1024,
302            ..Default::default()
303        };
304
305        let cloned = config.clone();
306        assert_eq!(cloned.model, config.model);
307        assert_eq!(cloned.api_key, config.api_key);
308    }
309
310    // ==================== LlmProvider Tests ====================
311
312    #[test]
313    fn test_llm_provider_openai() {
314        let provider = LlmProvider::OpenAI;
315        let config = RagConfig {
316            provider,
317            ..Default::default()
318        };
319        assert!(matches!(config.provider, LlmProvider::OpenAI));
320    }
321
322    #[test]
323    fn test_llm_provider_anthropic() {
324        let provider = LlmProvider::Anthropic;
325        let config = RagConfig {
326            provider,
327            ..Default::default()
328        };
329        assert!(matches!(config.provider, LlmProvider::Anthropic));
330    }
331
332    #[test]
333    fn test_llm_provider_ollama() {
334        let provider = LlmProvider::Ollama;
335        let config = RagConfig {
336            provider,
337            ..Default::default()
338        };
339        assert!(matches!(config.provider, LlmProvider::Ollama));
340    }
341
342    #[test]
343    fn test_llm_provider_openai_compatible() {
344        let provider = LlmProvider::OpenAICompatible;
345        let config = RagConfig {
346            provider,
347            ..Default::default()
348        };
349        assert!(matches!(config.provider, LlmProvider::OpenAICompatible));
350    }
351
352    // ==================== Generator Async Tests ====================
353
354    #[tokio::test]
355    async fn test_generate_fallback_to_json() {
356        // This test verifies that non-JSON responses are wrapped properly
357        // In a real scenario, this would require mocking the RAG engine
358
359        let config = RagConfig {
360            provider: LlmProvider::Ollama,
361            api_key: None,
362            model: "test-model".to_string(),
363            api_endpoint: "http://localhost:11434/api/generate".to_string(),
364            ..Default::default()
365        };
366
367        // We can't easily test the actual generation without a real LLM,
368        // but we can verify the generator was created successfully
369        let generator = RagAiGenerator::new(config);
370        assert!(generator.is_ok());
371    }
372
373    #[tokio::test]
374    async fn test_generator_engine_access() {
375        let config = RagConfig {
376            provider: LlmProvider::Ollama,
377            api_key: None,
378            model: "llama2".to_string(),
379            api_endpoint: "http://localhost:11434/api/generate".to_string(),
380            temperature: 0.8,
381            max_tokens: 512,
382            ..Default::default()
383        };
384
385        let generator = RagAiGenerator::new(config).unwrap();
386        // The engine is wrapped in Arc<RwLock>, verify we can access it
387        let engine = generator.engine.read().await;
388        let engine_config = engine.config();
389        assert_eq!(engine_config.model, "llama2");
390    }
391
392    #[tokio::test]
393    async fn test_generator_can_be_cloned_via_arc() {
394        let config = RagConfig {
395            provider: LlmProvider::OpenAI,
396            api_key: Some("test".to_string()),
397            model: "gpt-3.5-turbo".to_string(),
398            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
399            ..Default::default()
400        };
401
402        let generator = RagAiGenerator::new(config).unwrap();
403        // Engine is Arc-wrapped, so cloning should work
404        let engine_clone = generator.engine.clone();
405        assert!(Arc::strong_count(&engine_clone) >= 2);
406    }
407
408    // ==================== AiResponseConfig Tests ====================
409
410    #[test]
411    fn test_ai_response_config_with_generator() {
412        // Test that we can create AiResponseConfig compatible with the generator
413        let ai_config = AiResponseConfig {
414            temperature: 0.7,
415            max_tokens: 1024,
416            ..Default::default()
417        };
418
419        assert!((ai_config.temperature - 0.7).abs() < 0.001);
420        assert_eq!(ai_config.max_tokens, 1024);
421    }
422
423    #[test]
424    fn test_ai_response_config_low_temp() {
425        let ai_config = AiResponseConfig {
426            temperature: 0.0,
427            max_tokens: 256,
428            ..Default::default()
429        };
430
431        assert!((ai_config.temperature - 0.0).abs() < 0.001);
432    }
433
434    #[test]
435    fn test_ai_response_config_high_tokens() {
436        let ai_config = AiResponseConfig {
437            temperature: 0.5,
438            max_tokens: 8192,
439            ..Default::default()
440        };
441
442        assert_eq!(ai_config.max_tokens, 8192);
443    }
444
445    // ==================== Edge Cases ====================
446
447    #[test]
448    fn test_generator_with_empty_model_name() {
449        let config = RagConfig {
450            provider: LlmProvider::Ollama,
451            api_key: None,
452            model: String::new(), // Empty model name
453            api_endpoint: "http://localhost:11434/api/generate".to_string(),
454            ..Default::default()
455        };
456
457        // Should still create successfully (validation happens later)
458        let result = RagAiGenerator::new(config);
459        assert!(result.is_ok());
460    }
461
462    #[test]
463    fn test_generator_with_empty_endpoint() {
464        let config = RagConfig {
465            provider: LlmProvider::Ollama,
466            api_key: None,
467            model: "llama2".to_string(),
468            api_endpoint: String::new(), // Empty endpoint
469            ..Default::default()
470        };
471
472        // Should still create successfully (validation happens at request time)
473        let result = RagAiGenerator::new(config);
474        assert!(result.is_ok());
475    }
476
477    #[test]
478    fn test_generator_with_no_api_key_openai() {
479        let config = RagConfig {
480            provider: LlmProvider::OpenAI,
481            api_key: None, // No API key (will fail at request time)
482            model: "gpt-4".to_string(),
483            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
484            ..Default::default()
485        };
486
487        // Creation should succeed, failure happens when making requests
488        let result = RagAiGenerator::new(config);
489        assert!(result.is_ok());
490    }
491
492    // ==================== Integration-style Tests ====================
493
494    #[tokio::test]
495    async fn test_multiple_generators_different_providers() {
496        let openai_config = RagConfig {
497            provider: LlmProvider::OpenAI,
498            api_key: Some("test-key".to_string()),
499            model: "gpt-4".to_string(),
500            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
501            ..Default::default()
502        };
503
504        let ollama_config = RagConfig {
505            provider: LlmProvider::Ollama,
506            api_key: None,
507            model: "llama2".to_string(),
508            api_endpoint: "http://localhost:11434/api/generate".to_string(),
509            ..Default::default()
510        };
511
512        let anthropic_config = RagConfig {
513            provider: LlmProvider::Anthropic,
514            api_key: Some("test-key".to_string()),
515            model: "claude-3-haiku-20240307".to_string(),
516            api_endpoint: "https://api.anthropic.com/v1/messages".to_string(),
517            ..Default::default()
518        };
519
520        // All three should create successfully
521        assert!(RagAiGenerator::new(openai_config).is_ok());
522        assert!(RagAiGenerator::new(ollama_config).is_ok());
523        assert!(RagAiGenerator::new(anthropic_config).is_ok());
524    }
525
526    #[tokio::test]
527    async fn test_generator_engine_update() {
528        let config = RagConfig {
529            provider: LlmProvider::Ollama,
530            api_key: None,
531            model: "llama2".to_string(),
532            api_endpoint: "http://localhost:11434/api/generate".to_string(),
533            temperature: 0.7,
534            max_tokens: 1024,
535            ..Default::default()
536        };
537
538        let generator = RagAiGenerator::new(config).unwrap();
539
540        // Test that we can read and the engine has correct config
541        {
542            let engine = generator.engine.read().await;
543            let engine_config = engine.config();
544            assert!((engine_config.temperature - 0.7).abs() < 0.001);
545            assert_eq!(engine_config.max_tokens, 1024);
546        }
547
548        // Test that we can write to update the config
549        {
550            let mut engine = generator.engine.write().await;
551            let mut new_config = engine.config().clone();
552            new_config.temperature = 0.5;
553            new_config.max_tokens = 2048;
554            engine.update_config(new_config);
555        }
556
557        // Verify the update took effect
558        {
559            let engine = generator.engine.read().await;
560            let engine_config = engine.config();
561            assert!((engine_config.temperature - 0.5).abs() < 0.001);
562            assert_eq!(engine_config.max_tokens, 2048);
563        }
564    }
565}