Skip to main content

infiniloom_engine/embedding/
contract_detection.rs

1//! Cross-language protocol and API contract detection
2//!
3//! This module identifies shared protocol definitions (currently protobuf)
4//! and generates metadata for enriching embedding chunks. This enables RAG
5//! systems to understand that a protobuf message like `UserRequest` is a
6//! shared contract used across multiple languages/repos.
7//!
8//! # Supported Protocols
9//!
10//! - **Protobuf** (proto2 and proto3): Messages, services, enums, nested types
11//!
12//! # Example
13//!
14//! ```rust,ignore
15//! use infiniloom_engine::embedding::contract_detection::{detect_contracts, contract_tags};
16//!
17//! let proto_content = r#"
18//! syntax = "proto3";
19//! package myapp.user.v1;
20//!
21//! service UserService {
22//!     rpc GetUser(GetUserRequest) returns (GetUserResponse);
23//! }
24//!
25//! message User {
26//!     string id = 1;
27//!     string name = 2;
28//! }
29//! "#;
30//!
31//! if let Some(contract) = detect_contracts(proto_content, "user.proto") {
32//!     let tags = contract_tags(&contract);
33//!     assert!(tags.contains(&"protobuf".to_string()));
34//!     assert!(tags.contains(&"grpc".to_string()));
35//!     assert_eq!(contract.services.len(), 1);
36//!     assert_eq!(contract.messages.len(), 1);
37//! }
38//! ```
39
40use regex::Regex;
41use serde::{Deserialize, Serialize};
42use std::sync::OnceLock;
43
44/// Detected contract/protocol definition
45///
46/// Represents a parsed protocol definition with all its services,
47/// messages, and enums. This metadata enriches embedding chunks
48/// for better cross-language retrieval.
49#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
50pub struct ContractDefinition {
51    /// Contract type (always "protobuf" for now)
52    pub contract_type: ContractType,
53
54    /// Package name (e.g., "myapp.user.v1")
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub package: Option<String>,
57
58    /// Service definitions found
59    #[serde(skip_serializing_if = "Vec::is_empty", default)]
60    pub services: Vec<ServiceDef>,
61
62    /// Message definitions found
63    #[serde(skip_serializing_if = "Vec::is_empty", default)]
64    pub messages: Vec<MessageDef>,
65
66    /// Enum definitions found
67    #[serde(skip_serializing_if = "Vec::is_empty", default)]
68    pub enums: Vec<EnumDef>,
69}
70
71/// Type of protocol/contract
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
73#[serde(rename_all = "snake_case")]
74pub enum ContractType {
75    /// Protocol Buffers (proto2 or proto3)
76    Protobuf,
77}
78
79/// Service definition (gRPC service)
80#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
81pub struct ServiceDef {
82    /// Service name
83    pub name: String,
84
85    /// RPC methods with signatures
86    /// Format: "MethodName(RequestType) returns (ResponseType)"
87    #[serde(skip_serializing_if = "Vec::is_empty", default)]
88    pub methods: Vec<String>,
89}
90
91/// Message definition (protobuf message)
92#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
93pub struct MessageDef {
94    /// Message name
95    pub name: String,
96
97    /// Field definitions
98    /// Format: "field_name: type"
99    #[serde(skip_serializing_if = "Vec::is_empty", default)]
100    pub fields: Vec<String>,
101
102    /// Nested message names (one level deep)
103    #[serde(skip_serializing_if = "Vec::is_empty", default)]
104    pub nested_messages: Vec<String>,
105}
106
107/// Enum definition
108#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
109pub struct EnumDef {
110    /// Enum name
111    pub name: String,
112
113    /// Enum values
114    #[serde(skip_serializing_if = "Vec::is_empty", default)]
115    pub values: Vec<String>,
116}
117
118// Compiled regex patterns (initialized once)
119static PACKAGE_RE: OnceLock<Regex> = OnceLock::new();
120static MESSAGE_RE: OnceLock<Regex> = OnceLock::new();
121static FIELD_RE: OnceLock<Regex> = OnceLock::new();
122static SERVICE_RE: OnceLock<Regex> = OnceLock::new();
123static RPC_RE: OnceLock<Regex> = OnceLock::new();
124static ENUM_RE: OnceLock<Regex> = OnceLock::new();
125static ENUM_VALUE_RE: OnceLock<Regex> = OnceLock::new();
126
127/// Initialize regex patterns
128fn init_patterns() {
129    PACKAGE_RE.get_or_init(|| Regex::new(r"^\s*package\s+([a-zA-Z0-9_.]+)\s*;").unwrap());
130    MESSAGE_RE.get_or_init(|| Regex::new(r"^\s*message\s+([a-zA-Z0-9_]+)\s*\{").unwrap());
131    FIELD_RE.get_or_init(|| {
132        Regex::new(
133            r"^\s*(?:optional|required|repeated)?\s*([a-zA-Z0-9_.<>]+)\s+([a-zA-Z0-9_]+)\s*=\s*\d+",
134        )
135        .unwrap()
136    });
137    SERVICE_RE.get_or_init(|| Regex::new(r"^\s*service\s+([a-zA-Z0-9_]+)\s*\{").unwrap());
138    RPC_RE.get_or_init(|| {
139        Regex::new(
140            r"^\s*rpc\s+([a-zA-Z0-9_]+)\s*\(([a-zA-Z0-9_.]+)\)\s*returns\s*\(([a-zA-Z0-9_.]+)\)",
141        )
142        .unwrap()
143    });
144    ENUM_RE.get_or_init(|| Regex::new(r"^\s*enum\s+([a-zA-Z0-9_]+)\s*\{").unwrap());
145    ENUM_VALUE_RE.get_or_init(|| Regex::new(r"^\s*([a-zA-Z0-9_]+)\s*=\s*\d+").unwrap());
146}
147
148/// Detect protobuf contracts in a .proto file
149///
150/// Parses protobuf files and extracts package, services, messages, and enums.
151/// Returns None if the file is not a .proto file.
152///
153/// # Arguments
154///
155/// * `content` - The file content to parse
156/// * `file_path` - The file path (must end with .proto)
157///
158/// # Example
159///
160/// ```rust,ignore
161/// let content = r#"
162/// syntax = "proto3";
163/// package myapp.v1;
164///
165/// message User {
166///     string id = 1;
167/// }
168/// "#;
169///
170/// let contract = detect_contracts(content, "user.proto");
171/// assert!(contract.is_some());
172/// ```
173pub fn detect_contracts(content: &str, file_path: &str) -> Option<ContractDefinition> {
174    // Only process .proto files
175    if !file_path.ends_with(".proto") {
176        return None;
177    }
178
179    // Initialize regex patterns
180    init_patterns();
181
182    // Strip comments
183    let content = strip_comments(content);
184
185    // Parse contract
186    let package = extract_package(&content);
187    let services = extract_services(&content);
188    let messages = extract_messages(&content);
189    let enums = extract_enums(&content);
190
191    // Return contract even if empty (file exists)
192    Some(ContractDefinition {
193        contract_type: ContractType::Protobuf,
194        package,
195        services,
196        messages,
197        enums,
198    })
199}
200
201/// Generate semantic tags from contract definitions
202///
203/// Returns tags like `["protobuf", "grpc", "api-contract"]`.
204/// Adds "grpc" tag if services are present.
205///
206/// # Example
207///
208/// ```rust,ignore
209/// let contract = ContractDefinition {
210///     contract_type: ContractType::Protobuf,
211///     package: Some("myapp.v1".to_string()),
212///     services: vec![ServiceDef {
213///         name: "UserService".to_string(),
214///         methods: vec![],
215///     }],
216///     messages: vec![],
217///     enums: vec![],
218/// };
219///
220/// let tags = contract_tags(&contract);
221/// assert!(tags.contains(&"grpc".to_string()));
222/// ```
223pub fn contract_tags(contract: &ContractDefinition) -> Vec<String> {
224    let mut tags = vec!["protobuf".to_owned(), "api-contract".to_owned()];
225
226    // Add "grpc" if services present
227    if !contract.services.is_empty() {
228        tags.push("grpc".to_owned());
229    }
230
231    tags
232}
233
234/// Strip C-style and C++-style comments
235fn strip_comments(content: &str) -> String {
236    let mut result = String::with_capacity(content.len());
237    let mut chars = content.chars().peekable();
238
239    while let Some(ch) = chars.next() {
240        if ch == '/' {
241            if let Some(&next_ch) = chars.peek() {
242                if next_ch == '/' {
243                    // Line comment: skip until end of line
244                    chars.next(); // consume second '/'
245                    for c in chars.by_ref() {
246                        if c == '\n' {
247                            result.push('\n'); // preserve line breaks
248                            break;
249                        }
250                    }
251                    continue;
252                } else if next_ch == '*' {
253                    // Block comment: skip until */
254                    chars.next(); // consume '*'
255                    let mut prev = ' ';
256                    for c in chars.by_ref() {
257                        if prev == '*' && c == '/' {
258                            break;
259                        }
260                        if c == '\n' {
261                            result.push('\n'); // preserve line breaks
262                        }
263                        prev = c;
264                    }
265                    continue;
266                }
267            }
268        }
269        result.push(ch);
270    }
271
272    result
273}
274
275/// Extract package declaration
276fn extract_package(content: &str) -> Option<String> {
277    let package_re = PACKAGE_RE.get()?;
278    for line in content.lines() {
279        if let Some(caps) = package_re.captures(line) {
280            return Some(caps[1].to_string());
281        }
282    }
283    None
284}
285
286/// Extract service definitions
287fn extract_services(content: &str) -> Vec<ServiceDef> {
288    let service_re = SERVICE_RE.get().unwrap();
289    let rpc_re = RPC_RE.get().unwrap();
290
291    let mut services = Vec::new();
292    let lines: Vec<&str> = content.lines().collect();
293    let mut i = 0;
294
295    while i < lines.len() {
296        let line = lines[i];
297        if let Some(caps) = service_re.captures(line) {
298            let service_name = caps[1].to_string();
299            let mut methods = Vec::new();
300
301            // Parse methods until closing brace
302            i += 1;
303            let mut brace_depth = 1;
304            while i < lines.len() && brace_depth > 0 {
305                let method_line = lines[i];
306
307                // Track brace depth
308                brace_depth += method_line.matches('{').count();
309                brace_depth -= method_line.matches('}').count();
310
311                if let Some(rpc_caps) = rpc_re.captures(method_line) {
312                    let method_name = &rpc_caps[1];
313                    let request = &rpc_caps[2];
314                    let response = &rpc_caps[3];
315                    methods.push(format!("{}({}) returns ({})", method_name, request, response));
316                }
317                i += 1;
318            }
319
320            services.push(ServiceDef { name: service_name, methods });
321            continue;
322        }
323        i += 1;
324    }
325
326    services
327}
328
329/// Extract message definitions
330fn extract_messages(content: &str) -> Vec<MessageDef> {
331    let message_re = MESSAGE_RE.get().unwrap();
332    let field_re = FIELD_RE.get().unwrap();
333
334    let mut messages = Vec::new();
335    let lines: Vec<&str> = content.lines().collect();
336    let mut i = 0;
337
338    while i < lines.len() {
339        let line = lines[i];
340        if let Some(caps) = message_re.captures(line) {
341            let message_name = caps[1].to_string();
342            let mut fields = Vec::new();
343            let mut nested_messages = Vec::new();
344
345            // Parse fields and nested messages until closing brace
346            i += 1;
347            let mut brace_depth = 1;
348            while i < lines.len() && brace_depth > 0 {
349                let field_line = lines[i];
350
351                // Check for nested messages before updating brace depth
352                // (so we see them at depth 1 before the opening brace increments it)
353                if brace_depth == 1 {
354                    if let Some(nested_caps) = message_re.captures(field_line) {
355                        nested_messages.push(nested_caps[1].to_string());
356                    } else if let Some(field_caps) = field_re.captures(field_line) {
357                        let field_type = &field_caps[1];
358                        let field_name = &field_caps[2];
359                        fields.push(format!("{}: {}", field_name, field_type));
360                    }
361                }
362
363                // Track brace depth
364                brace_depth += field_line.matches('{').count();
365                brace_depth -= field_line.matches('}').count();
366
367                i += 1;
368            }
369
370            messages.push(MessageDef { name: message_name, fields, nested_messages });
371            continue;
372        }
373        i += 1;
374    }
375
376    messages
377}
378
379/// Extract enum definitions
380fn extract_enums(content: &str) -> Vec<EnumDef> {
381    let enum_re = ENUM_RE.get().unwrap();
382    let enum_value_re = ENUM_VALUE_RE.get().unwrap();
383
384    let mut enums = Vec::new();
385    let lines: Vec<&str> = content.lines().collect();
386    let mut i = 0;
387
388    while i < lines.len() {
389        let line = lines[i];
390        if let Some(caps) = enum_re.captures(line) {
391            let enum_name = caps[1].to_string();
392            let mut values = Vec::new();
393
394            // Parse values until closing brace
395            i += 1;
396            let mut brace_depth = 1;
397            while i < lines.len() && brace_depth > 0 {
398                let value_line = lines[i];
399
400                // Track brace depth
401                brace_depth += value_line.matches('{').count();
402                brace_depth -= value_line.matches('}').count();
403
404                if let Some(value_caps) = enum_value_re.captures(value_line) {
405                    values.push(value_caps[1].to_string());
406                }
407                i += 1;
408            }
409
410            enums.push(EnumDef { name: enum_name, values });
411            continue;
412        }
413        i += 1;
414    }
415
416    enums
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn test_non_proto_file_returns_none() {
425        let content = "function foo() {}";
426        let result = detect_contracts(content, "foo.js");
427        assert!(result.is_none());
428    }
429
430    #[test]
431    fn test_empty_proto_file() {
432        let content = r#"
433syntax = "proto3";
434"#;
435        let result = detect_contracts(content, "empty.proto");
436        assert!(result.is_some());
437        let contract = result.unwrap();
438        assert_eq!(contract.contract_type, ContractType::Protobuf);
439        assert!(contract.package.is_none());
440        assert!(contract.services.is_empty());
441        assert!(contract.messages.is_empty());
442        assert!(contract.enums.is_empty());
443    }
444
445    #[test]
446    fn test_parse_package() {
447        let content = r#"
448syntax = "proto3";
449package myapp.user.v1;
450"#;
451        let result = detect_contracts(content, "user.proto");
452        assert!(result.is_some());
453        let contract = result.unwrap();
454        assert_eq!(contract.package, Some("myapp.user.v1".to_string()));
455    }
456
457    #[test]
458    fn test_parse_message_with_fields() {
459        let content = r#"
460syntax = "proto3";
461
462message User {
463    string id = 1;
464    string name = 2;
465    int32 age = 3;
466}
467"#;
468        let result = detect_contracts(content, "user.proto");
469        assert!(result.is_some());
470        let contract = result.unwrap();
471        assert_eq!(contract.messages.len(), 1);
472
473        let message = &contract.messages[0];
474        assert_eq!(message.name, "User");
475        assert_eq!(message.fields.len(), 3);
476        assert!(message.fields.contains(&"id: string".to_string()));
477        assert!(message.fields.contains(&"name: string".to_string()));
478        assert!(message.fields.contains(&"age: int32".to_string()));
479    }
480
481    #[test]
482    fn test_parse_service_with_rpcs() {
483        let content = r#"
484syntax = "proto3";
485
486service UserService {
487    rpc GetUser(GetUserRequest) returns (GetUserResponse);
488    rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
489}
490"#;
491        let result = detect_contracts(content, "user.proto");
492        assert!(result.is_some());
493        let contract = result.unwrap();
494        assert_eq!(contract.services.len(), 1);
495
496        let service = &contract.services[0];
497        assert_eq!(service.name, "UserService");
498        assert_eq!(service.methods.len(), 2);
499        assert!(service
500            .methods
501            .contains(&"GetUser(GetUserRequest) returns (GetUserResponse)".to_string()));
502        assert!(service
503            .methods
504            .contains(&"CreateUser(CreateUserRequest) returns (CreateUserResponse)".to_string()));
505    }
506
507    #[test]
508    fn test_parse_enum() {
509        let content = r#"
510syntax = "proto3";
511
512enum UserRole {
513    USER_ROLE_UNSPECIFIED = 0;
514    USER_ROLE_ADMIN = 1;
515    USER_ROLE_MEMBER = 2;
516}
517"#;
518        let result = detect_contracts(content, "user.proto");
519        assert!(result.is_some());
520        let contract = result.unwrap();
521        assert_eq!(contract.enums.len(), 1);
522
523        let enum_def = &contract.enums[0];
524        assert_eq!(enum_def.name, "UserRole");
525        assert_eq!(enum_def.values.len(), 3);
526        assert!(enum_def
527            .values
528            .contains(&"USER_ROLE_UNSPECIFIED".to_string()));
529        assert!(enum_def.values.contains(&"USER_ROLE_ADMIN".to_string()));
530        assert!(enum_def.values.contains(&"USER_ROLE_MEMBER".to_string()));
531    }
532
533    #[test]
534    fn test_parse_nested_messages() {
535        let content = r#"
536syntax = "proto3";
537
538message User {
539    string id = 1;
540
541    message Address {
542        string street = 1;
543        string city = 2;
544    }
545
546    Address address = 2;
547}
548"#;
549        let result = detect_contracts(content, "user.proto");
550        assert!(result.is_some());
551        let contract = result.unwrap();
552        assert_eq!(contract.messages.len(), 1);
553
554        let message = &contract.messages[0];
555        assert_eq!(message.name, "User");
556        assert_eq!(message.nested_messages.len(), 1);
557        assert!(message.nested_messages.contains(&"Address".to_string()));
558        // The outer message should have the Address field
559        assert!(message.fields.iter().any(|f| f.contains("address")));
560    }
561
562    #[test]
563    fn test_full_proto3_file() {
564        let content = r#"
565syntax = "proto3";
566
567package myapp.user.v1;
568
569// User service handles user operations
570service UserService {
571    rpc GetUser(GetUserRequest) returns (GetUserResponse);
572    rpc CreateUser(CreateUserRequest) returns (CreateUserResponse);
573}
574
575message GetUserRequest {
576    string user_id = 1;
577}
578
579message GetUserResponse {
580    User user = 1;
581}
582
583message User {
584    string id = 1;
585    string name = 2;
586    string email = 3;
587    UserRole role = 4;
588}
589
590enum UserRole {
591    USER_ROLE_UNSPECIFIED = 0;
592    USER_ROLE_ADMIN = 1;
593    USER_ROLE_MEMBER = 2;
594}
595"#;
596        let result = detect_contracts(content, "user.proto");
597        assert!(result.is_some());
598        let contract = result.unwrap();
599
600        assert_eq!(contract.package, Some("myapp.user.v1".to_string()));
601        assert_eq!(contract.services.len(), 1);
602        assert_eq!(contract.messages.len(), 3);
603        assert_eq!(contract.enums.len(), 1);
604
605        // Verify service
606        let service = &contract.services[0];
607        assert_eq!(service.name, "UserService");
608        assert_eq!(service.methods.len(), 2);
609
610        // Verify messages
611        let message_names: Vec<&str> = contract.messages.iter().map(|m| m.name.as_str()).collect();
612        assert!(message_names.contains(&"GetUserRequest"));
613        assert!(message_names.contains(&"GetUserResponse"));
614        assert!(message_names.contains(&"User"));
615
616        // Verify enum
617        let enum_def = &contract.enums[0];
618        assert_eq!(enum_def.name, "UserRole");
619        assert_eq!(enum_def.values.len(), 3);
620    }
621
622    #[test]
623    fn test_tags_include_grpc_when_services_present() {
624        let contract = ContractDefinition {
625            contract_type: ContractType::Protobuf,
626            package: Some("myapp.v1".to_string()),
627            services: vec![ServiceDef { name: "UserService".to_string(), methods: vec![] }],
628            messages: vec![],
629            enums: vec![],
630        };
631
632        let tags = contract_tags(&contract);
633        assert!(tags.contains(&"protobuf".to_string()));
634        assert!(tags.contains(&"grpc".to_string()));
635        assert!(tags.contains(&"api-contract".to_string()));
636    }
637
638    #[test]
639    fn test_tags_without_services() {
640        let contract = ContractDefinition {
641            contract_type: ContractType::Protobuf,
642            package: Some("myapp.v1".to_string()),
643            services: vec![],
644            messages: vec![MessageDef {
645                name: "User".to_string(),
646                fields: vec![],
647                nested_messages: vec![],
648            }],
649            enums: vec![],
650        };
651
652        let tags = contract_tags(&contract);
653        assert!(tags.contains(&"protobuf".to_string()));
654        assert!(!tags.contains(&"grpc".to_string()));
655        assert!(tags.contains(&"api-contract".to_string()));
656    }
657
658    #[test]
659    fn test_comments_are_stripped() {
660        let content = r#"
661syntax = "proto3";
662
663// This is a line comment
664package myapp.v1;
665
666/* This is a block comment
667   spanning multiple lines */
668message User {
669    string id = 1; // inline comment
670    string name = 2;
671}
672"#;
673        let result = detect_contracts(content, "user.proto");
674        assert!(result.is_some());
675        let contract = result.unwrap();
676
677        assert_eq!(contract.package, Some("myapp.v1".to_string()));
678        assert_eq!(contract.messages.len(), 1);
679        assert_eq!(contract.messages[0].name, "User");
680    }
681}