mockforge_http/
rag_ai_generator.rs1use async_trait::async_trait;
7use mockforge_core::{ai_response::AiResponseConfig, openapi::response::AiGenerator, Result};
8use mockforge_data::rag::{LlmProvider, RagConfig, RagEngine};
9use serde_json::Value;
10use std::sync::Arc;
11use tracing::{debug, warn};
12
13pub struct RagAiGenerator {
15 engine: Arc<tokio::sync::RwLock<RagEngine>>,
17}
18
19impl RagAiGenerator {
20 pub fn new(rag_config: RagConfig) -> Result<Self> {
28 debug!("Creating RAG AI generator with provider: {:?}", rag_config.provider);
29
30 let engine = RagEngine::new(rag_config);
32
33 Ok(Self {
34 engine: Arc::new(tokio::sync::RwLock::new(engine)),
35 })
36 }
37
38 pub fn from_env() -> Result<Self> {
48 let provider =
49 std::env::var("MOCKFORGE_AI_PROVIDER").unwrap_or_else(|_| "openai".to_string());
50
51 let provider = match provider.to_lowercase().as_str() {
52 "openai" => LlmProvider::OpenAI,
53 "anthropic" => LlmProvider::Anthropic,
54 "ollama" => LlmProvider::Ollama,
55 "openai-compatible" => LlmProvider::OpenAICompatible,
56 _ => {
57 warn!("Unknown AI provider '{}', defaulting to OpenAI", provider);
58 LlmProvider::OpenAI
59 }
60 };
61
62 let api_key = std::env::var("MOCKFORGE_AI_API_KEY").ok();
63
64 let model = std::env::var("MOCKFORGE_AI_MODEL").unwrap_or_else(|_| match provider {
65 LlmProvider::OpenAI => "gpt-3.5-turbo".to_string(),
66 LlmProvider::Anthropic => "claude-3-haiku-20240307".to_string(),
67 LlmProvider::Ollama => "llama2".to_string(),
68 LlmProvider::OpenAICompatible => "gpt-3.5-turbo".to_string(),
69 });
70
71 let api_endpoint =
72 std::env::var("MOCKFORGE_AI_ENDPOINT").unwrap_or_else(|_| match provider {
73 LlmProvider::OpenAI => "https://api.openai.com/v1/chat/completions".to_string(),
74 LlmProvider::Anthropic => "https://api.anthropic.com/v1/messages".to_string(),
75 LlmProvider::Ollama => "http://localhost:11434/api/generate".to_string(),
76 LlmProvider::OpenAICompatible => {
77 "http://localhost:8080/v1/chat/completions".to_string()
78 }
79 });
80
81 let temperature = std::env::var("MOCKFORGE_AI_TEMPERATURE")
82 .ok()
83 .and_then(|s| s.parse::<f64>().ok())
84 .unwrap_or(0.7);
85
86 let max_tokens = std::env::var("MOCKFORGE_AI_MAX_TOKENS")
87 .ok()
88 .and_then(|s| s.parse::<usize>().ok())
89 .unwrap_or(1024);
90
91 let config = RagConfig {
92 provider,
93 api_key,
94 model,
95 api_endpoint,
96 temperature,
97 max_tokens,
98 ..Default::default()
99 };
100
101 debug!("Creating RAG AI generator from environment variables");
102 Self::new(config)
103 }
104}
105
106#[async_trait]
107impl AiGenerator for RagAiGenerator {
108 async fn generate(&self, prompt: &str, config: &AiResponseConfig) -> Result<Value> {
109 debug!("Generating AI response with RAG engine");
110
111 let mut engine = self.engine.write().await;
113
114 let mut engine_config = engine.config().clone();
116 engine_config.temperature = config.temperature as f64;
117 engine_config.max_tokens = config.max_tokens;
118
119 engine.update_config(engine_config);
121
122 match engine.generate_text(prompt).await {
124 Ok(response_text) => {
125 debug!("RAG engine generated response ({} chars)", response_text.len());
126
127 match serde_json::from_str::<Value>(&response_text) {
129 Ok(json_value) => Ok(json_value),
130 Err(_) => {
131 if let Some(start) = response_text.find('{') {
133 if let Some(end) = response_text.rfind('}') {
134 let json_str = &response_text[start..=end];
135 match serde_json::from_str::<Value>(json_str) {
136 Ok(json_value) => Ok(json_value),
137 Err(_) => {
138 Ok(serde_json::json!({
140 "response": response_text,
141 "note": "Response was not valid JSON, wrapped in object"
142 }))
143 }
144 }
145 } else {
146 Ok(serde_json::json!({
147 "response": response_text,
148 "note": "Response was not valid JSON, wrapped in object"
149 }))
150 }
151 } else {
152 Ok(serde_json::json!({
153 "response": response_text,
154 "note": "Response was not valid JSON, wrapped in object"
155 }))
156 }
157 }
158 }
159 }
160 Err(e) => {
161 warn!("RAG engine generation failed: {}", e);
162 Err(e)
163 }
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_rag_generator_creation() {
174 let config = RagConfig {
175 provider: LlmProvider::Ollama,
176 api_key: None,
177 model: "llama2".to_string(),
178 api_endpoint: "http://localhost:11434/api/generate".to_string(),
179 ..Default::default()
180 };
181
182 let result = RagAiGenerator::new(config);
183 assert!(result.is_ok());
184 }
185
186 #[tokio::test]
187 async fn test_generate_fallback_to_json() {
188 let config = RagConfig {
192 provider: LlmProvider::Ollama,
193 api_key: None,
194 model: "test-model".to_string(),
195 api_endpoint: "http://localhost:11434/api/generate".to_string(),
196 ..Default::default()
197 };
198
199 let generator = RagAiGenerator::new(config);
202 assert!(generator.is_ok());
203 }
204}