mockforge_grpc/reflection/
schema_graph.rs

1//! Schema relationship graph extraction
2//!
3//! This module extracts relationship graphs from proto and OpenAPI schemas,
4//! identifying foreign keys, references, and data dependencies for coherent
5//! synthetic data generation.
6
7use prost_reflect::{DescriptorPool, FieldDescriptor, Kind, MessageDescriptor};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tracing::info;
11
12/// A graph representing relationships between schema entities
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SchemaGraph {
15    /// All entities (messages/schemas) in the graph
16    pub entities: HashMap<String, EntityNode>,
17    /// Direct relationships between entities
18    pub relationships: Vec<Relationship>,
19    /// Detected foreign key patterns
20    pub foreign_keys: HashMap<String, Vec<ForeignKeyMapping>>,
21}
22
23/// An entity node in the schema graph
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct EntityNode {
26    /// Entity name (e.g., "User", "Order")
27    pub name: String,
28    /// Full qualified name (e.g., "com.example.User")
29    pub full_name: String,
30    /// Fields in this entity
31    pub fields: Vec<FieldInfo>,
32    /// Whether this is a root entity (not referenced by others)
33    pub is_root: bool,
34    /// Entities that reference this one
35    pub referenced_by: Vec<String>,
36    /// Entities that this one references
37    pub references: Vec<String>,
38}
39
40/// Information about a field in an entity
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct FieldInfo {
43    /// Field name
44    pub name: String,
45    /// Field type (string, int32, message, etc.)
46    pub field_type: String,
47    /// Whether this field is a potential foreign key
48    pub is_foreign_key: bool,
49    /// Target entity if this is a foreign key
50    pub foreign_key_target: Option<String>,
51    /// Whether this field is required
52    pub is_required: bool,
53    /// Constraints on this field
54    pub constraints: HashMap<String, String>,
55}
56
57/// A relationship between two entities
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct Relationship {
60    /// Source entity name
61    pub from_entity: String,
62    /// Target entity name
63    pub to_entity: String,
64    /// Type of relationship
65    pub relationship_type: RelationshipType,
66    /// Field name that creates the relationship
67    pub field_name: String,
68    /// Whether this relationship is required
69    pub is_required: bool,
70    /// Cardinality constraints
71    pub cardinality: Cardinality,
72}
73
74/// Type of relationship between entities
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum RelationshipType {
77    /// Direct foreign key reference (user_id -> User)
78    ForeignKey,
79    /// Embedded object (address within user)
80    Embedded,
81    /// Array/repeated field relationship
82    OneToMany,
83    /// Bidirectional relationship
84    ManyToMany,
85    /// Composition relationship
86    Composition,
87}
88
89/// Cardinality constraints for relationships
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct Cardinality {
92    /// Minimum number of related entities
93    pub min: u32,
94    /// Maximum number of related entities (None = unlimited)
95    pub max: Option<u32>,
96}
97
98/// Foreign key mapping detected via naming conventions
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ForeignKeyMapping {
101    /// Field name (e.g., "user_id")
102    pub field_name: String,
103    /// Target entity name (e.g., "User")
104    pub target_entity: String,
105    /// Confidence score (0.0 - 1.0)
106    pub confidence: f64,
107    /// Detection method used
108    pub detection_method: ForeignKeyDetectionMethod,
109}
110
111/// Methods used to detect foreign key relationships
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum ForeignKeyDetectionMethod {
114    /// Detected via naming convention (user_id -> User)
115    NamingConvention,
116    /// Detected via schema reference ($ref in OpenAPI)
117    SchemaReference,
118    /// Detected via field type (message type in proto)
119    MessageType,
120    /// Detected via constraint annotation
121    Constraint,
122}
123
124/// Schema graph extractor for protobuf schemas
125pub struct ProtoSchemaGraphExtractor {
126    /// Common foreign key patterns
127    foreign_key_patterns: Vec<ForeignKeyPattern>,
128}
129
130/// Pattern for detecting foreign keys via naming
131#[derive(Debug, Clone)]
132struct ForeignKeyPattern {
133    /// Regex pattern for field names
134    pattern: regex::Regex,
135    /// How to extract entity name from field name
136    entity_extraction: EntityExtractionMethod,
137    /// Confidence score for this pattern
138    #[allow(dead_code)] // Used in future relationship analysis
139    confidence: f64,
140}
141
142/// Methods for extracting entity names from field names
143#[derive(Debug, Clone)]
144enum EntityExtractionMethod {
145    /// Remove suffix (user_id -> user)
146    RemoveSuffix(String),
147    /// Direct mapping
148    #[allow(dead_code)] // Used in future entity extraction
149    Direct,
150    /// Custom transform function
151    #[allow(dead_code)] // Used in future entity extraction
152    Custom(fn(&str) -> Option<String>),
153}
154
155impl ProtoSchemaGraphExtractor {
156    /// Create a new proto schema graph extractor
157    pub fn new() -> Self {
158        let patterns = vec![
159            ForeignKeyPattern {
160                pattern: regex::Regex::new(r"^(.+)_id$").unwrap(),
161                entity_extraction: EntityExtractionMethod::RemoveSuffix("_id".to_string()),
162                confidence: 0.9,
163            },
164            ForeignKeyPattern {
165                pattern: regex::Regex::new(r"^(.+)Id$").unwrap(),
166                entity_extraction: EntityExtractionMethod::RemoveSuffix("Id".to_string()),
167                confidence: 0.85,
168            },
169            ForeignKeyPattern {
170                pattern: regex::Regex::new(r"^(.+)_ref$").unwrap(),
171                entity_extraction: EntityExtractionMethod::RemoveSuffix("_ref".to_string()),
172                confidence: 0.8,
173            },
174        ];
175
176        Self {
177            foreign_key_patterns: patterns,
178        }
179    }
180
181    /// Extract schema graph from protobuf descriptor pool
182    pub fn extract_from_proto(
183        &self,
184        pool: &DescriptorPool,
185    ) -> Result<SchemaGraph, Box<dyn std::error::Error + Send + Sync>> {
186        let mut entities = HashMap::new();
187        let mut relationships = Vec::new();
188        let mut foreign_keys = HashMap::new();
189
190        info!("Extracting schema graph from protobuf descriptors");
191
192        // First pass: Extract all entities and their fields
193        for message_descriptor in pool.all_messages() {
194            let entity = self.extract_entity_from_message(&message_descriptor)?;
195            entities.insert(entity.name.clone(), entity);
196        }
197
198        // Second pass: Analyze relationships and foreign keys
199        for (entity_name, entity) in &entities {
200            let fk_mappings = self.detect_foreign_keys(entity, &entities)?;
201            if !fk_mappings.is_empty() {
202                foreign_keys.insert(entity_name.clone(), fk_mappings);
203            }
204
205            let entity_relationships = self.extract_relationships(entity, &entities)?;
206            relationships.extend(entity_relationships);
207        }
208
209        // Third pass: Update cross-references
210        let mut updated_entities = entities;
211        self.update_cross_references(&mut updated_entities, &relationships);
212
213        let graph = SchemaGraph {
214            entities: updated_entities,
215            relationships,
216            foreign_keys,
217        };
218
219        info!(
220            "Extracted schema graph with {} entities and {} relationships",
221            graph.entities.len(),
222            graph.relationships.len()
223        );
224
225        Ok(graph)
226    }
227
228    /// Extract an entity from a proto message descriptor
229    fn extract_entity_from_message(
230        &self,
231        descriptor: &MessageDescriptor,
232    ) -> Result<EntityNode, Box<dyn std::error::Error + Send + Sync>> {
233        let name = Self::extract_entity_name(descriptor.name());
234        let full_name = descriptor.full_name().to_string();
235
236        let mut fields = Vec::new();
237        for field_descriptor in descriptor.fields() {
238            let field_info = self.extract_field_info(&field_descriptor)?;
239            fields.push(field_info);
240        }
241
242        Ok(EntityNode {
243            name,
244            full_name,
245            fields,
246            is_root: true, // Will be updated later
247            referenced_by: Vec::new(),
248            references: Vec::new(),
249        })
250    }
251
252    /// Extract field information from a proto field descriptor
253    fn extract_field_info(
254        &self,
255        field: &FieldDescriptor,
256    ) -> Result<FieldInfo, Box<dyn std::error::Error + Send + Sync>> {
257        let name = field.name().to_string();
258        let field_type = Self::kind_to_string(&field.kind());
259        let is_required = true; // Proto fields are required by default unless marked optional
260
261        // Check if this looks like a foreign key
262        let (is_foreign_key, foreign_key_target) =
263            self.analyze_potential_foreign_key(&name, &field.kind());
264
265        let mut constraints = HashMap::new();
266        if field.is_list() {
267            constraints.insert("repeated".to_string(), "true".to_string());
268        }
269
270        Ok(FieldInfo {
271            name,
272            field_type,
273            is_foreign_key,
274            foreign_key_target,
275            is_required,
276            constraints,
277        })
278    }
279
280    /// Analyze if a field might be a foreign key
281    fn analyze_potential_foreign_key(
282        &self,
283        field_name: &str,
284        kind: &Kind,
285    ) -> (bool, Option<String>) {
286        // Check naming patterns
287        for pattern in &self.foreign_key_patterns {
288            if pattern.pattern.is_match(field_name) {
289                if let Some(entity_name) = self.extract_entity_name_from_field(field_name, pattern)
290                {
291                    return (true, Some(entity_name));
292                }
293            }
294        }
295
296        // Check if it's a message type (embedded relationship)
297        if let Kind::Message(message_descriptor) = kind {
298            let entity_name = Self::extract_entity_name(message_descriptor.name());
299            return (false, Some(entity_name)); // Not FK, but related entity
300        }
301
302        (false, None)
303    }
304
305    /// Extract entity name from field name using pattern
306    fn extract_entity_name_from_field(
307        &self,
308        field_name: &str,
309        pattern: &ForeignKeyPattern,
310    ) -> Option<String> {
311        match &pattern.entity_extraction {
312            EntityExtractionMethod::RemoveSuffix(suffix) => {
313                if field_name.ends_with(suffix) {
314                    let base_name = &field_name[..field_name.len() - suffix.len()];
315                    Some(Self::normalize_entity_name(base_name))
316                } else {
317                    None
318                }
319            }
320            EntityExtractionMethod::Direct => Some(Self::normalize_entity_name(field_name)),
321            EntityExtractionMethod::Custom(func) => func(field_name),
322        }
323    }
324
325    /// Detect foreign keys in an entity
326    fn detect_foreign_keys(
327        &self,
328        entity: &EntityNode,
329        all_entities: &HashMap<String, EntityNode>,
330    ) -> Result<Vec<ForeignKeyMapping>, Box<dyn std::error::Error + Send + Sync>> {
331        let mut mappings = Vec::new();
332
333        for field in &entity.fields {
334            if field.is_foreign_key {
335                if let Some(target) = &field.foreign_key_target {
336                    // Check if target entity exists
337                    if all_entities.contains_key(target) {
338                        mappings.push(ForeignKeyMapping {
339                            field_name: field.name.clone(),
340                            target_entity: target.clone(),
341                            confidence: 0.9, // High confidence for detected patterns
342                            detection_method: ForeignKeyDetectionMethod::NamingConvention,
343                        });
344                    }
345                }
346            }
347        }
348
349        Ok(mappings)
350    }
351
352    /// Extract relationships from an entity
353    fn extract_relationships(
354        &self,
355        entity: &EntityNode,
356        all_entities: &HashMap<String, EntityNode>,
357    ) -> Result<Vec<Relationship>, Box<dyn std::error::Error + Send + Sync>> {
358        let mut relationships = Vec::new();
359
360        for field in &entity.fields {
361            if let Some(target_entity) = &field.foreign_key_target {
362                if all_entities.contains_key(target_entity) {
363                    let relationship_type = if field.is_foreign_key {
364                        RelationshipType::ForeignKey
365                    } else if field.field_type.contains("message") {
366                        RelationshipType::Embedded
367                    } else {
368                        RelationshipType::Composition
369                    };
370
371                    let cardinality = if field.constraints.contains_key("repeated") {
372                        Cardinality { min: 0, max: None }
373                    } else {
374                        Cardinality {
375                            min: if field.is_required { 1 } else { 0 },
376                            max: Some(1),
377                        }
378                    };
379
380                    relationships.push(Relationship {
381                        from_entity: entity.name.clone(),
382                        to_entity: target_entity.clone(),
383                        relationship_type,
384                        field_name: field.name.clone(),
385                        is_required: field.is_required,
386                        cardinality,
387                    });
388                }
389            }
390        }
391
392        Ok(relationships)
393    }
394
395    /// Update cross-references between entities
396    fn update_cross_references(
397        &self,
398        entities: &mut HashMap<String, EntityNode>,
399        relationships: &[Relationship],
400    ) {
401        // Build reference maps
402        let mut referenced_by_map: HashMap<String, Vec<String>> = HashMap::new();
403        let mut references_map: HashMap<String, Vec<String>> = HashMap::new();
404
405        for rel in relationships {
406            // Track what references what
407            references_map
408                .entry(rel.from_entity.clone())
409                .or_default()
410                .push(rel.to_entity.clone());
411
412            // Track what is referenced by what
413            referenced_by_map
414                .entry(rel.to_entity.clone())
415                .or_default()
416                .push(rel.from_entity.clone());
417        }
418
419        // Update entities
420        for (entity_name, entity) in entities.iter_mut() {
421            if let Some(refs) = references_map.get(entity_name) {
422                entity.references = refs.clone();
423            }
424
425            if let Some(referenced_by) = referenced_by_map.get(entity_name) {
426                entity.referenced_by = referenced_by.clone();
427                entity.is_root = false; // Referenced entities are not root
428            }
429        }
430    }
431
432    /// Convert protobuf Kind to string representation
433    fn kind_to_string(kind: &Kind) -> String {
434        match kind {
435            Kind::String => "string".to_string(),
436            Kind::Int32 => "int32".to_string(),
437            Kind::Int64 => "int64".to_string(),
438            Kind::Uint32 => "uint32".to_string(),
439            Kind::Uint64 => "uint64".to_string(),
440            Kind::Bool => "bool".to_string(),
441            Kind::Float => "float".to_string(),
442            Kind::Double => "double".to_string(),
443            Kind::Bytes => "bytes".to_string(),
444            Kind::Message(msg) => format!("message:{}", msg.full_name()),
445            Kind::Enum(enum_desc) => format!("enum:{}", enum_desc.full_name()),
446            _ => "unknown".to_string(),
447        }
448    }
449
450    /// Extract entity name from message name (remove package, normalize)
451    fn extract_entity_name(message_name: &str) -> String {
452        Self::normalize_entity_name(message_name)
453    }
454
455    /// Normalize entity name (PascalCase, singular)
456    fn normalize_entity_name(name: &str) -> String {
457        // Convert snake_case to PascalCase
458        name.split('_')
459            .map(|part| {
460                let mut chars: Vec<char> = part.chars().collect();
461                if let Some(first_char) = chars.first_mut() {
462                    *first_char = first_char.to_uppercase().next().unwrap_or(*first_char);
463                }
464                chars.into_iter().collect::<String>()
465            })
466            .collect::<String>()
467    }
468}
469
470impl Default for ProtoSchemaGraphExtractor {
471    fn default() -> Self {
472        Self::new()
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[test]
481    fn test_foreign_key_pattern_matching() {
482        let extractor = ProtoSchemaGraphExtractor::new();
483
484        // Test standard patterns
485        let (is_fk, target) = extractor.analyze_potential_foreign_key("user_id", &Kind::Int32);
486        assert!(is_fk);
487        assert_eq!(target, Some("User".to_string()));
488
489        let (is_fk, target) = extractor.analyze_potential_foreign_key("orderId", &Kind::Int64);
490        assert!(is_fk);
491        assert_eq!(target, Some("Order".to_string()));
492    }
493
494    #[test]
495    fn test_entity_name_normalization() {
496        assert_eq!(ProtoSchemaGraphExtractor::normalize_entity_name("user"), "User");
497        assert_eq!(ProtoSchemaGraphExtractor::normalize_entity_name("order_item"), "OrderItem");
498        assert_eq!(
499            ProtoSchemaGraphExtractor::normalize_entity_name("ProductCategory"),
500            "ProductCategory"
501        );
502    }
503}