1use prost_reflect::{DescriptorPool, FieldDescriptor, Kind, MessageDescriptor};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use tracing::info;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct SchemaGraph {
15 pub entities: HashMap<String, EntityNode>,
17 pub relationships: Vec<Relationship>,
19 pub foreign_keys: HashMap<String, Vec<ForeignKeyMapping>>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct EntityNode {
26 pub name: String,
28 pub full_name: String,
30 pub fields: Vec<FieldInfo>,
32 pub is_root: bool,
34 pub referenced_by: Vec<String>,
36 pub references: Vec<String>,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct FieldInfo {
43 pub name: String,
45 pub field_type: String,
47 pub is_foreign_key: bool,
49 pub foreign_key_target: Option<String>,
51 pub is_required: bool,
53 pub constraints: HashMap<String, String>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct Relationship {
60 pub from_entity: String,
62 pub to_entity: String,
64 pub relationship_type: RelationshipType,
66 pub field_name: String,
68 pub is_required: bool,
70 pub cardinality: Cardinality,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub enum RelationshipType {
77 ForeignKey,
79 Embedded,
81 OneToMany,
83 ManyToMany,
85 Composition,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct Cardinality {
92 pub min: u32,
94 pub max: Option<u32>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ForeignKeyMapping {
101 pub field_name: String,
103 pub target_entity: String,
105 pub confidence: f64,
107 pub detection_method: ForeignKeyDetectionMethod,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub enum ForeignKeyDetectionMethod {
114 NamingConvention,
116 SchemaReference,
118 MessageType,
120 Constraint,
122}
123
124pub struct ProtoSchemaGraphExtractor {
126 foreign_key_patterns: Vec<ForeignKeyPattern>,
128}
129
130#[derive(Debug, Clone)]
132struct ForeignKeyPattern {
133 pattern: regex::Regex,
135 entity_extraction: EntityExtractionMethod,
137 #[allow(dead_code)] confidence: f64,
140}
141
142#[derive(Debug, Clone)]
144enum EntityExtractionMethod {
145 RemoveSuffix(String),
147 #[allow(dead_code)] Direct,
150 #[allow(dead_code)] Custom(fn(&str) -> Option<String>),
153}
154
155impl ProtoSchemaGraphExtractor {
156 pub fn new() -> Self {
158 let patterns = vec![
159 ForeignKeyPattern {
160 pattern: regex::Regex::new(r"^(.+)_id$").unwrap(),
161 entity_extraction: EntityExtractionMethod::RemoveSuffix("_id".to_string()),
162 confidence: 0.9,
163 },
164 ForeignKeyPattern {
165 pattern: regex::Regex::new(r"^(.+)Id$").unwrap(),
166 entity_extraction: EntityExtractionMethod::RemoveSuffix("Id".to_string()),
167 confidence: 0.85,
168 },
169 ForeignKeyPattern {
170 pattern: regex::Regex::new(r"^(.+)_ref$").unwrap(),
171 entity_extraction: EntityExtractionMethod::RemoveSuffix("_ref".to_string()),
172 confidence: 0.8,
173 },
174 ];
175
176 Self {
177 foreign_key_patterns: patterns,
178 }
179 }
180
181 pub fn extract_from_proto(
183 &self,
184 pool: &DescriptorPool,
185 ) -> Result<SchemaGraph, Box<dyn std::error::Error + Send + Sync>> {
186 let mut entities = HashMap::new();
187 let mut relationships = Vec::new();
188 let mut foreign_keys = HashMap::new();
189
190 info!("Extracting schema graph from protobuf descriptors");
191
192 for message_descriptor in pool.all_messages() {
194 let entity = self.extract_entity_from_message(&message_descriptor)?;
195 entities.insert(entity.name.clone(), entity);
196 }
197
198 for (entity_name, entity) in &entities {
200 let fk_mappings = self.detect_foreign_keys(entity, &entities)?;
201 if !fk_mappings.is_empty() {
202 foreign_keys.insert(entity_name.clone(), fk_mappings);
203 }
204
205 let entity_relationships = self.extract_relationships(entity, &entities)?;
206 relationships.extend(entity_relationships);
207 }
208
209 let mut updated_entities = entities;
211 self.update_cross_references(&mut updated_entities, &relationships);
212
213 let graph = SchemaGraph {
214 entities: updated_entities,
215 relationships,
216 foreign_keys,
217 };
218
219 info!(
220 "Extracted schema graph with {} entities and {} relationships",
221 graph.entities.len(),
222 graph.relationships.len()
223 );
224
225 Ok(graph)
226 }
227
228 fn extract_entity_from_message(
230 &self,
231 descriptor: &MessageDescriptor,
232 ) -> Result<EntityNode, Box<dyn std::error::Error + Send + Sync>> {
233 let name = Self::extract_entity_name(descriptor.name());
234 let full_name = descriptor.full_name().to_string();
235
236 let mut fields = Vec::new();
237 for field_descriptor in descriptor.fields() {
238 let field_info = self.extract_field_info(&field_descriptor)?;
239 fields.push(field_info);
240 }
241
242 Ok(EntityNode {
243 name,
244 full_name,
245 fields,
246 is_root: true, referenced_by: Vec::new(),
248 references: Vec::new(),
249 })
250 }
251
252 fn extract_field_info(
254 &self,
255 field: &FieldDescriptor,
256 ) -> Result<FieldInfo, Box<dyn std::error::Error + Send + Sync>> {
257 let name = field.name().to_string();
258 let field_type = Self::kind_to_string(&field.kind());
259 let is_required = true; let (is_foreign_key, foreign_key_target) =
263 self.analyze_potential_foreign_key(&name, &field.kind());
264
265 let mut constraints = HashMap::new();
266 if field.is_list() {
267 constraints.insert("repeated".to_string(), "true".to_string());
268 }
269
270 Ok(FieldInfo {
271 name,
272 field_type,
273 is_foreign_key,
274 foreign_key_target,
275 is_required,
276 constraints,
277 })
278 }
279
280 fn analyze_potential_foreign_key(
282 &self,
283 field_name: &str,
284 kind: &Kind,
285 ) -> (bool, Option<String>) {
286 for pattern in &self.foreign_key_patterns {
288 if pattern.pattern.is_match(field_name) {
289 if let Some(entity_name) = self.extract_entity_name_from_field(field_name, pattern)
290 {
291 return (true, Some(entity_name));
292 }
293 }
294 }
295
296 if let Kind::Message(message_descriptor) = kind {
298 let entity_name = Self::extract_entity_name(message_descriptor.name());
299 return (false, Some(entity_name)); }
301
302 (false, None)
303 }
304
305 fn extract_entity_name_from_field(
307 &self,
308 field_name: &str,
309 pattern: &ForeignKeyPattern,
310 ) -> Option<String> {
311 match &pattern.entity_extraction {
312 EntityExtractionMethod::RemoveSuffix(suffix) => {
313 if field_name.ends_with(suffix) {
314 let base_name = &field_name[..field_name.len() - suffix.len()];
315 Some(Self::normalize_entity_name(base_name))
316 } else {
317 None
318 }
319 }
320 EntityExtractionMethod::Direct => Some(Self::normalize_entity_name(field_name)),
321 EntityExtractionMethod::Custom(func) => func(field_name),
322 }
323 }
324
325 fn detect_foreign_keys(
327 &self,
328 entity: &EntityNode,
329 all_entities: &HashMap<String, EntityNode>,
330 ) -> Result<Vec<ForeignKeyMapping>, Box<dyn std::error::Error + Send + Sync>> {
331 let mut mappings = Vec::new();
332
333 for field in &entity.fields {
334 if field.is_foreign_key {
335 if let Some(target) = &field.foreign_key_target {
336 if all_entities.contains_key(target) {
338 mappings.push(ForeignKeyMapping {
339 field_name: field.name.clone(),
340 target_entity: target.clone(),
341 confidence: 0.9, detection_method: ForeignKeyDetectionMethod::NamingConvention,
343 });
344 }
345 }
346 }
347 }
348
349 Ok(mappings)
350 }
351
352 fn extract_relationships(
354 &self,
355 entity: &EntityNode,
356 all_entities: &HashMap<String, EntityNode>,
357 ) -> Result<Vec<Relationship>, Box<dyn std::error::Error + Send + Sync>> {
358 let mut relationships = Vec::new();
359
360 for field in &entity.fields {
361 if let Some(target_entity) = &field.foreign_key_target {
362 if all_entities.contains_key(target_entity) {
363 let relationship_type = if field.is_foreign_key {
364 RelationshipType::ForeignKey
365 } else if field.field_type.contains("message") {
366 RelationshipType::Embedded
367 } else {
368 RelationshipType::Composition
369 };
370
371 let cardinality = if field.constraints.contains_key("repeated") {
372 Cardinality { min: 0, max: None }
373 } else {
374 Cardinality {
375 min: if field.is_required { 1 } else { 0 },
376 max: Some(1),
377 }
378 };
379
380 relationships.push(Relationship {
381 from_entity: entity.name.clone(),
382 to_entity: target_entity.clone(),
383 relationship_type,
384 field_name: field.name.clone(),
385 is_required: field.is_required,
386 cardinality,
387 });
388 }
389 }
390 }
391
392 Ok(relationships)
393 }
394
395 fn update_cross_references(
397 &self,
398 entities: &mut HashMap<String, EntityNode>,
399 relationships: &[Relationship],
400 ) {
401 let mut referenced_by_map: HashMap<String, Vec<String>> = HashMap::new();
403 let mut references_map: HashMap<String, Vec<String>> = HashMap::new();
404
405 for rel in relationships {
406 references_map
408 .entry(rel.from_entity.clone())
409 .or_default()
410 .push(rel.to_entity.clone());
411
412 referenced_by_map
414 .entry(rel.to_entity.clone())
415 .or_default()
416 .push(rel.from_entity.clone());
417 }
418
419 for (entity_name, entity) in entities.iter_mut() {
421 if let Some(refs) = references_map.get(entity_name) {
422 entity.references = refs.clone();
423 }
424
425 if let Some(referenced_by) = referenced_by_map.get(entity_name) {
426 entity.referenced_by = referenced_by.clone();
427 entity.is_root = false; }
429 }
430 }
431
432 fn kind_to_string(kind: &Kind) -> String {
434 match kind {
435 Kind::String => "string".to_string(),
436 Kind::Int32 => "int32".to_string(),
437 Kind::Int64 => "int64".to_string(),
438 Kind::Uint32 => "uint32".to_string(),
439 Kind::Uint64 => "uint64".to_string(),
440 Kind::Bool => "bool".to_string(),
441 Kind::Float => "float".to_string(),
442 Kind::Double => "double".to_string(),
443 Kind::Bytes => "bytes".to_string(),
444 Kind::Message(msg) => format!("message:{}", msg.full_name()),
445 Kind::Enum(enum_desc) => format!("enum:{}", enum_desc.full_name()),
446 _ => "unknown".to_string(),
447 }
448 }
449
450 fn extract_entity_name(message_name: &str) -> String {
452 Self::normalize_entity_name(message_name)
453 }
454
455 fn normalize_entity_name(name: &str) -> String {
457 name.split('_')
459 .map(|part| {
460 let mut chars: Vec<char> = part.chars().collect();
461 if let Some(first_char) = chars.first_mut() {
462 *first_char = first_char.to_uppercase().next().unwrap_or(*first_char);
463 }
464 chars.into_iter().collect::<String>()
465 })
466 .collect::<String>()
467 }
468}
469
470impl Default for ProtoSchemaGraphExtractor {
471 fn default() -> Self {
472 Self::new()
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_foreign_key_pattern_matching() {
482 let extractor = ProtoSchemaGraphExtractor::new();
483
484 let (is_fk, target) = extractor.analyze_potential_foreign_key("user_id", &Kind::Int32);
486 assert!(is_fk);
487 assert_eq!(target, Some("User".to_string()));
488
489 let (is_fk, target) = extractor.analyze_potential_foreign_key("orderId", &Kind::Int64);
490 assert!(is_fk);
491 assert_eq!(target, Some("Order".to_string()));
492 }
493
494 #[test]
495 fn test_entity_name_normalization() {
496 assert_eq!(ProtoSchemaGraphExtractor::normalize_entity_name("user"), "User");
497 assert_eq!(ProtoSchemaGraphExtractor::normalize_entity_name("order_item"), "OrderItem");
498 assert_eq!(
499 ProtoSchemaGraphExtractor::normalize_entity_name("ProductCategory"),
500 "ProductCategory"
501 );
502 }
503}