mockforge_http/
rag_ai_generator.rs

1//! RAG-based AI generator implementation
2//!
3//! This module provides an implementation of the AiGenerator trait
4//! using the RAG engine from mockforge-data.
5
6use async_trait::async_trait;
7use mockforge_core::{ai_response::AiResponseConfig, openapi::response::AiGenerator, Result};
8use mockforge_data::rag::{LlmProvider, RagConfig, RagEngine};
9use serde_json::Value;
10use std::sync::Arc;
11use tracing::{debug, warn};
12
13/// RAG-based AI generator that uses the mockforge-data RAG engine
14pub struct RagAiGenerator {
15    /// The RAG engine instance
16    engine: Arc<tokio::sync::RwLock<RagEngine>>,
17}
18
19impl RagAiGenerator {
20    /// Create a new RAG-based AI generator
21    ///
22    /// # Arguments
23    /// * `rag_config` - Configuration for the RAG engine (provider, model, API key, etc.)
24    ///
25    /// # Returns
26    /// A new RagAiGenerator instance
27    pub fn new(rag_config: RagConfig) -> Result<Self> {
28        debug!("Creating RAG AI generator with provider: {:?}", rag_config.provider);
29
30        // Create the RAG engine
31        let engine = RagEngine::new(rag_config);
32
33        Ok(Self {
34            engine: Arc::new(tokio::sync::RwLock::new(engine)),
35        })
36    }
37
38    /// Create a RAG AI generator from environment variables
39    ///
40    /// Reads configuration from:
41    /// - `MOCKFORGE_AI_PROVIDER`: LLM provider (openai, anthropic, ollama, etc.)
42    /// - `MOCKFORGE_AI_API_KEY`: API key for the LLM provider
43    /// - `MOCKFORGE_AI_MODEL`: Model name (e.g., gpt-4, claude-3-opus)
44    /// - `MOCKFORGE_AI_ENDPOINT`: API endpoint (optional, uses provider default)
45    /// - `MOCKFORGE_AI_TEMPERATURE`: Temperature for generation (optional, default: 0.7)
46    /// - `MOCKFORGE_AI_MAX_TOKENS`: Max tokens for generation (optional, default: 1024)
47    pub fn from_env() -> Result<Self> {
48        let provider =
49            std::env::var("MOCKFORGE_AI_PROVIDER").unwrap_or_else(|_| "openai".to_string());
50
51        let provider = match provider.to_lowercase().as_str() {
52            "openai" => LlmProvider::OpenAI,
53            "anthropic" => LlmProvider::Anthropic,
54            "ollama" => LlmProvider::Ollama,
55            "openai-compatible" => LlmProvider::OpenAICompatible,
56            _ => {
57                warn!("Unknown AI provider '{}', defaulting to OpenAI", provider);
58                LlmProvider::OpenAI
59            }
60        };
61
62        let api_key = std::env::var("MOCKFORGE_AI_API_KEY").ok();
63
64        let model = std::env::var("MOCKFORGE_AI_MODEL").unwrap_or_else(|_| match provider {
65            LlmProvider::OpenAI => "gpt-3.5-turbo".to_string(),
66            LlmProvider::Anthropic => "claude-3-haiku-20240307".to_string(),
67            LlmProvider::Ollama => "llama2".to_string(),
68            LlmProvider::OpenAICompatible => "gpt-3.5-turbo".to_string(),
69        });
70
71        let api_endpoint =
72            std::env::var("MOCKFORGE_AI_ENDPOINT").unwrap_or_else(|_| match provider {
73                LlmProvider::OpenAI => "https://api.openai.com/v1/chat/completions".to_string(),
74                LlmProvider::Anthropic => "https://api.anthropic.com/v1/messages".to_string(),
75                LlmProvider::Ollama => "http://localhost:11434/api/generate".to_string(),
76                LlmProvider::OpenAICompatible => {
77                    "http://localhost:8080/v1/chat/completions".to_string()
78                }
79            });
80
81        let temperature = std::env::var("MOCKFORGE_AI_TEMPERATURE")
82            .ok()
83            .and_then(|s| s.parse::<f64>().ok())
84            .unwrap_or(0.7);
85
86        let max_tokens = std::env::var("MOCKFORGE_AI_MAX_TOKENS")
87            .ok()
88            .and_then(|s| s.parse::<usize>().ok())
89            .unwrap_or(1024);
90
91        let config = RagConfig {
92            provider,
93            api_key,
94            model,
95            api_endpoint,
96            temperature,
97            max_tokens,
98            ..Default::default()
99        };
100
101        debug!("Creating RAG AI generator from environment variables");
102        Self::new(config)
103    }
104}
105
106#[async_trait]
107impl AiGenerator for RagAiGenerator {
108    async fn generate(&self, prompt: &str, config: &AiResponseConfig) -> Result<Value> {
109        debug!("Generating AI response with RAG engine");
110
111        // Lock the engine for generation
112        let mut engine = self.engine.write().await;
113
114        // Update engine config with request-specific settings if needed
115        let mut engine_config = engine.config().clone();
116        engine_config.temperature = config.temperature as f64;
117        engine_config.max_tokens = config.max_tokens;
118
119        // Temporarily update the engine config
120        engine.update_config(engine_config);
121
122        // Generate the response using the RAG engine
123        match engine.generate_text(prompt).await {
124            Ok(response_text) => {
125                debug!("RAG engine generated response ({} chars)", response_text.len());
126
127                // Try to parse the response as JSON
128                match serde_json::from_str::<Value>(&response_text) {
129                    Ok(json_value) => Ok(json_value),
130                    Err(_) => {
131                        // If not valid JSON, try to extract JSON from the response
132                        if let Some(start) = response_text.find('{') {
133                            if let Some(end) = response_text.rfind('}') {
134                                let json_str = &response_text[start..=end];
135                                match serde_json::from_str::<Value>(json_str) {
136                                    Ok(json_value) => Ok(json_value),
137                                    Err(_) => {
138                                        // If still not valid JSON, wrap in an object
139                                        Ok(serde_json::json!({
140                                            "response": response_text,
141                                            "note": "Response was not valid JSON, wrapped in object"
142                                        }))
143                                    }
144                                }
145                            } else {
146                                Ok(serde_json::json!({
147                                    "response": response_text,
148                                    "note": "Response was not valid JSON, wrapped in object"
149                                }))
150                            }
151                        } else {
152                            Ok(serde_json::json!({
153                                "response": response_text,
154                                "note": "Response was not valid JSON, wrapped in object"
155                            }))
156                        }
157                    }
158                }
159            }
160            Err(e) => {
161                warn!("RAG engine generation failed: {}", e);
162                Err(mockforge_core::Error::Config {
163                    message: format!("RAG engine generation failed: {}", e),
164                })
165            }
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    // ==================== RagAiGenerator Creation Tests ====================
175
176    #[test]
177    fn test_rag_generator_creation() {
178        let config = RagConfig {
179            provider: LlmProvider::Ollama,
180            api_key: None,
181            model: "llama2".to_string(),
182            api_endpoint: "http://localhost:11434/api/generate".to_string(),
183            ..Default::default()
184        };
185
186        let result = RagAiGenerator::new(config);
187        assert!(result.is_ok());
188    }
189
190    #[test]
191    fn test_rag_generator_creation_openai() {
192        let config = RagConfig {
193            provider: LlmProvider::OpenAI,
194            api_key: Some("test-api-key".to_string()),
195            model: "gpt-4".to_string(),
196            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
197            ..Default::default()
198        };
199
200        let result = RagAiGenerator::new(config);
201        assert!(result.is_ok());
202    }
203
204    #[test]
205    fn test_rag_generator_creation_anthropic() {
206        let config = RagConfig {
207            provider: LlmProvider::Anthropic,
208            api_key: Some("test-api-key".to_string()),
209            model: "claude-3-opus".to_string(),
210            api_endpoint: "https://api.anthropic.com/v1/messages".to_string(),
211            ..Default::default()
212        };
213
214        let result = RagAiGenerator::new(config);
215        assert!(result.is_ok());
216    }
217
218    #[test]
219    fn test_rag_generator_creation_openai_compatible() {
220        let config = RagConfig {
221            provider: LlmProvider::OpenAICompatible,
222            api_key: None,
223            model: "local-model".to_string(),
224            api_endpoint: "http://localhost:8080/v1/chat/completions".to_string(),
225            ..Default::default()
226        };
227
228        let result = RagAiGenerator::new(config);
229        assert!(result.is_ok());
230    }
231
232    #[test]
233    fn test_rag_generator_creation_with_custom_settings() {
234        let config = RagConfig {
235            provider: LlmProvider::Ollama,
236            api_key: None,
237            model: "codellama".to_string(),
238            api_endpoint: "http://localhost:11434/api/generate".to_string(),
239            temperature: 0.5,
240            max_tokens: 2048,
241            ..Default::default()
242        };
243
244        let result = RagAiGenerator::new(config);
245        assert!(result.is_ok());
246    }
247
248    #[test]
249    fn test_rag_generator_creation_with_low_temperature() {
250        let config = RagConfig {
251            provider: LlmProvider::OpenAI,
252            api_key: Some("test-key".to_string()),
253            model: "gpt-3.5-turbo".to_string(),
254            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
255            temperature: 0.0,
256            max_tokens: 512,
257            ..Default::default()
258        };
259
260        let result = RagAiGenerator::new(config);
261        assert!(result.is_ok());
262    }
263
264    #[test]
265    fn test_rag_generator_creation_with_high_temperature() {
266        let config = RagConfig {
267            provider: LlmProvider::OpenAI,
268            api_key: Some("test-key".to_string()),
269            model: "gpt-4".to_string(),
270            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
271            temperature: 1.0,
272            max_tokens: 4096,
273            ..Default::default()
274        };
275
276        let result = RagAiGenerator::new(config);
277        assert!(result.is_ok());
278    }
279
280    // ==================== RagConfig Tests ====================
281
282    #[test]
283    fn test_rag_config_default() {
284        let config = RagConfig::default();
285        // Default config should have reasonable defaults
286        assert!(config.temperature >= 0.0);
287        assert!(config.max_tokens > 0);
288    }
289
290    #[test]
291    fn test_rag_config_clone() {
292        let config = RagConfig {
293            provider: LlmProvider::Ollama,
294            api_key: Some("secret".to_string()),
295            model: "llama2".to_string(),
296            api_endpoint: "http://localhost:11434/api/generate".to_string(),
297            temperature: 0.7,
298            max_tokens: 1024,
299            ..Default::default()
300        };
301
302        let cloned = config.clone();
303        assert_eq!(cloned.model, config.model);
304        assert_eq!(cloned.api_key, config.api_key);
305    }
306
307    // ==================== LlmProvider Tests ====================
308
309    #[test]
310    fn test_llm_provider_openai() {
311        let provider = LlmProvider::OpenAI;
312        let config = RagConfig {
313            provider,
314            ..Default::default()
315        };
316        assert!(matches!(config.provider, LlmProvider::OpenAI));
317    }
318
319    #[test]
320    fn test_llm_provider_anthropic() {
321        let provider = LlmProvider::Anthropic;
322        let config = RagConfig {
323            provider,
324            ..Default::default()
325        };
326        assert!(matches!(config.provider, LlmProvider::Anthropic));
327    }
328
329    #[test]
330    fn test_llm_provider_ollama() {
331        let provider = LlmProvider::Ollama;
332        let config = RagConfig {
333            provider,
334            ..Default::default()
335        };
336        assert!(matches!(config.provider, LlmProvider::Ollama));
337    }
338
339    #[test]
340    fn test_llm_provider_openai_compatible() {
341        let provider = LlmProvider::OpenAICompatible;
342        let config = RagConfig {
343            provider,
344            ..Default::default()
345        };
346        assert!(matches!(config.provider, LlmProvider::OpenAICompatible));
347    }
348
349    // ==================== Generator Async Tests ====================
350
351    #[tokio::test]
352    async fn test_generate_fallback_to_json() {
353        // This test verifies that non-JSON responses are wrapped properly
354        // In a real scenario, this would require mocking the RAG engine
355
356        let config = RagConfig {
357            provider: LlmProvider::Ollama,
358            api_key: None,
359            model: "test-model".to_string(),
360            api_endpoint: "http://localhost:11434/api/generate".to_string(),
361            ..Default::default()
362        };
363
364        // We can't easily test the actual generation without a real LLM,
365        // but we can verify the generator was created successfully
366        let generator = RagAiGenerator::new(config);
367        assert!(generator.is_ok());
368    }
369
370    #[tokio::test]
371    async fn test_generator_engine_access() {
372        let config = RagConfig {
373            provider: LlmProvider::Ollama,
374            api_key: None,
375            model: "llama2".to_string(),
376            api_endpoint: "http://localhost:11434/api/generate".to_string(),
377            temperature: 0.8,
378            max_tokens: 512,
379            ..Default::default()
380        };
381
382        let generator = RagAiGenerator::new(config).unwrap();
383        // The engine is wrapped in Arc<RwLock>, verify we can access it
384        let engine = generator.engine.read().await;
385        let engine_config = engine.config();
386        assert_eq!(engine_config.model, "llama2");
387    }
388
389    #[tokio::test]
390    async fn test_generator_can_be_cloned_via_arc() {
391        let config = RagConfig {
392            provider: LlmProvider::OpenAI,
393            api_key: Some("test".to_string()),
394            model: "gpt-3.5-turbo".to_string(),
395            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
396            ..Default::default()
397        };
398
399        let generator = RagAiGenerator::new(config).unwrap();
400        // Engine is Arc-wrapped, so cloning should work
401        let engine_clone = generator.engine.clone();
402        assert!(Arc::strong_count(&engine_clone) >= 2);
403    }
404
405    // ==================== AiResponseConfig Tests ====================
406
407    #[test]
408    fn test_ai_response_config_with_generator() {
409        // Test that we can create AiResponseConfig compatible with the generator
410        let ai_config = AiResponseConfig {
411            temperature: 0.7,
412            max_tokens: 1024,
413            ..Default::default()
414        };
415
416        assert!((ai_config.temperature - 0.7).abs() < 0.001);
417        assert_eq!(ai_config.max_tokens, 1024);
418    }
419
420    #[test]
421    fn test_ai_response_config_low_temp() {
422        let ai_config = AiResponseConfig {
423            temperature: 0.0,
424            max_tokens: 256,
425            ..Default::default()
426        };
427
428        assert!((ai_config.temperature - 0.0).abs() < 0.001);
429    }
430
431    #[test]
432    fn test_ai_response_config_high_tokens() {
433        let ai_config = AiResponseConfig {
434            temperature: 0.5,
435            max_tokens: 8192,
436            ..Default::default()
437        };
438
439        assert_eq!(ai_config.max_tokens, 8192);
440    }
441
442    // ==================== Edge Cases ====================
443
444    #[test]
445    fn test_generator_with_empty_model_name() {
446        let config = RagConfig {
447            provider: LlmProvider::Ollama,
448            api_key: None,
449            model: String::new(), // Empty model name
450            api_endpoint: "http://localhost:11434/api/generate".to_string(),
451            ..Default::default()
452        };
453
454        // Should still create successfully (validation happens later)
455        let result = RagAiGenerator::new(config);
456        assert!(result.is_ok());
457    }
458
459    #[test]
460    fn test_generator_with_empty_endpoint() {
461        let config = RagConfig {
462            provider: LlmProvider::Ollama,
463            api_key: None,
464            model: "llama2".to_string(),
465            api_endpoint: String::new(), // Empty endpoint
466            ..Default::default()
467        };
468
469        // Should still create successfully (validation happens at request time)
470        let result = RagAiGenerator::new(config);
471        assert!(result.is_ok());
472    }
473
474    #[test]
475    fn test_generator_with_no_api_key_openai() {
476        let config = RagConfig {
477            provider: LlmProvider::OpenAI,
478            api_key: None, // No API key (will fail at request time)
479            model: "gpt-4".to_string(),
480            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
481            ..Default::default()
482        };
483
484        // Creation should succeed, failure happens when making requests
485        let result = RagAiGenerator::new(config);
486        assert!(result.is_ok());
487    }
488
489    // ==================== Integration-style Tests ====================
490
491    #[tokio::test]
492    async fn test_multiple_generators_different_providers() {
493        let openai_config = RagConfig {
494            provider: LlmProvider::OpenAI,
495            api_key: Some("test-key".to_string()),
496            model: "gpt-4".to_string(),
497            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
498            ..Default::default()
499        };
500
501        let ollama_config = RagConfig {
502            provider: LlmProvider::Ollama,
503            api_key: None,
504            model: "llama2".to_string(),
505            api_endpoint: "http://localhost:11434/api/generate".to_string(),
506            ..Default::default()
507        };
508
509        let anthropic_config = RagConfig {
510            provider: LlmProvider::Anthropic,
511            api_key: Some("test-key".to_string()),
512            model: "claude-3-haiku-20240307".to_string(),
513            api_endpoint: "https://api.anthropic.com/v1/messages".to_string(),
514            ..Default::default()
515        };
516
517        // All three should create successfully
518        assert!(RagAiGenerator::new(openai_config).is_ok());
519        assert!(RagAiGenerator::new(ollama_config).is_ok());
520        assert!(RagAiGenerator::new(anthropic_config).is_ok());
521    }
522
523    #[tokio::test]
524    async fn test_generator_engine_update() {
525        let config = RagConfig {
526            provider: LlmProvider::Ollama,
527            api_key: None,
528            model: "llama2".to_string(),
529            api_endpoint: "http://localhost:11434/api/generate".to_string(),
530            temperature: 0.7,
531            max_tokens: 1024,
532            ..Default::default()
533        };
534
535        let generator = RagAiGenerator::new(config).unwrap();
536
537        // Test that we can read and the engine has correct config
538        {
539            let engine = generator.engine.read().await;
540            let engine_config = engine.config();
541            assert!((engine_config.temperature - 0.7).abs() < 0.001);
542            assert_eq!(engine_config.max_tokens, 1024);
543        }
544
545        // Test that we can write to update the config
546        {
547            let mut engine = generator.engine.write().await;
548            let mut new_config = engine.config().clone();
549            new_config.temperature = 0.5;
550            new_config.max_tokens = 2048;
551            engine.update_config(new_config);
552        }
553
554        // Verify the update took effect
555        {
556            let engine = generator.engine.read().await;
557            let engine_config = engine.config();
558            assert!((engine_config.temperature - 0.5).abs() < 0.001);
559            assert_eq!(engine_config.max_tokens, 2048);
560        }
561    }
562}