mockforge_http/
rag_ai_generator.rs

1//! RAG-based AI generator implementation
2//!
3//! This module provides an implementation of the AiGenerator trait
4//! using the RAG engine from mockforge-data.
5
6use async_trait::async_trait;
7use axum::extract::State;
8use axum::http::StatusCode;
9use axum::response::Json;
10use mockforge_core::{ai_response::AiResponseConfig, openapi::response::AiGenerator, Result};
11use mockforge_data::rag::{LlmProvider, RagConfig, RagEngine};
12use serde_json::Value;
13use std::sync::Arc;
14use tracing::{debug, warn};
15
16/// RAG-based AI generator that uses the mockforge-data RAG engine
17pub struct RagAiGenerator {
18    /// The RAG engine instance
19    engine: Arc<tokio::sync::RwLock<RagEngine>>,
20}
21
22impl RagAiGenerator {
23    /// Create a new RAG-based AI generator
24    ///
25    /// # Arguments
26    /// * `rag_config` - Configuration for the RAG engine (provider, model, API key, etc.)
27    ///
28    /// # Returns
29    /// A new RagAiGenerator instance
30    pub fn new(rag_config: RagConfig) -> Result<Self> {
31        debug!("Creating RAG AI generator with provider: {:?}", rag_config.provider);
32
33        // Create the RAG engine
34        let engine = RagEngine::new(rag_config);
35
36        Ok(Self {
37            engine: Arc::new(tokio::sync::RwLock::new(engine)),
38        })
39    }
40
41    /// Create a RAG AI generator from environment variables
42    ///
43    /// Reads configuration from:
44    /// - `MOCKFORGE_AI_PROVIDER`: LLM provider (openai, anthropic, ollama, etc.)
45    /// - `MOCKFORGE_AI_API_KEY`: API key for the LLM provider
46    /// - `MOCKFORGE_AI_MODEL`: Model name (e.g., gpt-4, claude-3-opus)
47    /// - `MOCKFORGE_AI_ENDPOINT`: API endpoint (optional, uses provider default)
48    /// - `MOCKFORGE_AI_TEMPERATURE`: Temperature for generation (optional, default: 0.7)
49    /// - `MOCKFORGE_AI_MAX_TOKENS`: Max tokens for generation (optional, default: 1024)
50    pub fn from_env() -> Result<Self> {
51        let provider =
52            std::env::var("MOCKFORGE_AI_PROVIDER").unwrap_or_else(|_| "openai".to_string());
53
54        let provider = match provider.to_lowercase().as_str() {
55            "openai" => LlmProvider::OpenAI,
56            "anthropic" => LlmProvider::Anthropic,
57            "ollama" => LlmProvider::Ollama,
58            "openai-compatible" => LlmProvider::OpenAICompatible,
59            _ => {
60                warn!("Unknown AI provider '{}', defaulting to OpenAI", provider);
61                LlmProvider::OpenAI
62            }
63        };
64
65        let api_key = std::env::var("MOCKFORGE_AI_API_KEY").ok();
66
67        let model = std::env::var("MOCKFORGE_AI_MODEL").unwrap_or_else(|_| match provider {
68            LlmProvider::OpenAI => "gpt-3.5-turbo".to_string(),
69            LlmProvider::Anthropic => "claude-3-haiku-20240307".to_string(),
70            LlmProvider::Ollama => "llama2".to_string(),
71            LlmProvider::OpenAICompatible => "gpt-3.5-turbo".to_string(),
72        });
73
74        let api_endpoint =
75            std::env::var("MOCKFORGE_AI_ENDPOINT").unwrap_or_else(|_| match provider {
76                LlmProvider::OpenAI => "https://api.openai.com/v1/chat/completions".to_string(),
77                LlmProvider::Anthropic => "https://api.anthropic.com/v1/messages".to_string(),
78                LlmProvider::Ollama => "http://localhost:11434/api/generate".to_string(),
79                LlmProvider::OpenAICompatible => {
80                    "http://localhost:8080/v1/chat/completions".to_string()
81                }
82            });
83
84        let temperature = std::env::var("MOCKFORGE_AI_TEMPERATURE")
85            .ok()
86            .and_then(|s| s.parse::<f64>().ok())
87            .unwrap_or(0.7);
88
89        let max_tokens = std::env::var("MOCKFORGE_AI_MAX_TOKENS")
90            .ok()
91            .and_then(|s| s.parse::<usize>().ok())
92            .unwrap_or(1024);
93
94        let config = RagConfig {
95            provider,
96            api_key,
97            model,
98            api_endpoint,
99            temperature,
100            max_tokens,
101            ..Default::default()
102        };
103
104        debug!("Creating RAG AI generator from environment variables");
105        Self::new(config)
106    }
107}
108
109#[async_trait]
110impl AiGenerator for RagAiGenerator {
111    async fn generate(&self, prompt: &str, config: &AiResponseConfig) -> Result<Value> {
112        debug!("Generating AI response with RAG engine");
113
114        // Lock the engine for generation
115        let mut engine = self.engine.write().await;
116
117        // Update engine config with request-specific settings if needed
118        let mut engine_config = engine.config().clone();
119        engine_config.temperature = config.temperature as f64;
120        engine_config.max_tokens = config.max_tokens;
121
122        // Temporarily update the engine config
123        engine.update_config(engine_config);
124
125        // Generate the response using the RAG engine
126        match engine.generate_text(prompt).await {
127            Ok(response_text) => {
128                debug!("RAG engine generated response ({} chars)", response_text.len());
129
130                // Try to parse the response as JSON
131                match serde_json::from_str::<Value>(&response_text) {
132                    Ok(json_value) => Ok(json_value),
133                    Err(_) => {
134                        // If not valid JSON, try to extract JSON from the response
135                        if let Some(start) = response_text.find('{') {
136                            if let Some(end) = response_text.rfind('}') {
137                                let json_str = &response_text[start..=end];
138                                match serde_json::from_str::<Value>(json_str) {
139                                    Ok(json_value) => Ok(json_value),
140                                    Err(_) => {
141                                        // If still not valid JSON, wrap in an object
142                                        Ok(serde_json::json!({
143                                            "response": response_text,
144                                            "note": "Response was not valid JSON, wrapped in object"
145                                        }))
146                                    }
147                                }
148                            } else {
149                                Ok(serde_json::json!({
150                                    "response": response_text,
151                                    "note": "Response was not valid JSON, wrapped in object"
152                                }))
153                            }
154                        } else {
155                            Ok(serde_json::json!({
156                                "response": response_text,
157                                "note": "Response was not valid JSON, wrapped in object"
158                            }))
159                        }
160                    }
161                }
162            }
163            Err(e) => {
164                warn!("RAG engine generation failed: {}", e);
165                Err(mockforge_core::Error::Config {
166                    message: format!("RAG engine generation failed: {}", e),
167                })
168            }
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn test_rag_generator_creation() {
179        let config = RagConfig {
180            provider: LlmProvider::Ollama,
181            api_key: None,
182            model: "llama2".to_string(),
183            api_endpoint: "http://localhost:11434/api/generate".to_string(),
184            ..Default::default()
185        };
186
187        let result = RagAiGenerator::new(config);
188        assert!(result.is_ok());
189    }
190
191    #[tokio::test]
192    async fn test_generate_fallback_to_json() {
193        // This test verifies that non-JSON responses are wrapped properly
194        // In a real scenario, this would require mocking the RAG engine
195
196        let config = RagConfig {
197            provider: LlmProvider::Ollama,
198            api_key: None,
199            model: "test-model".to_string(),
200            api_endpoint: "http://localhost:11434/api/generate".to_string(),
201            ..Default::default()
202        };
203
204        // We can't easily test the actual generation without a real LLM,
205        // but we can verify the generator was created successfully
206        let generator = RagAiGenerator::new(config);
207        assert!(generator.is_ok());
208    }
209}