mockforge_grpc/reflection/
schema_graph.rs

1//! Schema relationship graph extraction
2//!
3//! This module extracts relationship graphs from proto and OpenAPI schemas,
4//! identifying foreign keys, references, and data dependencies for coherent
5//! synthetic data generation.
6
7use prost_reflect::{DescriptorPool, FieldDescriptor, Kind, MessageDescriptor};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tracing::info;
11
12/// A graph representing relationships between schema entities
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SchemaGraph {
15    /// All entities (messages/schemas) in the graph
16    pub entities: HashMap<String, EntityNode>,
17    /// Direct relationships between entities
18    pub relationships: Vec<Relationship>,
19    /// Detected foreign key patterns
20    pub foreign_keys: HashMap<String, Vec<ForeignKeyMapping>>,
21}
22
23/// An entity node in the schema graph
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct EntityNode {
26    /// Entity name (e.g., "User", "Order")
27    pub name: String,
28    /// Full qualified name (e.g., "com.example.User")
29    pub full_name: String,
30    /// Fields in this entity
31    pub fields: Vec<FieldInfo>,
32    /// Whether this is a root entity (not referenced by others)
33    pub is_root: bool,
34    /// Entities that reference this one
35    pub referenced_by: Vec<String>,
36    /// Entities that this one references
37    pub references: Vec<String>,
38}
39
40/// Information about a field in an entity
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct FieldInfo {
43    /// Field name
44    pub name: String,
45    /// Field type (string, int32, message, etc.)
46    pub field_type: String,
47    /// Whether this field is a potential foreign key
48    pub is_foreign_key: bool,
49    /// Target entity if this is a foreign key
50    pub foreign_key_target: Option<String>,
51    /// Whether this field is required
52    pub is_required: bool,
53    /// Constraints on this field
54    pub constraints: HashMap<String, String>,
55}
56
57/// A relationship between two entities
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct Relationship {
60    /// Source entity name
61    pub from_entity: String,
62    /// Target entity name
63    pub to_entity: String,
64    /// Type of relationship
65    pub relationship_type: RelationshipType,
66    /// Field name that creates the relationship
67    pub field_name: String,
68    /// Whether this relationship is required
69    pub is_required: bool,
70    /// Cardinality constraints
71    pub cardinality: Cardinality,
72}
73
74/// Type of relationship between entities
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum RelationshipType {
77    /// Direct foreign key reference (user_id -> User)
78    ForeignKey,
79    /// Embedded object (address within user)
80    Embedded,
81    /// Array/repeated field relationship
82    OneToMany,
83    /// Bidirectional relationship
84    ManyToMany,
85    /// Composition relationship
86    Composition,
87}
88
89/// Cardinality constraints for relationships
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct Cardinality {
92    /// Minimum number of related entities
93    pub min: u32,
94    /// Maximum number of related entities (None = unlimited)
95    pub max: Option<u32>,
96}
97
98/// Foreign key mapping detected via naming conventions
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ForeignKeyMapping {
101    /// Field name (e.g., "user_id")
102    pub field_name: String,
103    /// Target entity name (e.g., "User")
104    pub target_entity: String,
105    /// Confidence score (0.0 - 1.0)
106    pub confidence: f64,
107    /// Detection method used
108    pub detection_method: ForeignKeyDetectionMethod,
109}
110
111/// Methods used to detect foreign key relationships
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum ForeignKeyDetectionMethod {
114    /// Detected via naming convention (user_id -> User)
115    NamingConvention,
116    /// Detected via schema reference ($ref in OpenAPI)
117    SchemaReference,
118    /// Detected via field type (message type in proto)
119    MessageType,
120    /// Detected via constraint annotation
121    Constraint,
122}
123
124/// Schema graph extractor for protobuf schemas
125pub struct ProtoSchemaGraphExtractor {
126    /// Common foreign key patterns
127    foreign_key_patterns: Vec<ForeignKeyPattern>,
128}
129
130/// Pattern for detecting foreign keys via naming
131#[derive(Debug, Clone)]
132struct ForeignKeyPattern {
133    /// Regex pattern for field names
134    pattern: regex::Regex,
135    /// How to extract entity name from field name
136    entity_extraction: EntityExtractionMethod,
137    /// Confidence score for this pattern
138    ///
139    /// Confidence score indicates how reliable this pattern is for detecting relationships.
140    /// Higher scores (closer to 1.0) indicate more reliable patterns.
141    confidence: f64,
142}
143
144/// Methods for extracting entity names from field names
145#[derive(Debug, Clone)]
146enum EntityExtractionMethod {
147    /// Remove suffix (user_id -> user)
148    RemoveSuffix(String),
149    /// Direct mapping
150    ///
151    /// TODO: Use when direct entity name mapping without transformation is needed
152    #[allow(dead_code)] // TODO: Remove when entity extraction feature is implemented
153    Direct,
154    /// Custom transform function
155    ///
156    /// TODO: Use for custom entity name transformation functions
157    #[allow(dead_code)] // TODO: Remove when custom entity extraction is implemented
158    Custom(fn(&str) -> Option<String>),
159}
160
161impl ProtoSchemaGraphExtractor {
162    /// Create a new proto schema graph extractor
163    pub fn new() -> Self {
164        let patterns = vec![
165            ForeignKeyPattern {
166                pattern: regex::Regex::new(r"^(.+)_id$").unwrap(),
167                entity_extraction: EntityExtractionMethod::RemoveSuffix("_id".to_string()),
168                confidence: 0.9,
169            },
170            ForeignKeyPattern {
171                pattern: regex::Regex::new(r"^(.+)Id$").unwrap(),
172                entity_extraction: EntityExtractionMethod::RemoveSuffix("Id".to_string()),
173                confidence: 0.85,
174            },
175            ForeignKeyPattern {
176                pattern: regex::Regex::new(r"^(.+)_ref$").unwrap(),
177                entity_extraction: EntityExtractionMethod::RemoveSuffix("_ref".to_string()),
178                confidence: 0.8,
179            },
180        ];
181
182        Self {
183            foreign_key_patterns: patterns,
184        }
185    }
186
187    /// Extract schema graph from protobuf descriptor pool
188    pub fn extract_from_proto(
189        &self,
190        pool: &DescriptorPool,
191    ) -> Result<SchemaGraph, Box<dyn std::error::Error + Send + Sync>> {
192        let mut entities = HashMap::new();
193        let mut relationships = Vec::new();
194        let mut foreign_keys = HashMap::new();
195
196        info!("Extracting schema graph from protobuf descriptors");
197
198        // First pass: Extract all entities and their fields
199        for message_descriptor in pool.all_messages() {
200            let entity = self.extract_entity_from_message(&message_descriptor)?;
201            entities.insert(entity.name.clone(), entity);
202        }
203
204        // Second pass: Analyze relationships and foreign keys
205        for (entity_name, entity) in &entities {
206            let fk_mappings = self.detect_foreign_keys(entity, &entities)?;
207            if !fk_mappings.is_empty() {
208                foreign_keys.insert(entity_name.clone(), fk_mappings);
209            }
210
211            let entity_relationships = self.extract_relationships(entity, &entities)?;
212            relationships.extend(entity_relationships);
213        }
214
215        // Third pass: Update cross-references
216        let mut updated_entities = entities;
217        self.update_cross_references(&mut updated_entities, &relationships);
218
219        let graph = SchemaGraph {
220            entities: updated_entities,
221            relationships,
222            foreign_keys,
223        };
224
225        info!(
226            "Extracted schema graph with {} entities and {} relationships",
227            graph.entities.len(),
228            graph.relationships.len()
229        );
230
231        Ok(graph)
232    }
233
234    /// Extract an entity from a proto message descriptor
235    fn extract_entity_from_message(
236        &self,
237        descriptor: &MessageDescriptor,
238    ) -> Result<EntityNode, Box<dyn std::error::Error + Send + Sync>> {
239        let name = Self::extract_entity_name(descriptor.name());
240        let full_name = descriptor.full_name().to_string();
241
242        let mut fields = Vec::new();
243        for field_descriptor in descriptor.fields() {
244            let field_info = self.extract_field_info(&field_descriptor)?;
245            fields.push(field_info);
246        }
247
248        Ok(EntityNode {
249            name,
250            full_name,
251            fields,
252            is_root: true, // Will be updated later
253            referenced_by: Vec::new(),
254            references: Vec::new(),
255        })
256    }
257
258    /// Extract field information from a proto field descriptor
259    fn extract_field_info(
260        &self,
261        field: &FieldDescriptor,
262    ) -> Result<FieldInfo, Box<dyn std::error::Error + Send + Sync>> {
263        let name = field.name().to_string();
264        let field_type = Self::kind_to_string(&field.kind());
265        let is_required = true; // Proto fields are required by default unless marked optional
266
267        // Check if this looks like a foreign key
268        let (is_foreign_key, foreign_key_target) =
269            self.analyze_potential_foreign_key(&name, &field.kind());
270
271        let mut constraints = HashMap::new();
272        if field.is_list() {
273            constraints.insert("repeated".to_string(), "true".to_string());
274        }
275
276        Ok(FieldInfo {
277            name,
278            field_type,
279            is_foreign_key,
280            foreign_key_target,
281            is_required,
282            constraints,
283        })
284    }
285
286    /// Analyze if a field might be a foreign key
287    fn analyze_potential_foreign_key(
288        &self,
289        field_name: &str,
290        kind: &Kind,
291    ) -> (bool, Option<String>) {
292        // Check naming patterns
293        for pattern in &self.foreign_key_patterns {
294            if pattern.pattern.is_match(field_name) {
295                if let Some(entity_name) = self.extract_entity_name_from_field(field_name, pattern)
296                {
297                    return (true, Some(entity_name));
298                }
299            }
300        }
301
302        // Check if it's a message type (embedded relationship)
303        if let Kind::Message(message_descriptor) = kind {
304            let entity_name = Self::extract_entity_name(message_descriptor.name());
305            return (false, Some(entity_name)); // Not FK, but related entity
306        }
307
308        (false, None)
309    }
310
311    /// Extract entity name from field name using pattern
312    fn extract_entity_name_from_field(
313        &self,
314        field_name: &str,
315        pattern: &ForeignKeyPattern,
316    ) -> Option<String> {
317        match &pattern.entity_extraction {
318            EntityExtractionMethod::RemoveSuffix(suffix) => {
319                if field_name.ends_with(suffix) {
320                    let base_name = &field_name[..field_name.len() - suffix.len()];
321                    Some(Self::normalize_entity_name(base_name))
322                } else {
323                    None
324                }
325            }
326            EntityExtractionMethod::Direct => Some(Self::normalize_entity_name(field_name)),
327            EntityExtractionMethod::Custom(func) => func(field_name),
328        }
329    }
330
331    /// Detect foreign keys in an entity
332    fn detect_foreign_keys(
333        &self,
334        entity: &EntityNode,
335        all_entities: &HashMap<String, EntityNode>,
336    ) -> Result<Vec<ForeignKeyMapping>, Box<dyn std::error::Error + Send + Sync>> {
337        let mut mappings = Vec::new();
338
339        for field in &entity.fields {
340            if field.is_foreign_key {
341                if let Some(target) = &field.foreign_key_target {
342                    // Check if target entity exists
343                    if all_entities.contains_key(target) {
344                        // Calculate confidence score based on pattern match and entity existence
345                        let confidence =
346                            self.calculate_confidence_score(field, target, all_entities);
347
348                        mappings.push(ForeignKeyMapping {
349                            field_name: field.name.clone(),
350                            target_entity: target.clone(),
351                            confidence,
352                            detection_method: ForeignKeyDetectionMethod::NamingConvention,
353                        });
354                    }
355                }
356            }
357        }
358
359        Ok(mappings)
360    }
361
362    /// Calculate confidence score for a detected relationship
363    ///
364    /// Confidence is calculated based on:
365    /// - Pattern match quality (higher for common patterns like _id)
366    /// - Entity existence validation (target entity exists)
367    /// - Field type compatibility (message type matches entity name)
368    /// - Naming convention strength (more specific patterns score higher)
369    fn calculate_confidence_score(
370        &self,
371        field: &FieldInfo,
372        target_entity: &str,
373        all_entities: &HashMap<String, EntityNode>,
374    ) -> f64 {
375        let mut confidence = 0.5; // Base confidence
376
377        // Find matching pattern to get pattern-specific confidence
378        for pattern in &self.foreign_key_patterns {
379            if pattern.pattern.is_match(&field.name) {
380                confidence = pattern.confidence;
381                break;
382            }
383        }
384
385        // Boost confidence if target entity exists and matches naming convention
386        if all_entities.contains_key(target_entity) {
387            confidence += 0.1; // +10% for entity existence
388        }
389
390        // Boost confidence if field type suggests a relationship
391        if field.field_type.contains("message") || field.field_type.contains("Message") {
392            confidence += 0.1; // +10% for message type
393        }
394
395        // Cap confidence at 1.0
396        confidence.min(1.0)
397    }
398
399    /// Extract relationships from an entity
400    fn extract_relationships(
401        &self,
402        entity: &EntityNode,
403        all_entities: &HashMap<String, EntityNode>,
404    ) -> Result<Vec<Relationship>, Box<dyn std::error::Error + Send + Sync>> {
405        let mut relationships = Vec::new();
406
407        for field in &entity.fields {
408            if let Some(target_entity) = &field.foreign_key_target {
409                if all_entities.contains_key(target_entity) {
410                    let relationship_type = if field.is_foreign_key {
411                        RelationshipType::ForeignKey
412                    } else if field.field_type.contains("message") {
413                        RelationshipType::Embedded
414                    } else {
415                        RelationshipType::Composition
416                    };
417
418                    let cardinality = if field.constraints.contains_key("repeated") {
419                        Cardinality { min: 0, max: None }
420                    } else {
421                        Cardinality {
422                            min: if field.is_required { 1 } else { 0 },
423                            max: Some(1),
424                        }
425                    };
426
427                    relationships.push(Relationship {
428                        from_entity: entity.name.clone(),
429                        to_entity: target_entity.clone(),
430                        relationship_type,
431                        field_name: field.name.clone(),
432                        is_required: field.is_required,
433                        cardinality,
434                    });
435                }
436            }
437        }
438
439        Ok(relationships)
440    }
441
442    /// Update cross-references between entities
443    fn update_cross_references(
444        &self,
445        entities: &mut HashMap<String, EntityNode>,
446        relationships: &[Relationship],
447    ) {
448        // Build reference maps
449        let mut referenced_by_map: HashMap<String, Vec<String>> = HashMap::new();
450        let mut references_map: HashMap<String, Vec<String>> = HashMap::new();
451
452        for rel in relationships {
453            // Track what references what
454            references_map
455                .entry(rel.from_entity.clone())
456                .or_default()
457                .push(rel.to_entity.clone());
458
459            // Track what is referenced by what
460            referenced_by_map
461                .entry(rel.to_entity.clone())
462                .or_default()
463                .push(rel.from_entity.clone());
464        }
465
466        // Update entities
467        for (entity_name, entity) in entities.iter_mut() {
468            if let Some(refs) = references_map.get(entity_name) {
469                entity.references = refs.clone();
470            }
471
472            if let Some(referenced_by) = referenced_by_map.get(entity_name) {
473                entity.referenced_by = referenced_by.clone();
474                entity.is_root = false; // Referenced entities are not root
475            }
476        }
477    }
478
479    /// Convert protobuf Kind to string representation
480    fn kind_to_string(kind: &Kind) -> String {
481        match kind {
482            Kind::String => "string".to_string(),
483            Kind::Int32 => "int32".to_string(),
484            Kind::Int64 => "int64".to_string(),
485            Kind::Uint32 => "uint32".to_string(),
486            Kind::Uint64 => "uint64".to_string(),
487            Kind::Bool => "bool".to_string(),
488            Kind::Float => "float".to_string(),
489            Kind::Double => "double".to_string(),
490            Kind::Bytes => "bytes".to_string(),
491            Kind::Message(msg) => format!("message:{}", msg.full_name()),
492            Kind::Enum(enum_desc) => format!("enum:{}", enum_desc.full_name()),
493            _ => "unknown".to_string(),
494        }
495    }
496
497    /// Extract entity name from message name (remove package, normalize)
498    fn extract_entity_name(message_name: &str) -> String {
499        Self::normalize_entity_name(message_name)
500    }
501
502    /// Normalize entity name (PascalCase, singular)
503    fn normalize_entity_name(name: &str) -> String {
504        // Convert snake_case to PascalCase
505        name.split('_')
506            .map(|part| {
507                let mut chars: Vec<char> = part.chars().collect();
508                if let Some(first_char) = chars.first_mut() {
509                    *first_char = first_char.to_uppercase().next().unwrap_or(*first_char);
510                }
511                chars.into_iter().collect::<String>()
512            })
513            .collect::<String>()
514    }
515}
516
517impl Default for ProtoSchemaGraphExtractor {
518    fn default() -> Self {
519        Self::new()
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526
527    #[test]
528    fn test_foreign_key_pattern_matching() {
529        let extractor = ProtoSchemaGraphExtractor::new();
530
531        // Test standard patterns
532        let (is_fk, target) = extractor.analyze_potential_foreign_key("user_id", &Kind::Int32);
533        assert!(is_fk);
534        assert_eq!(target, Some("User".to_string()));
535
536        let (is_fk, target) = extractor.analyze_potential_foreign_key("orderId", &Kind::Int64);
537        assert!(is_fk);
538        assert_eq!(target, Some("Order".to_string()));
539    }
540
541    #[test]
542    fn test_entity_name_normalization() {
543        assert_eq!(ProtoSchemaGraphExtractor::normalize_entity_name("user"), "User");
544        assert_eq!(ProtoSchemaGraphExtractor::normalize_entity_name("order_item"), "OrderItem");
545        assert_eq!(
546            ProtoSchemaGraphExtractor::normalize_entity_name("ProductCategory"),
547            "ProductCategory"
548        );
549    }
550}