1use prost_reflect::{DescriptorPool, FieldDescriptor, Kind, MessageDescriptor};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tracing::info;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SchemaGraph {
15 pub entities: HashMap<String, EntityNode>,
17 pub relationships: Vec<Relationship>,
19 pub foreign_keys: HashMap<String, Vec<ForeignKeyMapping>>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct EntityNode {
26 pub name: String,
28 pub full_name: String,
30 pub fields: Vec<FieldInfo>,
32 pub is_root: bool,
34 pub referenced_by: Vec<String>,
36 pub references: Vec<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct FieldInfo {
43 pub name: String,
45 pub field_type: String,
47 pub is_foreign_key: bool,
49 pub foreign_key_target: Option<String>,
51 pub is_required: bool,
53 pub constraints: HashMap<String, String>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct Relationship {
60 pub from_entity: String,
62 pub to_entity: String,
64 pub relationship_type: RelationshipType,
66 pub field_name: String,
68 pub is_required: bool,
70 pub cardinality: Cardinality,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum RelationshipType {
77 ForeignKey,
79 Embedded,
81 OneToMany,
83 ManyToMany,
85 Composition,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct Cardinality {
92 pub min: u32,
94 pub max: Option<u32>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ForeignKeyMapping {
101 pub field_name: String,
103 pub target_entity: String,
105 pub confidence: f64,
107 pub detection_method: ForeignKeyDetectionMethod,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum ForeignKeyDetectionMethod {
114 NamingConvention,
116 SchemaReference,
118 MessageType,
120 Constraint,
122}
123
124pub struct ProtoSchemaGraphExtractor {
126 foreign_key_patterns: Vec<ForeignKeyPattern>,
128}
129
130#[derive(Debug, Clone)]
132struct ForeignKeyPattern {
133 pattern: regex::Regex,
135 entity_extraction: EntityExtractionMethod,
137 confidence: f64,
142}
143
144#[derive(Debug, Clone)]
146enum EntityExtractionMethod {
147 RemoveSuffix(String),
149 #[allow(dead_code)] Direct,
154 #[allow(dead_code)] Custom(fn(&str) -> Option<String>),
159}
160
161impl ProtoSchemaGraphExtractor {
162 pub fn new() -> Self {
164 let patterns = vec![
165 ForeignKeyPattern {
166 pattern: regex::Regex::new(r"^(.+)_id$").unwrap(),
167 entity_extraction: EntityExtractionMethod::RemoveSuffix("_id".to_string()),
168 confidence: 0.9,
169 },
170 ForeignKeyPattern {
171 pattern: regex::Regex::new(r"^(.+)Id$").unwrap(),
172 entity_extraction: EntityExtractionMethod::RemoveSuffix("Id".to_string()),
173 confidence: 0.85,
174 },
175 ForeignKeyPattern {
176 pattern: regex::Regex::new(r"^(.+)_ref$").unwrap(),
177 entity_extraction: EntityExtractionMethod::RemoveSuffix("_ref".to_string()),
178 confidence: 0.8,
179 },
180 ];
181
182 Self {
183 foreign_key_patterns: patterns,
184 }
185 }
186
187 pub fn extract_from_proto(
189 &self,
190 pool: &DescriptorPool,
191 ) -> Result<SchemaGraph, Box<dyn std::error::Error + Send + Sync>> {
192 let mut entities = HashMap::new();
193 let mut relationships = Vec::new();
194 let mut foreign_keys = HashMap::new();
195
196 info!("Extracting schema graph from protobuf descriptors");
197
198 for message_descriptor in pool.all_messages() {
200 let entity = self.extract_entity_from_message(&message_descriptor)?;
201 entities.insert(entity.name.clone(), entity);
202 }
203
204 for (entity_name, entity) in &entities {
206 let fk_mappings = self.detect_foreign_keys(entity, &entities)?;
207 if !fk_mappings.is_empty() {
208 foreign_keys.insert(entity_name.clone(), fk_mappings);
209 }
210
211 let entity_relationships = self.extract_relationships(entity, &entities)?;
212 relationships.extend(entity_relationships);
213 }
214
215 let mut updated_entities = entities;
217 self.update_cross_references(&mut updated_entities, &relationships);
218
219 let graph = SchemaGraph {
220 entities: updated_entities,
221 relationships,
222 foreign_keys,
223 };
224
225 info!(
226 "Extracted schema graph with {} entities and {} relationships",
227 graph.entities.len(),
228 graph.relationships.len()
229 );
230
231 Ok(graph)
232 }
233
234 fn extract_entity_from_message(
236 &self,
237 descriptor: &MessageDescriptor,
238 ) -> Result<EntityNode, Box<dyn std::error::Error + Send + Sync>> {
239 let name = Self::extract_entity_name(descriptor.name());
240 let full_name = descriptor.full_name().to_string();
241
242 let mut fields = Vec::new();
243 for field_descriptor in descriptor.fields() {
244 let field_info = self.extract_field_info(&field_descriptor)?;
245 fields.push(field_info);
246 }
247
248 Ok(EntityNode {
249 name,
250 full_name,
251 fields,
252 is_root: true, referenced_by: Vec::new(),
254 references: Vec::new(),
255 })
256 }
257
258 fn extract_field_info(
260 &self,
261 field: &FieldDescriptor,
262 ) -> Result<FieldInfo, Box<dyn std::error::Error + Send + Sync>> {
263 let name = field.name().to_string();
264 let field_type = Self::kind_to_string(&field.kind());
265 let is_required = true; let (is_foreign_key, foreign_key_target) =
269 self.analyze_potential_foreign_key(&name, &field.kind());
270
271 let mut constraints = HashMap::new();
272 if field.is_list() {
273 constraints.insert("repeated".to_string(), "true".to_string());
274 }
275
276 Ok(FieldInfo {
277 name,
278 field_type,
279 is_foreign_key,
280 foreign_key_target,
281 is_required,
282 constraints,
283 })
284 }
285
286 fn analyze_potential_foreign_key(
288 &self,
289 field_name: &str,
290 kind: &Kind,
291 ) -> (bool, Option<String>) {
292 for pattern in &self.foreign_key_patterns {
294 if pattern.pattern.is_match(field_name) {
295 if let Some(entity_name) = self.extract_entity_name_from_field(field_name, pattern)
296 {
297 return (true, Some(entity_name));
298 }
299 }
300 }
301
302 if let Kind::Message(message_descriptor) = kind {
304 let entity_name = Self::extract_entity_name(message_descriptor.name());
305 return (false, Some(entity_name)); }
307
308 (false, None)
309 }
310
311 fn extract_entity_name_from_field(
313 &self,
314 field_name: &str,
315 pattern: &ForeignKeyPattern,
316 ) -> Option<String> {
317 match &pattern.entity_extraction {
318 EntityExtractionMethod::RemoveSuffix(suffix) => {
319 if field_name.ends_with(suffix) {
320 let base_name = &field_name[..field_name.len() - suffix.len()];
321 Some(Self::normalize_entity_name(base_name))
322 } else {
323 None
324 }
325 }
326 EntityExtractionMethod::Direct => Some(Self::normalize_entity_name(field_name)),
327 EntityExtractionMethod::Custom(func) => func(field_name),
328 }
329 }
330
331 fn detect_foreign_keys(
333 &self,
334 entity: &EntityNode,
335 all_entities: &HashMap<String, EntityNode>,
336 ) -> Result<Vec<ForeignKeyMapping>, Box<dyn std::error::Error + Send + Sync>> {
337 let mut mappings = Vec::new();
338
339 for field in &entity.fields {
340 if field.is_foreign_key {
341 if let Some(target) = &field.foreign_key_target {
342 if all_entities.contains_key(target) {
344 let confidence = self.calculate_confidence_score(field, target, all_entities);
346
347 mappings.push(ForeignKeyMapping {
348 field_name: field.name.clone(),
349 target_entity: target.clone(),
350 confidence,
351 detection_method: ForeignKeyDetectionMethod::NamingConvention,
352 });
353 }
354 }
355 }
356 }
357
358 Ok(mappings)
359 }
360
361 fn calculate_confidence_score(
369 &self,
370 field: &FieldInfo,
371 target_entity: &str,
372 all_entities: &HashMap<String, EntityNode>,
373 ) -> f64 {
374 let mut confidence = 0.5; for pattern in &self.foreign_key_patterns {
378 if pattern.pattern.is_match(&field.name) {
379 confidence = pattern.confidence;
380 break;
381 }
382 }
383
384 if all_entities.contains_key(target_entity) {
386 confidence += 0.1; }
388
389 if field.field_type.contains("message") || field.field_type.contains("Message") {
391 confidence += 0.1; }
393
394 confidence.min(1.0)
396 }
397
398 fn extract_relationships(
400 &self,
401 entity: &EntityNode,
402 all_entities: &HashMap<String, EntityNode>,
403 ) -> Result<Vec<Relationship>, Box<dyn std::error::Error + Send + Sync>> {
404 let mut relationships = Vec::new();
405
406 for field in &entity.fields {
407 if let Some(target_entity) = &field.foreign_key_target {
408 if all_entities.contains_key(target_entity) {
409 let relationship_type = if field.is_foreign_key {
410 RelationshipType::ForeignKey
411 } else if field.field_type.contains("message") {
412 RelationshipType::Embedded
413 } else {
414 RelationshipType::Composition
415 };
416
417 let cardinality = if field.constraints.contains_key("repeated") {
418 Cardinality { min: 0, max: None }
419 } else {
420 Cardinality {
421 min: if field.is_required { 1 } else { 0 },
422 max: Some(1),
423 }
424 };
425
426 relationships.push(Relationship {
427 from_entity: entity.name.clone(),
428 to_entity: target_entity.clone(),
429 relationship_type,
430 field_name: field.name.clone(),
431 is_required: field.is_required,
432 cardinality,
433 });
434 }
435 }
436 }
437
438 Ok(relationships)
439 }
440
441 fn update_cross_references(
443 &self,
444 entities: &mut HashMap<String, EntityNode>,
445 relationships: &[Relationship],
446 ) {
447 let mut referenced_by_map: HashMap<String, Vec<String>> = HashMap::new();
449 let mut references_map: HashMap<String, Vec<String>> = HashMap::new();
450
451 for rel in relationships {
452 references_map
454 .entry(rel.from_entity.clone())
455 .or_default()
456 .push(rel.to_entity.clone());
457
458 referenced_by_map
460 .entry(rel.to_entity.clone())
461 .or_default()
462 .push(rel.from_entity.clone());
463 }
464
465 for (entity_name, entity) in entities.iter_mut() {
467 if let Some(refs) = references_map.get(entity_name) {
468 entity.references = refs.clone();
469 }
470
471 if let Some(referenced_by) = referenced_by_map.get(entity_name) {
472 entity.referenced_by = referenced_by.clone();
473 entity.is_root = false; }
475 }
476 }
477
478 fn kind_to_string(kind: &Kind) -> String {
480 match kind {
481 Kind::String => "string".to_string(),
482 Kind::Int32 => "int32".to_string(),
483 Kind::Int64 => "int64".to_string(),
484 Kind::Uint32 => "uint32".to_string(),
485 Kind::Uint64 => "uint64".to_string(),
486 Kind::Bool => "bool".to_string(),
487 Kind::Float => "float".to_string(),
488 Kind::Double => "double".to_string(),
489 Kind::Bytes => "bytes".to_string(),
490 Kind::Message(msg) => format!("message:{}", msg.full_name()),
491 Kind::Enum(enum_desc) => format!("enum:{}", enum_desc.full_name()),
492 _ => "unknown".to_string(),
493 }
494 }
495
496 fn extract_entity_name(message_name: &str) -> String {
498 Self::normalize_entity_name(message_name)
499 }
500
501 fn normalize_entity_name(name: &str) -> String {
503 name.split('_')
505 .map(|part| {
506 let mut chars: Vec<char> = part.chars().collect();
507 if let Some(first_char) = chars.first_mut() {
508 *first_char = first_char.to_uppercase().next().unwrap_or(*first_char);
509 }
510 chars.into_iter().collect::<String>()
511 })
512 .collect::<String>()
513 }
514}
515
516impl Default for ProtoSchemaGraphExtractor {
517 fn default() -> Self {
518 Self::new()
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525
526 #[test]
527 fn test_foreign_key_pattern_matching() {
528 let extractor = ProtoSchemaGraphExtractor::new();
529
530 let (is_fk, target) = extractor.analyze_potential_foreign_key("user_id", &Kind::Int32);
532 assert!(is_fk);
533 assert_eq!(target, Some("User".to_string()));
534
535 let (is_fk, target) = extractor.analyze_potential_foreign_key("orderId", &Kind::Int64);
536 assert!(is_fk);
537 assert_eq!(target, Some("Order".to_string()));
538 }
539
540 #[test]
541 fn test_entity_name_normalization() {
542 assert_eq!(ProtoSchemaGraphExtractor::normalize_entity_name("user"), "User");
543 assert_eq!(ProtoSchemaGraphExtractor::normalize_entity_name("order_item"), "OrderItem");
544 assert_eq!(
545 ProtoSchemaGraphExtractor::normalize_entity_name("ProductCategory"),
546 "ProductCategory"
547 );
548 }
549}