infiniloom_engine/embedding/
contract_detection.rs1use regex::Regex;
41use serde::{Deserialize, Serialize};
42use std::sync::OnceLock;
43
44#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
50pub struct ContractDefinition {
51 pub contract_type: ContractType,
53
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub package: Option<String>,
57
58 #[serde(skip_serializing_if = "Vec::is_empty", default)]
60 pub services: Vec<ServiceDef>,
61
62 #[serde(skip_serializing_if = "Vec::is_empty", default)]
64 pub messages: Vec<MessageDef>,
65
66 #[serde(skip_serializing_if = "Vec::is_empty", default)]
68 pub enums: Vec<EnumDef>,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
73#[serde(rename_all = "snake_case")]
74pub enum ContractType {
75 Protobuf,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
81pub struct ServiceDef {
82 pub name: String,
84
85 #[serde(skip_serializing_if = "Vec::is_empty", default)]
88 pub methods: Vec<String>,
89}
90
91#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
93pub struct MessageDef {
94 pub name: String,
96
97 #[serde(skip_serializing_if = "Vec::is_empty", default)]
100 pub fields: Vec<String>,
101
102 #[serde(skip_serializing_if = "Vec::is_empty", default)]
104 pub nested_messages: Vec<String>,
105}
106
107#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
109pub struct EnumDef {
110 pub name: String,
112
113 #[serde(skip_serializing_if = "Vec::is_empty", default)]
115 pub values: Vec<String>,
116}
117
118static PACKAGE_RE: OnceLock<Regex> = OnceLock::new();
120static MESSAGE_RE: OnceLock<Regex> = OnceLock::new();
121static FIELD_RE: OnceLock<Regex> = OnceLock::new();
122static SERVICE_RE: OnceLock<Regex> = OnceLock::new();
123static RPC_RE: OnceLock<Regex> = OnceLock::new();
124static ENUM_RE: OnceLock<Regex> = OnceLock::new();
125static ENUM_VALUE_RE: OnceLock<Regex> = OnceLock::new();
126
127fn init_patterns() {
129 PACKAGE_RE.get_or_init(|| Regex::new(r"^\s*package\s+([a-zA-Z0-9_.]+)\s*;").unwrap());
130 MESSAGE_RE.get_or_init(|| Regex::new(r"^\s*message\s+([a-zA-Z0-9_]+)\s*\{").unwrap());
131 FIELD_RE.get_or_init(|| {
132 Regex::new(
133 r"^\s*(?:optional|required|repeated)?\s*([a-zA-Z0-9_.<>]+)\s+([a-zA-Z0-9_]+)\s*=\s*\d+",
134 )
135 .unwrap()
136 });
137 SERVICE_RE.get_or_init(|| Regex::new(r"^\s*service\s+([a-zA-Z0-9_]+)\s*\{").unwrap());
138 RPC_RE.get_or_init(|| {
139 Regex::new(
140 r"^\s*rpc\s+([a-zA-Z0-9_]+)\s*\(([a-zA-Z0-9_.]+)\)\s*returns\s*\(([a-zA-Z0-9_.]+)\)",
141 )
142 .unwrap()
143 });
144 ENUM_RE.get_or_init(|| Regex::new(r"^\s*enum\s+([a-zA-Z0-9_]+)\s*\{").unwrap());
145 ENUM_VALUE_RE.get_or_init(|| Regex::new(r"^\s*([a-zA-Z0-9_]+)\s*=\s*\d+").unwrap());
146}
147
148pub fn detect_contracts(content: &str, file_path: &str) -> Option<ContractDefinition> {
174 if !file_path.ends_with(".proto") {
176 return None;
177 }
178
179 init_patterns();
181
182 let content = strip_comments(content);
184
185 let package = extract_package(&content);
187 let services = extract_services(&content);
188 let messages = extract_messages(&content);
189 let enums = extract_enums(&content);
190
191 Some(ContractDefinition {
193 contract_type: ContractType::Protobuf,
194 package,
195 services,
196 messages,
197 enums,
198 })
199}
200
201pub fn contract_tags(contract: &ContractDefinition) -> Vec<String> {
224 let mut tags = vec!["protobuf".to_owned(), "api-contract".to_owned()];
225
226 if !contract.services.is_empty() {
228 tags.push("grpc".to_owned());
229 }
230
231 tags
232}
233
234fn strip_comments(content: &str) -> String {
236 let mut result = String::with_capacity(content.len());
237 let mut chars = content.chars().peekable();
238
239 while let Some(ch) = chars.next() {
240 if ch == '/' {
241 if let Some(&next_ch) = chars.peek() {
242 if next_ch == '/' {
243 chars.next(); for c in chars.by_ref() {
246 if c == '\n' {
247 result.push('\n'); break;
249 }
250 }
251 continue;
252 } else if next_ch == '*' {
253 chars.next(); let mut prev = ' ';
256 for c in chars.by_ref() {
257 if prev == '*' && c == '/' {
258 break;
259 }
260 if c == '\n' {
261 result.push('\n'); }
263 prev = c;
264 }
265 continue;
266 }
267 }
268 }
269 result.push(ch);
270 }
271
272 result
273}
274
275fn extract_package(content: &str) -> Option<String> {
277 let package_re = PACKAGE_RE.get()?;
278 for line in content.lines() {
279 if let Some(caps) = package_re.captures(line) {
280 return Some(caps[1].to_string());
281 }
282 }
283 None
284}
285
286fn extract_services(content: &str) -> Vec<ServiceDef> {
288 let service_re = SERVICE_RE.get().unwrap();
289 let rpc_re = RPC_RE.get().unwrap();
290
291 let mut services = Vec::new();
292 let lines: Vec<&str> = content.lines().collect();
293 let mut i = 0;
294
295 while i < lines.len() {
296 let line = lines[i];
297 if let Some(caps) = service_re.captures(line) {
298 let service_name = caps[1].to_string();
299 let mut methods = Vec::new();
300
301 i += 1;
303 let mut brace_depth = 1;
304 while i < lines.len() && brace_depth > 0 {
305 let method_line = lines[i];
306
307 brace_depth += method_line.matches('{').count();
309 brace_depth -= method_line.matches('}').count();
310
311 if let Some(rpc_caps) = rpc_re.captures(method_line) {
312 let method_name = &rpc_caps[1];
313 let request = &rpc_caps[2];
314 let response = &rpc_caps[3];
315 methods.push(format!("{}({}) returns ({})", method_name, request, response));
316 }
317 i += 1;
318 }
319
320 services.push(ServiceDef { name: service_name, methods });
321 continue;
322 }
323 i += 1;
324 }
325
326 services
327}
328
329fn extract_messages(content: &str) -> Vec<MessageDef> {
331 let message_re = MESSAGE_RE.get().unwrap();
332 let field_re = FIELD_RE.get().unwrap();
333
334 let mut messages = Vec::new();
335 let lines: Vec<&str> = content.lines().collect();
336 let mut i = 0;
337
338 while i < lines.len() {
339 let line = lines[i];
340 if let Some(caps) = message_re.captures(line) {
341 let message_name = caps[1].to_string();
342 let mut fields = Vec::new();
343 let mut nested_messages = Vec::new();
344
345 i += 1;
347 let mut brace_depth = 1;
348 while i < lines.len() && brace_depth > 0 {
349 let field_line = lines[i];
350
351 if brace_depth == 1 {
354 if let Some(nested_caps) = message_re.captures(field_line) {
355 nested_messages.push(nested_caps[1].to_string());
356 } else if let Some(field_caps) = field_re.captures(field_line) {
357 let field_type = &field_caps[1];
358 let field_name = &field_caps[2];
359 fields.push(format!("{}: {}", field_name, field_type));
360 }
361 }
362
363 brace_depth += field_line.matches('{').count();
365 brace_depth -= field_line.matches('}').count();
366
367 i += 1;
368 }
369
370 messages.push(MessageDef { name: message_name, fields, nested_messages });
371 continue;
372 }
373 i += 1;
374 }
375
376 messages
377}
378
379fn extract_enums(content: &str) -> Vec<EnumDef> {
381 let enum_re = ENUM_RE.get().unwrap();
382 let enum_value_re = ENUM_VALUE_RE.get().unwrap();
383
384 let mut enums = Vec::new();
385 let lines: Vec<&str> = content.lines().collect();
386 let mut i = 0;
387
388 while i < lines.len() {
389 let line = lines[i];
390 if let Some(caps) = enum_re.captures(line) {
391 let enum_name = caps[1].to_string();
392 let mut values = Vec::new();
393
394 i += 1;
396 let mut brace_depth = 1;
397 while i < lines.len() && brace_depth > 0 {
398 let value_line = lines[i];
399
400 brace_depth += value_line.matches('{').count();
402 brace_depth -= value_line.matches('}').count();
403
404 if let Some(value_caps) = enum_value_re.captures(value_line) {
405 values.push(value_caps[1].to_string());
406 }
407 i += 1;
408 }
409
410 enums.push(EnumDef { name: enum_name, values });
411 continue;
412 }
413 i += 1;
414 }
415
416 enums
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_non_proto_file_returns_none() {
425 let content = "function foo() {}";
426 let result = detect_contracts(content, "foo.js");
427 assert!(result.is_none());
428 }
429
430 #[test]
431 fn test_empty_proto_file() {
432 let content = r#"
433syntax = "proto3";
434"#;
435 let result = detect_contracts(content, "empty.proto");
436 assert!(result.is_some());
437 let contract = result.unwrap();
438 assert_eq!(contract.contract_type, ContractType::Protobuf);
439 assert!(contract.package.is_none());
440 assert!(contract.services.is_empty());
441 assert!(contract.messages.is_empty());
442 assert!(contract.enums.is_empty());
443 }
444
445 #[test]
446 fn test_parse_package() {
447 let content = r#"
448syntax = "proto3";
449package myapp.user.v1;
450"#;
451 let result = detect_contracts(content, "user.proto");
452 assert!(result.is_some());
453 let contract = result.unwrap();
454 assert_eq!(contract.package, Some("myapp.user.v1".to_string()));
455 }
456
457 #[test]
458 fn test_parse_message_with_fields() {
459 let content = r#"
460syntax = "proto3";
461
462message User {
463 string id = 1;
464 string name = 2;
465 int32 age = 3;
466}
467"#;
468 let result = detect_contracts(content, "user.proto");
469 assert!(result.is_some());
470 let contract = result.unwrap();
471 assert_eq!(contract.messages.len(), 1);
472
473 let message = &contract.messages[0];
474 assert_eq!(message.name, "User");
475 assert_eq!(message.fields.len(), 3);
476 assert!(message.fields.contains(&"id: string".to_string()));
477 assert!(message.fields.contains(&"name: string".to_string()));
478 assert!(message.fields.contains(&"age: int32".to_string()));
479 }
480
481 #[test]
482 fn test_parse_service_with_rpcs() {
483 let content = r#"
484syntax = "proto3";
485
486service UserService {
487 rpc GetUser(GetUserRequest) returns (GetUserResponse);
488 rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
489}
490"#;
491 let result = detect_contracts(content, "user.proto");
492 assert!(result.is_some());
493 let contract = result.unwrap();
494 assert_eq!(contract.services.len(), 1);
495
496 let service = &contract.services[0];
497 assert_eq!(service.name, "UserService");
498 assert_eq!(service.methods.len(), 2);
499 assert!(service
500 .methods
501 .contains(&"GetUser(GetUserRequest) returns (GetUserResponse)".to_string()));
502 assert!(service
503 .methods
504 .contains(&"CreateUser(CreateUserRequest) returns (CreateUserResponse)".to_string()));
505 }
506
507 #[test]
508 fn test_parse_enum() {
509 let content = r#"
510syntax = "proto3";
511
512enum UserRole {
513 USER_ROLE_UNSPECIFIED = 0;
514 USER_ROLE_ADMIN = 1;
515 USER_ROLE_MEMBER = 2;
516}
517"#;
518 let result = detect_contracts(content, "user.proto");
519 assert!(result.is_some());
520 let contract = result.unwrap();
521 assert_eq!(contract.enums.len(), 1);
522
523 let enum_def = &contract.enums[0];
524 assert_eq!(enum_def.name, "UserRole");
525 assert_eq!(enum_def.values.len(), 3);
526 assert!(enum_def
527 .values
528 .contains(&"USER_ROLE_UNSPECIFIED".to_string()));
529 assert!(enum_def.values.contains(&"USER_ROLE_ADMIN".to_string()));
530 assert!(enum_def.values.contains(&"USER_ROLE_MEMBER".to_string()));
531 }
532
533 #[test]
534 fn test_parse_nested_messages() {
535 let content = r#"
536syntax = "proto3";
537
538message User {
539 string id = 1;
540
541 message Address {
542 string street = 1;
543 string city = 2;
544 }
545
546 Address address = 2;
547}
548"#;
549 let result = detect_contracts(content, "user.proto");
550 assert!(result.is_some());
551 let contract = result.unwrap();
552 assert_eq!(contract.messages.len(), 1);
553
554 let message = &contract.messages[0];
555 assert_eq!(message.name, "User");
556 assert_eq!(message.nested_messages.len(), 1);
557 assert!(message.nested_messages.contains(&"Address".to_string()));
558 assert!(message.fields.iter().any(|f| f.contains("address")));
560 }
561
562 #[test]
563 fn test_full_proto3_file() {
564 let content = r#"
565syntax = "proto3";
566
567package myapp.user.v1;
568
569// User service handles user operations
570service UserService {
571 rpc GetUser(GetUserRequest) returns (GetUserResponse);
572 rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
573}
574
575message GetUserRequest {
576 string user_id = 1;
577}
578
579message GetUserResponse {
580 User user = 1;
581}
582
583message User {
584 string id = 1;
585 string name = 2;
586 string email = 3;
587 UserRole role = 4;
588}
589
590enum UserRole {
591 USER_ROLE_UNSPECIFIED = 0;
592 USER_ROLE_ADMIN = 1;
593 USER_ROLE_MEMBER = 2;
594}
595"#;
596 let result = detect_contracts(content, "user.proto");
597 assert!(result.is_some());
598 let contract = result.unwrap();
599
600 assert_eq!(contract.package, Some("myapp.user.v1".to_string()));
601 assert_eq!(contract.services.len(), 1);
602 assert_eq!(contract.messages.len(), 3);
603 assert_eq!(contract.enums.len(), 1);
604
605 let service = &contract.services[0];
607 assert_eq!(service.name, "UserService");
608 assert_eq!(service.methods.len(), 2);
609
610 let message_names: Vec<&str> = contract.messages.iter().map(|m| m.name.as_str()).collect();
612 assert!(message_names.contains(&"GetUserRequest"));
613 assert!(message_names.contains(&"GetUserResponse"));
614 assert!(message_names.contains(&"User"));
615
616 let enum_def = &contract.enums[0];
618 assert_eq!(enum_def.name, "UserRole");
619 assert_eq!(enum_def.values.len(), 3);
620 }
621
622 #[test]
623 fn test_tags_include_grpc_when_services_present() {
624 let contract = ContractDefinition {
625 contract_type: ContractType::Protobuf,
626 package: Some("myapp.v1".to_string()),
627 services: vec![ServiceDef { name: "UserService".to_string(), methods: vec![] }],
628 messages: vec![],
629 enums: vec![],
630 };
631
632 let tags = contract_tags(&contract);
633 assert!(tags.contains(&"protobuf".to_string()));
634 assert!(tags.contains(&"grpc".to_string()));
635 assert!(tags.contains(&"api-contract".to_string()));
636 }
637
638 #[test]
639 fn test_tags_without_services() {
640 let contract = ContractDefinition {
641 contract_type: ContractType::Protobuf,
642 package: Some("myapp.v1".to_string()),
643 services: vec![],
644 messages: vec![MessageDef {
645 name: "User".to_string(),
646 fields: vec![],
647 nested_messages: vec![],
648 }],
649 enums: vec![],
650 };
651
652 let tags = contract_tags(&contract);
653 assert!(tags.contains(&"protobuf".to_string()));
654 assert!(!tags.contains(&"grpc".to_string()));
655 assert!(tags.contains(&"api-contract".to_string()));
656 }
657
658 #[test]
659 fn test_comments_are_stripped() {
660 let content = r#"
661syntax = "proto3";
662
663// This is a line comment
664package myapp.v1;
665
666/* This is a block comment
667 spanning multiple lines */
668message User {
669 string id = 1; // inline comment
670 string name = 2;
671}
672"#;
673 let result = detect_contracts(content, "user.proto");
674 assert!(result.is_some());
675 let contract = result.unwrap();
676
677 assert_eq!(contract.package, Some("myapp.v1".to_string()));
678 assert_eq!(contract.messages.len(), 1);
679 assert_eq!(contract.messages[0].name, "User");
680 }
681}