1use crate::reflection::schema_graph::SchemaGraph;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tracing::{debug, info, warn};
11
12#[cfg(feature = "data-faker")]
13use mockforge_data::rag::{RagConfig, RagEngine};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct RagSynthesisConfig {
18 pub enabled: bool,
20 pub rag_config: Option<RagSynthesisRagConfig>,
22 pub context_sources: Vec<ContextSource>,
24 pub prompt_templates: HashMap<String, PromptTemplate>,
26 pub max_context_length: usize,
28 pub cache_contexts: bool,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct RagSynthesisRagConfig {
35 pub api_endpoint: String,
37 pub api_key: Option<String>,
39 pub model: String,
41 pub embedding_model: String,
43 pub similarity_threshold: f64,
45 pub max_documents: usize,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ContextSource {
52 pub id: String,
54 pub source_type: ContextSourceType,
56 pub path: String,
58 pub weight: f32,
60 pub required: bool,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(rename_all = "snake_case")]
67pub enum ContextSourceType {
68 Documentation,
70 Examples,
72 BusinessRules,
74 Glossary,
76 KnowledgeBase,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct PromptTemplate {
83 pub name: String,
85 pub entity_types: Vec<String>,
87 pub template: String,
89 pub variables: Vec<String>,
91 pub examples: Vec<PromptExample>,
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct PromptExample {
98 pub input: HashMap<String, String>,
100 pub output: String,
102 pub description: String,
104}
105
106#[derive(Debug, Clone)]
108pub struct EntityContext {
109 pub entity_name: String,
111 pub domain_context: String,
113 pub related_contexts: HashMap<String, String>,
115 pub business_rules: Vec<BusinessRule>,
117 pub example_values: HashMap<String, Vec<String>>,
119}
120
121#[derive(Debug, Clone)]
123pub struct BusinessRule {
124 pub description: String,
126 pub applies_to_fields: Vec<String>,
128 pub rule_type: BusinessRuleType,
130 pub parameters: HashMap<String, String>,
132}
133
134#[derive(Debug, Clone)]
136pub enum BusinessRuleType {
137 Format,
139 Range,
141 Relationship,
143 BusinessLogic,
145 Validation,
147}
148
149pub struct RagDataSynthesizer {
151 config: RagSynthesisConfig,
153 #[cfg(feature = "data-faker")]
155 rag_engine: Option<RagEngine>,
156 entity_contexts: HashMap<String, EntityContext>,
158 schema_graph: Option<SchemaGraph>,
160}
161
162impl RagDataSynthesizer {
163 pub fn new(config: RagSynthesisConfig) -> Self {
165 #[cfg(feature = "data-faker")]
166 let rag_engine = if config.enabled && config.rag_config.is_some() {
167 let rag_config = config.rag_config.as_ref().unwrap();
168 match Self::initialize_rag_engine(rag_config) {
169 Ok(engine) => Some(engine),
170 Err(e) => {
171 warn!("Failed to initialize RAG engine: {}", e);
172 None
173 }
174 }
175 } else {
176 None
177 };
178
179 Self {
180 config,
181 #[cfg(feature = "data-faker")]
182 rag_engine,
183 entity_contexts: HashMap::new(),
184 schema_graph: None,
185 }
186 }
187
188 pub fn set_schema_graph(&mut self, schema_graph: SchemaGraph) {
190 let entity_count = schema_graph.entities.len();
191 self.schema_graph = Some(schema_graph);
192 info!("Schema graph set with {} entities", entity_count);
193 }
194
195 pub async fn generate_entity_context(
197 &mut self,
198 entity_name: &str,
199 ) -> Result<EntityContext, Box<dyn std::error::Error + Send + Sync>> {
200 if let Some(cached_context) = self.entity_contexts.get(entity_name) {
202 return Ok(cached_context.clone());
203 }
204
205 info!("Generating RAG context for entity: {}", entity_name);
206
207 let mut context = EntityContext {
208 entity_name: entity_name.to_string(),
209 domain_context: String::new(),
210 related_contexts: HashMap::new(),
211 business_rules: Vec::new(),
212 example_values: HashMap::new(),
213 };
214
215 if self.config.enabled {
217 context.domain_context = self.query_rag_for_entity(entity_name).await?;
218 }
219
220 context.business_rules =
222 self.extract_business_rules(&context.domain_context, entity_name)?;
223
224 context.example_values =
226 self.extract_example_values(&context.domain_context, entity_name)?;
227
228 if let Some(schema_graph) = &self.schema_graph {
230 context.related_contexts =
231 self.generate_related_contexts(entity_name, schema_graph).await?;
232 }
233
234 if self.config.cache_contexts {
236 self.entity_contexts.insert(entity_name.to_string(), context.clone());
237 }
238
239 Ok(context)
240 }
241
242 pub async fn synthesize_field_data(
244 &mut self,
245 entity_name: &str,
246 field_name: &str,
247 field_type: &str,
248 ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
249 let context = self.generate_entity_context(entity_name).await?;
250
251 if let Some(examples) = context.example_values.get(field_name) {
253 if !examples.is_empty() {
254 let field_hash = self.hash_field_name(field_name);
256 let index = field_hash as usize % examples.len();
257 return Ok(Some(examples[index].clone()));
258 }
259 }
260
261 for rule in &context.business_rules {
263 if rule.applies_to_fields.contains(&field_name.to_string()) {
264 if let Some(value) = self.apply_business_rule(rule, field_name, field_type)? {
265 return Ok(Some(value));
266 }
267 }
268 }
269
270 if self.config.enabled && !context.domain_context.is_empty() {
272 let rag_value =
273 self.generate_contextual_value(&context, field_name, field_type).await?;
274 if !rag_value.is_empty() {
275 return Ok(Some(rag_value));
276 }
277 }
278
279 Ok(None)
280 }
281
282 #[cfg(feature = "data-faker")]
284 fn initialize_rag_engine(
285 config: &RagSynthesisRagConfig,
286 ) -> Result<RagEngine, Box<dyn std::error::Error + Send + Sync>> {
287 let rag_config = RagConfig {
288 provider: mockforge_data::rag::LlmProvider::OpenAI,
289 api_endpoint: config.api_endpoint.clone(),
290 api_key: config.api_key.clone(),
291 model: config.model.clone(),
292 max_tokens: 1000,
293 temperature: 0.7,
294 context_window: 4000,
295 semantic_search_enabled: true,
296 embedding_provider: mockforge_data::rag::EmbeddingProvider::OpenAI,
297 embedding_model: config.embedding_model.clone(),
298 embedding_endpoint: None,
299 similarity_threshold: config.similarity_threshold,
300 max_chunks: config.max_documents,
301 request_timeout_seconds: 30,
302 max_retries: 3,
303 };
304
305 Ok(RagEngine::new(rag_config))
306 }
307
308 async fn query_rag_for_entity(
310 &self,
311 entity_name: &str,
312 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
313 #[cfg(feature = "data-faker")]
314 if let Some(rag_engine) = &self.rag_engine {
315 let query = format!("What is {} in this domain? What are typical values and constraints for {} entities?", entity_name, entity_name);
316
317 let chunks = rag_engine
318 .keyword_search(&query, self.config.rag_config.as_ref().unwrap().max_documents);
319 if !chunks.is_empty() {
320 let context = chunks
321 .into_iter()
322 .map(|chunk| &chunk.content)
323 .cloned()
324 .collect::<Vec<_>>()
325 .join("\n\n");
326 return Ok(context);
327 } else {
328 warn!("No RAG results found for entity {}", entity_name);
329 }
330 }
331
332 Ok(format!("Entity: {} - A data entity in the system", entity_name))
334 }
335
336 fn extract_business_rules(
338 &self,
339 context: &str,
340 entity_name: &str,
341 ) -> Result<Vec<BusinessRule>, Box<dyn std::error::Error + Send + Sync>> {
342 let mut rules = Vec::new();
343
344 if context.to_lowercase().contains("email") && context.to_lowercase().contains("format") {
346 rules.push(BusinessRule {
347 description: "Email fields must follow email format".to_string(),
348 applies_to_fields: vec!["email".to_string(), "email_address".to_string()],
349 rule_type: BusinessRuleType::Format,
350 parameters: {
351 let mut params = HashMap::new();
352 params.insert("format".to_string(), "email".to_string());
353 params
354 },
355 });
356 }
357
358 if context.to_lowercase().contains("phone") && context.to_lowercase().contains("number") {
359 rules.push(BusinessRule {
360 description: "Phone fields must follow phone number format".to_string(),
361 applies_to_fields: vec![
362 "phone".to_string(),
363 "mobile".to_string(),
364 "phone_number".to_string(),
365 ],
366 rule_type: BusinessRuleType::Format,
367 parameters: {
368 let mut params = HashMap::new();
369 params.insert("format".to_string(), "phone".to_string());
370 params
371 },
372 });
373 }
374
375 debug!("Extracted {} business rules for entity {}", rules.len(), entity_name);
376 Ok(rules)
377 }
378
379 fn extract_example_values(
381 &self,
382 context: &str,
383 _entity_name: &str,
384 ) -> Result<HashMap<String, Vec<String>>, Box<dyn std::error::Error + Send + Sync>> {
385 let mut examples = HashMap::new();
386
387 let lines: Vec<&str> = context.lines().collect();
389 for line in lines {
390 if line.contains("example:") || line.contains("e.g.") {
391 if line.to_lowercase().contains("email") {
393 examples
394 .entry("email".to_string())
395 .or_insert_with(Vec::new)
396 .push("user@example.com".to_string());
397 }
398 if line.to_lowercase().contains("name") {
399 examples
400 .entry("name".to_string())
401 .or_insert_with(Vec::new)
402 .push("John Doe".to_string());
403 }
404 }
405 }
406
407 Ok(examples)
408 }
409
410 async fn generate_related_contexts(
412 &self,
413 entity_name: &str,
414 schema_graph: &SchemaGraph,
415 ) -> Result<HashMap<String, String>, Box<dyn std::error::Error + Send + Sync>> {
416 let mut related_contexts = HashMap::new();
417
418 if let Some(entity) = schema_graph.entities.get(entity_name) {
419 for related_entity in &entity.references {
420 if related_entity != entity_name {
421 let related_context = self.query_rag_for_entity(related_entity).await?;
422 related_contexts.insert(related_entity.clone(), related_context);
423 }
424 }
425 }
426
427 Ok(related_contexts)
428 }
429
430 fn apply_business_rule(
432 &self,
433 rule: &BusinessRule,
434 field_name: &str,
435 _field_type: &str,
436 ) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
437 match rule.rule_type {
438 BusinessRuleType::Format => {
439 if let Some(format) = rule.parameters.get("format") {
440 match format.as_str() {
441 "email" => return Ok(Some("user@example.com".to_string())),
442 "phone" => return Ok(Some("+1-555-0123".to_string())),
443 _ => {}
444 }
445 }
446 }
447 BusinessRuleType::Range => {
448 if let (Some(min), Some(max)) =
450 (rule.parameters.get("min"), rule.parameters.get("max"))
451 {
452 if let (Ok(min_val), Ok(max_val)) = (min.parse::<i32>(), max.parse::<i32>()) {
453 let field_hash = self.hash_field_name(field_name);
455 let value = (field_hash as i32 % (max_val - min_val)) + min_val;
456 return Ok(Some(value.to_string()));
457 }
458 }
459 }
460 _ => {
461 debug!("Unhandled business rule type for field {}", field_name);
462 }
463 }
464
465 Ok(None)
466 }
467
468 async fn generate_contextual_value(
470 &self,
471 context: &EntityContext,
472 field_name: &str,
473 field_type: &str,
474 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
475 if let Some(template) = self.find_applicable_template(&context.entity_name) {
477 let prompt =
478 self.build_prompt_from_template(template, context, field_name, field_type)?;
479
480 #[cfg(feature = "data-faker")]
481 if let Some(rag_engine) = &self.rag_engine {
482 let chunks = rag_engine.keyword_search(&prompt, 1);
483 if let Some(chunk) = chunks.first() {
484 return Ok(chunk.content.clone());
485 } else {
486 debug!("No contextual value found for prompt: {}", prompt);
487 }
488 }
489 }
490
491 Ok(format!("contextual_{}_{}", context.entity_name.to_lowercase(), field_name))
493 }
494
495 fn find_applicable_template(&self, entity_name: &str) -> Option<&PromptTemplate> {
497 self.config.prompt_templates.values().find(|template| {
498 template.entity_types.contains(&entity_name.to_string())
499 || template.entity_types.contains(&"*".to_string())
500 })
501 }
502
503 fn build_prompt_from_template(
505 &self,
506 template: &PromptTemplate,
507 context: &EntityContext,
508 field_name: &str,
509 field_type: &str,
510 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
511 let mut prompt = template.template.clone();
512
513 prompt = prompt.replace("{entity_name}", &context.entity_name);
515 prompt = prompt.replace("{field_name}", field_name);
516 prompt = prompt.replace("{field_type}", field_type);
517 prompt = prompt.replace("{domain_context}", &context.domain_context);
518
519 Ok(prompt)
520 }
521
522 pub fn config(&self) -> &RagSynthesisConfig {
524 &self.config
525 }
526
527 pub fn is_enabled(&self) -> bool {
529 self.config.enabled && {
530 #[cfg(feature = "data-faker")]
531 {
532 self.rag_engine.is_some()
533 }
534 #[cfg(not(feature = "data-faker"))]
535 {
536 false
537 }
538 }
539 }
540
541 pub fn hash_field_name(&self, field_name: &str) -> u64 {
543 use std::collections::hash_map::DefaultHasher;
544 use std::hash::{Hash, Hasher};
545
546 let mut hasher = DefaultHasher::new();
547 field_name.hash(&mut hasher);
548 hasher.finish()
549 }
550}
551
552impl Default for RagSynthesisConfig {
553 fn default() -> Self {
554 let mut prompt_templates = HashMap::new();
555
556 prompt_templates.insert("default".to_string(), PromptTemplate {
558 name: "default".to_string(),
559 entity_types: vec!["*".to_string()],
560 template: "Generate a realistic value for {field_name} field of type {field_type} in a {entity_name} entity. Context: {domain_context}".to_string(),
561 variables: vec!["entity_name".to_string(), "field_name".to_string(), "field_type".to_string(), "domain_context".to_string()],
562 examples: vec![],
563 });
564
565 Self {
566 enabled: false,
567 rag_config: None,
568 context_sources: vec![],
569 prompt_templates,
570 max_context_length: 2000,
571 cache_contexts: true,
572 }
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579
580 #[test]
581 fn test_default_config() {
582 let config = RagSynthesisConfig::default();
583 assert!(!config.enabled);
584 assert!(config.prompt_templates.contains_key("default"));
585 assert!(config.cache_contexts);
586 }
587
588 #[tokio::test]
589 async fn test_synthesizer_creation() {
590 let config = RagSynthesisConfig::default();
591 let synthesizer = RagDataSynthesizer::new(config);
592 assert!(!synthesizer.is_enabled());
593 }
594
595 #[test]
596 fn test_business_rule_extraction() {
597 let config = RagSynthesisConfig::default();
598 let synthesizer = RagDataSynthesizer::new(config);
599
600 let context = "Users must provide a valid email format. Phone numbers should be in international format.";
601 let rules = synthesizer.extract_business_rules(context, "User").unwrap();
602
603 assert!(!rules.is_empty());
604 assert!(rules.iter().any(|r| matches!(r.rule_type, BusinessRuleType::Format)));
605 }
606}