1use prost_reflect::{DescriptorPool, FieldDescriptor, Kind, MessageDescriptor};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tracing::info;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SchemaGraph {
15 pub entities: HashMap<String, EntityNode>,
17 pub relationships: Vec<Relationship>,
19 pub foreign_keys: HashMap<String, Vec<ForeignKeyMapping>>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct EntityNode {
26 pub name: String,
28 pub full_name: String,
30 pub fields: Vec<FieldInfo>,
32 pub is_root: bool,
34 pub referenced_by: Vec<String>,
36 pub references: Vec<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct FieldInfo {
43 pub name: String,
45 pub field_type: String,
47 pub is_foreign_key: bool,
49 pub foreign_key_target: Option<String>,
51 pub is_required: bool,
53 pub constraints: HashMap<String, String>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct Relationship {
60 pub from_entity: String,
62 pub to_entity: String,
64 pub relationship_type: RelationshipType,
66 pub field_name: String,
68 pub is_required: bool,
70 pub cardinality: Cardinality,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum RelationshipType {
77 ForeignKey,
79 Embedded,
81 OneToMany,
83 ManyToMany,
85 Composition,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct Cardinality {
92 pub min: u32,
94 pub max: Option<u32>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ForeignKeyMapping {
101 pub field_name: String,
103 pub target_entity: String,
105 pub confidence: f64,
107 pub detection_method: ForeignKeyDetectionMethod,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum ForeignKeyDetectionMethod {
114 NamingConvention,
116 SchemaReference,
118 MessageType,
120 Constraint,
122}
123
124pub struct ProtoSchemaGraphExtractor {
126 foreign_key_patterns: Vec<ForeignKeyPattern>,
128}
129
130#[derive(Debug, Clone)]
132struct ForeignKeyPattern {
133 pattern: regex::Regex,
135 entity_extraction: EntityExtractionMethod,
137 confidence: f64,
142}
143
144#[derive(Debug, Clone)]
146enum EntityExtractionMethod {
147 RemoveSuffix(String),
149 #[allow(dead_code)] Direct,
154 #[allow(dead_code)] Custom(fn(&str) -> Option<String>),
159}
160
161impl ProtoSchemaGraphExtractor {
162 pub fn new() -> Self {
164 let patterns = vec![
165 ForeignKeyPattern {
166 pattern: regex::Regex::new(r"^(.+)_id$").unwrap(),
167 entity_extraction: EntityExtractionMethod::RemoveSuffix("_id".to_string()),
168 confidence: 0.9,
169 },
170 ForeignKeyPattern {
171 pattern: regex::Regex::new(r"^(.+)Id$").unwrap(),
172 entity_extraction: EntityExtractionMethod::RemoveSuffix("Id".to_string()),
173 confidence: 0.85,
174 },
175 ForeignKeyPattern {
176 pattern: regex::Regex::new(r"^(.+)_ref$").unwrap(),
177 entity_extraction: EntityExtractionMethod::RemoveSuffix("_ref".to_string()),
178 confidence: 0.8,
179 },
180 ];
181
182 Self {
183 foreign_key_patterns: patterns,
184 }
185 }
186
187 pub fn extract_from_proto(
189 &self,
190 pool: &DescriptorPool,
191 ) -> Result<SchemaGraph, Box<dyn std::error::Error + Send + Sync>> {
192 let mut entities = HashMap::new();
193 let mut relationships = Vec::new();
194 let mut foreign_keys = HashMap::new();
195
196 info!("Extracting schema graph from protobuf descriptors");
197
198 for message_descriptor in pool.all_messages() {
200 let entity = self.extract_entity_from_message(&message_descriptor)?;
201 entities.insert(entity.name.clone(), entity);
202 }
203
204 for (entity_name, entity) in &entities {
206 let fk_mappings = self.detect_foreign_keys(entity, &entities)?;
207 if !fk_mappings.is_empty() {
208 foreign_keys.insert(entity_name.clone(), fk_mappings);
209 }
210
211 let entity_relationships = self.extract_relationships(entity, &entities)?;
212 relationships.extend(entity_relationships);
213 }
214
215 let mut updated_entities = entities;
217 self.update_cross_references(&mut updated_entities, &relationships);
218
219 let graph = SchemaGraph {
220 entities: updated_entities,
221 relationships,
222 foreign_keys,
223 };
224
225 info!(
226 "Extracted schema graph with {} entities and {} relationships",
227 graph.entities.len(),
228 graph.relationships.len()
229 );
230
231 Ok(graph)
232 }
233
234 fn extract_entity_from_message(
236 &self,
237 descriptor: &MessageDescriptor,
238 ) -> Result<EntityNode, Box<dyn std::error::Error + Send + Sync>> {
239 let name = Self::extract_entity_name(descriptor.name());
240 let full_name = descriptor.full_name().to_string();
241
242 let mut fields = Vec::new();
243 for field_descriptor in descriptor.fields() {
244 let field_info = self.extract_field_info(&field_descriptor)?;
245 fields.push(field_info);
246 }
247
248 Ok(EntityNode {
249 name,
250 full_name,
251 fields,
252 is_root: true, referenced_by: Vec::new(),
254 references: Vec::new(),
255 })
256 }
257
258 fn extract_field_info(
260 &self,
261 field: &FieldDescriptor,
262 ) -> Result<FieldInfo, Box<dyn std::error::Error + Send + Sync>> {
263 let name = field.name().to_string();
264 let field_type = Self::kind_to_string(&field.kind());
265 let is_required = true; let (is_foreign_key, foreign_key_target) =
269 self.analyze_potential_foreign_key(&name, &field.kind());
270
271 let mut constraints = HashMap::new();
272 if field.is_list() {
273 constraints.insert("repeated".to_string(), "true".to_string());
274 }
275
276 Ok(FieldInfo {
277 name,
278 field_type,
279 is_foreign_key,
280 foreign_key_target,
281 is_required,
282 constraints,
283 })
284 }
285
286 fn analyze_potential_foreign_key(
288 &self,
289 field_name: &str,
290 kind: &Kind,
291 ) -> (bool, Option<String>) {
292 for pattern in &self.foreign_key_patterns {
294 if pattern.pattern.is_match(field_name) {
295 if let Some(entity_name) = self.extract_entity_name_from_field(field_name, pattern)
296 {
297 return (true, Some(entity_name));
298 }
299 }
300 }
301
302 if let Kind::Message(message_descriptor) = kind {
304 let entity_name = Self::extract_entity_name(message_descriptor.name());
305 return (false, Some(entity_name)); }
307
308 (false, None)
309 }
310
311 fn extract_entity_name_from_field(
313 &self,
314 field_name: &str,
315 pattern: &ForeignKeyPattern,
316 ) -> Option<String> {
317 match &pattern.entity_extraction {
318 EntityExtractionMethod::RemoveSuffix(suffix) => {
319 if field_name.ends_with(suffix) {
320 let base_name = &field_name[..field_name.len() - suffix.len()];
321 Some(Self::normalize_entity_name(base_name))
322 } else {
323 None
324 }
325 }
326 EntityExtractionMethod::Direct => Some(Self::normalize_entity_name(field_name)),
327 EntityExtractionMethod::Custom(func) => func(field_name),
328 }
329 }
330
331 fn detect_foreign_keys(
333 &self,
334 entity: &EntityNode,
335 all_entities: &HashMap<String, EntityNode>,
336 ) -> Result<Vec<ForeignKeyMapping>, Box<dyn std::error::Error + Send + Sync>> {
337 let mut mappings = Vec::new();
338
339 for field in &entity.fields {
340 if field.is_foreign_key {
341 if let Some(target) = &field.foreign_key_target {
342 if all_entities.contains_key(target) {
344 let confidence =
346 self.calculate_confidence_score(field, target, all_entities);
347
348 mappings.push(ForeignKeyMapping {
349 field_name: field.name.clone(),
350 target_entity: target.clone(),
351 confidence,
352 detection_method: ForeignKeyDetectionMethod::NamingConvention,
353 });
354 }
355 }
356 }
357 }
358
359 Ok(mappings)
360 }
361
362 fn calculate_confidence_score(
370 &self,
371 field: &FieldInfo,
372 target_entity: &str,
373 all_entities: &HashMap<String, EntityNode>,
374 ) -> f64 {
375 let mut confidence = 0.5; for pattern in &self.foreign_key_patterns {
379 if pattern.pattern.is_match(&field.name) {
380 confidence = pattern.confidence;
381 break;
382 }
383 }
384
385 if all_entities.contains_key(target_entity) {
387 confidence += 0.1; }
389
390 if field.field_type.contains("message") || field.field_type.contains("Message") {
392 confidence += 0.1; }
394
395 confidence.min(1.0)
397 }
398
399 fn extract_relationships(
401 &self,
402 entity: &EntityNode,
403 all_entities: &HashMap<String, EntityNode>,
404 ) -> Result<Vec<Relationship>, Box<dyn std::error::Error + Send + Sync>> {
405 let mut relationships = Vec::new();
406
407 for field in &entity.fields {
408 if let Some(target_entity) = &field.foreign_key_target {
409 if all_entities.contains_key(target_entity) {
410 let relationship_type = if field.is_foreign_key {
411 RelationshipType::ForeignKey
412 } else if field.field_type.contains("message") {
413 RelationshipType::Embedded
414 } else {
415 RelationshipType::Composition
416 };
417
418 let cardinality = if field.constraints.contains_key("repeated") {
419 Cardinality { min: 0, max: None }
420 } else {
421 Cardinality {
422 min: if field.is_required { 1 } else { 0 },
423 max: Some(1),
424 }
425 };
426
427 relationships.push(Relationship {
428 from_entity: entity.name.clone(),
429 to_entity: target_entity.clone(),
430 relationship_type,
431 field_name: field.name.clone(),
432 is_required: field.is_required,
433 cardinality,
434 });
435 }
436 }
437 }
438
439 Ok(relationships)
440 }
441
442 fn update_cross_references(
444 &self,
445 entities: &mut HashMap<String, EntityNode>,
446 relationships: &[Relationship],
447 ) {
448 let mut referenced_by_map: HashMap<String, Vec<String>> = HashMap::new();
450 let mut references_map: HashMap<String, Vec<String>> = HashMap::new();
451
452 for rel in relationships {
453 references_map
455 .entry(rel.from_entity.clone())
456 .or_default()
457 .push(rel.to_entity.clone());
458
459 referenced_by_map
461 .entry(rel.to_entity.clone())
462 .or_default()
463 .push(rel.from_entity.clone());
464 }
465
466 for (entity_name, entity) in entities.iter_mut() {
468 if let Some(refs) = references_map.get(entity_name) {
469 entity.references = refs.clone();
470 }
471
472 if let Some(referenced_by) = referenced_by_map.get(entity_name) {
473 entity.referenced_by = referenced_by.clone();
474 entity.is_root = false; }
476 }
477 }
478
479 fn kind_to_string(kind: &Kind) -> String {
481 match kind {
482 Kind::String => "string".to_string(),
483 Kind::Int32 => "int32".to_string(),
484 Kind::Int64 => "int64".to_string(),
485 Kind::Uint32 => "uint32".to_string(),
486 Kind::Uint64 => "uint64".to_string(),
487 Kind::Bool => "bool".to_string(),
488 Kind::Float => "float".to_string(),
489 Kind::Double => "double".to_string(),
490 Kind::Bytes => "bytes".to_string(),
491 Kind::Message(msg) => format!("message:{}", msg.full_name()),
492 Kind::Enum(enum_desc) => format!("enum:{}", enum_desc.full_name()),
493 _ => "unknown".to_string(),
494 }
495 }
496
497 fn extract_entity_name(message_name: &str) -> String {
499 Self::normalize_entity_name(message_name)
500 }
501
502 fn normalize_entity_name(name: &str) -> String {
504 name.split('_')
506 .map(|part| {
507 let mut chars: Vec<char> = part.chars().collect();
508 if let Some(first_char) = chars.first_mut() {
509 *first_char = first_char.to_uppercase().next().unwrap_or(*first_char);
510 }
511 chars.into_iter().collect::<String>()
512 })
513 .collect::<String>()
514 }
515}
516
517impl Default for ProtoSchemaGraphExtractor {
518 fn default() -> Self {
519 Self::new()
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526
527 #[test]
528 fn test_foreign_key_pattern_matching() {
529 let extractor = ProtoSchemaGraphExtractor::new();
530
531 let (is_fk, target) = extractor.analyze_potential_foreign_key("user_id", &Kind::Int32);
533 assert!(is_fk);
534 assert_eq!(target, Some("User".to_string()));
535
536 let (is_fk, target) = extractor.analyze_potential_foreign_key("orderId", &Kind::Int64);
537 assert!(is_fk);
538 assert_eq!(target, Some("Order".to_string()));
539 }
540
541 #[test]
542 fn test_entity_name_normalization() {
543 assert_eq!(ProtoSchemaGraphExtractor::normalize_entity_name("user"), "User");
544 assert_eq!(ProtoSchemaGraphExtractor::normalize_entity_name("order_item"), "OrderItem");
545 assert_eq!(
546 ProtoSchemaGraphExtractor::normalize_entity_name("ProductCategory"),
547 "ProductCategory"
548 );
549 }
550}