Skip to main content

mockforge_data/
rag.rs

1//! RAG (Retrieval-Augmented Generation) for enhanced data synthesis
2//!
3//! This module has been refactored into sub-modules for better organization:
4//! - config: RAG configuration and settings management
5//! - engine: Core RAG engine and retrieval logic
6//! - providers: LLM and embedding provider integrations
7//! - storage: Document storage and vector indexing
8//! - utils: Utility functions and helpers for RAG operations
9
10// Re-export sub-modules for backward compatibility
11pub mod config;
12pub mod engine;
13pub mod providers;
14pub mod storage;
15pub mod utils;
16
17// Re-export commonly used types
18pub use config::*;
19pub use providers::*;
20pub use utils::*;
21
22// Re-export engine and storage types with explicit names to avoid conflicts
23pub use engine::StorageStats as EngineStorageStats;
24pub use storage::StorageStats as StorageStorageStats;
25
26// Legacy imports for compatibility
27use crate::Result;
28use crate::{schema::SchemaDefinition, DataConfig};
29use reqwest::{Client, ClientBuilder};
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::cmp::Ordering;
33use std::collections::HashMap;
34use std::time::Duration;
35use tokio::time::sleep;
36use tracing::{debug, warn};
37
38/// Supported LLM providers
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
40#[serde(rename_all = "lowercase")]
41pub enum LlmProvider {
42    /// OpenAI GPT models
43    OpenAI,
44    /// Anthropic Claude models
45    Anthropic,
46    /// Generic OpenAI-compatible API
47    OpenAICompatible,
48    /// Local Ollama instance
49    Ollama,
50}
51
52/// Supported embedding providers
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
54#[serde(rename_all = "lowercase")]
55pub enum EmbeddingProvider {
56    /// OpenAI text-embedding-ada-002
57    OpenAI,
58    /// Generic OpenAI-compatible embeddings API
59    OpenAICompatible,
60    /// Local Ollama instance
61    Ollama,
62}
63
64/// RAG configuration
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct RagConfig {
67    /// LLM provider
68    pub provider: LlmProvider,
69    /// LLM API endpoint
70    pub api_endpoint: String,
71    /// API key for authentication
72    pub api_key: Option<String>,
73    /// Model name to use
74    pub model: String,
75    /// Maximum tokens for generation
76    pub max_tokens: usize,
77    /// Temperature for generation
78    pub temperature: f64,
79    /// Context window size
80    pub context_window: usize,
81
82    /// Whether to use semantic search instead of keyword search
83    pub semantic_search_enabled: bool,
84    /// Embedding provider for semantic search
85    pub embedding_provider: EmbeddingProvider,
86    /// Embedding model to use
87    pub embedding_model: String,
88    /// Embedding API endpoint (if different from LLM endpoint)
89    pub embedding_endpoint: Option<String>,
90    /// Similarity threshold for semantic search (0.0 to 1.0)
91    pub similarity_threshold: f64,
92    /// Maximum number of chunks to retrieve for semantic search
93    pub max_chunks: usize,
94
95    /// Request timeout in seconds
96    pub request_timeout_seconds: u64,
97    /// Maximum number of retries for failed requests
98    pub max_retries: usize,
99}
100
101impl Default for RagConfig {
102    fn default() -> Self {
103        Self {
104            provider: LlmProvider::OpenAI,
105            api_endpoint: "https://api.openai.com/v1/chat/completions".to_string(),
106            api_key: None,
107            model: "gpt-3.5-turbo".to_string(),
108            max_tokens: 1000,
109            temperature: 0.7,
110            context_window: 4000,
111            semantic_search_enabled: true,
112            embedding_provider: EmbeddingProvider::OpenAI,
113            embedding_model: "text-embedding-ada-002".to_string(),
114            embedding_endpoint: None,
115            similarity_threshold: 0.7,
116            max_chunks: 5,
117            request_timeout_seconds: 30,
118            max_retries: 3,
119        }
120    }
121}
122
123/// Document chunk for RAG indexing
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct DocumentChunk {
126    /// Chunk ID
127    pub id: String,
128    /// Content text
129    pub content: String,
130    /// Metadata
131    pub metadata: HashMap<String, Value>,
132    /// Embedding vector for semantic search
133    pub embedding: Vec<f32>,
134}
135
136/// Search result with similarity score
137#[derive(Debug)]
138pub struct SearchResult<'a> {
139    /// The document chunk
140    pub chunk: &'a DocumentChunk,
141    /// Similarity score (0.0 to 1.0)
142    pub score: f64,
143}
144
145/// RAG engine for enhanced data generation
146#[derive(Debug)]
147pub struct RagEngine {
148    /// Configuration
149    config: RagConfig,
150    /// Document chunks for retrieval
151    chunks: Vec<DocumentChunk>,
152    /// Schema knowledge base
153    schema_kb: HashMap<String, Vec<String>>,
154    /// HTTP client for LLM API calls
155    client: Client,
156}
157
158impl RagEngine {
159    /// Create a new RAG engine
160    pub fn new(config: RagConfig) -> Self {
161        let client = ClientBuilder::new()
162            .timeout(Duration::from_secs(config.request_timeout_seconds))
163            .build()
164            .unwrap_or_else(|e| {
165                warn!("Failed to create HTTP client with timeout, using default: {}", e);
166                Client::new()
167            });
168
169        Self {
170            config,
171            chunks: Vec::new(),
172            schema_kb: HashMap::new(),
173            client,
174        }
175    }
176
177    /// Add a document to the knowledge base
178    pub fn add_document(
179        &mut self,
180        content: String,
181        metadata: HashMap<String, Value>,
182    ) -> Result<String> {
183        let id = format!("chunk_{}", self.chunks.len());
184        let chunk = DocumentChunk {
185            id: id.clone(),
186            content,
187            metadata,
188            embedding: Vec::new(), // Would compute embedding here
189        };
190
191        self.chunks.push(chunk);
192        Ok(id)
193    }
194
195    /// Add schema information to knowledge base
196    pub fn add_schema(&mut self, schema: &SchemaDefinition) -> Result<()> {
197        let mut schema_info = Vec::new();
198
199        schema_info.push(format!("Schema: {}", schema.name));
200
201        if let Some(description) = &schema.description {
202            schema_info.push(format!("Description: {}", description));
203        }
204
205        for field in &schema.fields {
206            let mut field_info = format!(
207                "Field '{}': type={}, required={}",
208                field.name, field.field_type, field.required
209            );
210
211            if let Some(description) = &field.description {
212                field_info.push_str(&format!(" - {}", description));
213            }
214
215            schema_info.push(field_info);
216        }
217
218        for (rel_name, relationship) in &schema.relationships {
219            schema_info.push(format!(
220                "Relationship '{}': {} -> {} ({:?})",
221                rel_name, schema.name, relationship.target_schema, relationship.relationship_type
222            ));
223        }
224
225        self.schema_kb.insert(schema.name.clone(), schema_info);
226        Ok(())
227    }
228
229    /// Generate data using RAG-augmented prompts
230    pub async fn generate_with_rag(
231        &self,
232        schema: &SchemaDefinition,
233        config: &DataConfig,
234    ) -> Result<Vec<Value>> {
235        if !config.rag_enabled {
236            return Err(crate::Error::generic("RAG is not enabled in config"));
237        }
238
239        // Validate RAG configuration before proceeding
240        if self.config.api_key.is_none() {
241            return Err(crate::Error::generic(
242                "RAG is enabled but no API key is configured. Please set MOCKFORGE_RAG_API_KEY or provide --rag-api-key"
243            ));
244        }
245
246        let mut results = Vec::new();
247        let mut failed_rows = 0;
248
249        // Generate prompts for each row
250        for i in 0..config.rows {
251            match self.generate_single_row_with_rag(schema, i).await {
252                Ok(data) => results.push(data),
253                Err(e) => {
254                    failed_rows += 1;
255                    warn!("Failed to generate RAG data for row {}: {}", i, e);
256
257                    // If too many rows fail, return an error
258                    if failed_rows > config.rows / 4 {
259                        // Allow up to 25% failure rate
260                        return Err(crate::Error::generic(
261                            format!("Too many RAG generation failures ({} out of {} rows failed). Check API configuration and network connectivity.", failed_rows, config.rows)
262                        ));
263                    }
264
265                    // For failed rows, generate fallback data
266                    let fallback_data = self.generate_fallback_data(schema);
267                    results.push(fallback_data);
268                }
269            }
270        }
271
272        if failed_rows > 0 {
273            warn!(
274                "RAG generation completed with {} failed rows out of {}",
275                failed_rows, config.rows
276            );
277        }
278
279        Ok(results)
280    }
281
282    /// Generate a single row using RAG
283    async fn generate_single_row_with_rag(
284        &self,
285        schema: &SchemaDefinition,
286        row_index: usize,
287    ) -> Result<Value> {
288        let prompt = self.build_generation_prompt(schema, row_index).await?;
289        let generated_data = self.call_llm(&prompt).await?;
290        self.parse_llm_response(&generated_data)
291    }
292
293    /// Generate fallback data when RAG fails
294    fn generate_fallback_data(&self, schema: &SchemaDefinition) -> Value {
295        let mut obj = serde_json::Map::new();
296
297        for field in &schema.fields {
298            let value = match field.field_type.as_str() {
299                "string" => Value::String("sample_data".to_string()),
300                "integer" | "number" => Value::Number(42.into()),
301                "boolean" => Value::Bool(true),
302                _ => Value::String("sample_data".to_string()),
303            };
304            obj.insert(field.name.clone(), value);
305        }
306
307        Value::Object(obj)
308    }
309
310    /// Build a generation prompt with retrieved context
311    async fn build_generation_prompt(
312        &self,
313        schema: &SchemaDefinition,
314        _row_index: usize,
315    ) -> Result<String> {
316        let mut prompt =
317            format!("Generate a single row of data for the '{}' schema.\n\n", schema.name);
318
319        // Add schema information
320        if let Some(schema_info) = self.schema_kb.get(&schema.name) {
321            prompt.push_str("Schema Information:\n");
322            for info in schema_info {
323                prompt.push_str(&format!("- {}\n", info));
324            }
325            prompt.push('\n');
326        }
327
328        // Retrieve relevant context from documents
329        let relevant_chunks = self.retrieve_relevant_chunks(&schema.name, 3).await?;
330        if !relevant_chunks.is_empty() {
331            prompt.push_str("Relevant Context:\n");
332            for chunk in relevant_chunks {
333                prompt.push_str(&format!("- {}\n", chunk.content));
334            }
335            prompt.push('\n');
336        }
337
338        // Add generation instructions
339        prompt.push_str("Instructions:\n");
340        prompt.push_str("- Generate realistic data that matches the schema\n");
341        prompt.push_str("- Ensure all required fields are present\n");
342        prompt.push_str("- Use appropriate data types and formats\n");
343        prompt.push_str("- Make relationships consistent if referenced\n");
344        prompt.push_str("- Output only valid JSON for a single object\n\n");
345
346        prompt.push_str("Generate the data:");
347
348        Ok(prompt)
349    }
350
351    /// Retrieve relevant document chunks using semantic search or keyword search
352    async fn retrieve_relevant_chunks(
353        &self,
354        query: &str,
355        limit: usize,
356    ) -> Result<Vec<&DocumentChunk>> {
357        if self.config.semantic_search_enabled {
358            // Use semantic search
359            let results = self.semantic_search(query, limit).await?;
360            Ok(results.into_iter().map(|r| r.chunk).collect())
361        } else {
362            // Fall back to keyword search
363            Ok(self.keyword_search(query, limit))
364        }
365    }
366
367    /// Perform keyword-based search (fallback)
368    pub fn keyword_search(&self, query: &str, limit: usize) -> Vec<&DocumentChunk> {
369        self.chunks
370            .iter()
371            .filter(|chunk| {
372                chunk.content.to_lowercase().contains(&query.to_lowercase())
373                    || chunk.metadata.values().any(|v| {
374                        if let Some(s) = v.as_str() {
375                            s.to_lowercase().contains(&query.to_lowercase())
376                        } else {
377                            false
378                        }
379                    })
380            })
381            .take(limit)
382            .collect()
383    }
384
385    /// Perform semantic search using embeddings
386    async fn semantic_search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult<'_>>> {
387        // Generate embedding for the query
388        let query_embedding = self.generate_embedding(query).await?;
389
390        // Calculate similarity scores for all chunks
391        let mut results: Vec<SearchResult> = Vec::new();
392
393        for chunk in &self.chunks {
394            if chunk.embedding.is_empty() {
395                // Skip chunks without embeddings
396                continue;
397            }
398
399            let score = Self::cosine_similarity(&query_embedding, &chunk.embedding);
400            if score >= self.config.similarity_threshold {
401                results.push(SearchResult { chunk, score });
402            }
403        }
404
405        // Sort by similarity score (descending) and take top results
406        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(Ordering::Equal));
407        results.truncate(limit);
408
409        Ok(results)
410    }
411
412    /// Generate embedding for text using the configured embedding provider
413    async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>> {
414        match &self.config.embedding_provider {
415            EmbeddingProvider::OpenAI => self.generate_openai_embedding(text).await,
416            EmbeddingProvider::OpenAICompatible => {
417                self.generate_openai_compatible_embedding(text).await
418            }
419            EmbeddingProvider::Ollama => self.generate_ollama_embedding(text).await,
420        }
421    }
422
423    /// Generate embedding using OpenAI API
424    async fn generate_openai_embedding(&self, text: &str) -> Result<Vec<f32>> {
425        let api_key = self
426            .config
427            .api_key
428            .as_ref()
429            .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
430
431        let endpoint = self
432            .config
433            .embedding_endpoint
434            .as_ref()
435            .unwrap_or(&self.config.api_endpoint)
436            .replace("chat/completions", "embeddings");
437
438        let request_body = serde_json::json!({
439            "model": self.config.embedding_model,
440            "input": text
441        });
442
443        debug!("Generating embedding for text with OpenAI API");
444
445        let response = self
446            .client
447            .post(&endpoint)
448            .header("Authorization", format!("Bearer {}", api_key))
449            .header("Content-Type", "application/json")
450            .json(&request_body)
451            .send()
452            .await
453            .map_err(|e| crate::Error::generic(format!("Embedding API request failed: {}", e)))?;
454
455        if !response.status().is_success() {
456            let error_text = response.text().await.unwrap_or_default();
457            return Err(crate::Error::generic(format!("Embedding API error: {}", error_text)));
458        }
459
460        let response_json: Value = response.json().await.map_err(|e| {
461            crate::Error::generic(format!("Failed to parse embedding response: {}", e))
462        })?;
463
464        if let Some(data) = response_json.get("data").and_then(|d| d.as_array()) {
465            if let Some(first_item) = data.first() {
466                if let Some(embedding) = first_item.get("embedding").and_then(|e| e.as_array()) {
467                    let embedding_vec: Vec<f32> =
468                        embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
469                    return Ok(embedding_vec);
470                }
471            }
472        }
473
474        Err(crate::Error::generic("Invalid embedding response format"))
475    }
476
477    /// Generate embedding using OpenAI-compatible API
478    async fn generate_openai_compatible_embedding(&self, text: &str) -> Result<Vec<f32>> {
479        let endpoint = self
480            .config
481            .embedding_endpoint
482            .as_ref()
483            .unwrap_or(&self.config.api_endpoint)
484            .replace("chat/completions", "embeddings");
485
486        let request_body = serde_json::json!({
487            "model": self.config.embedding_model,
488            "input": text
489        });
490
491        debug!("Generating embedding for text with OpenAI-compatible API");
492
493        let mut request = self
494            .client
495            .post(&endpoint)
496            .header("Content-Type", "application/json")
497            .json(&request_body);
498
499        if let Some(api_key) = &self.config.api_key {
500            request = request.header("Authorization", format!("Bearer {}", api_key));
501        }
502
503        let response = request
504            .send()
505            .await
506            .map_err(|e| crate::Error::generic(format!("Embedding API request failed: {}", e)))?;
507
508        if !response.status().is_success() {
509            let error_text = response.text().await.unwrap_or_default();
510            return Err(crate::Error::generic(format!("Embedding API error: {}", error_text)));
511        }
512
513        let response_json: Value = response.json().await.map_err(|e| {
514            crate::Error::generic(format!("Failed to parse embedding response: {}", e))
515        })?;
516
517        if let Some(data) = response_json.get("data").and_then(|d| d.as_array()) {
518            if let Some(first_item) = data.first() {
519                if let Some(embedding) = first_item.get("embedding").and_then(|e| e.as_array()) {
520                    let embedding_vec: Vec<f32> =
521                        embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
522                    return Ok(embedding_vec);
523                }
524            }
525        }
526
527        Err(crate::Error::generic("Invalid embedding response format"))
528    }
529
530    /// Generate embedding using Ollama API
531    ///
532    /// Ollama exposes embeddings via `POST /api/embeddings` with `{ model, prompt }`.
533    async fn generate_ollama_embedding(&self, text: &str) -> Result<Vec<f32>> {
534        let base_url = self.config.embedding_endpoint.as_ref().unwrap_or(&self.config.api_endpoint);
535
536        // Ollama embedding endpoint
537        let endpoint = if base_url.ends_with("/api/embeddings") {
538            base_url.clone()
539        } else {
540            format!("{}/api/embeddings", base_url.trim_end_matches('/'))
541        };
542
543        let model = &self.config.embedding_model;
544        let request_body = serde_json::json!({
545            "model": model,
546            "prompt": text
547        });
548
549        debug!("Generating embedding for text with Ollama (model: {})", model);
550
551        let response = self
552            .client
553            .post(&endpoint)
554            .header("Content-Type", "application/json")
555            .json(&request_body)
556            .send()
557            .await
558            .map_err(|e| {
559                crate::Error::generic(format!("Ollama embedding request failed: {}", e))
560            })?;
561
562        if !response.status().is_success() {
563            let error_text = response.text().await.unwrap_or_default();
564            return Err(crate::Error::generic(format!("Ollama embedding error: {}", error_text)));
565        }
566
567        let response_json: Value = response.json().await.map_err(|e| {
568            crate::Error::generic(format!("Failed to parse Ollama embedding response: {}", e))
569        })?;
570
571        // Ollama returns { "embedding": [...] }
572        if let Some(embedding) = response_json.get("embedding").and_then(|e| e.as_array()) {
573            let embedding_vec: Vec<f32> =
574                embedding.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
575            return Ok(embedding_vec);
576        }
577
578        Err(crate::Error::generic("Invalid Ollama embedding response format"))
579    }
580
581    /// Compute embeddings for all document chunks
582    pub async fn compute_embeddings(&mut self) -> Result<()> {
583        debug!("Computing embeddings for {} chunks", self.chunks.len());
584
585        // Collect chunks that need embeddings
586        let chunks_to_embed: Vec<(usize, String)> = self
587            .chunks
588            .iter()
589            .enumerate()
590            .filter(|(_, chunk)| chunk.embedding.is_empty())
591            .map(|(idx, chunk)| (idx, chunk.content.clone()))
592            .collect();
593
594        // Generate embeddings for chunks that need them
595        for (idx, content) in chunks_to_embed {
596            let embedding = self.generate_embedding(&content).await?;
597            self.chunks[idx].embedding = embedding;
598            debug!("Computed embedding for chunk {}", self.chunks[idx].id);
599        }
600
601        Ok(())
602    }
603
604    /// Calculate cosine similarity between two vectors
605    fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
606        if a.len() != b.len() || a.is_empty() {
607            return 0.0;
608        }
609
610        let mut dot_product = 0.0;
611        let mut norm_a = 0.0;
612        let mut norm_b = 0.0;
613
614        for i in 0..a.len() {
615            dot_product += a[i] as f64 * b[i] as f64;
616            norm_a += (a[i] as f64).powi(2);
617            norm_b += (b[i] as f64).powi(2);
618        }
619
620        norm_a = norm_a.sqrt();
621        norm_b = norm_b.sqrt();
622
623        if norm_a == 0.0 || norm_b == 0.0 {
624            0.0
625        } else {
626            dot_product / (norm_a * norm_b)
627        }
628    }
629
630    /// Call LLM API with provider-specific implementation and retry logic
631    async fn call_llm(&self, prompt: &str) -> Result<String> {
632        let mut last_error = None;
633
634        for attempt in 0..=self.config.max_retries {
635            match self.call_llm_single_attempt(prompt).await {
636                Ok(result) => return Ok(result),
637                Err(e) => {
638                    last_error = Some(e);
639                    if attempt < self.config.max_retries {
640                        let delay = Duration::from_millis(500 * (attempt + 1) as u64);
641                        warn!(
642                            "LLM API call failed (attempt {}), retrying in {:?}: {:?}",
643                            attempt + 1,
644                            delay,
645                            last_error
646                        );
647                        sleep(delay).await;
648                    }
649                }
650            }
651        }
652
653        Err(last_error
654            .unwrap_or_else(|| crate::Error::generic("All LLM API retry attempts failed")))
655    }
656
657    /// Single attempt to call LLM API with provider-specific implementation
658    async fn call_llm_single_attempt(&self, prompt: &str) -> Result<String> {
659        match &self.config.provider {
660            LlmProvider::OpenAI => self.call_openai(prompt).await,
661            LlmProvider::Anthropic => self.call_anthropic(prompt).await,
662            LlmProvider::OpenAICompatible => self.call_openai_compatible(prompt).await,
663            LlmProvider::Ollama => self.call_ollama(prompt).await,
664        }
665    }
666
667    /// Call OpenAI API
668    async fn call_openai(&self, prompt: &str) -> Result<String> {
669        let api_key = self
670            .config
671            .api_key
672            .as_ref()
673            .ok_or_else(|| crate::Error::generic("OpenAI API key not configured"))?;
674
675        let request_body = serde_json::json!({
676            "model": self.config.model,
677            "messages": [
678                {
679                    "role": "user",
680                    "content": prompt
681                }
682            ],
683            "max_tokens": self.config.max_tokens,
684            "temperature": self.config.temperature
685        });
686
687        debug!("Calling OpenAI API with model: {}", self.config.model);
688
689        let response = self
690            .client
691            .post(&self.config.api_endpoint)
692            .header("Authorization", format!("Bearer {}", api_key))
693            .header("Content-Type", "application/json")
694            .json(&request_body)
695            .send()
696            .await
697            .map_err(|e| crate::Error::generic(format!("OpenAI API request failed: {}", e)))?;
698
699        if !response.status().is_success() {
700            let error_text = response.text().await.unwrap_or_default();
701            return Err(crate::Error::generic(format!("OpenAI API error: {}", error_text)));
702        }
703
704        let response_json: Value = response.json().await.map_err(|e| {
705            crate::Error::generic(format!("Failed to parse OpenAI response: {}", e))
706        })?;
707
708        if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
709            if let Some(choice) = choices.first() {
710                if let Some(message) = choice.get("message").and_then(|m| m.get("content")) {
711                    if let Some(content) = message.as_str() {
712                        return Ok(content.to_string());
713                    }
714                }
715            }
716        }
717
718        Err(crate::Error::generic("Invalid OpenAI response format"))
719    }
720
721    /// Call Anthropic API
722    async fn call_anthropic(&self, prompt: &str) -> Result<String> {
723        let api_key = self
724            .config
725            .api_key
726            .as_ref()
727            .ok_or_else(|| crate::Error::generic("Anthropic API key not configured"))?;
728
729        let request_body = serde_json::json!({
730            "model": self.config.model,
731            "max_tokens": self.config.max_tokens,
732            "temperature": self.config.temperature,
733            "messages": [
734                {
735                    "role": "user",
736                    "content": prompt
737                }
738            ]
739        });
740
741        debug!("Calling Anthropic API with model: {}", self.config.model);
742
743        let response = self
744            .client
745            .post(&self.config.api_endpoint)
746            .header("x-api-key", api_key)
747            .header("Content-Type", "application/json")
748            .header("anthropic-version", "2023-06-01")
749            .json(&request_body)
750            .send()
751            .await
752            .map_err(|e| crate::Error::generic(format!("Anthropic API request failed: {}", e)))?;
753
754        if !response.status().is_success() {
755            let error_text = response.text().await.unwrap_or_default();
756            return Err(crate::Error::generic(format!("Anthropic API error: {}", error_text)));
757        }
758
759        let response_json: Value = response.json().await.map_err(|e| {
760            crate::Error::generic(format!("Failed to parse Anthropic response: {}", e))
761        })?;
762
763        if let Some(content) = response_json.get("content") {
764            if let Some(content_array) = content.as_array() {
765                if let Some(first_content) = content_array.first() {
766                    if let Some(text) = first_content.get("text").and_then(|t| t.as_str()) {
767                        return Ok(text.to_string());
768                    }
769                }
770            }
771        }
772
773        Err(crate::Error::generic("Invalid Anthropic response format"))
774    }
775
776    /// Call OpenAI-compatible API
777    async fn call_openai_compatible(&self, prompt: &str) -> Result<String> {
778        let request_body = serde_json::json!({
779            "model": self.config.model,
780            "messages": [
781                {
782                    "role": "user",
783                    "content": prompt
784                }
785            ],
786            "max_tokens": self.config.max_tokens,
787            "temperature": self.config.temperature
788        });
789
790        debug!("Calling OpenAI-compatible API with model: {}", self.config.model);
791
792        let mut request = self
793            .client
794            .post(&self.config.api_endpoint)
795            .header("Content-Type", "application/json")
796            .json(&request_body);
797
798        if let Some(api_key) = &self.config.api_key {
799            request = request.header("Authorization", format!("Bearer {}", api_key));
800        }
801
802        let response = request.send().await.map_err(|e| {
803            crate::Error::generic(format!("OpenAI-compatible API request failed: {}", e))
804        })?;
805
806        if !response.status().is_success() {
807            let error_text = response.text().await.unwrap_or_default();
808            return Err(crate::Error::generic(format!(
809                "OpenAI-compatible API error: {}",
810                error_text
811            )));
812        }
813
814        let response_json: Value = response.json().await.map_err(|e| {
815            crate::Error::generic(format!("Failed to parse OpenAI-compatible response: {}", e))
816        })?;
817
818        if let Some(choices) = response_json.get("choices").and_then(|c| c.as_array()) {
819            if let Some(choice) = choices.first() {
820                if let Some(message) = choice.get("message").and_then(|m| m.get("content")) {
821                    if let Some(content) = message.as_str() {
822                        return Ok(content.to_string());
823                    }
824                }
825            }
826        }
827
828        Err(crate::Error::generic("Invalid OpenAI-compatible response format"))
829    }
830
831    /// Call Ollama API
832    async fn call_ollama(&self, prompt: &str) -> Result<String> {
833        let request_body = serde_json::json!({
834            "model": self.config.model,
835            "prompt": prompt,
836            "stream": false
837        });
838
839        debug!("Calling Ollama API with model: {}", self.config.model);
840
841        let response = self
842            .client
843            .post(&self.config.api_endpoint)
844            .header("Content-Type", "application/json")
845            .json(&request_body)
846            .send()
847            .await
848            .map_err(|e| crate::Error::generic(format!("Ollama API request failed: {}", e)))?;
849
850        if !response.status().is_success() {
851            let error_text = response.text().await.unwrap_or_default();
852            return Err(crate::Error::generic(format!("Ollama API error: {}", error_text)));
853        }
854
855        let response_json: Value = response.json().await.map_err(|e| {
856            crate::Error::generic(format!("Failed to parse Ollama response: {}", e))
857        })?;
858
859        if let Some(response_text) = response_json.get("response").and_then(|r| r.as_str()) {
860            return Ok(response_text.to_string());
861        }
862
863        Err(crate::Error::generic("Invalid Ollama response format"))
864    }
865
866    /// Parse LLM response into structured data
867    fn parse_llm_response(&self, response: &str) -> Result<Value> {
868        // Try to parse as JSON
869        match serde_json::from_str(response) {
870            Ok(value) => Ok(value),
871            Err(e) => {
872                // If direct parsing fails, try to extract JSON from the response
873                if let Some(start) = response.find('{') {
874                    if let Some(end) = response.rfind('}') {
875                        let json_str = &response[start..=end];
876                        match serde_json::from_str(json_str) {
877                            Ok(value) => Ok(value),
878                            Err(_) => Err(crate::Error::generic(format!(
879                                "Failed to parse LLM response: {}",
880                                e
881                            ))),
882                        }
883                    } else {
884                        Err(crate::Error::generic(format!(
885                            "No closing brace found in response: {}",
886                            e
887                        )))
888                    }
889                } else {
890                    Err(crate::Error::generic(format!("No JSON found in response: {}", e)))
891                }
892            }
893        }
894    }
895
896    /// Update RAG configuration
897    pub fn update_config(&mut self, config: RagConfig) {
898        self.config = config;
899    }
900
901    /// Get current configuration
902    pub fn config(&self) -> &RagConfig {
903        &self.config
904    }
905
906    /// Get number of indexed chunks
907    pub fn chunk_count(&self) -> usize {
908        self.chunks.len()
909    }
910
911    /// Get number of indexed schemas
912    pub fn schema_count(&self) -> usize {
913        self.schema_kb.len()
914    }
915
916    /// Get chunk by index
917    pub fn get_chunk(&self, index: usize) -> Option<&DocumentChunk> {
918        self.chunks.get(index)
919    }
920
921    /// Check if schema exists in knowledge base
922    pub fn has_schema(&self, name: &str) -> bool {
923        self.schema_kb.contains_key(name)
924    }
925
926    /// Generate text using LLM (for intelligent mock generation)
927    pub async fn generate_text(&self, prompt: &str) -> Result<String> {
928        self.call_llm(prompt).await
929    }
930}
931
932impl Default for RagEngine {
933    fn default() -> Self {
934        Self::new(RagConfig::default())
935    }
936}
937
938/// RAG-enhanced data generation utilities
939pub mod rag_utils {
940    use super::*;
941
942    /// Create a RAG engine with common business domain knowledge
943    pub fn create_business_rag_engine() -> Result<RagEngine> {
944        let mut engine = RagEngine::default();
945
946        // Add common business knowledge
947        engine.add_document(
948            "Customer data typically includes personal information like name, email, phone, and address. Customers usually have unique identifiers and account creation dates.".to_string(),
949            HashMap::from([
950                ("domain".to_string(), Value::String("customer".to_string())),
951                ("type".to_string(), Value::String("general".to_string())),
952            ]),
953        )?;
954
955        engine.add_document(
956            "Product information includes name, description, price, category, and stock status. Products should have unique SKUs or IDs.".to_string(),
957            HashMap::from([
958                ("domain".to_string(), Value::String("product".to_string())),
959                ("type".to_string(), Value::String("general".to_string())),
960            ]),
961        )?;
962
963        engine.add_document(
964            "Order data contains customer references, product lists, total amounts, status, and timestamps. Orders should maintain referential integrity with customers and products.".to_string(),
965            HashMap::from([
966                ("domain".to_string(), Value::String("order".to_string())),
967                ("type".to_string(), Value::String("general".to_string())),
968            ]),
969        )?;
970
971        Ok(engine)
972    }
973
974    /// Create a RAG engine with technical domain knowledge
975    pub fn create_technical_rag_engine() -> Result<RagEngine> {
976        let mut engine = RagEngine::default();
977
978        // Add technical knowledge
979        engine.add_document(
980            "API endpoints should follow RESTful conventions with proper HTTP methods. GET for retrieval, POST for creation, PUT for updates, DELETE for removal.".to_string(),
981            HashMap::from([
982                ("domain".to_string(), Value::String("api".to_string())),
983                ("type".to_string(), Value::String("technical".to_string())),
984            ]),
985        )?;
986
987        engine.add_document(
988            "Database records typically have auto-incrementing primary keys, created_at and updated_at timestamps, and foreign key relationships.".to_string(),
989            HashMap::from([
990                ("domain".to_string(), Value::String("database".to_string())),
991                ("type".to_string(), Value::String("technical".to_string())),
992            ]),
993        )?;
994
995        Ok(engine)
996    }
997}
998#[cfg(test)]
999mod tests {
1000    use super::*;
1001
1002    #[test]
1003    fn test_llm_provider_variants() {
1004        let openai = LlmProvider::OpenAI;
1005        let anthropic = LlmProvider::Anthropic;
1006        let compatible = LlmProvider::OpenAICompatible;
1007        let ollama = LlmProvider::Ollama;
1008
1009        assert!(matches!(openai, LlmProvider::OpenAI));
1010        assert!(matches!(anthropic, LlmProvider::Anthropic));
1011        assert!(matches!(compatible, LlmProvider::OpenAICompatible));
1012        assert!(matches!(ollama, LlmProvider::Ollama));
1013    }
1014
1015    #[test]
1016    fn test_embedding_provider_variants() {
1017        let openai = EmbeddingProvider::OpenAI;
1018        let compatible = EmbeddingProvider::OpenAICompatible;
1019
1020        assert!(matches!(openai, EmbeddingProvider::OpenAI));
1021        assert!(matches!(compatible, EmbeddingProvider::OpenAICompatible));
1022    }
1023
1024    #[test]
1025    fn test_rag_config_default() {
1026        let config = RagConfig::default();
1027
1028        assert!(config.max_tokens > 0);
1029        assert!(config.temperature >= 0.0 && config.temperature <= 1.0);
1030        assert!(config.context_window > 0);
1031    }
1032}