mockforge_http/
rag_ai_generator.rs

1//! RAG-based AI generator implementation
2//!
3//! This module provides an implementation of the AiGenerator trait
4//! using the RAG engine from mockforge-data.
5
6use async_trait::async_trait;
7use mockforge_core::{ai_response::AiResponseConfig, openapi::response::AiGenerator, Result};
8use mockforge_data::rag::{LlmProvider, RagConfig, RagEngine};
9use serde_json::Value;
10use std::sync::Arc;
11use tracing::{debug, warn};
12
13/// RAG-based AI generator that uses the mockforge-data RAG engine
14pub struct RagAiGenerator {
15    /// The RAG engine instance
16    engine: Arc<tokio::sync::RwLock<RagEngine>>,
17}
18
19impl RagAiGenerator {
20    /// Create a new RAG-based AI generator
21    ///
22    /// # Arguments
23    /// * `rag_config` - Configuration for the RAG engine (provider, model, API key, etc.)
24    ///
25    /// # Returns
26    /// A new RagAiGenerator instance
27    pub fn new(rag_config: RagConfig) -> Result<Self> {
28        debug!("Creating RAG AI generator with provider: {:?}", rag_config.provider);
29
30        // Create the RAG engine
31        let engine = RagEngine::new(rag_config);
32
33        Ok(Self {
34            engine: Arc::new(tokio::sync::RwLock::new(engine)),
35        })
36    }
37
38    /// Create a RAG AI generator from environment variables
39    ///
40    /// Reads configuration from:
41    /// - `MOCKFORGE_AI_PROVIDER`: LLM provider (openai, anthropic, ollama, etc.)
42    /// - `MOCKFORGE_AI_API_KEY`: API key for the LLM provider
43    /// - `MOCKFORGE_AI_MODEL`: Model name (e.g., gpt-4, claude-3-opus)
44    /// - `MOCKFORGE_AI_ENDPOINT`: API endpoint (optional, uses provider default)
45    /// - `MOCKFORGE_AI_TEMPERATURE`: Temperature for generation (optional, default: 0.7)
46    /// - `MOCKFORGE_AI_MAX_TOKENS`: Max tokens for generation (optional, default: 1024)
47    pub fn from_env() -> Result<Self> {
48        let provider =
49            std::env::var("MOCKFORGE_AI_PROVIDER").unwrap_or_else(|_| "openai".to_string());
50
51        let provider = match provider.to_lowercase().as_str() {
52            "openai" => LlmProvider::OpenAI,
53            "anthropic" => LlmProvider::Anthropic,
54            "ollama" => LlmProvider::Ollama,
55            "openai-compatible" => LlmProvider::OpenAICompatible,
56            _ => {
57                warn!("Unknown AI provider '{}', defaulting to OpenAI", provider);
58                LlmProvider::OpenAI
59            }
60        };
61
62        let api_key = std::env::var("MOCKFORGE_AI_API_KEY").ok();
63
64        let model = std::env::var("MOCKFORGE_AI_MODEL").unwrap_or_else(|_| match provider {
65            LlmProvider::OpenAI => "gpt-3.5-turbo".to_string(),
66            LlmProvider::Anthropic => "claude-3-haiku-20240307".to_string(),
67            LlmProvider::Ollama => "llama2".to_string(),
68            LlmProvider::OpenAICompatible => "gpt-3.5-turbo".to_string(),
69        });
70
71        let api_endpoint =
72            std::env::var("MOCKFORGE_AI_ENDPOINT").unwrap_or_else(|_| match provider {
73                LlmProvider::OpenAI => "https://api.openai.com/v1/chat/completions".to_string(),
74                LlmProvider::Anthropic => "https://api.anthropic.com/v1/messages".to_string(),
75                LlmProvider::Ollama => "http://localhost:11434/api/generate".to_string(),
76                LlmProvider::OpenAICompatible => {
77                    "http://localhost:8080/v1/chat/completions".to_string()
78                }
79            });
80
81        let temperature = std::env::var("MOCKFORGE_AI_TEMPERATURE")
82            .ok()
83            .and_then(|s| s.parse::<f64>().ok())
84            .unwrap_or(0.7);
85
86        let max_tokens = std::env::var("MOCKFORGE_AI_MAX_TOKENS")
87            .ok()
88            .and_then(|s| s.parse::<usize>().ok())
89            .unwrap_or(1024);
90
91        let config = RagConfig {
92            provider,
93            api_key,
94            model,
95            api_endpoint,
96            temperature,
97            max_tokens,
98            ..Default::default()
99        };
100
101        debug!("Creating RAG AI generator from environment variables");
102        Self::new(config)
103    }
104}
105
106#[async_trait]
107impl AiGenerator for RagAiGenerator {
108    async fn generate(&self, prompt: &str, config: &AiResponseConfig) -> Result<Value> {
109        debug!("Generating AI response with RAG engine");
110
111        // Lock the engine for generation
112        let mut engine = self.engine.write().await;
113
114        // Update engine config with request-specific settings if needed
115        let mut engine_config = engine.config().clone();
116        engine_config.temperature = config.temperature as f64;
117        engine_config.max_tokens = config.max_tokens;
118
119        // Temporarily update the engine config
120        engine.update_config(engine_config);
121
122        // Generate the response using the RAG engine
123        match engine.generate_text(prompt).await {
124            Ok(response_text) => {
125                debug!("RAG engine generated response ({} chars)", response_text.len());
126
127                // Try to parse the response as JSON
128                match serde_json::from_str::<Value>(&response_text) {
129                    Ok(json_value) => Ok(json_value),
130                    Err(_) => {
131                        // If not valid JSON, try to extract JSON from the response
132                        if let Some(start) = response_text.find('{') {
133                            if let Some(end) = response_text.rfind('}') {
134                                let json_str = &response_text[start..=end];
135                                match serde_json::from_str::<Value>(json_str) {
136                                    Ok(json_value) => Ok(json_value),
137                                    Err(_) => {
138                                        // If still not valid JSON, wrap in an object
139                                        Ok(serde_json::json!({
140                                            "response": response_text,
141                                            "note": "Response was not valid JSON, wrapped in object"
142                                        }))
143                                    }
144                                }
145                            } else {
146                                Ok(serde_json::json!({
147                                    "response": response_text,
148                                    "note": "Response was not valid JSON, wrapped in object"
149                                }))
150                            }
151                        } else {
152                            Ok(serde_json::json!({
153                                "response": response_text,
154                                "note": "Response was not valid JSON, wrapped in object"
155                            }))
156                        }
157                    }
158                }
159            }
160            Err(e) => {
161                warn!("RAG engine generation failed: {}", e);
162                Err(e)
163            }
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_rag_generator_creation() {
174        let config = RagConfig {
175            provider: LlmProvider::Ollama,
176            api_key: None,
177            model: "llama2".to_string(),
178            api_endpoint: "http://localhost:11434/api/generate".to_string(),
179            ..Default::default()
180        };
181
182        let result = RagAiGenerator::new(config);
183        assert!(result.is_ok());
184    }
185
186    #[tokio::test]
187    async fn test_generate_fallback_to_json() {
188        // This test verifies that non-JSON responses are wrapped properly
189        // In a real scenario, this would require mocking the RAG engine
190
191        let config = RagConfig {
192            provider: LlmProvider::Ollama,
193            api_key: None,
194            model: "test-model".to_string(),
195            api_endpoint: "http://localhost:11434/api/generate".to_string(),
196            ..Default::default()
197        };
198
199        // We can't easily test the actual generation without a real LLM,
200        // but we can verify the generator was created successfully
201        let generator = RagAiGenerator::new(config);
202        assert!(generator.is_ok());
203    }
204}