mockforge_grpc/reflection/
rag_synthesis.rs

1//! RAG-driven domain-aware data synthesis
2//!
3//! This module integrates with the MockForge RAG system to generate contextually
4//! appropriate synthetic data based on schema documentation, API specifications,
5//! and domain knowledge.
6
7use crate::reflection::schema_graph::SchemaGraph;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tracing::{debug, info, warn};
11
12#[cfg(feature = "data-faker")]
13use mockforge_data::rag::{RagConfig, RagEngine};
14
15/// Configuration for RAG-driven data synthesis
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct RagSynthesisConfig {
18    /// Enable RAG-driven synthesis
19    pub enabled: bool,
20    /// RAG engine configuration
21    pub rag_config: Option<RagSynthesisRagConfig>,
22    /// Domain context sources
23    pub context_sources: Vec<ContextSource>,
24    /// Prompt templates for different entity types
25    pub prompt_templates: HashMap<String, PromptTemplate>,
26    /// Maximum context length for RAG queries
27    pub max_context_length: usize,
28    /// Cache generated contexts for performance
29    pub cache_contexts: bool,
30}
31
32/// RAG configuration specific to synthesis (wrapper around mockforge_data::RagConfig)
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct RagSynthesisRagConfig {
35    /// API endpoint
36    pub api_endpoint: String,
37    /// API key for authentication
38    pub api_key: Option<String>,
39    /// Model name
40    pub model: String,
41    /// Embedding model configuration
42    pub embedding_model: String,
43    /// Search similarity threshold
44    pub similarity_threshold: f64,
45    /// Maximum documents to retrieve
46    pub max_documents: usize,
47}
48
49/// Source of domain context for RAG synthesis
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ContextSource {
52    /// Source identifier
53    pub id: String,
54    /// Source type (documentation, examples, etc.)
55    pub source_type: ContextSourceType,
56    /// Path or URL to the source
57    pub path: String,
58    /// Weight for this source in context generation
59    pub weight: f32,
60    /// Whether this source is required for synthesis
61    pub required: bool,
62}
63
64/// Types of context sources
65#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(rename_all = "snake_case")]
67pub enum ContextSourceType {
68    /// API documentation (OpenAPI, proto comments)
69    Documentation,
70    /// Example data files (JSON, YAML)
71    Examples,
72    /// Business rules and constraints
73    BusinessRules,
74    /// Domain glossary/terminology
75    Glossary,
76    /// External knowledge base
77    KnowledgeBase,
78}
79
80/// Template for generating RAG prompts
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct PromptTemplate {
83    /// Template name/identifier
84    pub name: String,
85    /// Entity types this template applies to
86    pub entity_types: Vec<String>,
87    /// Template string with placeholders
88    pub template: String,
89    /// Variables that can be substituted in the template
90    pub variables: Vec<String>,
91    /// Examples of expected outputs
92    pub examples: Vec<PromptExample>,
93}
94
95/// Example for prompt template
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct PromptExample {
98    /// Input context
99    pub input: HashMap<String, String>,
100    /// Expected output
101    pub output: String,
102    /// Description of this example
103    pub description: String,
104}
105
106/// Context extracted for an entity using RAG
107#[derive(Debug, Clone)]
108pub struct EntityContext {
109    /// Entity name
110    pub entity_name: String,
111    /// Domain context from RAG
112    pub domain_context: String,
113    /// Related entities and their contexts
114    pub related_contexts: HashMap<String, String>,
115    /// Business rules applicable to this entity
116    pub business_rules: Vec<BusinessRule>,
117    /// Example values from documentation
118    pub example_values: HashMap<String, Vec<String>>,
119}
120
121/// A business rule extracted from context
122#[derive(Debug, Clone)]
123pub struct BusinessRule {
124    /// Rule description
125    pub description: String,
126    /// Fields this rule applies to
127    pub applies_to_fields: Vec<String>,
128    /// Rule type (constraint, format, relationship, etc.)
129    pub rule_type: BusinessRuleType,
130    /// Rule parameters/configuration
131    pub parameters: HashMap<String, String>,
132}
133
134/// Types of business rules
135#[derive(Debug, Clone)]
136pub enum BusinessRuleType {
137    /// Format constraint (email format, phone format, etc.)
138    Format,
139    /// Value range constraint
140    Range,
141    /// Relationship constraint (foreign key rules)
142    Relationship,
143    /// Business logic constraint
144    BusinessLogic,
145    /// Validation rule
146    Validation,
147}
148
149/// RAG-driven data synthesis engine
150pub struct RagDataSynthesizer {
151    /// Configuration
152    config: RagSynthesisConfig,
153    /// RAG engine instance
154    #[cfg(feature = "data-faker")]
155    rag_engine: Option<RagEngine>,
156    /// Cached entity contexts
157    entity_contexts: HashMap<String, EntityContext>,
158    /// Schema graph for relationship understanding
159    schema_graph: Option<SchemaGraph>,
160}
161
162impl RagDataSynthesizer {
163    /// Create a new RAG data synthesizer
164    pub fn new(config: RagSynthesisConfig) -> Self {
165        #[cfg(feature = "data-faker")]
166        let rag_engine = if config.enabled && config.rag_config.is_some() {
167            let rag_config = config.rag_config.as_ref().unwrap();
168            match Self::initialize_rag_engine(rag_config) {
169                Ok(engine) => Some(engine),
170                Err(e) => {
171                    warn!("Failed to initialize RAG engine: {}", e);
172                    None
173                }
174            }
175        } else {
176            None
177        };
178
179        Self {
180            config,
181            #[cfg(feature = "data-faker")]
182            rag_engine,
183            entity_contexts: HashMap::new(),
184            schema_graph: None,
185        }
186    }
187
188    /// Set the schema graph for relationship-aware synthesis
189    pub fn set_schema_graph(&mut self, schema_graph: SchemaGraph) {
190        let entity_count = schema_graph.entities.len();
191        self.schema_graph = Some(schema_graph);
192        info!("Schema graph set with {} entities", entity_count);
193    }
194
195    /// Generate domain context for an entity using RAG
196    pub async fn generate_entity_context(
197        &mut self,
198        entity_name: &str,
199    ) -> Result<EntityContext, Box<dyn std::error::Error + Send + Sync>> {
200        // Check cache first
201        if let Some(cached_context) = self.entity_contexts.get(entity_name) {
202            return Ok(cached_context.clone());
203        }
204
205        info!("Generating RAG context for entity: {}", entity_name);
206
207        let mut context = EntityContext {
208            entity_name: entity_name.to_string(),
209            domain_context: String::new(),
210            related_contexts: HashMap::new(),
211            business_rules: Vec::new(),
212            example_values: HashMap::new(),
213        };
214
215        // Generate base context using RAG
216        if self.config.enabled {
217            context.domain_context = self.query_rag_for_entity(entity_name).await?;
218        }
219
220        // Extract business rules from context
221        context.business_rules =
222            self.extract_business_rules(&context.domain_context, entity_name)?;
223
224        // Find example values from context sources
225        context.example_values =
226            self.extract_example_values(&context.domain_context, entity_name)?;
227
228        // Generate related entity contexts if schema graph is available
229        if let Some(schema_graph) = &self.schema_graph {
230            context.related_contexts =
231                self.generate_related_contexts(entity_name, schema_graph).await?;
232        }
233
234        // Cache the context
235        if self.config.cache_contexts {
236            self.entity_contexts.insert(entity_name.to_string(), context.clone());
237        }
238
239        Ok(context)
240    }
241
242    /// Generate contextually appropriate data for an entity field
243    pub async fn synthesize_field_data(
244        &mut self,
245        entity_name: &str,
246        field_name: &str,
247        field_type: &str,
248    ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
249        let context = self.generate_entity_context(entity_name).await?;
250
251        // Check for example values first
252        if let Some(examples) = context.example_values.get(field_name) {
253            if !examples.is_empty() {
254                // Use a deterministic example selection based on field name hash for stability
255                let field_hash = self.hash_field_name(field_name);
256                let index = field_hash as usize % examples.len();
257                return Ok(Some(examples[index].clone()));
258            }
259        }
260
261        // Apply business rules
262        for rule in &context.business_rules {
263            if rule.applies_to_fields.contains(&field_name.to_string()) {
264                if let Some(value) = self.apply_business_rule(rule, field_name, field_type)? {
265                    return Ok(Some(value));
266                }
267            }
268        }
269
270        // Use RAG to generate contextually appropriate value
271        if self.config.enabled && !context.domain_context.is_empty() {
272            let rag_value =
273                self.generate_contextual_value(&context, field_name, field_type).await?;
274            if !rag_value.is_empty() {
275                return Ok(Some(rag_value));
276            }
277        }
278
279        Ok(None)
280    }
281
282    /// Initialize RAG engine from configuration
283    #[cfg(feature = "data-faker")]
284    fn initialize_rag_engine(
285        config: &RagSynthesisRagConfig,
286    ) -> Result<RagEngine, Box<dyn std::error::Error + Send + Sync>> {
287        let rag_config = RagConfig {
288            provider: mockforge_data::rag::LlmProvider::OpenAI,
289            api_endpoint: config.api_endpoint.clone(),
290            api_key: config.api_key.clone(),
291            model: config.model.clone(),
292            max_tokens: 1000,
293            temperature: 0.7,
294            context_window: 4000,
295            semantic_search_enabled: true,
296            embedding_provider: mockforge_data::rag::EmbeddingProvider::OpenAI,
297            embedding_model: config.embedding_model.clone(),
298            embedding_endpoint: None,
299            similarity_threshold: config.similarity_threshold,
300            max_chunks: config.max_documents,
301            request_timeout_seconds: 30,
302            max_retries: 3,
303        };
304
305        Ok(RagEngine::new(rag_config))
306    }
307
308    /// Query RAG system for entity-specific context
309    async fn query_rag_for_entity(
310        &self,
311        entity_name: &str,
312    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
313        #[cfg(feature = "data-faker")]
314        if let Some(rag_engine) = &self.rag_engine {
315            let query = format!("What is {} in this domain? What are typical values and constraints for {} entities?", entity_name, entity_name);
316
317            let chunks = rag_engine
318                .keyword_search(&query, self.config.rag_config.as_ref().unwrap().max_documents);
319            if !chunks.is_empty() {
320                let context = chunks
321                    .into_iter()
322                    .map(|chunk| &chunk.content)
323                    .cloned()
324                    .collect::<Vec<_>>()
325                    .join("\n\n");
326                return Ok(context);
327            } else {
328                warn!("No RAG results found for entity {}", entity_name);
329            }
330        }
331
332        // Fallback to basic context
333        Ok(format!("Entity: {} - A data entity in the system", entity_name))
334    }
335
336    /// Extract business rules from context text
337    fn extract_business_rules(
338        &self,
339        context: &str,
340        entity_name: &str,
341    ) -> Result<Vec<BusinessRule>, Box<dyn std::error::Error + Send + Sync>> {
342        let mut rules = Vec::new();
343
344        // Simple rule extraction - can be enhanced with NLP
345        if context.to_lowercase().contains("email") && context.to_lowercase().contains("format") {
346            rules.push(BusinessRule {
347                description: "Email fields must follow email format".to_string(),
348                applies_to_fields: vec!["email".to_string(), "email_address".to_string()],
349                rule_type: BusinessRuleType::Format,
350                parameters: {
351                    let mut params = HashMap::new();
352                    params.insert("format".to_string(), "email".to_string());
353                    params
354                },
355            });
356        }
357
358        if context.to_lowercase().contains("phone") && context.to_lowercase().contains("number") {
359            rules.push(BusinessRule {
360                description: "Phone fields must follow phone number format".to_string(),
361                applies_to_fields: vec![
362                    "phone".to_string(),
363                    "mobile".to_string(),
364                    "phone_number".to_string(),
365                ],
366                rule_type: BusinessRuleType::Format,
367                parameters: {
368                    let mut params = HashMap::new();
369                    params.insert("format".to_string(), "phone".to_string());
370                    params
371                },
372            });
373        }
374
375        debug!("Extracted {} business rules for entity {}", rules.len(), entity_name);
376        Ok(rules)
377    }
378
379    /// Extract example values from context
380    fn extract_example_values(
381        &self,
382        context: &str,
383        _entity_name: &str,
384    ) -> Result<HashMap<String, Vec<String>>, Box<dyn std::error::Error + Send + Sync>> {
385        let mut examples = HashMap::new();
386
387        // Simple example extraction - can be enhanced with regex/NLP
388        let lines: Vec<&str> = context.lines().collect();
389        for line in lines {
390            if line.contains("example:") || line.contains("e.g.") {
391                // Extract examples from line - simplified implementation
392                if line.to_lowercase().contains("email") {
393                    examples
394                        .entry("email".to_string())
395                        .or_insert_with(Vec::new)
396                        .push("user@example.com".to_string());
397                }
398                if line.to_lowercase().contains("name") {
399                    examples
400                        .entry("name".to_string())
401                        .or_insert_with(Vec::new)
402                        .push("John Doe".to_string());
403                }
404            }
405        }
406
407        Ok(examples)
408    }
409
410    /// Generate contexts for related entities
411    async fn generate_related_contexts(
412        &self,
413        entity_name: &str,
414        schema_graph: &SchemaGraph,
415    ) -> Result<HashMap<String, String>, Box<dyn std::error::Error + Send + Sync>> {
416        let mut related_contexts = HashMap::new();
417
418        if let Some(entity) = schema_graph.entities.get(entity_name) {
419            for related_entity in &entity.references {
420                if related_entity != entity_name {
421                    let related_context = self.query_rag_for_entity(related_entity).await?;
422                    related_contexts.insert(related_entity.clone(), related_context);
423                }
424            }
425        }
426
427        Ok(related_contexts)
428    }
429
430    /// Apply a business rule to generate field value
431    fn apply_business_rule(
432        &self,
433        rule: &BusinessRule,
434        field_name: &str,
435        _field_type: &str,
436    ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
437        match rule.rule_type {
438            BusinessRuleType::Format => {
439                if let Some(format) = rule.parameters.get("format") {
440                    match format.as_str() {
441                        "email" => return Ok(Some("user@example.com".to_string())),
442                        "phone" => return Ok(Some("+1-555-0123".to_string())),
443                        _ => {}
444                    }
445                }
446            }
447            BusinessRuleType::Range => {
448                // Apply range constraints
449                if let (Some(min), Some(max)) =
450                    (rule.parameters.get("min"), rule.parameters.get("max"))
451                {
452                    if let (Ok(min_val), Ok(max_val)) = (min.parse::<i32>(), max.parse::<i32>()) {
453                        // Use deterministic value based on field name hash
454                        let field_hash = self.hash_field_name(field_name);
455                        let value = (field_hash as i32 % (max_val - min_val)) + min_val;
456                        return Ok(Some(value.to_string()));
457                    }
458                }
459            }
460            _ => {
461                debug!("Unhandled business rule type for field {}", field_name);
462            }
463        }
464
465        Ok(None)
466    }
467
468    /// Generate contextual value using RAG
469    async fn generate_contextual_value(
470        &self,
471        context: &EntityContext,
472        field_name: &str,
473        field_type: &str,
474    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
475        // Use a prompt template to generate contextually appropriate value
476        if let Some(template) = self.find_applicable_template(&context.entity_name) {
477            let prompt =
478                self.build_prompt_from_template(template, context, field_name, field_type)?;
479
480            #[cfg(feature = "data-faker")]
481            if let Some(rag_engine) = &self.rag_engine {
482                let chunks = rag_engine.keyword_search(&prompt, 1);
483                if let Some(chunk) = chunks.first() {
484                    return Ok(chunk.content.clone());
485                } else {
486                    debug!("No contextual value found for prompt: {}", prompt);
487                }
488            }
489        }
490
491        // Fallback to basic contextual generation
492        Ok(format!("contextual_{}_{}", context.entity_name.to_lowercase(), field_name))
493    }
494
495    /// Find applicable prompt template for entity
496    fn find_applicable_template(&self, entity_name: &str) -> Option<&PromptTemplate> {
497        self.config.prompt_templates.values().find(|template| {
498            template.entity_types.contains(&entity_name.to_string())
499                || template.entity_types.contains(&"*".to_string())
500        })
501    }
502
503    /// Build prompt from template
504    fn build_prompt_from_template(
505        &self,
506        template: &PromptTemplate,
507        context: &EntityContext,
508        field_name: &str,
509        field_type: &str,
510    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
511        let mut prompt = template.template.clone();
512
513        // Replace variables in template
514        prompt = prompt.replace("{entity_name}", &context.entity_name);
515        prompt = prompt.replace("{field_name}", field_name);
516        prompt = prompt.replace("{field_type}", field_type);
517        prompt = prompt.replace("{domain_context}", &context.domain_context);
518
519        Ok(prompt)
520    }
521
522    /// Get configuration
523    pub fn config(&self) -> &RagSynthesisConfig {
524        &self.config
525    }
526
527    /// Check if RAG synthesis is enabled and available
528    pub fn is_enabled(&self) -> bool {
529        self.config.enabled && {
530            #[cfg(feature = "data-faker")]
531            {
532                self.rag_engine.is_some()
533            }
534            #[cfg(not(feature = "data-faker"))]
535            {
536                false
537            }
538        }
539    }
540
541    /// Generate a deterministic hash for a field name for stable data generation
542    pub fn hash_field_name(&self, field_name: &str) -> u64 {
543        use std::collections::hash_map::DefaultHasher;
544        use std::hash::{Hash, Hasher};
545
546        let mut hasher = DefaultHasher::new();
547        field_name.hash(&mut hasher);
548        hasher.finish()
549    }
550}
551
552impl Default for RagSynthesisConfig {
553    fn default() -> Self {
554        let mut prompt_templates = HashMap::new();
555
556        // Default template for all entities
557        prompt_templates.insert("default".to_string(), PromptTemplate {
558            name: "default".to_string(),
559            entity_types: vec!["*".to_string()],
560            template: "Generate a realistic value for {field_name} field of type {field_type} in a {entity_name} entity. Context: {domain_context}".to_string(),
561            variables: vec!["entity_name".to_string(), "field_name".to_string(), "field_type".to_string(), "domain_context".to_string()],
562            examples: vec![],
563        });
564
565        Self {
566            enabled: false,
567            rag_config: None,
568            context_sources: vec![],
569            prompt_templates,
570            max_context_length: 2000,
571            cache_contexts: true,
572        }
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579
580    #[test]
581    fn test_default_config() {
582        let config = RagSynthesisConfig::default();
583        assert!(!config.enabled);
584        assert!(config.prompt_templates.contains_key("default"));
585        assert!(config.cache_contexts);
586    }
587
588    #[tokio::test]
589    async fn test_synthesizer_creation() {
590        let config = RagSynthesisConfig::default();
591        let synthesizer = RagDataSynthesizer::new(config);
592        assert!(!synthesizer.is_enabled());
593    }
594
595    #[test]
596    fn test_business_rule_extraction() {
597        let config = RagSynthesisConfig::default();
598        let synthesizer = RagDataSynthesizer::new(config);
599
600        let context = "Users must provide a valid email format. Phone numbers should be in international format.";
601        let rules = synthesizer.extract_business_rules(context, "User").unwrap();
602
603        assert!(!rules.is_empty());
604        assert!(rules.iter().any(|r| matches!(r.rule_type, BusinessRuleType::Format)));
605    }
606}