mockforge_data/
rag.rs

1//! RAG (Retrieval-Augmented Generation) for enhanced data synthesis
2//!
3//! This module has been refactored into sub-modules for better organization:
4//! - config: RAG configuration and settings management
5//! - engine: Core RAG engine and retrieval logic
6//! - providers: LLM and embedding provider integrations
7//! - storage: Document storage and vector indexing
8//! - utils: Utility functions and helpers for RAG operations
9
10// Re-export sub-modules for backward compatibility
11pub mod config;
12pub mod engine;
13pub mod providers;
14pub mod storage;
15pub mod utils;
16
17// Re-export commonly used types
18pub use config::*;
19pub use providers::*;
20pub use utils::*;
21
22// Re-export engine and storage types with explicit names to avoid conflicts
23pub use engine::StorageStats as EngineStorageStats;
24pub use storage::StorageStats as StorageStorageStats;
25
26// Legacy imports for compatibility
27use crate::{schema::SchemaDefinition, DataConfig};
28use mockforge_core::Result;
29use reqwest::{Client, ClientBuilder};
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::cmp::Ordering;
33use std::collections::HashMap;
34use std::time::Duration;
35use tokio::time::sleep;
36use tracing::{debug, warn};
37
38/// Supported LLM providers
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
40#[serde(rename_all = "lowercase")]
41pub enum LlmProvider {
42    /// OpenAI GPT models
43    OpenAI,
44    /// Anthropic Claude models
45    Anthropic,
46    /// Generic OpenAI-compatible API
47    OpenAICompatible,
48    /// Local Ollama instance
49    Ollama,
50}
51
52/// Supported embedding providers
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
54#[serde(rename_all = "lowercase")]
55pub enum EmbeddingProvider {
56    /// OpenAI text-embedding-ada-002
57    OpenAI,
58    /// Generic OpenAI-compatible embeddings API
59    OpenAICompatible,
60}
61
62/// RAG configuration
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct RagConfig {
65    /// LLM provider
66    pub provider: LlmProvider,
67    /// LLM API endpoint
68    pub api_endpoint: String,
69    /// API key for authentication
70    pub api_key: Option<String>,
71    /// Model name to use
72    pub model: String,
73    /// Maximum tokens for generation
74    pub max_tokens: usize,
75    /// Temperature for generation
76    pub temperature: f64,
77    /// Context window size
78    pub context_window: usize,
79
80    /// Whether to use semantic search instead of keyword search
81    pub semantic_search_enabled: bool,
82    /// Embedding provider for semantic search
83    pub embedding_provider: EmbeddingProvider,
84    /// Embedding model to use
85    pub embedding_model: String,
86    /// Embedding API endpoint (if different from LLM endpoint)
87    pub embedding_endpoint: Option<String>,
88    /// Similarity threshold for semantic search (0.0 to 1.0)
89    pub similarity_threshold: f64,
90    /// Maximum number of chunks to retrieve for semantic search
91    pub max_chunks: usize,
92
93    /// Request timeout in seconds
94    pub request_timeout_seconds: u64,
95    /// Maximum number of retries for failed requests
96    pub max_retries: usize,
97}
98
99impl Default for RagConfig {
100    fn default() -> Self {
101        Self {
102            provider: LlmProvider::OpenAI,
103            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
104            api_key: None,
105            model: "gpt-3.5-turbo".to_string(),
106            max_tokens: 1000,
107            temperature: 0.7,
108            context_window: 4000,
109            semantic_search_enabled: true,
110            embedding_provider: EmbeddingProvider::OpenAI,
111            embedding_model: "text-embedding-ada-002".to_string(),
112            embedding_endpoint: None,
113            similarity_threshold: 0.7,
114            max_chunks: 5,
115            request_timeout_seconds: 30,
116            max_retries: 3,
117        }
118    }
119}
120
121/// Document chunk for RAG indexing
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct DocumentChunk {
124    /// Chunk ID
125    pub id: String,
126    /// Content text
127    pub content: String,
128    /// Metadata
129    pub metadata: HashMap<String, Value>,
130    /// Embedding vector for semantic search
131    pub embedding: Vec<f32>,
132}
133
134/// Search result with similarity score
135#[derive(Debug)]
136pub struct SearchResult<'a> {
137    /// The document chunk
138    pub chunk: &'a DocumentChunk,
139    /// Similarity score (0.0 to 1.0)
140    pub score: f64,
141}
142
143/// RAG engine for enhanced data generation
144#[derive(Debug)]
145pub struct RagEngine {
146    /// Configuration
147    config: RagConfig,
148    /// Document chunks for retrieval
149    chunks: Vec<DocumentChunk>,
150    /// Schema knowledge base
151    schema_kb: HashMap<String, Vec<String>>,
152    /// HTTP client for LLM API calls
153    client: Client,
154}
155
156impl RagEngine {
157    /// Create a new RAG engine
158    pub fn new(config: RagConfig) -> Self {
159        let client = ClientBuilder::new()
160            .timeout(Duration::from_secs(config.request_timeout_seconds))
161            .build()
162            .unwrap_or_else(|e| {
163                warn!("Failed to create HTTP client with timeout, using default: {}", e);
164                Client::new()
165            });
166
167        Self {
168            config,
169            chunks: Vec::new(),
170            schema_kb: HashMap::new(),
171            client,
172        }
173    }
174
175    /// Add a document to the knowledge base
176    pub fn add_document(
177        &mut self,
178        content: String,
179        metadata: HashMap<String, Value>,
180    ) -> Result<String> {
181        let id = format!("chunk_{}", self.chunks.len());
182        let chunk = DocumentChunk {
183            id: id.clone(),
184            content,
185            metadata,
186            embedding: Vec::new(), // Would compute embedding here
187        };
188
189        self.chunks.push(chunk);
190        Ok(id)
191    }
192
193    /// Add schema information to knowledge base
194    pub fn add_schema(&mut self, schema: &SchemaDefinition) -> Result<()> {
195        let mut schema_info = Vec::new();
196
197        schema_info.push(format!("Schema: {}", schema.name));
198
199        if let Some(description) = &schema.description {
200            schema_info.push(format!("Description: {}", description));
201        }
202
203        for field in &schema.fields {
204            let mut field_info = format!(
205                "Field '{}': type={}, required={}",
206                field.name, field.field_type, field.required
207            );
208
209            if let Some(description) = &field.description {
210                field_info.push_str(&format!(" - {}", description));
211            }
212
213            schema_info.push(field_info);
214        }
215
216        for (rel_name, relationship) in &schema.relationships {
217            schema_info.push(format!(
218                "Relationship '{}': {} -> {} ({:?})",
219                rel_name, schema.name, relationship.target_schema, relationship.relationship_type
220            ));
221        }
222
223        self.schema_kb.insert(schema.name.clone(), schema_info);
224        Ok(())
225    }
226
227    /// Generate data using RAG-augmented prompts
228    pub async fn generate_with_rag(
229        &self,
230        schema: &SchemaDefinition,
231        config: &DataConfig,
232    ) -> Result<Vec<Value>> {
233        if !config.rag_enabled {
234            return Err(mockforge_core::Error::generic("RAG is not enabled in config"));
235        }
236
237        // Validate RAG configuration before proceeding
238        if self.config.api_key.is_none() {
239            return Err(mockforge_core::Error::generic(
240                "RAG is enabled but no API key is configured. Please set MOCKFORGE_RAG_API_KEY or provide --rag-api-key"
241            ));
242        }
243
244        let mut results = Vec::new();
245        let mut failed_rows = 0;
246
247        // Generate prompts for each row
248        for i in 0..config.rows {
249            match self.generate_single_row_with_rag(schema, i).await {
250                Ok(data) => results.push(data),
251                Err(e) => {
252                    failed_rows += 1;
253                    warn!("Failed to generate RAG data for row {}: {}", i, e);
254
255                    // If too many rows fail, return an error
256                    if failed_rows > config.rows / 4 {
257                        // Allow up to 25% failure rate
258                        return Err(mockforge_core::Error::generic(
259                            format!("Too many RAG generation failures ({} out of {} rows failed). Check API configuration and network connectivity.", failed_rows, config.rows)
260                        ));
261                    }
262
263                    // For failed rows, generate fallback data
264                    let fallback_data = self.generate_fallback_data(schema);
265                    results.push(fallback_data);
266                }
267            }
268        }
269
270        if failed_rows > 0 {
271            warn!(
272                "RAG generation completed with {} failed rows out of {}",
273                failed_rows, config.rows
274            );
275        }
276
277        Ok(results)
278    }
279
280    /// Generate a single row using RAG
281    async fn generate_single_row_with_rag(
282        &self,
283        schema: &SchemaDefinition,
284        row_index: usize,
285    ) -> Result<Value> {
286        let prompt = self.build_generation_prompt(schema, row_index).await?;
287        let generated_data = self.call_llm(&prompt).await?;
288        self.parse_llm_response(&generated_data)
289    }
290
291    /// Generate fallback data when RAG fails
292    fn generate_fallback_data(&self, schema: &SchemaDefinition) -> Value {
293        let mut obj = serde_json::Map::new();
294
295        for field in &schema.fields {
296            let value = match field.field_type.as_str() {
297                "string" => Value::String("sample_data".to_string()),
298                "integer" | "number" => Value::Number(42.into()),
299                "boolean" => Value::Bool(true),
300                _ => Value::String("sample_data".to_string()),
301            };
302            obj.insert(field.name.clone(), value);
303        }
304
305        Value::Object(obj)
306    }
307
308    /// Build a generation prompt with retrieved context
309    async fn build_generation_prompt(
310        &self,
311        schema: &SchemaDefinition,
312        _row_index: usize,
313    ) -> Result<String> {
314        let mut prompt =
315            format!("Generate a single row of data for the '{}' schema.\n\n", schema.name);
316
317        // Add schema information
318        if let Some(schema_info) = self.schema_kb.get(&schema.name) {
319            prompt.push_str("Schema Information:\n");
320            for info in schema_info {
321                prompt.push_str(&format!("- {}\n", info));
322            }
323            prompt.push('\n');
324        }
325
326        // Retrieve relevant context from documents
327        let relevant_chunks = self.retrieve_relevant_chunks(&schema.name, 3).await?;
328        if !relevant_chunks.is_empty() {
329            prompt.push_str("Relevant Context:\n");
330            for chunk in relevant_chunks {
331                prompt.push_str(&format!("- {}\n", chunk.content));
332            }
333            prompt.push('\n');
334        }
335
336        // Add generation instructions
337        prompt.push_str("Instructions:\n");
338        prompt.push_str("- Generate realistic data that matches the schema\n");
339        prompt.push_str("- Ensure all required fields are present\n");
340        prompt.push_str("- Use appropriate data types and formats\n");
341        prompt.push_str("- Make relationships consistent if referenced\n");
342        prompt.push_str("- Output only valid JSON for a single object\n\n");
343
344        prompt.push_str("Generate the data:");
345
346        Ok(prompt)
347    }
348
349    /// Retrieve relevant document chunks using semantic search or keyword search
350    async fn retrieve_relevant_chunks(
351        &self,
352        query: &str,
353        limit: usize,
354    ) -> Result<Vec<&DocumentChunk>> {
355        if self.config.semantic_search_enabled {
356            // Use semantic search
357            let results = self.semantic_search(query, limit).await?;
358            Ok(results.into_iter().map(|r| r.chunk).collect())
359        } else {
360            // Fall back to keyword search
361            Ok(self.keyword_search(query, limit))
362        }
363    }
364
365    /// Perform keyword-based search (fallback)
366    pub fn keyword_search(&self, query: &str, limit: usize) -> Vec<&DocumentChunk> {
367        self.chunks
368            .iter()
369            .filter(|chunk| {
370                chunk.content.to_lowercase().contains(&query.to_lowercase())
371                    || chunk.metadata.values().any(|v| {
372                        if let Some(s) = v.as_str() {
373                            s.to_lowercase().contains(&query.to_lowercase())
374                        } else {
375                            false
376                        }
377                    })
378            })
379            .take(limit)
380            .collect()
381    }
382
383    /// Perform semantic search using embeddings
384    async fn semantic_search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult<'_>>> {
385        // Generate embedding for the query
386        let query_embedding = self.generate_embedding(query).await?;
387
388        // Calculate similarity scores for all chunks
389        let mut results: Vec<SearchResult> = Vec::new();
390
391        for chunk in &self.chunks {
392            if chunk.embedding.is_empty() {
393                // Skip chunks without embeddings
394                continue;
395            }
396
397            let score = Self::cosine_similarity(&query_embedding, &chunk.embedding);
398            if score >= self.config.similarity_threshold {
399                results.push(SearchResult { chunk, score });
400            }
401        }
402
403        // Sort by similarity score (descending) and take top results
404        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
405        results.truncate(limit);
406
407        Ok(results)
408    }
409
410    /// Generate embedding for text using the configured embedding provider
411    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
412        match &self.config.embedding_provider {
413            EmbeddingProvider::OpenAI => self.generate_openai_embedding(text).await,
414            EmbeddingProvider::OpenAICompatible => {
415                self.generate_openai_compatible_embedding(text).await
416            }
417        }
418    }
419
420    /// Generate embedding using OpenAI API
421    async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
422        let api_key = self
423            .config
424            .api_key
425            .as_ref()
426            .ok_or_else(|| mockforge_core::Error::generic("OpenAI API key not configured"))?;
427
428        let endpoint = self
429            .config
430            .embedding_endpoint
431            .as_ref()
432            .unwrap_or(&self.config.api_endpoint)
433            .replace("chat/completions", "embeddings");
434
435        let request_body = serde_json::json!({
436            "model": self.config.embedding_model,
437            "input": text
438        });
439
440        debug!("Generating embedding for text with OpenAI API");
441
442        let response = self
443            .client
444            .post(&endpoint)
445            .header("Authorization", format!("Bearer {}", api_key))
446            .header("Content-Type", "application/json")
447            .json(&request_body)
448            .send()
449            .await
450            .map_err(|e| {
451                mockforge_core::Error::generic(format!("Embedding API request failed: {}", e))
452            })?;
453
454        if !response.status().is_success() {
455            let error_text = response.text().await.unwrap_or_default();
456            return Err(mockforge_core::Error::generic(format!(
457                "Embedding API error: {}",
458                error_text
459            )));
460        }
461
462        let response_json: Value = response.json().await.map_err(|e| {
463            mockforge_core::Error::generic(format!("Failed to parse embedding response: {}", e))
464        })?;
465
466        if let Some(data) = response_json.get("data").and_then(|d| d.as_array()) {
467            if let Some(first_item) = data.first() {
468                if let Some(embedding) = first_item.get("embedding").and_then(|e| e.as_array()) {
469                    let embedding_vec: Vec<f32> =
470                        embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
471                    return Ok(embedding_vec);
472                }
473            }
474        }
475
476        Err(mockforge_core::Error::generic("Invalid embedding response format"))
477    }
478
479    /// Generate embedding using OpenAI-compatible API
480    async fn generate_openai_compatible_embedding(&self, text: &str) -> Result<Vec<f32>> {
481        let endpoint = self
482            .config
483            .embedding_endpoint
484            .as_ref()
485            .unwrap_or(&self.config.api_endpoint)
486            .replace("chat/completions", "embeddings");
487
488        let request_body = serde_json::json!({
489            "model": self.config.embedding_model,
490            "input": text
491        });
492
493        debug!("Generating embedding for text with OpenAI-compatible API");
494
495        let mut request = self
496            .client
497            .post(&endpoint)
498            .header("Content-Type", "application/json")
499            .json(&request_body);
500
501        if let Some(api_key) = &self.config.api_key {
502            request = request.header("Authorization", format!("Bearer {}", api_key));
503        }
504
505        let response = request.send().await.map_err(|e| {
506            mockforge_core::Error::generic(format!("Embedding API request failed: {}", e))
507        })?;
508
509        if !response.status().is_success() {
510            let error_text = response.text().await.unwrap_or_default();
511            return Err(mockforge_core::Error::generic(format!(
512                "Embedding API error: {}",
513                error_text
514            )));
515        }
516
517        let response_json: Value = response.json().await.map_err(|e| {
518            mockforge_core::Error::generic(format!("Failed to parse embedding response: {}", e))
519        })?;
520
521        if let Some(data) = response_json.get("data").and_then(|d| d.as_array()) {
522            if let Some(first_item) = data.first() {
523                if let Some(embedding) = first_item.get("embedding").and_then(|e| e.as_array()) {
524                    let embedding_vec: Vec<f32> =
525                        embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
526                    return Ok(embedding_vec);
527                }
528            }
529        }
530
531        Err(mockforge_core::Error::generic("Invalid embedding response format"))
532    }
533
534    /// Compute embeddings for all document chunks
535    pub async fn compute_embeddings(&mut self) -> Result<()> {
536        debug!("Computing embeddings for {} chunks", self.chunks.len());
537
538        // Collect chunks that need embeddings
539        let chunks_to_embed: Vec<(usize, String)> = self
540            .chunks
541            .iter()
542            .enumerate()
543            .filter(|(_, chunk)| chunk.embedding.is_empty())
544            .map(|(idx, chunk)| (idx, chunk.content.clone()))
545            .collect();
546
547        // Generate embeddings for chunks that need them
548        for (idx, content) in chunks_to_embed {
549            let embedding = self.generate_embedding(&content).await?;
550            self.chunks[idx].embedding = embedding;
551            debug!("Computed embedding for chunk {}", self.chunks[idx].id);
552        }
553
554        Ok(())
555    }
556
557    /// Calculate cosine similarity between two vectors
558    fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
559        if a.len() != b.len() || a.is_empty() {
560            return 0.0;
561        }
562
563        let mut dot_product = 0.0;
564        let mut norm_a = 0.0;
565        let mut norm_b = 0.0;
566
567        for i in 0..a.len() {
568            dot_product += a[i] as f64 * b[i] as f64;
569            norm_a += (a[i] as f64).powi(2);
570            norm_b += (b[i] as f64).powi(2);
571        }
572
573        norm_a = norm_a.sqrt();
574        norm_b = norm_b.sqrt();
575
576        if norm_a == 0.0 || norm_b == 0.0 {
577            0.0
578        } else {
579            dot_product / (norm_a * norm_b)
580        }
581    }
582
583    /// Call LLM API with provider-specific implementation and retry logic
584    async fn call_llm(&self, prompt: &str) -> Result<String> {
585        let mut last_error = None;
586
587        for attempt in 0..=self.config.max_retries {
588            match self.call_llm_single_attempt(prompt).await {
589                Ok(result) => return Ok(result),
590                Err(e) => {
591                    last_error = Some(e);
592                    if attempt < self.config.max_retries {
593                        let delay = Duration::from_millis(500 * (attempt + 1) as u64);
594                        warn!(
595                            "LLM API call failed (attempt {}), retrying in {:?}: {:?}",
596                            attempt + 1,
597                            delay,
598                            last_error
599                        );
600                        sleep(delay).await;
601                    }
602                }
603            }
604        }
605
606        Err(last_error
607            .unwrap_or_else(|| mockforge_core::Error::generic("All LLM API retry attempts failed")))
608    }
609
610    /// Single attempt to call LLM API with provider-specific implementation
611    async fn call_llm_single_attempt(&self, prompt: &str) -> Result<String> {
612        match &self.config.provider {
613            LlmProvider::OpenAI => self.call_openai(prompt).await,
614            LlmProvider::Anthropic => self.call_anthropic(prompt).await,
615            LlmProvider::OpenAICompatible => self.call_openai_compatible(prompt).await,
616            LlmProvider::Ollama => self.call_ollama(prompt).await,
617        }
618    }
619
620    /// Call OpenAI API
621    async fn call_openai(&self, prompt: &str) -> Result<String> {
622        let api_key = self
623            .config
624            .api_key
625            .as_ref()
626            .ok_or_else(|| mockforge_core::Error::generic("OpenAI API key not configured"))?;
627
628        let request_body = serde_json::json!({
629            "model": self.config.model,
630            "messages": [
631                {
632                    "role": "user",
633                    "content": prompt
634                }
635            ],
636            "max_tokens": self.config.max_tokens,
637            "temperature": self.config.temperature
638        });
639
640        debug!("Calling OpenAI API with model: {}", self.config.model);
641
642        let response = self
643            .client
644            .post(&self.config.api_endpoint)
645            .header("Authorization", format!("Bearer {}", api_key))
646            .header("Content-Type", "application/json")
647            .json(&request_body)
648            .send()
649            .await
650            .map_err(|e| {
651                mockforge_core::Error::generic(format!("OpenAI API request failed: {}", e))
652            })?;
653
654        if !response.status().is_success() {
655            let error_text = response.text().await.unwrap_or_default();
656            return Err(mockforge_core::Error::generic(format!(
657                "OpenAI API error: {}",
658                error_text
659            )));
660        }
661
662        let response_json: Value = response.json().await.map_err(|e| {
663            mockforge_core::Error::generic(format!("Failed to parse OpenAI response: {}", e))
664        })?;
665
666        if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
667            if let Some(choice) = choices.first() {
668                if let Some(message) = choice.get("message").and_then(|m| m.get("content")) {
669                    if let Some(content) = message.as_str() {
670                        return Ok(content.to_string());
671                    }
672                }
673            }
674        }
675
676        Err(mockforge_core::Error::generic("Invalid OpenAI response format"))
677    }
678
679    /// Call Anthropic API
680    async fn call_anthropic(&self, prompt: &str) -> Result<String> {
681        let api_key =
682            self.config.api_key.as_ref().ok_or_else(|| {
683                mockforge_core::Error::generic("Anthropic API key not configured")
684            })?;
685
686        let request_body = serde_json::json!({
687            "model": self.config.model,
688            "max_tokens": self.config.max_tokens,
689            "temperature": self.config.temperature,
690            "messages": [
691                {
692                    "role": "user",
693                    "content": prompt
694                }
695            ]
696        });
697
698        debug!("Calling Anthropic API with model: {}", self.config.model);
699
700        let response = self
701            .client
702            .post(&self.config.api_endpoint)
703            .header("x-api-key", api_key)
704            .header("Content-Type", "application/json")
705            .header("anthropic-version", "2023-06-01")
706            .json(&request_body)
707            .send()
708            .await
709            .map_err(|e| {
710                mockforge_core::Error::generic(format!("Anthropic API request failed: {}", e))
711            })?;
712
713        if !response.status().is_success() {
714            let error_text = response.text().await.unwrap_or_default();
715            return Err(mockforge_core::Error::generic(format!(
716                "Anthropic API error: {}",
717                error_text
718            )));
719        }
720
721        let response_json: Value = response.json().await.map_err(|e| {
722            mockforge_core::Error::generic(format!("Failed to parse Anthropic response: {}", e))
723        })?;
724
725        if let Some(content) = response_json.get("content") {
726            if let Some(content_array) = content.as_array() {
727                if let Some(first_content) = content_array.first() {
728                    if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
729                        return Ok(text.to_string());
730                    }
731                }
732            }
733        }
734
735        Err(mockforge_core::Error::generic("Invalid Anthropic response format"))
736    }
737
738    /// Call OpenAI-compatible API
739    async fn call_openai_compatible(&self, prompt: &str) -> Result<String> {
740        let request_body = serde_json::json!({
741            "model": self.config.model,
742            "messages": [
743                {
744                    "role": "user",
745                    "content": prompt
746                }
747            ],
748            "max_tokens": self.config.max_tokens,
749            "temperature": self.config.temperature
750        });
751
752        debug!("Calling OpenAI-compatible API with model: {}", self.config.model);
753
754        let mut request = self
755            .client
756            .post(&self.config.api_endpoint)
757            .header("Content-Type", "application/json")
758            .json(&request_body);
759
760        if let Some(api_key) = &self.config.api_key {
761            request = request.header("Authorization", format!("Bearer {}", api_key));
762        }
763
764        let response = request.send().await.map_err(|e| {
765            mockforge_core::Error::generic(format!("OpenAI-compatible API request failed: {}", e))
766        })?;
767
768        if !response.status().is_success() {
769            let error_text = response.text().await.unwrap_or_default();
770            return Err(mockforge_core::Error::generic(format!(
771                "OpenAI-compatible API error: {}",
772                error_text
773            )));
774        }
775
776        let response_json: Value = response.json().await.map_err(|e| {
777            mockforge_core::Error::generic(format!(
778                "Failed to parse OpenAI-compatible response: {}",
779                e
780            ))
781        })?;
782
783        if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
784            if let Some(choice) = choices.first() {
785                if let Some(message) = choice.get("message").and_then(|m| m.get("content")) {
786                    if let Some(content) = message.as_str() {
787                        return Ok(content.to_string());
788                    }
789                }
790            }
791        }
792
793        Err(mockforge_core::Error::generic("Invalid OpenAI-compatible response format"))
794    }
795
796    /// Call Ollama API
797    async fn call_ollama(&self, prompt: &str) -> Result<String> {
798        let request_body = serde_json::json!({
799            "model": self.config.model,
800            "prompt": prompt,
801            "stream": false
802        });
803
804        debug!("Calling Ollama API with model: {}", self.config.model);
805
806        let response = self
807            .client
808            .post(&self.config.api_endpoint)
809            .header("Content-Type", "application/json")
810            .json(&request_body)
811            .send()
812            .await
813            .map_err(|e| {
814                mockforge_core::Error::generic(format!("Ollama API request failed: {}", e))
815            })?;
816
817        if !response.status().is_success() {
818            let error_text = response.text().await.unwrap_or_default();
819            return Err(mockforge_core::Error::generic(format!(
820                "Ollama API error: {}",
821                error_text
822            )));
823        }
824
825        let response_json: Value = response.json().await.map_err(|e| {
826            mockforge_core::Error::generic(format!("Failed to parse Ollama response: {}", e))
827        })?;
828
829        if let Some(response_text) = response_json.get("response").and_then(|r| r.as_str()) {
830            return Ok(response_text.to_string());
831        }
832
833        Err(mockforge_core::Error::generic("Invalid Ollama response format"))
834    }
835
836    /// Parse LLM response into structured data
837    fn parse_llm_response(&self, response: &str) -> Result<Value> {
838        // Try to parse as JSON
839        match serde_json::from_str(response) {
840            Ok(value) => Ok(value),
841            Err(e) => {
842                // If direct parsing fails, try to extract JSON from the response
843                if let Some(start) = response.find('{') {
844                    if let Some(end) = response.rfind('}') {
845                        let json_str = &response[start..=end];
846                        match serde_json::from_str(json_str) {
847                            Ok(value) => Ok(value),
848                            Err(_) => Err(mockforge_core::Error::generic(format!(
849                                "Failed to parse LLM response: {}",
850                                e
851                            ))),
852                        }
853                    } else {
854                        Err(mockforge_core::Error::generic(format!(
855                            "No closing brace found in response: {}",
856                            e
857                        )))
858                    }
859                } else {
860                    Err(mockforge_core::Error::generic(format!("No JSON found in response: {}", e)))
861                }
862            }
863        }
864    }
865
866    /// Update RAG configuration
867    pub fn update_config(&mut self, config: RagConfig) {
868        self.config = config;
869    }
870
871    /// Get current configuration
872    pub fn config(&self) -> &RagConfig {
873        &self.config
874    }
875
876    /// Get number of indexed chunks
877    pub fn chunk_count(&self) -> usize {
878        self.chunks.len()
879    }
880
881    /// Get number of indexed schemas
882    pub fn schema_count(&self) -> usize {
883        self.schema_kb.len()
884    }
885
886    /// Get chunk by index
887    pub fn get_chunk(&self, index: usize) -> Option<&DocumentChunk> {
888        self.chunks.get(index)
889    }
890
891    /// Check if schema exists in knowledge base
892    pub fn has_schema(&self, name: &str) -> bool {
893        self.schema_kb.contains_key(name)
894    }
895
896    /// Generate text using LLM (for intelligent mock generation)
897    pub async fn generate_text(&self, prompt: &str) -> Result<String> {
898        self.call_llm(prompt).await
899    }
900}
901
902impl Default for RagEngine {
903    fn default() -> Self {
904        Self::new(RagConfig::default())
905    }
906}
907
908/// RAG-enhanced data generation utilities
909pub mod rag_utils {
910    use super::*;
911
912    /// Create a RAG engine with common business domain knowledge
913    pub fn create_business_rag_engine() -> Result<RagEngine> {
914        let mut engine = RagEngine::default();
915
916        // Add common business knowledge
917        engine.add_document(
918            "Customer data typically includes personal information like name, email, phone, and address. Customers usually have unique identifiers and account creation dates.".to_string(),
919            HashMap::from([
920                ("domain".to_string(), Value::String("customer".to_string())),
921                ("type".to_string(), Value::String("general".to_string())),
922            ]),
923        )?;
924
925        engine.add_document(
926            "Product information includes name, description, price, category, and stock status. Products should have unique SKUs or IDs.".to_string(),
927            HashMap::from([
928                ("domain".to_string(), Value::String("product".to_string())),
929                ("type".to_string(), Value::String("general".to_string())),
930            ]),
931        )?;
932
933        engine.add_document(
934            "Order data contains customer references, product lists, total amounts, status, and timestamps. Orders should maintain referential integrity with customers and products.".to_string(),
935            HashMap::from([
936                ("domain".to_string(), Value::String("order".to_string())),
937                ("type".to_string(), Value::String("general".to_string())),
938            ]),
939        )?;
940
941        Ok(engine)
942    }
943
944    /// Create a RAG engine with technical domain knowledge
945    pub fn create_technical_rag_engine() -> Result<RagEngine> {
946        let mut engine = RagEngine::default();
947
948        // Add technical knowledge
949        engine.add_document(
950            "API endpoints should follow RESTful conventions with proper HTTP methods. GET for retrieval, POST for creation, PUT for updates, DELETE for removal.".to_string(),
951            HashMap::from([
952                ("domain".to_string(), Value::String("api".to_string())),
953                ("type".to_string(), Value::String("technical".to_string())),
954            ]),
955        )?;
956
957        engine.add_document(
958            "Database records typically have auto-incrementing primary keys, created_at and updated_at timestamps, and foreign key relationships.".to_string(),
959            HashMap::from([
960                ("domain".to_string(), Value::String("database".to_string())),
961                ("type".to_string(), Value::String("technical".to_string())),
962            ]),
963        )?;
964
965        Ok(engine)
966    }
967}
968#[cfg(test)]
969mod tests {
970    use super::*;
971
972    #[test]
973    fn test_llm_provider_variants() {
974        let openai = LlmProvider::OpenAI;
975        let anthropic = LlmProvider::Anthropic;
976        let compatible = LlmProvider::OpenAICompatible;
977        let ollama = LlmProvider::Ollama;
978
979        assert!(matches!(openai, LlmProvider::OpenAI));
980        assert!(matches!(anthropic, LlmProvider::Anthropic));
981        assert!(matches!(compatible, LlmProvider::OpenAICompatible));
982        assert!(matches!(ollama, LlmProvider::Ollama));
983    }
984
985    #[test]
986    fn test_embedding_provider_variants() {
987        let openai = EmbeddingProvider::OpenAI;
988        let compatible = EmbeddingProvider::OpenAICompatible;
989
990        assert!(matches!(openai, EmbeddingProvider::OpenAI));
991        assert!(matches!(compatible, EmbeddingProvider::OpenAICompatible));
992    }
993
994    #[test]
995    fn test_rag_config_default() {
996        let config = RagConfig::default();
997
998        assert!(config.max_tokens > 0);
999        assert!(config.temperature >= 0.0 && config.temperature <= 1.0);
1000        assert!(config.context_window > 0);
1001    }
1002}