use crate::reflection::schema_graph::SchemaGraph;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tracing::{debug, info, warn};
#[cfg(feature = "data-faker")]
use mockforge_data::rag::{RagConfig, RagEngine};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagSynthesisConfig {
pub enabled: bool,
pub rag_config: Option<RagSynthesisRagConfig>,
pub context_sources: Vec<ContextSource>,
pub prompt_templates: HashMap<String, PromptTemplate>,
pub max_context_length: usize,
pub cache_contexts: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagSynthesisRagConfig {
pub api_endpoint: String,
pub api_key: Option<String>,
pub model: String,
pub embedding_model: String,
pub similarity_threshold: f64,
pub max_documents: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextSource {
pub id: String,
pub source_type: ContextSourceType,
pub path: String,
pub weight: f32,
pub required: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ContextSourceType {
Documentation,
Examples,
BusinessRules,
Glossary,
KnowledgeBase,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptTemplate {
pub name: String,
pub entity_types: Vec<String>,
pub template: String,
pub variables: Vec<String>,
pub examples: Vec<PromptExample>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PromptExample {
pub input: HashMap<String, String>,
pub output: String,
pub description: String,
}
#[derive(Debug, Clone)]
pub struct EntityContext {
pub entity_name: String,
pub domain_context: String,
pub related_contexts: HashMap<String, String>,
pub business_rules: Vec<BusinessRule>,
pub example_values: HashMap<String, Vec<String>>,
}
#[derive(Debug, Clone)]
pub struct BusinessRule {
pub description: String,
pub applies_to_fields: Vec<String>,
pub rule_type: BusinessRuleType,
pub parameters: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub enum BusinessRuleType {
Format,
Range,
Relationship,
BusinessLogic,
Validation,
}
pub struct RagDataSynthesizer {
config: RagSynthesisConfig,
#[cfg(feature = "data-faker")]
rag_engine: Option<RagEngine>,
entity_contexts: HashMap<String, EntityContext>,
schema_graph: Option<SchemaGraph>,
}
impl RagDataSynthesizer {
pub fn new(config: RagSynthesisConfig) -> Self {
#[cfg(feature = "data-faker")]
let rag_engine = if config.enabled && config.rag_config.is_some() {
let rag_config = config.rag_config.as_ref().unwrap();
match Self::initialize_rag_engine(rag_config) {
Ok(engine) => Some(engine),
Err(e) => {
warn!("Failed to initialize RAG engine: {}", e);
None
}
}
} else {
None
};
Self {
config,
#[cfg(feature = "data-faker")]
rag_engine,
entity_contexts: HashMap::new(),
schema_graph: None,
}
}
pub fn set_schema_graph(&mut self, schema_graph: SchemaGraph) {
let entity_count = schema_graph.entities.len();
self.schema_graph = Some(schema_graph);
info!("Schema graph set with {} entities", entity_count);
}
pub async fn generate_entity_context(
&mut self,
entity_name: &str,
) -> Result<EntityContext, Box<dyn std::error::Error + Send + Sync>> {
if let Some(cached_context) = self.entity_contexts.get(entity_name) {
return Ok(cached_context.clone());
}
info!("Generating RAG context for entity: {}", entity_name);
let mut context = EntityContext {
entity_name: entity_name.to_string(),
domain_context: String::new(),
related_contexts: HashMap::new(),
business_rules: Vec::new(),
example_values: HashMap::new(),
};
if self.config.enabled {
context.domain_context = self.query_rag_for_entity(entity_name).await?;
}
context.business_rules =
self.extract_business_rules(&context.domain_context, entity_name)?;
context.example_values =
self.extract_example_values(&context.domain_context, entity_name)?;
if let Some(schema_graph) = &self.schema_graph {
context.related_contexts =
self.generate_related_contexts(entity_name, schema_graph).await?;
}
if self.config.cache_contexts {
self.entity_contexts.insert(entity_name.to_string(), context.clone());
}
Ok(context)
}
pub async fn synthesize_field_data(
&mut self,
entity_name: &str,
field_name: &str,
field_type: &str,
) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
let context = self.generate_entity_context(entity_name).await?;
if let Some(examples) = context.example_values.get(field_name) {
if !examples.is_empty() {
let field_hash = self.hash_field_name(field_name);
let index = field_hash as usize % examples.len();
return Ok(Some(examples[index].clone()));
}
}
for rule in &context.business_rules {
if rule.applies_to_fields.contains(&field_name.to_string()) {
if let Some(value) = self.apply_business_rule(rule, field_name, field_type)? {
return Ok(Some(value));
}
}
}
if self.config.enabled && !context.domain_context.is_empty() {
let rag_value =
self.generate_contextual_value(&context, field_name, field_type).await?;
if !rag_value.is_empty() {
return Ok(Some(rag_value));
}
}
Ok(None)
}
#[cfg(feature = "data-faker")]
fn initialize_rag_engine(
config: &RagSynthesisRagConfig,
) -> Result<RagEngine, Box<dyn std::error::Error + Send + Sync>> {
let rag_config = RagConfig {
provider: mockforge_data::rag::LlmProvider::OpenAI,
api_endpoint: config.api_endpoint.clone(),
api_key: config.api_key.clone(),
model: config.model.clone(),
max_tokens: 1000,
temperature: 0.7,
context_window: 4000,
semantic_search_enabled: true,
embedding_provider: mockforge_data::rag::EmbeddingProvider::OpenAI,
embedding_model: config.embedding_model.clone(),
embedding_endpoint: None,
similarity_threshold: config.similarity_threshold,
max_chunks: config.max_documents,
request_timeout_seconds: 30,
max_retries: 3,
};
Ok(RagEngine::new(rag_config))
}
async fn query_rag_for_entity(
&self,
entity_name: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
#[cfg(feature = "data-faker")]
if let Some(rag_engine) = &self.rag_engine {
let query = format!("What is {} in this domain? What are typical values and constraints for {} entities?", entity_name, entity_name);
let chunks = rag_engine
.keyword_search(&query, self.config.rag_config.as_ref().unwrap().max_documents);
if !chunks.is_empty() {
let context = chunks
.into_iter()
.map(|chunk| &chunk.content)
.cloned()
.collect::<Vec<_>>()
.join("\n\n");
return Ok(context);
} else {
warn!("No RAG results found for entity {}", entity_name);
}
}
Ok(format!("Entity: {} - A data entity in the system", entity_name))
}
fn extract_business_rules(
&self,
context: &str,
entity_name: &str,
) -> Result<Vec<BusinessRule>, Box<dyn std::error::Error + Send + Sync>> {
let mut rules = Vec::new();
if context.to_lowercase().contains("email") && context.to_lowercase().contains("format") {
rules.push(BusinessRule {
description: "Email fields must follow email format".to_string(),
applies_to_fields: vec!["email".to_string(), "email_address".to_string()],
rule_type: BusinessRuleType::Format,
parameters: {
let mut params = HashMap::new();
params.insert("format".to_string(), "email".to_string());
params
},
});
}
if context.to_lowercase().contains("phone") && context.to_lowercase().contains("number") {
rules.push(BusinessRule {
description: "Phone fields must follow phone number format".to_string(),
applies_to_fields: vec![
"phone".to_string(),
"mobile".to_string(),
"phone_number".to_string(),
],
rule_type: BusinessRuleType::Format,
parameters: {
let mut params = HashMap::new();
params.insert("format".to_string(), "phone".to_string());
params
},
});
}
debug!("Extracted {} business rules for entity {}", rules.len(), entity_name);
Ok(rules)
}
fn extract_example_values(
&self,
context: &str,
_entity_name: &str,
) -> Result<HashMap<String, Vec<String>>, Box<dyn std::error::Error + Send + Sync>> {
let mut examples = HashMap::new();
let lines: Vec<&str> = context.lines().collect();
for line in lines {
if line.contains("example:") || line.contains("e.g.") {
if line.to_lowercase().contains("email") {
examples
.entry("email".to_string())
.or_insert_with(Vec::new)
.push("user@example.com".to_string());
}
if line.to_lowercase().contains("name") {
examples
.entry("name".to_string())
.or_insert_with(Vec::new)
.push("John Doe".to_string());
}
}
}
Ok(examples)
}
async fn generate_related_contexts(
&self,
entity_name: &str,
schema_graph: &SchemaGraph,
) -> Result<HashMap<String, String>, Box<dyn std::error::Error + Send + Sync>> {
let mut related_contexts = HashMap::new();
if let Some(entity) = schema_graph.entities.get(entity_name) {
for related_entity in &entity.references {
if related_entity != entity_name {
let related_context = self.query_rag_for_entity(related_entity).await?;
related_contexts.insert(related_entity.clone(), related_context);
}
}
}
Ok(related_contexts)
}
fn apply_business_rule(
&self,
rule: &BusinessRule,
field_name: &str,
_field_type: &str,
) -> Result<Option<String>, Box<dyn std::error::Error + Send + Sync>> {
match rule.rule_type {
BusinessRuleType::Format => {
if let Some(format) = rule.parameters.get("format") {
match format.as_str() {
"email" => return Ok(Some("user@example.com".to_string())),
"phone" => return Ok(Some("+1-555-0123".to_string())),
_ => {}
}
}
}
BusinessRuleType::Range => {
if let (Some(min), Some(max)) =
(rule.parameters.get("min"), rule.parameters.get("max"))
{
if let (Ok(min_val), Ok(max_val)) = (min.parse::<i32>(), max.parse::<i32>()) {
let field_hash = self.hash_field_name(field_name);
let value = (field_hash as i32 % (max_val - min_val)) + min_val;
return Ok(Some(value.to_string()));
}
}
}
_ => {
debug!("Unhandled business rule type for field {}", field_name);
}
}
Ok(None)
}
async fn generate_contextual_value(
&self,
context: &EntityContext,
field_name: &str,
field_type: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
if let Some(template) = self.find_applicable_template(&context.entity_name) {
let prompt =
self.build_prompt_from_template(template, context, field_name, field_type)?;
#[cfg(feature = "data-faker")]
if let Some(rag_engine) = &self.rag_engine {
let chunks = rag_engine.keyword_search(&prompt, 1);
if let Some(chunk) = chunks.first() {
return Ok(chunk.content.clone());
} else {
debug!("No contextual value found for prompt: {}", prompt);
}
}
}
Ok(format!("contextual_{}_{}", context.entity_name.to_lowercase(), field_name))
}
fn find_applicable_template(&self, entity_name: &str) -> Option<&PromptTemplate> {
self.config.prompt_templates.values().find(|template| {
template.entity_types.contains(&entity_name.to_string())
|| template.entity_types.contains(&"*".to_string())
})
}
fn build_prompt_from_template(
&self,
template: &PromptTemplate,
context: &EntityContext,
field_name: &str,
field_type: &str,
) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
let mut prompt = template.template.clone();
prompt = prompt.replace("{entity_name}", &context.entity_name);
prompt = prompt.replace("{field_name}", field_name);
prompt = prompt.replace("{field_type}", field_type);
prompt = prompt.replace("{domain_context}", &context.domain_context);
Ok(prompt)
}
pub fn config(&self) -> &RagSynthesisConfig {
&self.config
}
pub fn is_enabled(&self) -> bool {
self.config.enabled && {
#[cfg(feature = "data-faker")]
{
self.rag_engine.is_some()
}
#[cfg(not(feature = "data-faker"))]
{
false
}
}
}
pub fn hash_field_name(&self, field_name: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
field_name.hash(&mut hasher);
hasher.finish()
}
}
impl Default for RagSynthesisConfig {
fn default() -> Self {
let mut prompt_templates = HashMap::new();
prompt_templates.insert("default".to_string(), PromptTemplate {
name: "default".to_string(),
entity_types: vec!["*".to_string()],
template: "Generate a realistic value for {field_name} field of type {field_type} in a {entity_name} entity. Context: {domain_context}".to_string(),
variables: vec!["entity_name".to_string(), "field_name".to_string(), "field_type".to_string(), "domain_context".to_string()],
examples: vec![],
});
Self {
enabled: false,
rag_config: None,
context_sources: vec![],
prompt_templates,
max_context_length: 2000,
cache_contexts: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = RagSynthesisConfig::default();
assert!(!config.enabled);
assert!(config.prompt_templates.contains_key("default"));
assert!(config.cache_contexts);
}
#[tokio::test]
async fn test_synthesizer_creation() {
let config = RagSynthesisConfig::default();
let synthesizer = RagDataSynthesizer::new(config);
assert!(!synthesizer.is_enabled());
}
#[test]
fn test_business_rule_extraction() {
let config = RagSynthesisConfig::default();
let synthesizer = RagDataSynthesizer::new(config);
let context = "Users must provide a valid email format. Phone numbers should be in international format.";
let rules = synthesizer.extract_business_rules(context, "User").unwrap();
assert!(!rules.is_empty());
assert!(rules.iter().any(|r| matches!(r.rule_type, BusinessRuleType::Format)));
}
}