mockforge_grpc/reflection/
schema_graph.rs

1//! Schema relationship graph extraction
2//!
3//! This module extracts relationship graphs from proto and OpenAPI schemas,
4//! identifying foreign keys, references, and data dependencies for coherent
5//! synthetic data generation.
6
7use prost_reflect::{DescriptorPool, FieldDescriptor, Kind, MessageDescriptor};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tracing::info;
11
12/// A graph representing relationships between schema entities
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SchemaGraph {
15    /// All entities (messages/schemas) in the graph
16    pub entities: HashMap<String, EntityNode>,
17    /// Direct relationships between entities
18    pub relationships: Vec<Relationship>,
19    /// Detected foreign key patterns
20    pub foreign_keys: HashMap<String, Vec<ForeignKeyMapping>>,
21}
22
23/// An entity node in the schema graph
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct EntityNode {
26    /// Entity name (e.g., "User", "Order")
27    pub name: String,
28    /// Full qualified name (e.g., "com.example.User")
29    pub full_name: String,
30    /// Fields in this entity
31    pub fields: Vec<FieldInfo>,
32    /// Whether this is a root entity (not referenced by others)
33    pub is_root: bool,
34    /// Entities that reference this one
35    pub referenced_by: Vec<String>,
36    /// Entities that this one references
37    pub references: Vec<String>,
38}
39
40/// Information about a field in an entity
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct FieldInfo {
43    /// Field name
44    pub name: String,
45    /// Field type (string, int32, message, etc.)
46    pub field_type: String,
47    /// Whether this field is a potential foreign key
48    pub is_foreign_key: bool,
49    /// Target entity if this is a foreign key
50    pub foreign_key_target: Option<String>,
51    /// Whether this field is required
52    pub is_required: bool,
53    /// Constraints on this field
54    pub constraints: HashMap<String, String>,
55}
56
57/// A relationship between two entities
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct Relationship {
60    /// Source entity name
61    pub from_entity: String,
62    /// Target entity name
63    pub to_entity: String,
64    /// Type of relationship
65    pub relationship_type: RelationshipType,
66    /// Field name that creates the relationship
67    pub field_name: String,
68    /// Whether this relationship is required
69    pub is_required: bool,
70    /// Cardinality constraints
71    pub cardinality: Cardinality,
72}
73
74/// Type of relationship between entities
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum RelationshipType {
77    /// Direct foreign key reference (user_id -> User)
78    ForeignKey,
79    /// Embedded object (address within user)
80    Embedded,
81    /// Array/repeated field relationship
82    OneToMany,
83    /// Bidirectional relationship
84    ManyToMany,
85    /// Composition relationship
86    Composition,
87}
88
89/// Cardinality constraints for relationships
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct Cardinality {
92    /// Minimum number of related entities
93    pub min: u32,
94    /// Maximum number of related entities (None = unlimited)
95    pub max: Option<u32>,
96}
97
98/// Foreign key mapping detected via naming conventions
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ForeignKeyMapping {
101    /// Field name (e.g., "user_id")
102    pub field_name: String,
103    /// Target entity name (e.g., "User")
104    pub target_entity: String,
105    /// Confidence score (0.0 - 1.0)
106    pub confidence: f64,
107    /// Detection method used
108    pub detection_method: ForeignKeyDetectionMethod,
109}
110
111/// Methods used to detect foreign key relationships
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum ForeignKeyDetectionMethod {
114    /// Detected via naming convention (user_id -> User)
115    NamingConvention,
116    /// Detected via schema reference ($ref in OpenAPI)
117    SchemaReference,
118    /// Detected via field type (message type in proto)
119    MessageType,
120    /// Detected via constraint annotation
121    Constraint,
122}
123
124/// Schema graph extractor for protobuf schemas
125pub struct ProtoSchemaGraphExtractor {
126    /// Common foreign key patterns
127    foreign_key_patterns: Vec<ForeignKeyPattern>,
128}
129
130/// Pattern for detecting foreign keys via naming
131#[derive(Debug, Clone)]
132struct ForeignKeyPattern {
133    /// Regex pattern for field names
134    pattern: regex::Regex,
135    /// How to extract entity name from field name
136    entity_extraction: EntityExtractionMethod,
137    /// Confidence score for this pattern
138    ///
139    /// Confidence score indicates how reliable this pattern is for detecting relationships.
140    /// Higher scores (closer to 1.0) indicate more reliable patterns.
141    confidence: f64,
142}
143
144/// Methods for extracting entity names from field names
145#[derive(Debug, Clone)]
146enum EntityExtractionMethod {
147    /// Remove suffix (user_id -> user)
148    RemoveSuffix(String),
149    /// Direct mapping
150    ///
151    /// TODO: Use when direct entity name mapping without transformation is needed
152    #[allow(dead_code)] // TODO: Remove when entity extraction feature is implemented
153    Direct,
154    /// Custom transform function
155    ///
156    /// TODO: Use for custom entity name transformation functions
157    #[allow(dead_code)] // TODO: Remove when custom entity extraction is implemented
158    Custom(fn(&str) -> Option<String>),
159}
160
161impl ProtoSchemaGraphExtractor {
162    /// Create a new proto schema graph extractor
163    pub fn new() -> Self {
164        let patterns = vec![
165            ForeignKeyPattern {
166                pattern: regex::Regex::new(r"^(.+)_id$").unwrap(),
167                entity_extraction: EntityExtractionMethod::RemoveSuffix("_id".to_string()),
168                confidence: 0.9,
169            },
170            ForeignKeyPattern {
171                pattern: regex::Regex::new(r"^(.+)Id$").unwrap(),
172                entity_extraction: EntityExtractionMethod::RemoveSuffix("Id".to_string()),
173                confidence: 0.85,
174            },
175            ForeignKeyPattern {
176                pattern: regex::Regex::new(r"^(.+)_ref$").unwrap(),
177                entity_extraction: EntityExtractionMethod::RemoveSuffix("_ref".to_string()),
178                confidence: 0.8,
179            },
180        ];
181
182        Self {
183            foreign_key_patterns: patterns,
184        }
185    }
186
187    /// Extract schema graph from protobuf descriptor pool
188    pub fn extract_from_proto(
189        &self,
190        pool: &DescriptorPool,
191    ) -> Result<SchemaGraph, Box<dyn std::error::Error + Send + Sync>> {
192        let mut entities = HashMap::new();
193        let mut relationships = Vec::new();
194        let mut foreign_keys = HashMap::new();
195
196        info!("Extracting schema graph from protobuf descriptors");
197
198        // First pass: Extract all entities and their fields
199        for message_descriptor in pool.all_messages() {
200            let entity = self.extract_entity_from_message(&message_descriptor)?;
201            entities.insert(entity.name.clone(), entity);
202        }
203
204        // Second pass: Analyze relationships and foreign keys
205        for (entity_name, entity) in &entities {
206            let fk_mappings = self.detect_foreign_keys(entity, &entities)?;
207            if !fk_mappings.is_empty() {
208                foreign_keys.insert(entity_name.clone(), fk_mappings);
209            }
210
211            let entity_relationships = self.extract_relationships(entity, &entities)?;
212            relationships.extend(entity_relationships);
213        }
214
215        // Third pass: Update cross-references
216        let mut updated_entities = entities;
217        self.update_cross_references(&mut updated_entities, &relationships);
218
219        let graph = SchemaGraph {
220            entities: updated_entities,
221            relationships,
222            foreign_keys,
223        };
224
225        info!(
226            "Extracted schema graph with {} entities and {} relationships",
227            graph.entities.len(),
228            graph.relationships.len()
229        );
230
231        Ok(graph)
232    }
233
234    /// Extract an entity from a proto message descriptor
235    fn extract_entity_from_message(
236        &self,
237        descriptor: &MessageDescriptor,
238    ) -> Result<EntityNode, Box<dyn std::error::Error + Send + Sync>> {
239        let name = Self::extract_entity_name(descriptor.name());
240        let full_name = descriptor.full_name().to_string();
241
242        let mut fields = Vec::new();
243        for field_descriptor in descriptor.fields() {
244            let field_info = self.extract_field_info(&field_descriptor)?;
245            fields.push(field_info);
246        }
247
248        Ok(EntityNode {
249            name,
250            full_name,
251            fields,
252            is_root: true, // Will be updated later
253            referenced_by: Vec::new(),
254            references: Vec::new(),
255        })
256    }
257
258    /// Extract field information from a proto field descriptor
259    fn extract_field_info(
260        &self,
261        field: &FieldDescriptor,
262    ) -> Result<FieldInfo, Box<dyn std::error::Error + Send + Sync>> {
263        let name = field.name().to_string();
264        let field_type = Self::kind_to_string(&field.kind());
265        let is_required = true; // Proto fields are required by default unless marked optional
266
267        // Check if this looks like a foreign key
268        let (is_foreign_key, foreign_key_target) =
269            self.analyze_potential_foreign_key(&name, &field.kind());
270
271        let mut constraints = HashMap::new();
272        if field.is_list() {
273            constraints.insert("repeated".to_string(), "true".to_string());
274        }
275
276        Ok(FieldInfo {
277            name,
278            field_type,
279            is_foreign_key,
280            foreign_key_target,
281            is_required,
282            constraints,
283        })
284    }
285
286    /// Analyze if a field might be a foreign key
287    fn analyze_potential_foreign_key(
288        &self,
289        field_name: &str,
290        kind: &Kind,
291    ) -> (bool, Option<String>) {
292        // Check naming patterns
293        for pattern in &self.foreign_key_patterns {
294            if pattern.pattern.is_match(field_name) {
295                if let Some(entity_name) = self.extract_entity_name_from_field(field_name, pattern)
296                {
297                    return (true, Some(entity_name));
298                }
299            }
300        }
301
302        // Check if it's a message type (embedded relationship)
303        if let Kind::Message(message_descriptor) = kind {
304            let entity_name = Self::extract_entity_name(message_descriptor.name());
305            return (false, Some(entity_name)); // Not FK, but related entity
306        }
307
308        (false, None)
309    }
310
311    /// Extract entity name from field name using pattern
312    fn extract_entity_name_from_field(
313        &self,
314        field_name: &str,
315        pattern: &ForeignKeyPattern,
316    ) -> Option<String> {
317        match &pattern.entity_extraction {
318            EntityExtractionMethod::RemoveSuffix(suffix) => {
319                if field_name.ends_with(suffix) {
320                    let base_name = &field_name[..field_name.len() - suffix.len()];
321                    Some(Self::normalize_entity_name(base_name))
322                } else {
323                    None
324                }
325            }
326            EntityExtractionMethod::Direct => Some(Self::normalize_entity_name(field_name)),
327            EntityExtractionMethod::Custom(func) => func(field_name),
328        }
329    }
330
331    /// Detect foreign keys in an entity
332    fn detect_foreign_keys(
333        &self,
334        entity: &EntityNode,
335        all_entities: &HashMap<String, EntityNode>,
336    ) -> Result<Vec<ForeignKeyMapping>, Box<dyn std::error::Error + Send + Sync>> {
337        let mut mappings = Vec::new();
338
339        for field in &entity.fields {
340            if field.is_foreign_key {
341                if let Some(target) = &field.foreign_key_target {
342                    // Check if target entity exists
343                    if all_entities.contains_key(target) {
344                        // Calculate confidence score based on pattern match and entity existence
345                        let confidence = self.calculate_confidence_score(field, target, all_entities);
346                        
347                        mappings.push(ForeignKeyMapping {
348                            field_name: field.name.clone(),
349                            target_entity: target.clone(),
350                            confidence,
351                            detection_method: ForeignKeyDetectionMethod::NamingConvention,
352                        });
353                    }
354                }
355            }
356        }
357
358        Ok(mappings)
359    }
360
361    /// Calculate confidence score for a detected relationship
362    /// 
363    /// Confidence is calculated based on:
364    /// - Pattern match quality (higher for common patterns like _id)
365    /// - Entity existence validation (target entity exists)
366    /// - Field type compatibility (message type matches entity name)
367    /// - Naming convention strength (more specific patterns score higher)
368    fn calculate_confidence_score(
369        &self,
370        field: &FieldInfo,
371        target_entity: &str,
372        all_entities: &HashMap<String, EntityNode>,
373    ) -> f64 {
374        let mut confidence = 0.5; // Base confidence
375
376        // Find matching pattern to get pattern-specific confidence
377        for pattern in &self.foreign_key_patterns {
378            if pattern.pattern.is_match(&field.name) {
379                confidence = pattern.confidence;
380                break;
381            }
382        }
383
384        // Boost confidence if target entity exists and matches naming convention
385        if all_entities.contains_key(target_entity) {
386            confidence += 0.1; // +10% for entity existence
387        }
388
389        // Boost confidence if field type suggests a relationship
390        if field.field_type.contains("message") || field.field_type.contains("Message") {
391            confidence += 0.1; // +10% for message type
392        }
393
394        // Cap confidence at 1.0
395        confidence.min(1.0)
396    }
397
398    /// Extract relationships from an entity
399    fn extract_relationships(
400        &self,
401        entity: &EntityNode,
402        all_entities: &HashMap<String, EntityNode>,
403    ) -> Result<Vec<Relationship>, Box<dyn std::error::Error + Send + Sync>> {
404        let mut relationships = Vec::new();
405
406        for field in &entity.fields {
407            if let Some(target_entity) = &field.foreign_key_target {
408                if all_entities.contains_key(target_entity) {
409                    let relationship_type = if field.is_foreign_key {
410                        RelationshipType::ForeignKey
411                    } else if field.field_type.contains("message") {
412                        RelationshipType::Embedded
413                    } else {
414                        RelationshipType::Composition
415                    };
416
417                    let cardinality = if field.constraints.contains_key("repeated") {
418                        Cardinality { min: 0, max: None }
419                    } else {
420                        Cardinality {
421                            min: if field.is_required { 1 } else { 0 },
422                            max: Some(1),
423                        }
424                    };
425
426                    relationships.push(Relationship {
427                        from_entity: entity.name.clone(),
428                        to_entity: target_entity.clone(),
429                        relationship_type,
430                        field_name: field.name.clone(),
431                        is_required: field.is_required,
432                        cardinality,
433                    });
434                }
435            }
436        }
437
438        Ok(relationships)
439    }
440
441    /// Update cross-references between entities
442    fn update_cross_references(
443        &self,
444        entities: &mut HashMap<String, EntityNode>,
445        relationships: &[Relationship],
446    ) {
447        // Build reference maps
448        let mut referenced_by_map: HashMap<String, Vec<String>> = HashMap::new();
449        let mut references_map: HashMap<String, Vec<String>> = HashMap::new();
450
451        for rel in relationships {
452            // Track what references what
453            references_map
454                .entry(rel.from_entity.clone())
455                .or_default()
456                .push(rel.to_entity.clone());
457
458            // Track what is referenced by what
459            referenced_by_map
460                .entry(rel.to_entity.clone())
461                .or_default()
462                .push(rel.from_entity.clone());
463        }
464
465        // Update entities
466        for (entity_name, entity) in entities.iter_mut() {
467            if let Some(refs) = references_map.get(entity_name) {
468                entity.references = refs.clone();
469            }
470
471            if let Some(referenced_by) = referenced_by_map.get(entity_name) {
472                entity.referenced_by = referenced_by.clone();
473                entity.is_root = false; // Referenced entities are not root
474            }
475        }
476    }
477
478    /// Convert protobuf Kind to string representation
479    fn kind_to_string(kind: &Kind) -> String {
480        match kind {
481            Kind::String => "string".to_string(),
482            Kind::Int32 => "int32".to_string(),
483            Kind::Int64 => "int64".to_string(),
484            Kind::Uint32 => "uint32".to_string(),
485            Kind::Uint64 => "uint64".to_string(),
486            Kind::Bool => "bool".to_string(),
487            Kind::Float => "float".to_string(),
488            Kind::Double => "double".to_string(),
489            Kind::Bytes => "bytes".to_string(),
490            Kind::Message(msg) => format!("message:{}", msg.full_name()),
491            Kind::Enum(enum_desc) => format!("enum:{}", enum_desc.full_name()),
492            _ => "unknown".to_string(),
493        }
494    }
495
496    /// Extract entity name from message name (remove package, normalize)
497    fn extract_entity_name(message_name: &str) -> String {
498        Self::normalize_entity_name(message_name)
499    }
500
501    /// Normalize entity name (PascalCase, singular)
502    fn normalize_entity_name(name: &str) -> String {
503        // Convert snake_case to PascalCase
504        name.split('_')
505            .map(|part| {
506                let mut chars: Vec<char> = part.chars().collect();
507                if let Some(first_char) = chars.first_mut() {
508                    *first_char = first_char.to_uppercase().next().unwrap_or(*first_char);
509                }
510                chars.into_iter().collect::<String>()
511            })
512            .collect::<String>()
513    }
514}
515
516impl Default for ProtoSchemaGraphExtractor {
517    fn default() -> Self {
518        Self::new()
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525
526    #[test]
527    fn test_foreign_key_pattern_matching() {
528        let extractor = ProtoSchemaGraphExtractor::new();
529
530        // Test standard patterns
531        let (is_fk, target) = extractor.analyze_potential_foreign_key("user_id", &Kind::Int32);
532        assert!(is_fk);
533        assert_eq!(target, Some("User".to_string()));
534
535        let (is_fk, target) = extractor.analyze_potential_foreign_key("orderId", &Kind::Int64);
536        assert!(is_fk);
537        assert_eq!(target, Some("Order".to_string()));
538    }
539
540    #[test]
541    fn test_entity_name_normalization() {
542        assert_eq!(ProtoSchemaGraphExtractor::normalize_entity_name("user"), "User");
543        assert_eq!(ProtoSchemaGraphExtractor::normalize_entity_name("order_item"), "OrderItem");
544        assert_eq!(
545            ProtoSchemaGraphExtractor::normalize_entity_name("ProductCategory"),
546            "ProductCategory"
547        );
548    }
549}