mockforge_grpc/dynamic/
proto_parser.rs

1//! Proto file parsing and service discovery
2//!
3//! This module handles parsing of .proto files and extracting service definitions
4//! to generate dynamic gRPC service implementations.
5
6use prost_reflect::DescriptorPool;
7use std::collections::HashMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11use tempfile::TempDir;
12use tracing::{debug, error, info, warn};
13
14/// A parsed proto service definition
15#[derive(Debug, Clone)]
16pub struct ProtoService {
17    /// The service name (e.g., "mockforge.greeter.Greeter")
18    pub name: String,
19    /// The package name (e.g., "mockforge.greeter")
20    pub package: String,
21    /// The short service name (e.g., "Greeter")
22    pub short_name: String,
23    /// List of methods in this service
24    pub methods: Vec<ProtoMethod>,
25}
26
27/// A parsed proto method definition
28#[derive(Debug, Clone)]
29pub struct ProtoMethod {
30    /// The method name (e.g., "SayHello")
31    pub name: String,
32    /// The input message type
33    pub input_type: String,
34    /// The output message type
35    pub output_type: String,
36    /// Whether this is a client streaming method
37    pub client_streaming: bool,
38    /// Whether this is a server streaming method
39    pub server_streaming: bool,
40}
41
42/// A proto file parser that can extract service definitions
43pub struct ProtoParser {
44    /// The descriptor pool containing parsed proto files
45    pool: DescriptorPool,
46    /// Map of service names to their definitions
47    services: HashMap<String, ProtoService>,
48    /// Include paths for proto compilation
49    include_paths: Vec<PathBuf>,
50    /// Temporary directory for compilation artifacts
51    temp_dir: Option<TempDir>,
52}
53
54impl ProtoParser {
55    /// Create a new proto parser
56    pub fn new() -> Self {
57        Self {
58            pool: DescriptorPool::new(),
59            services: HashMap::new(),
60            include_paths: vec![],
61            temp_dir: None,
62        }
63    }
64
65    /// Create a new proto parser with include paths
66    pub fn with_include_paths(include_paths: Vec<PathBuf>) -> Self {
67        Self {
68            pool: DescriptorPool::new(),
69            services: HashMap::new(),
70            include_paths,
71            temp_dir: None,
72        }
73    }
74
75    /// Parse proto files from a directory
76    pub async fn parse_directory(
77        &mut self,
78        proto_dir: &str,
79    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
80        info!("Parsing proto files from directory: {}", proto_dir);
81
82        let proto_path = Path::new(proto_dir);
83        if !proto_path.exists() {
84            return Err(format!("Proto directory does not exist: {}", proto_dir).into());
85        }
86
87        // Discover all proto files
88        let proto_files = self.discover_proto_files(proto_path)?;
89        if proto_files.is_empty() {
90            warn!("No proto files found in directory: {}", proto_dir);
91            return Ok(());
92        }
93
94        info!("Found {} proto files: {:?}", proto_files.len(), proto_files);
95
96        // Optimize: Batch compile all proto files in a single protoc invocation
97        if proto_files.len() > 1 {
98            if let Err(e) = self.compile_protos_batch(&proto_files).await {
99                warn!("Batch compilation failed, falling back to individual compilation: {}", e);
100                // Fall back to individual compilation
101                for proto_file in proto_files {
102                    if let Err(e) = self.parse_proto_file(&proto_file).await {
103                        error!("Failed to parse proto file {}: {}", proto_file, e);
104                        // Continue with other files
105                    }
106                }
107            }
108        } else if !proto_files.is_empty() {
109            // Single file - use existing method
110            if let Err(e) = self.parse_proto_file(&proto_files[0]).await {
111                error!("Failed to parse proto file {}: {}", proto_files[0], e);
112            }
113        }
114
115        // Extract services from the descriptor pool only if there are any services in the pool
116        if self.pool.services().count() > 0 {
117            self.extract_services()?;
118        } else {
119            debug!("No services found in descriptor pool, keeping mock services");
120        }
121
122        info!("Successfully parsed {} services", self.services.len());
123        Ok(())
124    }
125
126    /// Discover proto files in a directory recursively
127    #[allow(clippy::only_used_in_recursion)]
128    fn discover_proto_files(
129        &self,
130        dir: &Path,
131    ) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
132        let mut proto_files = Vec::new();
133
134        if let Ok(entries) = fs::read_dir(dir) {
135            for entry in entries.flatten() {
136                let path = entry.path();
137
138                if path.is_dir() {
139                    // Recursively search subdirectories
140                    proto_files.extend(self.discover_proto_files(&path)?);
141                } else if path.extension().and_then(|s| s.to_str()) == Some("proto") {
142                    // Found a .proto file
143                    proto_files.push(path.to_string_lossy().to_string());
144                }
145            }
146        }
147
148        Ok(proto_files)
149    }
150
151    /// Parse a single proto file using protoc compilation
152    async fn parse_proto_file(
153        &mut self,
154        proto_file: &str,
155    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
156        debug!("Parsing proto file: {}", proto_file);
157
158        // Create temporary directory for compilation artifacts if not exists
159        if self.temp_dir.is_none() {
160            self.temp_dir = Some(TempDir::new()?);
161        }
162
163        // Safe to unwrap here: we just created it above if it was None
164        let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
165            Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
166        })?;
167        let descriptor_path = temp_dir.path().join("descriptors.bin");
168
169        // Try real protoc compilation first
170        match self.compile_with_protoc(proto_file, &descriptor_path).await {
171            Ok(()) => {
172                // Load the compiled descriptor set into the pool
173                let descriptor_bytes = fs::read(&descriptor_path)?;
174                match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
175                    Ok(()) => {
176                        info!("Successfully compiled and loaded proto file: {}", proto_file);
177                        // Extract services from the descriptor pool if successful
178                        if self.pool.services().count() > 0 {
179                            self.extract_services()?;
180                        }
181                        return Ok(());
182                    }
183                    Err(e) => {
184                        warn!("Failed to decode descriptor set, falling back to mock: {}", e);
185                    }
186                }
187            }
188            Err(e) => {
189                // This is expected behavior if protoc is not installed or proto files don't require compilation
190                // MockForge will use fallback mock services, which is fine for basic usage
191                warn!(
192                    "protoc not available or compilation failed (this is OK for basic usage, using fallback): {}",
193                    e
194                );
195            }
196        }
197
198        // Fallback to mock service for testing
199        if proto_file.contains("gretter.proto") || proto_file.contains("greeter.proto") {
200            debug!("Adding mock greeter service for {}", proto_file);
201            self.add_mock_greeter_service();
202        }
203
204        Ok(())
205    }
206
207    /// Batch compile multiple proto files in a single protoc invocation
208    /// This is significantly faster than compiling files individually
209    async fn compile_protos_batch(
210        &mut self,
211        proto_files: &[String],
212    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
213        if proto_files.is_empty() {
214            return Ok(());
215        }
216
217        info!("Batch compiling {} proto files", proto_files.len());
218
219        // Create temporary directory for compilation artifacts if not exists
220        if self.temp_dir.is_none() {
221            self.temp_dir = Some(TempDir::new()?);
222        }
223
224        let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
225            Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
226        })?;
227        let descriptor_path = temp_dir.path().join("descriptors_batch.bin");
228
229        // Build protoc command
230        let mut cmd = Command::new("protoc");
231
232        // Collect unique parent directories for include paths
233        let mut parent_dirs = std::collections::HashSet::new();
234        for proto_file in proto_files {
235            if let Some(parent_dir) = Path::new(proto_file).parent() {
236                parent_dirs.insert(parent_dir.to_path_buf());
237            }
238        }
239
240        // Add include paths
241        for include_path in &self.include_paths {
242            cmd.arg("-I").arg(include_path);
243        }
244
245        // Add proto file parent directories as include paths
246        for parent_dir in &parent_dirs {
247            cmd.arg("-I").arg(parent_dir);
248        }
249
250        // Add well-known types include path (common protoc installation paths)
251        let well_known_paths = [
252            "/usr/local/include",
253            "/usr/include",
254            "/opt/homebrew/include",
255        ];
256
257        for path in &well_known_paths {
258            if Path::new(path).exists() {
259                cmd.arg("-I").arg(path);
260            }
261        }
262
263        // Set output path and format
264        cmd.arg("--descriptor_set_out")
265            .arg(&descriptor_path)
266            .arg("--include_imports")
267            .arg("--include_source_info");
268
269        // Add all proto files to compile
270        for proto_file in proto_files {
271            cmd.arg(proto_file);
272        }
273
274        debug!("Running batch protoc command for {} files", proto_files.len());
275
276        // Execute protoc
277        let output = cmd.output()?;
278
279        if !output.status.success() {
280            let stderr = String::from_utf8_lossy(&output.stderr);
281            return Err(format!("Batch protoc compilation failed: {}", stderr).into());
282        }
283
284        // Load the compiled descriptor set into the pool
285        let descriptor_bytes = fs::read(&descriptor_path)?;
286        match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
287            Ok(()) => {
288                info!("Successfully batch compiled and loaded {} proto files", proto_files.len());
289                // Extract services from the descriptor pool if successful
290                if self.pool.services().count() > 0 {
291                    self.extract_services()?;
292                }
293                Ok(())
294            }
295            Err(e) => Err(format!("Failed to decode batch descriptor set: {}", e).into()),
296        }
297    }
298
299    /// Compile proto file using protoc
300    async fn compile_with_protoc(
301        &self,
302        proto_file: &str,
303        output_path: &Path,
304    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
305        debug!("Compiling proto file with protoc: {}", proto_file);
306
307        // Build protoc command
308        let mut cmd = Command::new("protoc");
309
310        // Add include paths
311        for include_path in &self.include_paths {
312            cmd.arg("-I").arg(include_path);
313        }
314
315        // Add proto file's directory as include path
316        if let Some(parent_dir) = Path::new(proto_file).parent() {
317            cmd.arg("-I").arg(parent_dir);
318        }
319
320        // Add well-known types include path (common protoc installation paths)
321        let well_known_paths = [
322            "/usr/local/include",
323            "/usr/include",
324            "/opt/homebrew/include",
325        ];
326
327        for path in &well_known_paths {
328            if Path::new(path).exists() {
329                cmd.arg("-I").arg(path);
330            }
331        }
332
333        // Set output path and format
334        cmd.arg("--descriptor_set_out")
335            .arg(output_path)
336            .arg("--include_imports")
337            .arg("--include_source_info")
338            .arg(proto_file);
339
340        debug!("Running protoc command: {:?}", cmd);
341
342        // Execute protoc
343        let output = cmd.output()?;
344
345        if !output.status.success() {
346            let stderr = String::from_utf8_lossy(&output.stderr);
347            return Err(format!("protoc failed: {}", stderr).into());
348        }
349
350        info!("Successfully compiled proto file with protoc: {}", proto_file);
351        Ok(())
352    }
353
354    /// Add a mock greeter service (for demonstration)
355    fn add_mock_greeter_service(&mut self) {
356        let service = ProtoService {
357            name: "mockforge.greeter.Greeter".to_string(),
358            package: "mockforge.greeter".to_string(),
359            short_name: "Greeter".to_string(),
360            methods: vec![
361                ProtoMethod {
362                    name: "SayHello".to_string(),
363                    input_type: "mockforge.greeter.HelloRequest".to_string(),
364                    output_type: "mockforge.greeter.HelloReply".to_string(),
365                    client_streaming: false,
366                    server_streaming: false,
367                },
368                ProtoMethod {
369                    name: "SayHelloStream".to_string(),
370                    input_type: "mockforge.greeter.HelloRequest".to_string(),
371                    output_type: "mockforge.greeter.HelloReply".to_string(),
372                    client_streaming: false,
373                    server_streaming: true,
374                },
375                ProtoMethod {
376                    name: "SayHelloClientStream".to_string(),
377                    input_type: "mockforge.greeter.HelloRequest".to_string(),
378                    output_type: "mockforge.greeter.HelloReply".to_string(),
379                    client_streaming: true,
380                    server_streaming: false,
381                },
382                ProtoMethod {
383                    name: "Chat".to_string(),
384                    input_type: "mockforge.greeter.HelloRequest".to_string(),
385                    output_type: "mockforge.greeter.HelloReply".to_string(),
386                    client_streaming: true,
387                    server_streaming: true,
388                },
389            ],
390        };
391
392        self.services.insert(service.name.clone(), service);
393    }
394
395    /// Extract services from the descriptor pool
396    fn extract_services(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
397        debug!("Extracting services from descriptor pool");
398
399        // Clear existing services (except mock ones)
400        let mock_services: HashMap<String, ProtoService> = self
401            .services
402            .drain()
403            .filter(|(name, _)| name.contains("mockforge.greeter"))
404            .collect();
405
406        self.services = mock_services;
407
408        // Extract services from the descriptor pool
409        for service_descriptor in self.pool.services() {
410            let service_name = service_descriptor.full_name().to_string();
411            let package_name = service_descriptor.parent_file().package_name().to_string();
412            let short_name = service_descriptor.name().to_string();
413
414            debug!("Found service: {} in package: {}", service_name, package_name);
415
416            // Extract methods for this service
417            let mut methods = Vec::new();
418            for method_descriptor in service_descriptor.methods() {
419                let method = ProtoMethod {
420                    name: method_descriptor.name().to_string(),
421                    input_type: method_descriptor.input().full_name().to_string(),
422                    output_type: method_descriptor.output().full_name().to_string(),
423                    client_streaming: method_descriptor.is_client_streaming(),
424                    server_streaming: method_descriptor.is_server_streaming(),
425                };
426
427                debug!(
428                    "  Found method: {} ({} -> {})",
429                    method.name, method.input_type, method.output_type
430                );
431
432                methods.push(method);
433            }
434
435            let service = ProtoService {
436                name: service_name.clone(),
437                package: package_name,
438                short_name,
439                methods,
440            };
441
442            self.services.insert(service_name, service);
443        }
444
445        info!("Extracted {} services from descriptor pool", self.services.len());
446        Ok(())
447    }
448
449    /// Get all discovered services
450    pub fn services(&self) -> &HashMap<String, ProtoService> {
451        &self.services
452    }
453
454    /// Get a specific service by name
455    pub fn get_service(&self, name: &str) -> Option<&ProtoService> {
456        self.services.get(name)
457    }
458
459    /// Get the descriptor pool
460    pub fn pool(&self) -> &DescriptorPool {
461        &self.pool
462    }
463
464    /// Consume the parser and return the descriptor pool
465    pub fn into_pool(self) -> DescriptorPool {
466        self.pool
467    }
468}
469
470impl Default for ProtoParser {
471    fn default() -> Self {
472        Self::new()
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[tokio::test]
481    async fn test_parse_proto_file() {
482        // Test with the existing greeter.proto file
483        let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
484        let proto_path = format!("{}/gretter.proto", proto_dir);
485
486        // Parse the proto file
487        let mut parser = ProtoParser::new();
488        parser.parse_proto_file(&proto_path).await.unwrap();
489
490        // Verify the service was parsed correctly
491        let services = parser.services();
492        assert_eq!(services.len(), 1);
493
494        let service_name = "mockforge.greeter.Greeter";
495        assert!(services.contains_key(service_name));
496
497        let service = &services[service_name];
498        assert_eq!(service.name, service_name);
499        assert_eq!(service.methods.len(), 4); // SayHello, SayHelloStream, SayHelloClientStream, Chat
500
501        // Check SayHello method (unary)
502        let say_hello = service.methods.iter().find(|m| m.name == "SayHello").unwrap();
503        assert_eq!(say_hello.input_type, "mockforge.greeter.HelloRequest");
504        assert_eq!(say_hello.output_type, "mockforge.greeter.HelloReply");
505        assert!(!say_hello.client_streaming);
506        assert!(!say_hello.server_streaming);
507
508        // Check SayHelloStream method (server streaming)
509        let say_hello_stream = service.methods.iter().find(|m| m.name == "SayHelloStream").unwrap();
510        assert!(!say_hello_stream.client_streaming);
511        assert!(say_hello_stream.server_streaming);
512    }
513
514    #[tokio::test]
515    async fn test_parse_directory() {
516        // Test with the existing proto directory
517        let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
518
519        // Parse the directory
520        let mut parser = ProtoParser::new();
521        parser.parse_directory(&proto_dir).await.unwrap();
522
523        // Verify services were discovered
524        let services = parser.services();
525        assert_eq!(services.len(), 1);
526
527        let service_name = "mockforge.greeter.Greeter";
528        assert!(services.contains_key(service_name));
529
530        let service = &services[service_name];
531        assert_eq!(service.methods.len(), 4);
532
533        // Check all methods exist
534        let method_names: Vec<&str> = service.methods.iter().map(|m| m.name.as_str()).collect();
535        assert!(method_names.contains(&"SayHello"));
536        assert!(method_names.contains(&"SayHelloStream"));
537        assert!(method_names.contains(&"SayHelloClientStream"));
538        assert!(method_names.contains(&"Chat"));
539    }
540}