Skip to main content

mockforge_intelligence/
rag_ai_generator.rs

1//! RAG-based AI generator implementation
2//!
3//! This module provides an implementation of the AiGenerator trait
4//! using the RAG engine from mockforge-data.
5//!
6//! Moved from `mockforge_http::rag_ai_generator` under #555 phase 10 —
7//! both foreign deps (`AiResponseConfig` and the `AiGenerator` trait)
8//! were originally in `mockforge-core` but have been promoted to
9//! `mockforge-foundation::ai_response` and `mockforge-openapi::response`
10//! respectively, so the move no longer requires a `mockforge-core` dep
11//! and stays cycle-safe with the Issue #562 cycle-break.
12
13use async_trait::async_trait;
14use mockforge_data::rag::{LlmProvider, RagConfig, RagEngine};
15use mockforge_foundation::ai_response::AiResponseConfig;
16use mockforge_foundation::{Error, Result};
17use mockforge_openapi::response::AiGenerator;
18use serde_json::Value;
19use std::sync::Arc;
20use tracing::{debug, warn};
21
22/// RAG-based AI generator that uses the mockforge-data RAG engine
23pub struct RagAiGenerator {
24    /// The RAG engine instance
25    engine: Arc<tokio::sync::RwLock<RagEngine>>,
26}
27
28impl RagAiGenerator {
29    /// Create a new RAG-based AI generator
30    ///
31    /// # Arguments
32    /// * `rag_config` - Configuration for the RAG engine (provider, model, API key, etc.)
33    ///
34    /// # Returns
35    /// A new RagAiGenerator instance
36    pub fn new(rag_config: RagConfig) -> Result<Self> {
37        debug!("Creating RAG AI generator with provider: {:?}", rag_config.provider);
38
39        // Create the RAG engine
40        let engine = RagEngine::new(rag_config);
41
42        Ok(Self {
43            engine: Arc::new(tokio::sync::RwLock::new(engine)),
44        })
45    }
46
47    /// Create a RAG AI generator from environment variables
48    ///
49    /// Reads configuration from:
50    /// - `MOCKFORGE_AI_PROVIDER`: LLM provider (openai, anthropic, ollama, etc.)
51    /// - `MOCKFORGE_AI_API_KEY`: API key for the LLM provider
52    /// - `MOCKFORGE_AI_MODEL`: Model name (e.g., gpt-4, claude-3-opus)
53    /// - `MOCKFORGE_AI_ENDPOINT`: API endpoint (optional, uses provider default)
54    /// - `MOCKFORGE_AI_TEMPERATURE`: Temperature for generation (optional, default: 0.7)
55    /// - `MOCKFORGE_AI_MAX_TOKENS`: Max tokens for generation (optional, default: 1024)
56    pub fn from_env() -> Result<Self> {
57        let provider =
58            std::env::var("MOCKFORGE_AI_PROVIDER").unwrap_or_else(|_| "openai".to_string());
59
60        let provider = match provider.to_lowercase().as_str() {
61            "openai" => LlmProvider::OpenAI,
62            "anthropic" => LlmProvider::Anthropic,
63            "ollama" => LlmProvider::Ollama,
64            "openai-compatible" => LlmProvider::OpenAICompatible,
65            _ => {
66                warn!("Unknown AI provider '{}', defaulting to OpenAI", provider);
67                LlmProvider::OpenAI
68            }
69        };
70
71        let api_key = std::env::var("MOCKFORGE_AI_API_KEY").ok();
72
73        let model = std::env::var("MOCKFORGE_AI_MODEL").unwrap_or_else(|_| match provider {
74            LlmProvider::OpenAI => "gpt-3.5-turbo".to_string(),
75            LlmProvider::Anthropic => "claude-3-haiku-20240307".to_string(),
76            LlmProvider::Ollama => "llama2".to_string(),
77            LlmProvider::OpenAICompatible => "gpt-3.5-turbo".to_string(),
78        });
79
80        let api_endpoint =
81            std::env::var("MOCKFORGE_AI_ENDPOINT").unwrap_or_else(|_| match provider {
82                LlmProvider::OpenAI => "https://api.openai.com/v1/chat/completions".to_string(),
83                LlmProvider::Anthropic => "https://api.anthropic.com/v1/messages".to_string(),
84                LlmProvider::Ollama => "http://localhost:11434/api/generate".to_string(),
85                LlmProvider::OpenAICompatible => {
86                    "http://localhost:8080/v1/chat/completions".to_string()
87                }
88            });
89
90        let temperature = std::env::var("MOCKFORGE_AI_TEMPERATURE")
91            .ok()
92            .and_then(|s| s.parse::<f64>().ok())
93            .unwrap_or(0.7);
94
95        let max_tokens = std::env::var("MOCKFORGE_AI_MAX_TOKENS")
96            .ok()
97            .and_then(|s| s.parse::<usize>().ok())
98            .unwrap_or(1024);
99
100        let config = RagConfig {
101            provider,
102            api_key,
103            model,
104            api_endpoint,
105            temperature,
106            max_tokens,
107            ..Default::default()
108        };
109
110        debug!("Creating RAG AI generator from environment variables");
111        Self::new(config)
112    }
113}
114
115#[async_trait]
116impl AiGenerator for RagAiGenerator {
117    async fn generate(&self, prompt: &str, config: &AiResponseConfig) -> Result<Value> {
118        debug!("Generating AI response with RAG engine");
119
120        // Lock the engine for generation
121        let mut engine = self.engine.write().await;
122
123        // Update engine config with request-specific settings if needed
124        let mut engine_config = engine.config().clone();
125        engine_config.temperature = config.temperature as f64;
126        engine_config.max_tokens = config.max_tokens;
127
128        // Temporarily update the engine config
129        engine.update_config(engine_config);
130
131        // Generate the response using the RAG engine
132        match engine.generate_text(prompt).await {
133            Ok(response_text) => {
134                debug!("RAG engine generated response ({} chars)", response_text.len());
135
136                // Try to parse the response as JSON
137                match serde_json::from_str::<Value>(&response_text) {
138                    Ok(json_value) => Ok(json_value),
139                    Err(_) => {
140                        // If not valid JSON, try to extract JSON from the response
141                        if let Some(start) = response_text.find('{') {
142                            if let Some(end) = response_text.rfind('}') {
143                                let json_str = &response_text[start..=end];
144                                match serde_json::from_str::<Value>(json_str) {
145                                    Ok(json_value) => Ok(json_value),
146                                    Err(_) => {
147                                        // If still not valid JSON, wrap in an object
148                                        Ok(serde_json::json!({
149                                            "response": response_text,
150                                            "note": "Response was not valid JSON, wrapped in object"
151                                        }))
152                                    }
153                                }
154                            } else {
155                                Ok(serde_json::json!({
156                                    "response": response_text,
157                                    "note": "Response was not valid JSON, wrapped in object"
158                                }))
159                            }
160                        } else {
161                            Ok(serde_json::json!({
162                                "response": response_text,
163                                "note": "Response was not valid JSON, wrapped in object"
164                            }))
165                        }
166                    }
167                }
168            }
169            Err(e) => {
170                warn!("RAG engine generation failed: {}", e);
171                Err(Error::Config {
172                    message: format!("RAG engine generation failed: {}", e),
173                })
174            }
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    // ==================== RagAiGenerator Creation Tests ====================
184
185    #[test]
186    fn test_rag_generator_creation() {
187        let config = RagConfig {
188            provider: LlmProvider::Ollama,
189            api_key: None,
190            model: "llama2".to_string(),
191            api_endpoint: "http://localhost:11434/api/generate".to_string(),
192            ..Default::default()
193        };
194
195        let result = RagAiGenerator::new(config);
196        assert!(result.is_ok());
197    }
198
199    #[test]
200    fn test_rag_generator_creation_openai() {
201        let config = RagConfig {
202            provider: LlmProvider::OpenAI,
203            api_key: Some("test-api-key".to_string()),
204            model: "gpt-4".to_string(),
205            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
206            ..Default::default()
207        };
208
209        let result = RagAiGenerator::new(config);
210        assert!(result.is_ok());
211    }
212
213    #[test]
214    fn test_rag_generator_creation_anthropic() {
215        let config = RagConfig {
216            provider: LlmProvider::Anthropic,
217            api_key: Some("test-api-key".to_string()),
218            model: "claude-3-opus".to_string(),
219            api_endpoint: "https://api.anthropic.com/v1/messages".to_string(),
220            ..Default::default()
221        };
222
223        let result = RagAiGenerator::new(config);
224        assert!(result.is_ok());
225    }
226
227    #[test]
228    fn test_rag_generator_creation_openai_compatible() {
229        let config = RagConfig {
230            provider: LlmProvider::OpenAICompatible,
231            api_key: None,
232            model: "local-model".to_string(),
233            api_endpoint: "http://localhost:8080/v1/chat/completions".to_string(),
234            ..Default::default()
235        };
236
237        let result = RagAiGenerator::new(config);
238        assert!(result.is_ok());
239    }
240
241    #[test]
242    fn test_rag_generator_creation_with_custom_settings() {
243        let config = RagConfig {
244            provider: LlmProvider::Ollama,
245            api_key: None,
246            model: "codellama".to_string(),
247            api_endpoint: "http://localhost:11434/api/generate".to_string(),
248            temperature: 0.5,
249            max_tokens: 2048,
250            ..Default::default()
251        };
252
253        let result = RagAiGenerator::new(config);
254        assert!(result.is_ok());
255    }
256
257    #[test]
258    fn test_rag_generator_creation_with_low_temperature() {
259        let config = RagConfig {
260            provider: LlmProvider::OpenAI,
261            api_key: Some("test-key".to_string()),
262            model: "gpt-3.5-turbo".to_string(),
263            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
264            temperature: 0.0,
265            max_tokens: 512,
266            ..Default::default()
267        };
268
269        let result = RagAiGenerator::new(config);
270        assert!(result.is_ok());
271    }
272
273    #[test]
274    fn test_rag_generator_creation_with_high_temperature() {
275        let config = RagConfig {
276            provider: LlmProvider::OpenAI,
277            api_key: Some("test-key".to_string()),
278            model: "gpt-4".to_string(),
279            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
280            temperature: 1.0,
281            max_tokens: 4096,
282            ..Default::default()
283        };
284
285        let result = RagAiGenerator::new(config);
286        assert!(result.is_ok());
287    }
288
289    // ==================== RagConfig Tests ====================
290
291    #[test]
292    fn test_rag_config_default() {
293        let config = RagConfig::default();
294        // Default config should have reasonable defaults
295        assert!(config.temperature >= 0.0);
296        assert!(config.max_tokens > 0);
297    }
298
299    #[test]
300    fn test_rag_config_clone() {
301        let config = RagConfig {
302            provider: LlmProvider::Ollama,
303            api_key: Some("secret".to_string()),
304            model: "llama2".to_string(),
305            api_endpoint: "http://localhost:11434/api/generate".to_string(),
306            temperature: 0.7,
307            max_tokens: 1024,
308            ..Default::default()
309        };
310
311        let cloned = config.clone();
312        assert_eq!(cloned.model, config.model);
313        assert_eq!(cloned.api_key, config.api_key);
314    }
315
316    // ==================== LlmProvider Tests ====================
317
318    #[test]
319    fn test_llm_provider_openai() {
320        let provider = LlmProvider::OpenAI;
321        let config = RagConfig {
322            provider,
323            ..Default::default()
324        };
325        assert!(matches!(config.provider, LlmProvider::OpenAI));
326    }
327
328    #[test]
329    fn test_llm_provider_anthropic() {
330        let provider = LlmProvider::Anthropic;
331        let config = RagConfig {
332            provider,
333            ..Default::default()
334        };
335        assert!(matches!(config.provider, LlmProvider::Anthropic));
336    }
337
338    #[test]
339    fn test_llm_provider_ollama() {
340        let provider = LlmProvider::Ollama;
341        let config = RagConfig {
342            provider,
343            ..Default::default()
344        };
345        assert!(matches!(config.provider, LlmProvider::Ollama));
346    }
347
348    #[test]
349    fn test_llm_provider_openai_compatible() {
350        let provider = LlmProvider::OpenAICompatible;
351        let config = RagConfig {
352            provider,
353            ..Default::default()
354        };
355        assert!(matches!(config.provider, LlmProvider::OpenAICompatible));
356    }
357
358    // ==================== Generator Async Tests ====================
359
360    #[tokio::test]
361    async fn test_generate_fallback_to_json() {
362        // This test verifies that non-JSON responses are wrapped properly
363        // In a real scenario, this would require mocking the RAG engine
364
365        let config = RagConfig {
366            provider: LlmProvider::Ollama,
367            api_key: None,
368            model: "test-model".to_string(),
369            api_endpoint: "http://localhost:11434/api/generate".to_string(),
370            ..Default::default()
371        };
372
373        // We can't easily test the actual generation without a real LLM,
374        // but we can verify the generator was created successfully
375        let generator = RagAiGenerator::new(config);
376        assert!(generator.is_ok());
377    }
378
379    #[tokio::test]
380    async fn test_generator_engine_access() {
381        let config = RagConfig {
382            provider: LlmProvider::Ollama,
383            api_key: None,
384            model: "llama2".to_string(),
385            api_endpoint: "http://localhost:11434/api/generate".to_string(),
386            temperature: 0.8,
387            max_tokens: 512,
388            ..Default::default()
389        };
390
391        let generator = RagAiGenerator::new(config).unwrap();
392        // The engine is wrapped in Arc<RwLock>, verify we can access it
393        let engine = generator.engine.read().await;
394        let engine_config = engine.config();
395        assert_eq!(engine_config.model, "llama2");
396    }
397
398    #[tokio::test]
399    async fn test_generator_can_be_cloned_via_arc() {
400        let config = RagConfig {
401            provider: LlmProvider::OpenAI,
402            api_key: Some("test".to_string()),
403            model: "gpt-3.5-turbo".to_string(),
404            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
405            ..Default::default()
406        };
407
408        let generator = RagAiGenerator::new(config).unwrap();
409        // Engine is Arc-wrapped, so cloning should work
410        let engine_clone = generator.engine.clone();
411        assert!(Arc::strong_count(&engine_clone) >= 2);
412    }
413
414    // ==================== AiResponseConfig Tests ====================
415
416    #[test]
417    fn test_ai_response_config_with_generator() {
418        // Test that we can create AiResponseConfig compatible with the generator
419        let ai_config = AiResponseConfig {
420            temperature: 0.7,
421            max_tokens: 1024,
422            ..Default::default()
423        };
424
425        assert!((ai_config.temperature - 0.7).abs() < 0.001);
426        assert_eq!(ai_config.max_tokens, 1024);
427    }
428
429    #[test]
430    fn test_ai_response_config_low_temp() {
431        let ai_config = AiResponseConfig {
432            temperature: 0.0,
433            max_tokens: 256,
434            ..Default::default()
435        };
436
437        assert!((ai_config.temperature - 0.0).abs() < 0.001);
438    }
439
440    #[test]
441    fn test_ai_response_config_high_tokens() {
442        let ai_config = AiResponseConfig {
443            temperature: 0.5,
444            max_tokens: 8192,
445            ..Default::default()
446        };
447
448        assert_eq!(ai_config.max_tokens, 8192);
449    }
450
451    // ==================== Edge Cases ====================
452
453    #[test]
454    fn test_generator_with_empty_model_name() {
455        let config = RagConfig {
456            provider: LlmProvider::Ollama,
457            api_key: None,
458            model: String::new(), // Empty model name
459            api_endpoint: "http://localhost:11434/api/generate".to_string(),
460            ..Default::default()
461        };
462
463        // Should still create successfully (validation happens later)
464        let result = RagAiGenerator::new(config);
465        assert!(result.is_ok());
466    }
467
468    #[test]
469    fn test_generator_with_empty_endpoint() {
470        let config = RagConfig {
471            provider: LlmProvider::Ollama,
472            api_key: None,
473            model: "llama2".to_string(),
474            api_endpoint: String::new(), // Empty endpoint
475            ..Default::default()
476        };
477
478        // Should still create successfully (validation happens at request time)
479        let result = RagAiGenerator::new(config);
480        assert!(result.is_ok());
481    }
482
483    #[test]
484    fn test_generator_with_no_api_key_openai() {
485        let config = RagConfig {
486            provider: LlmProvider::OpenAI,
487            api_key: None, // No API key (will fail at request time)
488            model: "gpt-4".to_string(),
489            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
490            ..Default::default()
491        };
492
493        // Creation should succeed, failure happens when making requests
494        let result = RagAiGenerator::new(config);
495        assert!(result.is_ok());
496    }
497
498    // ==================== Integration-style Tests ====================
499
500    #[tokio::test]
501    async fn test_multiple_generators_different_providers() {
502        let openai_config = RagConfig {
503            provider: LlmProvider::OpenAI,
504            api_key: Some("test-key".to_string()),
505            model: "gpt-4".to_string(),
506            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
507            ..Default::default()
508        };
509
510        let ollama_config = RagConfig {
511            provider: LlmProvider::Ollama,
512            api_key: None,
513            model: "llama2".to_string(),
514            api_endpoint: "http://localhost:11434/api/generate".to_string(),
515            ..Default::default()
516        };
517
518        let anthropic_config = RagConfig {
519            provider: LlmProvider::Anthropic,
520            api_key: Some("test-key".to_string()),
521            model: "claude-3-haiku-20240307".to_string(),
522            api_endpoint: "https://api.anthropic.com/v1/messages".to_string(),
523            ..Default::default()
524        };
525
526        // All three should create successfully
527        assert!(RagAiGenerator::new(openai_config).is_ok());
528        assert!(RagAiGenerator::new(ollama_config).is_ok());
529        assert!(RagAiGenerator::new(anthropic_config).is_ok());
530    }
531
532    #[tokio::test]
533    async fn test_generator_engine_update() {
534        let config = RagConfig {
535            provider: LlmProvider::Ollama,
536            api_key: None,
537            model: "llama2".to_string(),
538            api_endpoint: "http://localhost:11434/api/generate".to_string(),
539            temperature: 0.7,
540            max_tokens: 1024,
541            ..Default::default()
542        };
543
544        let generator = RagAiGenerator::new(config).unwrap();
545
546        // Test that we can read and the engine has correct config
547        {
548            let engine = generator.engine.read().await;
549            let engine_config = engine.config();
550            assert!((engine_config.temperature - 0.7).abs() < 0.001);
551            assert_eq!(engine_config.max_tokens, 1024);
552        }
553
554        // Test that we can write to update the config
555        {
556            let mut engine = generator.engine.write().await;
557            let mut new_config = engine.config().clone();
558            new_config.temperature = 0.5;
559            new_config.max_tokens = 2048;
560            engine.update_config(new_config);
561        }
562
563        // Verify the update took effect
564        {
565            let engine = generator.engine.read().await;
566            let engine_config = engine.config();
567            assert!((engine_config.temperature - 0.5).abs() < 0.001);
568            assert_eq!(engine_config.max_tokens, 2048);
569        }
570    }
571}