mockforge_http/
rag_ai_generator.rs1use async_trait::async_trait;
7use axum::extract::State;
8use axum::http::StatusCode;
9use axum::response::Json;
10use mockforge_core::{ai_response::AiResponseConfig, openapi::response::AiGenerator, Result};
11use mockforge_data::rag::{LlmProvider, RagConfig, RagEngine};
12use serde_json::Value;
13use std::sync::Arc;
14use tracing::{debug, warn};
15
16pub struct RagAiGenerator {
18 engine: Arc<tokio::sync::RwLock<RagEngine>>,
20}
21
22impl RagAiGenerator {
23 pub fn new(rag_config: RagConfig) -> Result<Self> {
31 debug!("Creating RAG AI generator with provider: {:?}", rag_config.provider);
32
33 let engine = RagEngine::new(rag_config);
35
36 Ok(Self {
37 engine: Arc::new(tokio::sync::RwLock::new(engine)),
38 })
39 }
40
41 pub fn from_env() -> Result<Self> {
51 let provider =
52 std::env::var("MOCKFORGE_AI_PROVIDER").unwrap_or_else(|_| "openai".to_string());
53
54 let provider = match provider.to_lowercase().as_str() {
55 "openai" => LlmProvider::OpenAI,
56 "anthropic" => LlmProvider::Anthropic,
57 "ollama" => LlmProvider::Ollama,
58 "openai-compatible" => LlmProvider::OpenAICompatible,
59 _ => {
60 warn!("Unknown AI provider '{}', defaulting to OpenAI", provider);
61 LlmProvider::OpenAI
62 }
63 };
64
65 let api_key = std::env::var("MOCKFORGE_AI_API_KEY").ok();
66
67 let model = std::env::var("MOCKFORGE_AI_MODEL").unwrap_or_else(|_| match provider {
68 LlmProvider::OpenAI => "gpt-3.5-turbo".to_string(),
69 LlmProvider::Anthropic => "claude-3-haiku-20240307".to_string(),
70 LlmProvider::Ollama => "llama2".to_string(),
71 LlmProvider::OpenAICompatible => "gpt-3.5-turbo".to_string(),
72 });
73
74 let api_endpoint =
75 std::env::var("MOCKFORGE_AI_ENDPOINT").unwrap_or_else(|_| match provider {
76 LlmProvider::OpenAI => "https://api.openai.com/v1/chat/completions".to_string(),
77 LlmProvider::Anthropic => "https://api.anthropic.com/v1/messages".to_string(),
78 LlmProvider::Ollama => "http://localhost:11434/api/generate".to_string(),
79 LlmProvider::OpenAICompatible => {
80 "http://localhost:8080/v1/chat/completions".to_string()
81 }
82 });
83
84 let temperature = std::env::var("MOCKFORGE_AI_TEMPERATURE")
85 .ok()
86 .and_then(|s| s.parse::<f64>().ok())
87 .unwrap_or(0.7);
88
89 let max_tokens = std::env::var("MOCKFORGE_AI_MAX_TOKENS")
90 .ok()
91 .and_then(|s| s.parse::<usize>().ok())
92 .unwrap_or(1024);
93
94 let config = RagConfig {
95 provider,
96 api_key,
97 model,
98 api_endpoint,
99 temperature,
100 max_tokens,
101 ..Default::default()
102 };
103
104 debug!("Creating RAG AI generator from environment variables");
105 Self::new(config)
106 }
107}
108
109#[async_trait]
110impl AiGenerator for RagAiGenerator {
111 async fn generate(&self, prompt: &str, config: &AiResponseConfig) -> Result<Value> {
112 debug!("Generating AI response with RAG engine");
113
114 let mut engine = self.engine.write().await;
116
117 let mut engine_config = engine.config().clone();
119 engine_config.temperature = config.temperature as f64;
120 engine_config.max_tokens = config.max_tokens;
121
122 engine.update_config(engine_config);
124
125 match engine.generate_text(prompt).await {
127 Ok(response_text) => {
128 debug!("RAG engine generated response ({} chars)", response_text.len());
129
130 match serde_json::from_str::<Value>(&response_text) {
132 Ok(json_value) => Ok(json_value),
133 Err(_) => {
134 if let Some(start) = response_text.find('{') {
136 if let Some(end) = response_text.rfind('}') {
137 let json_str = &response_text[start..=end];
138 match serde_json::from_str::<Value>(json_str) {
139 Ok(json_value) => Ok(json_value),
140 Err(_) => {
141 Ok(serde_json::json!({
143 "response": response_text,
144 "note": "Response was not valid JSON, wrapped in object"
145 }))
146 }
147 }
148 } else {
149 Ok(serde_json::json!({
150 "response": response_text,
151 "note": "Response was not valid JSON, wrapped in object"
152 }))
153 }
154 } else {
155 Ok(serde_json::json!({
156 "response": response_text,
157 "note": "Response was not valid JSON, wrapped in object"
158 }))
159 }
160 }
161 }
162 }
163 Err(e) => {
164 warn!("RAG engine generation failed: {}", e);
165 Err(mockforge_core::Error::Config {
166 message: format!("RAG engine generation failed: {}", e),
167 })
168 }
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn test_rag_generator_creation() {
179 let config = RagConfig {
180 provider: LlmProvider::Ollama,
181 api_key: None,
182 model: "llama2".to_string(),
183 api_endpoint: "http://localhost:11434/api/generate".to_string(),
184 ..Default::default()
185 };
186
187 let result = RagAiGenerator::new(config);
188 assert!(result.is_ok());
189 }
190
191 #[tokio::test]
192 async fn test_generate_fallback_to_json() {
193 let config = RagConfig {
197 provider: LlmProvider::Ollama,
198 api_key: None,
199 model: "test-model".to_string(),
200 api_endpoint: "http://localhost:11434/api/generate".to_string(),
201 ..Default::default()
202 };
203
204 let generator = RagAiGenerator::new(config);
207 assert!(generator.is_ok());
208 }
209}