mockforge_grpc/dynamic/
proto_parser.rs

1//! Proto file parsing and service discovery
2//!
3//! This module handles parsing of .proto files and extracting service definitions
4//! to generate dynamic gRPC service implementations.
5
6use prost_reflect::DescriptorPool;
7use std::collections::HashMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11use tempfile::TempDir;
12use tracing::{debug, error, info, warn};
13
14/// A parsed proto service definition
15#[derive(Debug, Clone)]
16pub struct ProtoService {
17    /// The service name (e.g., "mockforge.greeter.Greeter")
18    pub name: String,
19    /// The package name (e.g., "mockforge.greeter")
20    pub package: String,
21    /// The short service name (e.g., "Greeter")
22    pub short_name: String,
23    /// List of methods in this service
24    pub methods: Vec<ProtoMethod>,
25}
26
27/// A parsed proto method definition
28#[derive(Debug, Clone)]
29pub struct ProtoMethod {
30    /// The method name (e.g., "SayHello")
31    pub name: String,
32    /// The input message type
33    pub input_type: String,
34    /// The output message type
35    pub output_type: String,
36    /// Whether this is a client streaming method
37    pub client_streaming: bool,
38    /// Whether this is a server streaming method
39    pub server_streaming: bool,
40}
41
42/// A proto file parser that can extract service definitions
43pub struct ProtoParser {
44    /// The descriptor pool containing parsed proto files
45    pool: DescriptorPool,
46    /// Map of service names to their definitions
47    services: HashMap<String, ProtoService>,
48    /// Include paths for proto compilation
49    include_paths: Vec<PathBuf>,
50    /// Temporary directory for compilation artifacts
51    temp_dir: Option<TempDir>,
52}
53
54impl ProtoParser {
55    /// Create a new proto parser
56    pub fn new() -> Self {
57        Self {
58            pool: DescriptorPool::new(),
59            services: HashMap::new(),
60            include_paths: vec![],
61            temp_dir: None,
62        }
63    }
64
65    /// Create a new proto parser with include paths
66    pub fn with_include_paths(include_paths: Vec<PathBuf>) -> Self {
67        Self {
68            pool: DescriptorPool::new(),
69            services: HashMap::new(),
70            include_paths,
71            temp_dir: None,
72        }
73    }
74
75    /// Parse proto files from a directory
76    pub async fn parse_directory(
77        &mut self,
78        proto_dir: &str,
79    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
80        info!("Parsing proto files from directory: {}", proto_dir);
81
82        let proto_path = Path::new(proto_dir);
83        if !proto_path.exists() {
84            return Err(format!("Proto directory does not exist: {}", proto_dir).into());
85        }
86
87        // Discover all proto files
88        let proto_files = self.discover_proto_files(proto_path)?;
89        if proto_files.is_empty() {
90            warn!("No proto files found in directory: {}", proto_dir);
91            return Ok(());
92        }
93
94        info!("Found {} proto files: {:?}", proto_files.len(), proto_files);
95
96        // Parse each proto file
97        for proto_file in proto_files {
98            if let Err(e) = self.parse_proto_file(&proto_file).await {
99                error!("Failed to parse proto file {}: {}", proto_file, e);
100                // Continue with other files
101            }
102        }
103
104        // Extract services from the descriptor pool only if there are any services in the pool
105        if self.pool.services().count() > 0 {
106            self.extract_services()?;
107        } else {
108            debug!("No services found in descriptor pool, keeping mock services");
109        }
110
111        info!("Successfully parsed {} services", self.services.len());
112        Ok(())
113    }
114
115    /// Discover proto files in a directory recursively
116    #[allow(clippy::only_used_in_recursion)]
117    fn discover_proto_files(
118        &self,
119        dir: &Path,
120    ) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
121        let mut proto_files = Vec::new();
122
123        if let Ok(entries) = fs::read_dir(dir) {
124            for entry in entries.flatten() {
125                let path = entry.path();
126
127                if path.is_dir() {
128                    // Recursively search subdirectories
129                    proto_files.extend(self.discover_proto_files(&path)?);
130                } else if path.extension().and_then(|s| s.to_str()) == Some("proto") {
131                    // Found a .proto file
132                    proto_files.push(path.to_string_lossy().to_string());
133                }
134            }
135        }
136
137        Ok(proto_files)
138    }
139
140    /// Parse a single proto file using protoc compilation
141    async fn parse_proto_file(
142        &mut self,
143        proto_file: &str,
144    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
145        debug!("Parsing proto file: {}", proto_file);
146
147        // Create temporary directory for compilation artifacts if not exists
148        if self.temp_dir.is_none() {
149            self.temp_dir = Some(TempDir::new()?);
150        }
151
152        // Safe to unwrap here: we just created it above if it was None
153        let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
154            Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
155        })?;
156        let descriptor_path = temp_dir.path().join("descriptors.bin");
157
158        // Try real protoc compilation first
159        match self.compile_with_protoc(proto_file, &descriptor_path).await {
160            Ok(()) => {
161                // Load the compiled descriptor set into the pool
162                let descriptor_bytes = fs::read(&descriptor_path)?;
163                match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
164                    Ok(()) => {
165                        info!("Successfully compiled and loaded proto file: {}", proto_file);
166                        // Extract services from the descriptor pool if successful
167                        if self.pool.services().count() > 0 {
168                            self.extract_services()?;
169                        }
170                        return Ok(());
171                    }
172                    Err(e) => {
173                        warn!("Failed to decode descriptor set, falling back to mock: {}", e);
174                    }
175                }
176            }
177            Err(e) => {
178                // This is expected behavior if protoc is not installed or proto files don't require compilation
179                // MockForge will use fallback mock services, which is fine for basic usage
180                warn!(
181                    "protoc not available or compilation failed (this is OK for basic usage, using fallback): {}",
182                    e
183                );
184            }
185        }
186
187        // Fallback to mock service for testing
188        if proto_file.contains("gretter.proto") || proto_file.contains("greeter.proto") {
189            debug!("Adding mock greeter service for {}", proto_file);
190            self.add_mock_greeter_service();
191        }
192
193        Ok(())
194    }
195
196    /// Compile proto file using protoc
197    async fn compile_with_protoc(
198        &self,
199        proto_file: &str,
200        output_path: &Path,
201    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
202        debug!("Compiling proto file with protoc: {}", proto_file);
203
204        // Build protoc command
205        let mut cmd = Command::new("protoc");
206
207        // Add include paths
208        for include_path in &self.include_paths {
209            cmd.arg("-I").arg(include_path);
210        }
211
212        // Add proto file's directory as include path
213        if let Some(parent_dir) = Path::new(proto_file).parent() {
214            cmd.arg("-I").arg(parent_dir);
215        }
216
217        // Add well-known types include path (common protoc installation paths)
218        let well_known_paths = [
219            "/usr/local/include",
220            "/usr/include",
221            "/opt/homebrew/include",
222        ];
223
224        for path in &well_known_paths {
225            if Path::new(path).exists() {
226                cmd.arg("-I").arg(path);
227            }
228        }
229
230        // Set output path and format
231        cmd.arg("--descriptor_set_out")
232            .arg(output_path)
233            .arg("--include_imports")
234            .arg("--include_source_info")
235            .arg(proto_file);
236
237        debug!("Running protoc command: {:?}", cmd);
238
239        // Execute protoc
240        let output = cmd.output()?;
241
242        if !output.status.success() {
243            let stderr = String::from_utf8_lossy(&output.stderr);
244            return Err(format!("protoc failed: {}", stderr).into());
245        }
246
247        info!("Successfully compiled proto file with protoc: {}", proto_file);
248        Ok(())
249    }
250
251    /// Add a mock greeter service (for demonstration)
252    fn add_mock_greeter_service(&mut self) {
253        let service = ProtoService {
254            name: "mockforge.greeter.Greeter".to_string(),
255            package: "mockforge.greeter".to_string(),
256            short_name: "Greeter".to_string(),
257            methods: vec![
258                ProtoMethod {
259                    name: "SayHello".to_string(),
260                    input_type: "mockforge.greeter.HelloRequest".to_string(),
261                    output_type: "mockforge.greeter.HelloReply".to_string(),
262                    client_streaming: false,
263                    server_streaming: false,
264                },
265                ProtoMethod {
266                    name: "SayHelloStream".to_string(),
267                    input_type: "mockforge.greeter.HelloRequest".to_string(),
268                    output_type: "mockforge.greeter.HelloReply".to_string(),
269                    client_streaming: false,
270                    server_streaming: true,
271                },
272                ProtoMethod {
273                    name: "SayHelloClientStream".to_string(),
274                    input_type: "mockforge.greeter.HelloRequest".to_string(),
275                    output_type: "mockforge.greeter.HelloReply".to_string(),
276                    client_streaming: true,
277                    server_streaming: false,
278                },
279                ProtoMethod {
280                    name: "Chat".to_string(),
281                    input_type: "mockforge.greeter.HelloRequest".to_string(),
282                    output_type: "mockforge.greeter.HelloReply".to_string(),
283                    client_streaming: true,
284                    server_streaming: true,
285                },
286            ],
287        };
288
289        self.services.insert(service.name.clone(), service);
290    }
291
292    /// Extract services from the descriptor pool
293    fn extract_services(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
294        debug!("Extracting services from descriptor pool");
295
296        // Clear existing services (except mock ones)
297        let mock_services: HashMap<String, ProtoService> = self
298            .services
299            .drain()
300            .filter(|(name, _)| name.contains("mockforge.greeter"))
301            .collect();
302
303        self.services = mock_services;
304
305        // Extract services from the descriptor pool
306        for service_descriptor in self.pool.services() {
307            let service_name = service_descriptor.full_name().to_string();
308            let package_name = service_descriptor.parent_file().package_name().to_string();
309            let short_name = service_descriptor.name().to_string();
310
311            debug!("Found service: {} in package: {}", service_name, package_name);
312
313            // Extract methods for this service
314            let mut methods = Vec::new();
315            for method_descriptor in service_descriptor.methods() {
316                let method = ProtoMethod {
317                    name: method_descriptor.name().to_string(),
318                    input_type: method_descriptor.input().full_name().to_string(),
319                    output_type: method_descriptor.output().full_name().to_string(),
320                    client_streaming: method_descriptor.is_client_streaming(),
321                    server_streaming: method_descriptor.is_server_streaming(),
322                };
323
324                debug!(
325                    "  Found method: {} ({} -> {})",
326                    method.name, method.input_type, method.output_type
327                );
328
329                methods.push(method);
330            }
331
332            let service = ProtoService {
333                name: service_name.clone(),
334                package: package_name,
335                short_name,
336                methods,
337            };
338
339            self.services.insert(service_name, service);
340        }
341
342        info!("Extracted {} services from descriptor pool", self.services.len());
343        Ok(())
344    }
345
346    /// Get all discovered services
347    pub fn services(&self) -> &HashMap<String, ProtoService> {
348        &self.services
349    }
350
351    /// Get a specific service by name
352    pub fn get_service(&self, name: &str) -> Option<&ProtoService> {
353        self.services.get(name)
354    }
355
356    /// Get the descriptor pool
357    pub fn pool(&self) -> &DescriptorPool {
358        &self.pool
359    }
360
361    /// Consume the parser and return the descriptor pool
362    pub fn into_pool(self) -> DescriptorPool {
363        self.pool
364    }
365}
366
367impl Default for ProtoParser {
368    fn default() -> Self {
369        Self::new()
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[tokio::test]
378    async fn test_parse_proto_file() {
379        // Test with the existing greeter.proto file
380        let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
381        let proto_path = format!("{}/gretter.proto", proto_dir);
382
383        // Parse the proto file
384        let mut parser = ProtoParser::new();
385        parser.parse_proto_file(&proto_path).await.unwrap();
386
387        // Verify the service was parsed correctly
388        let services = parser.services();
389        assert_eq!(services.len(), 1);
390
391        let service_name = "mockforge.greeter.Greeter";
392        assert!(services.contains_key(service_name));
393
394        let service = &services[service_name];
395        assert_eq!(service.name, service_name);
396        assert_eq!(service.methods.len(), 4); // SayHello, SayHelloStream, SayHelloClientStream, Chat
397
398        // Check SayHello method (unary)
399        let say_hello = service.methods.iter().find(|m| m.name == "SayHello").unwrap();
400        assert_eq!(say_hello.input_type, "mockforge.greeter.HelloRequest");
401        assert_eq!(say_hello.output_type, "mockforge.greeter.HelloReply");
402        assert!(!say_hello.client_streaming);
403        assert!(!say_hello.server_streaming);
404
405        // Check SayHelloStream method (server streaming)
406        let say_hello_stream = service.methods.iter().find(|m| m.name == "SayHelloStream").unwrap();
407        assert!(!say_hello_stream.client_streaming);
408        assert!(say_hello_stream.server_streaming);
409    }
410
411    #[tokio::test]
412    async fn test_parse_directory() {
413        // Test with the existing proto directory
414        let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
415
416        // Parse the directory
417        let mut parser = ProtoParser::new();
418        parser.parse_directory(&proto_dir).await.unwrap();
419
420        // Verify services were discovered
421        let services = parser.services();
422        assert_eq!(services.len(), 1);
423
424        let service_name = "mockforge.greeter.Greeter";
425        assert!(services.contains_key(service_name));
426
427        let service = &services[service_name];
428        assert_eq!(service.methods.len(), 4);
429
430        // Check all methods exist
431        let method_names: Vec<&str> = service.methods.iter().map(|m| m.name.as_str()).collect();
432        assert!(method_names.contains(&"SayHello"));
433        assert!(method_names.contains(&"SayHelloStream"));
434        assert!(method_names.contains(&"SayHelloClientStream"));
435        assert!(method_names.contains(&"Chat"));
436    }
437}