mockforge_data/
rag.rs

1//! RAG (Retrieval-Augmented Generation) for enhanced data synthesis
2//!
3//! This module has been refactored into sub-modules for better organization:
4//! - config: RAG configuration and settings management
5//! - engine: Core RAG engine and retrieval logic
6//! - providers: LLM and embedding provider integrations
7//! - storage: Document storage and vector indexing
8//! - utils: Utility functions and helpers for RAG operations
9
10// Re-export sub-modules for backward compatibility
11pub mod config;
12pub mod engine;
13pub mod providers;
14pub mod storage;
15pub mod utils;
16
17// Re-export commonly used types
18pub use config::*;
19pub use providers::*;
20pub use utils::*;
21
22// Re-export engine and storage types with explicit names to avoid conflicts
23pub use engine::StorageStats as EngineStorageStats;
24pub use storage::StorageStats as StorageStorageStats;
25
26// Legacy imports for compatibility
27use crate::Result;
28use crate::{schema::SchemaDefinition, DataConfig};
29use reqwest::{Client, ClientBuilder};
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::cmp::Ordering;
33use std::collections::HashMap;
34use std::time::Duration;
35use tokio::time::sleep;
36use tracing::{debug, warn};
37
38/// Supported LLM providers
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
40#[serde(rename_all = "lowercase")]
41pub enum LlmProvider {
42    /// OpenAI GPT models
43    OpenAI,
44    /// Anthropic Claude models
45    Anthropic,
46    /// Generic OpenAI-compatible API
47    OpenAICompatible,
48    /// Local Ollama instance
49    Ollama,
50}
51
52/// Supported embedding providers
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
54#[serde(rename_all = "lowercase")]
55pub enum EmbeddingProvider {
56    /// OpenAI text-embedding-ada-002
57    OpenAI,
58    /// Generic OpenAI-compatible embeddings API
59    OpenAICompatible,
60}
61
62/// RAG configuration
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct RagConfig {
65    /// LLM provider
66    pub provider: LlmProvider,
67    /// LLM API endpoint
68    pub api_endpoint: String,
69    /// API key for authentication
70    pub api_key: Option<String>,
71    /// Model name to use
72    pub model: String,
73    /// Maximum tokens for generation
74    pub max_tokens: usize,
75    /// Temperature for generation
76    pub temperature: f64,
77    /// Context window size
78    pub context_window: usize,
79
80    /// Whether to use semantic search instead of keyword search
81    pub semantic_search_enabled: bool,
82    /// Embedding provider for semantic search
83    pub embedding_provider: EmbeddingProvider,
84    /// Embedding model to use
85    pub embedding_model: String,
86    /// Embedding API endpoint (if different from LLM endpoint)
87    pub embedding_endpoint: Option<String>,
88    /// Similarity threshold for semantic search (0.0 to 1.0)
89    pub similarity_threshold: f64,
90    /// Maximum number of chunks to retrieve for semantic search
91    pub max_chunks: usize,
92
93    /// Request timeout in seconds
94    pub request_timeout_seconds: u64,
95    /// Maximum number of retries for failed requests
96    pub max_retries: usize,
97}
98
99impl Default for RagConfig {
100    fn default() -> Self {
101        Self {
102            provider: LlmProvider::OpenAI,
103            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
104            api_key: None,
105            model: "gpt-3.5-turbo".to_string(),
106            max_tokens: 1000,
107            temperature: 0.7,
108            context_window: 4000,
109            semantic_search_enabled: true,
110            embedding_provider: EmbeddingProvider::OpenAI,
111            embedding_model: "text-embedding-ada-002".to_string(),
112            embedding_endpoint: None,
113            similarity_threshold: 0.7,
114            max_chunks: 5,
115            request_timeout_seconds: 30,
116            max_retries: 3,
117        }
118    }
119}
120
121/// Document chunk for RAG indexing
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct DocumentChunk {
124    /// Chunk ID
125    pub id: String,
126    /// Content text
127    pub content: String,
128    /// Metadata
129    pub metadata: HashMap<String, Value>,
130    /// Embedding vector for semantic search
131    pub embedding: Vec<f32>,
132}
133
134/// Search result with similarity score
135#[derive(Debug)]
136pub struct SearchResult<'a> {
137    /// The document chunk
138    pub chunk: &'a DocumentChunk,
139    /// Similarity score (0.0 to 1.0)
140    pub score: f64,
141}
142
143/// RAG engine for enhanced data generation
144#[derive(Debug)]
145pub struct RagEngine {
146    /// Configuration
147    config: RagConfig,
148    /// Document chunks for retrieval
149    chunks: Vec<DocumentChunk>,
150    /// Schema knowledge base
151    schema_kb: HashMap<String, Vec<String>>,
152    /// HTTP client for LLM API calls
153    client: Client,
154}
155
156impl RagEngine {
157    /// Create a new RAG engine
158    pub fn new(config: RagConfig) -> Self {
159        let client = ClientBuilder::new()
160            .timeout(Duration::from_secs(config.request_timeout_seconds))
161            .build()
162            .unwrap_or_else(|e| {
163                warn!("Failed to create HTTP client with timeout, using default: {}", e);
164                Client::new()
165            });
166
167        Self {
168            config,
169            chunks: Vec::new(),
170            schema_kb: HashMap::new(),
171            client,
172        }
173    }
174
175    /// Add a document to the knowledge base
176    pub fn add_document(
177        &mut self,
178        content: String,
179        metadata: HashMap<String, Value>,
180    ) -> Result<String> {
181        let id = format!("chunk_{}", self.chunks.len());
182        let chunk = DocumentChunk {
183            id: id.clone(),
184            content,
185            metadata,
186            embedding: Vec::new(), // Would compute embedding here
187        };
188
189        self.chunks.push(chunk);
190        Ok(id)
191    }
192
193    /// Add schema information to knowledge base
194    pub fn add_schema(&mut self, schema: &SchemaDefinition) -> Result<()> {
195        let mut schema_info = Vec::new();
196
197        schema_info.push(format!("Schema: {}", schema.name));
198
199        if let Some(description) = &schema.description {
200            schema_info.push(format!("Description: {}", description));
201        }
202
203        for field in &schema.fields {
204            let mut field_info = format!(
205                "Field '{}': type={}, required={}",
206                field.name, field.field_type, field.required
207            );
208
209            if let Some(description) = &field.description {
210                field_info.push_str(&format!(" - {}", description));
211            }
212
213            schema_info.push(field_info);
214        }
215
216        for (rel_name, relationship) in &schema.relationships {
217            schema_info.push(format!(
218                "Relationship '{}': {} -> {} ({:?})",
219                rel_name, schema.name, relationship.target_schema, relationship.relationship_type
220            ));
221        }
222
223        self.schema_kb.insert(schema.name.clone(), schema_info);
224        Ok(())
225    }
226
227    /// Generate data using RAG-augmented prompts
228    pub async fn generate_with_rag(
229        &self,
230        schema: &SchemaDefinition,
231        config: &DataConfig,
232    ) -> Result<Vec<Value>> {
233        if !config.rag_enabled {
234            return Err(crate::Error::generic("RAG is not enabled in config"));
235        }
236
237        // Validate RAG configuration before proceeding
238        if self.config.api_key.is_none() {
239            return Err(crate::Error::generic(
240                "RAG is enabled but no API key is configured. Please set MOCKFORGE_RAG_API_KEY or provide --rag-api-key"
241            ));
242        }
243
244        let mut results = Vec::new();
245        let mut failed_rows = 0;
246
247        // Generate prompts for each row
248        for i in 0..config.rows {
249            match self.generate_single_row_with_rag(schema, i).await {
250                Ok(data) => results.push(data),
251                Err(e) => {
252                    failed_rows += 1;
253                    warn!("Failed to generate RAG data for row {}: {}", i, e);
254
255                    // If too many rows fail, return an error
256                    if failed_rows > config.rows / 4 {
257                        // Allow up to 25% failure rate
258                        return Err(crate::Error::generic(
259                            format!("Too many RAG generation failures ({} out of {} rows failed). Check API configuration and network connectivity.", failed_rows, config.rows)
260                        ));
261                    }
262
263                    // For failed rows, generate fallback data
264                    let fallback_data = self.generate_fallback_data(schema);
265                    results.push(fallback_data);
266                }
267            }
268        }
269
270        if failed_rows > 0 {
271            warn!(
272                "RAG generation completed with {} failed rows out of {}",
273                failed_rows, config.rows
274            );
275        }
276
277        Ok(results)
278    }
279
280    /// Generate a single row using RAG
281    async fn generate_single_row_with_rag(
282        &self,
283        schema: &SchemaDefinition,
284        row_index: usize,
285    ) -> Result<Value> {
286        let prompt = self.build_generation_prompt(schema, row_index).await?;
287        let generated_data = self.call_llm(&prompt).await?;
288        self.parse_llm_response(&generated_data)
289    }
290
291    /// Generate fallback data when RAG fails
292    fn generate_fallback_data(&self, schema: &SchemaDefinition) -> Value {
293        let mut obj = serde_json::Map::new();
294
295        for field in &schema.fields {
296            let value = match field.field_type.as_str() {
297                "string" => Value::String("sample_data".to_string()),
298                "integer" | "number" => Value::Number(42.into()),
299                "boolean" => Value::Bool(true),
300                _ => Value::String("sample_data".to_string()),
301            };
302            obj.insert(field.name.clone(), value);
303        }
304
305        Value::Object(obj)
306    }
307
308    /// Build a generation prompt with retrieved context
309    async fn build_generation_prompt(
310        &self,
311        schema: &SchemaDefinition,
312        _row_index: usize,
313    ) -> Result<String> {
314        let mut prompt =
315            format!("Generate a single row of data for the '{}' schema.\n\n", schema.name);
316
317        // Add schema information
318        if let Some(schema_info) = self.schema_kb.get(&schema.name) {
319            prompt.push_str("Schema Information:\n");
320            for info in schema_info {
321                prompt.push_str(&format!("- {}\n", info));
322            }
323            prompt.push('\n');
324        }
325
326        // Retrieve relevant context from documents
327        let relevant_chunks = self.retrieve_relevant_chunks(&schema.name, 3).await?;
328        if !relevant_chunks.is_empty() {
329            prompt.push_str("Relevant Context:\n");
330            for chunk in relevant_chunks {
331                prompt.push_str(&format!("- {}\n", chunk.content));
332            }
333            prompt.push('\n');
334        }
335
336        // Add generation instructions
337        prompt.push_str("Instructions:\n");
338        prompt.push_str("- Generate realistic data that matches the schema\n");
339        prompt.push_str("- Ensure all required fields are present\n");
340        prompt.push_str("- Use appropriate data types and formats\n");
341        prompt.push_str("- Make relationships consistent if referenced\n");
342        prompt.push_str("- Output only valid JSON for a single object\n\n");
343
344        prompt.push_str("Generate the data:");
345
346        Ok(prompt)
347    }
348
349    /// Retrieve relevant document chunks using semantic search or keyword search
350    async fn retrieve_relevant_chunks(
351        &self,
352        query: &str,
353        limit: usize,
354    ) -> Result<Vec<&DocumentChunk>> {
355        if self.config.semantic_search_enabled {
356            // Use semantic search
357            let results = self.semantic_search(query, limit).await?;
358            Ok(results.into_iter().map(|r| r.chunk).collect())
359        } else {
360            // Fall back to keyword search
361            Ok(self.keyword_search(query, limit))
362        }
363    }
364
365    /// Perform keyword-based search (fallback)
366    pub fn keyword_search(&self, query: &str, limit: usize) -> Vec<&DocumentChunk> {
367        self.chunks
368            .iter()
369            .filter(|chunk| {
370                chunk.content.to_lowercase().contains(&query.to_lowercase())
371                    || chunk.metadata.values().any(|v| {
372                        if let Some(s) = v.as_str() {
373                            s.to_lowercase().contains(&query.to_lowercase())
374                        } else {
375                            false
376                        }
377                    })
378            })
379            .take(limit)
380            .collect()
381    }
382
383    /// Perform semantic search using embeddings
384    async fn semantic_search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult<'_>>> {
385        // Generate embedding for the query
386        let query_embedding = self.generate_embedding(query).await?;
387
388        // Calculate similarity scores for all chunks
389        let mut results: Vec<SearchResult> = Vec::new();
390
391        for chunk in &self.chunks {
392            if chunk.embedding.is_empty() {
393                // Skip chunks without embeddings
394                continue;
395            }
396
397            let score = Self::cosine_similarity(&query_embedding, &chunk.embedding);
398            if score >= self.config.similarity_threshold {
399                results.push(SearchResult { chunk, score });
400            }
401        }
402
403        // Sort by similarity score (descending) and take top results
404        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
405        results.truncate(limit);
406
407        Ok(results)
408    }
409
410    /// Generate embedding for text using the configured embedding provider
411    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
412        match &self.config.embedding_provider {
413            EmbeddingProvider::OpenAI => self.generate_openai_embedding(text).await,
414            EmbeddingProvider::OpenAICompatible => {
415                self.generate_openai_compatible_embedding(text).await
416            }
417        }
418    }
419
420    /// Generate embedding using OpenAI API
421    async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
422        let api_key = self
423            .config
424            .api_key
425            .as_ref()
426            .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
427
428        let endpoint = self
429            .config
430            .embedding_endpoint
431            .as_ref()
432            .unwrap_or(&self.config.api_endpoint)
433            .replace("chat/completions", "embeddings");
434
435        let request_body = serde_json::json!({
436            "model": self.config.embedding_model,
437            "input": text
438        });
439
440        debug!("Generating embedding for text with OpenAI API");
441
442        let response = self
443            .client
444            .post(&endpoint)
445            .header("Authorization", format!("Bearer {}", api_key))
446            .header("Content-Type", "application/json")
447            .json(&request_body)
448            .send()
449            .await
450            .map_err(|e| crate::Error::generic(format!("Embedding API request failed: {}", e)))?;
451
452        if !response.status().is_success() {
453            let error_text = response.text().await.unwrap_or_default();
454            return Err(crate::Error::generic(format!("Embedding API error: {}", error_text)));
455        }
456
457        let response_json: Value = response.json().await.map_err(|e| {
458            crate::Error::generic(format!("Failed to parse embedding response: {}", e))
459        })?;
460
461        if let Some(data) = response_json.get("data").and_then(|d| d.as_array()) {
462            if let Some(first_item) = data.first() {
463                if let Some(embedding) = first_item.get("embedding").and_then(|e| e.as_array()) {
464                    let embedding_vec: Vec<f32> =
465                        embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
466                    return Ok(embedding_vec);
467                }
468            }
469        }
470
471        Err(crate::Error::generic("Invalid embedding response format"))
472    }
473
474    /// Generate embedding using OpenAI-compatible API
475    async fn generate_openai_compatible_embedding(&self, text: &str) -> Result<Vec<f32>> {
476        let endpoint = self
477            .config
478            .embedding_endpoint
479            .as_ref()
480            .unwrap_or(&self.config.api_endpoint)
481            .replace("chat/completions", "embeddings");
482
483        let request_body = serde_json::json!({
484            "model": self.config.embedding_model,
485            "input": text
486        });
487
488        debug!("Generating embedding for text with OpenAI-compatible API");
489
490        let mut request = self
491            .client
492            .post(&endpoint)
493            .header("Content-Type", "application/json")
494            .json(&request_body);
495
496        if let Some(api_key) = &self.config.api_key {
497            request = request.header("Authorization", format!("Bearer {}", api_key));
498        }
499
500        let response = request
501            .send()
502            .await
503            .map_err(|e| crate::Error::generic(format!("Embedding API request failed: {}", e)))?;
504
505        if !response.status().is_success() {
506            let error_text = response.text().await.unwrap_or_default();
507            return Err(crate::Error::generic(format!("Embedding API error: {}", error_text)));
508        }
509
510        let response_json: Value = response.json().await.map_err(|e| {
511            crate::Error::generic(format!("Failed to parse embedding response: {}", e))
512        })?;
513
514        if let Some(data) = response_json.get("data").and_then(|d| d.as_array()) {
515            if let Some(first_item) = data.first() {
516                if let Some(embedding) = first_item.get("embedding").and_then(|e| e.as_array()) {
517                    let embedding_vec: Vec<f32> =
518                        embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
519                    return Ok(embedding_vec);
520                }
521            }
522        }
523
524        Err(crate::Error::generic("Invalid embedding response format"))
525    }
526
527    /// Compute embeddings for all document chunks
528    pub async fn compute_embeddings(&mut self) -> Result<()> {
529        debug!("Computing embeddings for {} chunks", self.chunks.len());
530
531        // Collect chunks that need embeddings
532        let chunks_to_embed: Vec<(usize, String)> = self
533            .chunks
534            .iter()
535            .enumerate()
536            .filter(|(_, chunk)| chunk.embedding.is_empty())
537            .map(|(idx, chunk)| (idx, chunk.content.clone()))
538            .collect();
539
540        // Generate embeddings for chunks that need them
541        for (idx, content) in chunks_to_embed {
542            let embedding = self.generate_embedding(&content).await?;
543            self.chunks[idx].embedding = embedding;
544            debug!("Computed embedding for chunk {}", self.chunks[idx].id);
545        }
546
547        Ok(())
548    }
549
550    /// Calculate cosine similarity between two vectors
551    fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
552        if a.len() != b.len() || a.is_empty() {
553            return 0.0;
554        }
555
556        let mut dot_product = 0.0;
557        let mut norm_a = 0.0;
558        let mut norm_b = 0.0;
559
560        for i in 0..a.len() {
561            dot_product += a[i] as f64 * b[i] as f64;
562            norm_a += (a[i] as f64).powi(2);
563            norm_b += (b[i] as f64).powi(2);
564        }
565
566        norm_a = norm_a.sqrt();
567        norm_b = norm_b.sqrt();
568
569        if norm_a == 0.0 || norm_b == 0.0 {
570            0.0
571        } else {
572            dot_product / (norm_a * norm_b)
573        }
574    }
575
576    /// Call LLM API with provider-specific implementation and retry logic
577    async fn call_llm(&self, prompt: &str) -> Result<String> {
578        let mut last_error = None;
579
580        for attempt in 0..=self.config.max_retries {
581            match self.call_llm_single_attempt(prompt).await {
582                Ok(result) => return Ok(result),
583                Err(e) => {
584                    last_error = Some(e);
585                    if attempt < self.config.max_retries {
586                        let delay = Duration::from_millis(500 * (attempt + 1) as u64);
587                        warn!(
588                            "LLM API call failed (attempt {}), retrying in {:?}: {:?}",
589                            attempt + 1,
590                            delay,
591                            last_error
592                        );
593                        sleep(delay).await;
594                    }
595                }
596            }
597        }
598
599        Err(last_error
600            .unwrap_or_else(|| crate::Error::generic("All LLM API retry attempts failed")))
601    }
602
603    /// Single attempt to call LLM API with provider-specific implementation
604    async fn call_llm_single_attempt(&self, prompt: &str) -> Result<String> {
605        match &self.config.provider {
606            LlmProvider::OpenAI => self.call_openai(prompt).await,
607            LlmProvider::Anthropic => self.call_anthropic(prompt).await,
608            LlmProvider::OpenAICompatible => self.call_openai_compatible(prompt).await,
609            LlmProvider::Ollama => self.call_ollama(prompt).await,
610        }
611    }
612
613    /// Call OpenAI API
614    async fn call_openai(&self, prompt: &str) -> Result<String> {
615        let api_key = self
616            .config
617            .api_key
618            .as_ref()
619            .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
620
621        let request_body = serde_json::json!({
622            "model": self.config.model,
623            "messages": [
624                {
625                    "role": "user",
626                    "content": prompt
627                }
628            ],
629            "max_tokens": self.config.max_tokens,
630            "temperature": self.config.temperature
631        });
632
633        debug!("Calling OpenAI API with model: {}", self.config.model);
634
635        let response = self
636            .client
637            .post(&self.config.api_endpoint)
638            .header("Authorization", format!("Bearer {}", api_key))
639            .header("Content-Type", "application/json")
640            .json(&request_body)
641            .send()
642            .await
643            .map_err(|e| crate::Error::generic(format!("OpenAI API request failed: {}", e)))?;
644
645        if !response.status().is_success() {
646            let error_text = response.text().await.unwrap_or_default();
647            return Err(crate::Error::generic(format!("OpenAI API error: {}", error_text)));
648        }
649
650        let response_json: Value = response.json().await.map_err(|e| {
651            crate::Error::generic(format!("Failed to parse OpenAI response: {}", e))
652        })?;
653
654        if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
655            if let Some(choice) = choices.first() {
656                if let Some(message) = choice.get("message").and_then(|m| m.get("content")) {
657                    if let Some(content) = message.as_str() {
658                        return Ok(content.to_string());
659                    }
660                }
661            }
662        }
663
664        Err(crate::Error::generic("Invalid OpenAI response format"))
665    }
666
667    /// Call Anthropic API
668    async fn call_anthropic(&self, prompt: &str) -> Result<String> {
669        let api_key = self
670            .config
671            .api_key
672            .as_ref()
673            .ok_or_else(|| crate::Error::generic("Anthropic API key not configured"))?;
674
675        let request_body = serde_json::json!({
676            "model": self.config.model,
677            "max_tokens": self.config.max_tokens,
678            "temperature": self.config.temperature,
679            "messages": [
680                {
681                    "role": "user",
682                    "content": prompt
683                }
684            ]
685        });
686
687        debug!("Calling Anthropic API with model: {}", self.config.model);
688
689        let response = self
690            .client
691            .post(&self.config.api_endpoint)
692            .header("x-api-key", api_key)
693            .header("Content-Type", "application/json")
694            .header("anthropic-version", "2023-06-01")
695            .json(&request_body)
696            .send()
697            .await
698            .map_err(|e| crate::Error::generic(format!("Anthropic API request failed: {}", e)))?;
699
700        if !response.status().is_success() {
701            let error_text = response.text().await.unwrap_or_default();
702            return Err(crate::Error::generic(format!("Anthropic API error: {}", error_text)));
703        }
704
705        let response_json: Value = response.json().await.map_err(|e| {
706            crate::Error::generic(format!("Failed to parse Anthropic response: {}", e))
707        })?;
708
709        if let Some(content) = response_json.get("content") {
710            if let Some(content_array) = content.as_array() {
711                if let Some(first_content) = content_array.first() {
712                    if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
713                        return Ok(text.to_string());
714                    }
715                }
716            }
717        }
718
719        Err(crate::Error::generic("Invalid Anthropic response format"))
720    }
721
722    /// Call OpenAI-compatible API
723    async fn call_openai_compatible(&self, prompt: &str) -> Result<String> {
724        let request_body = serde_json::json!({
725            "model": self.config.model,
726            "messages": [
727                {
728                    "role": "user",
729                    "content": prompt
730                }
731            ],
732            "max_tokens": self.config.max_tokens,
733            "temperature": self.config.temperature
734        });
735
736        debug!("Calling OpenAI-compatible API with model: {}", self.config.model);
737
738        let mut request = self
739            .client
740            .post(&self.config.api_endpoint)
741            .header("Content-Type", "application/json")
742            .json(&request_body);
743
744        if let Some(api_key) = &self.config.api_key {
745            request = request.header("Authorization", format!("Bearer {}", api_key));
746        }
747
748        let response = request.send().await.map_err(|e| {
749            crate::Error::generic(format!("OpenAI-compatible API request failed: {}", e))
750        })?;
751
752        if !response.status().is_success() {
753            let error_text = response.text().await.unwrap_or_default();
754            return Err(crate::Error::generic(format!(
755                "OpenAI-compatible API error: {}",
756                error_text
757            )));
758        }
759
760        let response_json: Value = response.json().await.map_err(|e| {
761            crate::Error::generic(format!("Failed to parse OpenAI-compatible response: {}", e))
762        })?;
763
764        if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
765            if let Some(choice) = choices.first() {
766                if let Some(message) = choice.get("message").and_then(|m| m.get("content")) {
767                    if let Some(content) = message.as_str() {
768                        return Ok(content.to_string());
769                    }
770                }
771            }
772        }
773
774        Err(crate::Error::generic("Invalid OpenAI-compatible response format"))
775    }
776
777    /// Call Ollama API
778    async fn call_ollama(&self, prompt: &str) -> Result<String> {
779        let request_body = serde_json::json!({
780            "model": self.config.model,
781            "prompt": prompt,
782            "stream": false
783        });
784
785        debug!("Calling Ollama API with model: {}", self.config.model);
786
787        let response = self
788            .client
789            .post(&self.config.api_endpoint)
790            .header("Content-Type", "application/json")
791            .json(&request_body)
792            .send()
793            .await
794            .map_err(|e| crate::Error::generic(format!("Ollama API request failed: {}", e)))?;
795
796        if !response.status().is_success() {
797            let error_text = response.text().await.unwrap_or_default();
798            return Err(crate::Error::generic(format!("Ollama API error: {}", error_text)));
799        }
800
801        let response_json: Value = response.json().await.map_err(|e| {
802            crate::Error::generic(format!("Failed to parse Ollama response: {}", e))
803        })?;
804
805        if let Some(response_text) = response_json.get("response").and_then(|r| r.as_str()) {
806            return Ok(response_text.to_string());
807        }
808
809        Err(crate::Error::generic("Invalid Ollama response format"))
810    }
811
812    /// Parse LLM response into structured data
813    fn parse_llm_response(&self, response: &str) -> Result<Value> {
814        // Try to parse as JSON
815        match serde_json::from_str(response) {
816            Ok(value) => Ok(value),
817            Err(e) => {
818                // If direct parsing fails, try to extract JSON from the response
819                if let Some(start) = response.find('{') {
820                    if let Some(end) = response.rfind('}') {
821                        let json_str = &response[start..=end];
822                        match serde_json::from_str(json_str) {
823                            Ok(value) => Ok(value),
824                            Err(_) => Err(crate::Error::generic(format!(
825                                "Failed to parse LLM response: {}",
826                                e
827                            ))),
828                        }
829                    } else {
830                        Err(crate::Error::generic(format!(
831                            "No closing brace found in response: {}",
832                            e
833                        )))
834                    }
835                } else {
836                    Err(crate::Error::generic(format!("No JSON found in response: {}", e)))
837                }
838            }
839        }
840    }
841
842    /// Update RAG configuration
843    pub fn update_config(&mut self, config: RagConfig) {
844        self.config = config;
845    }
846
847    /// Get current configuration
848    pub fn config(&self) -> &RagConfig {
849        &self.config
850    }
851
852    /// Get number of indexed chunks
853    pub fn chunk_count(&self) -> usize {
854        self.chunks.len()
855    }
856
857    /// Get number of indexed schemas
858    pub fn schema_count(&self) -> usize {
859        self.schema_kb.len()
860    }
861
862    /// Get chunk by index
863    pub fn get_chunk(&self, index: usize) -> Option<&DocumentChunk> {
864        self.chunks.get(index)
865    }
866
867    /// Check if schema exists in knowledge base
868    pub fn has_schema(&self, name: &str) -> bool {
869        self.schema_kb.contains_key(name)
870    }
871
872    /// Generate text using LLM (for intelligent mock generation)
873    pub async fn generate_text(&self, prompt: &str) -> Result<String> {
874        self.call_llm(prompt).await
875    }
876}
877
878impl Default for RagEngine {
879    fn default() -> Self {
880        Self::new(RagConfig::default())
881    }
882}
883
884/// RAG-enhanced data generation utilities
885pub mod rag_utils {
886    use super::*;
887
888    /// Create a RAG engine with common business domain knowledge
889    pub fn create_business_rag_engine() -> Result<RagEngine> {
890        let mut engine = RagEngine::default();
891
892        // Add common business knowledge
893        engine.add_document(
894            "Customer data typically includes personal information like name, email, phone, and address. Customers usually have unique identifiers and account creation dates.".to_string(),
895            HashMap::from([
896                ("domain".to_string(), Value::String("customer".to_string())),
897                ("type".to_string(), Value::String("general".to_string())),
898            ]),
899        )?;
900
901        engine.add_document(
902            "Product information includes name, description, price, category, and stock status. Products should have unique SKUs or IDs.".to_string(),
903            HashMap::from([
904                ("domain".to_string(), Value::String("product".to_string())),
905                ("type".to_string(), Value::String("general".to_string())),
906            ]),
907        )?;
908
909        engine.add_document(
910            "Order data contains customer references, product lists, total amounts, status, and timestamps. Orders should maintain referential integrity with customers and products.".to_string(),
911            HashMap::from([
912                ("domain".to_string(), Value::String("order".to_string())),
913                ("type".to_string(), Value::String("general".to_string())),
914            ]),
915        )?;
916
917        Ok(engine)
918    }
919
920    /// Create a RAG engine with technical domain knowledge
921    pub fn create_technical_rag_engine() -> Result<RagEngine> {
922        let mut engine = RagEngine::default();
923
924        // Add technical knowledge
925        engine.add_document(
926            "API endpoints should follow RESTful conventions with proper HTTP methods. GET for retrieval, POST for creation, PUT for updates, DELETE for removal.".to_string(),
927            HashMap::from([
928                ("domain".to_string(), Value::String("api".to_string())),
929                ("type".to_string(), Value::String("technical".to_string())),
930            ]),
931        )?;
932
933        engine.add_document(
934            "Database records typically have auto-incrementing primary keys, created_at and updated_at timestamps, and foreign key relationships.".to_string(),
935            HashMap::from([
936                ("domain".to_string(), Value::String("database".to_string())),
937                ("type".to_string(), Value::String("technical".to_string())),
938            ]),
939        )?;
940
941        Ok(engine)
942    }
943}
944#[cfg(test)]
945mod tests {
946    use super::*;
947
948    #[test]
949    fn test_llm_provider_variants() {
950        let openai = LlmProvider::OpenAI;
951        let anthropic = LlmProvider::Anthropic;
952        let compatible = LlmProvider::OpenAICompatible;
953        let ollama = LlmProvider::Ollama;
954
955        assert!(matches!(openai, LlmProvider::OpenAI));
956        assert!(matches!(anthropic, LlmProvider::Anthropic));
957        assert!(matches!(compatible, LlmProvider::OpenAICompatible));
958        assert!(matches!(ollama, LlmProvider::Ollama));
959    }
960
961    #[test]
962    fn test_embedding_provider_variants() {
963        let openai = EmbeddingProvider::OpenAI;
964        let compatible = EmbeddingProvider::OpenAICompatible;
965
966        assert!(matches!(openai, EmbeddingProvider::OpenAI));
967        assert!(matches!(compatible, EmbeddingProvider::OpenAICompatible));
968    }
969
970    #[test]
971    fn test_rag_config_default() {
972        let config = RagConfig::default();
973
974        assert!(config.max_tokens > 0);
975        assert!(config.temperature >= 0.0 && config.temperature <= 1.0);
976        assert!(config.context_window > 0);
977    }
978}