1use prost_reflect::DescriptorPool;
7use std::collections::HashMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11use tempfile::TempDir;
12use tracing::{debug, error, info, warn};
13
14#[derive(Debug, Clone)]
16pub struct ProtoService {
17 pub name: String,
19 pub package: String,
21 pub short_name: String,
23 pub methods: Vec<ProtoMethod>,
25}
26
27#[derive(Debug, Clone)]
29pub struct ProtoMethod {
30 pub name: String,
32 pub input_type: String,
34 pub output_type: String,
36 pub client_streaming: bool,
38 pub server_streaming: bool,
40}
41
42pub struct ProtoParser {
44 pool: DescriptorPool,
46 services: HashMap<String, ProtoService>,
48 include_paths: Vec<PathBuf>,
50 temp_dir: Option<TempDir>,
52}
53
54impl ProtoParser {
55 pub fn new() -> Self {
57 Self {
58 pool: DescriptorPool::new(),
59 services: HashMap::new(),
60 include_paths: vec![],
61 temp_dir: None,
62 }
63 }
64
65 pub fn with_include_paths(include_paths: Vec<PathBuf>) -> Self {
67 Self {
68 pool: DescriptorPool::new(),
69 services: HashMap::new(),
70 include_paths,
71 temp_dir: None,
72 }
73 }
74
75 pub async fn parse_directory(
77 &mut self,
78 proto_dir: &str,
79 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
80 info!("Parsing proto files from directory: {}", proto_dir);
81
82 let proto_path = Path::new(proto_dir);
83 if !proto_path.exists() {
84 return Err(format!("Proto directory does not exist: {}", proto_dir).into());
85 }
86
87 let proto_files = self.discover_proto_files(proto_path)?;
89 if proto_files.is_empty() {
90 warn!("No proto files found in directory: {}", proto_dir);
91 return Ok(());
92 }
93
94 info!("Found {} proto files: {:?}", proto_files.len(), proto_files);
95
96 if proto_files.len() > 1 {
98 if let Err(e) = self.compile_protos_batch(&proto_files).await {
99 warn!("Batch compilation failed, falling back to individual compilation: {}", e);
100 for proto_file in proto_files {
102 if let Err(e) = self.parse_proto_file(&proto_file).await {
103 error!("Failed to parse proto file {}: {}", proto_file, e);
104 }
106 }
107 }
108 } else if !proto_files.is_empty() {
109 if let Err(e) = self.parse_proto_file(&proto_files[0]).await {
111 error!("Failed to parse proto file {}: {}", proto_files[0], e);
112 }
113 }
114
115 if self.pool.services().count() > 0 {
117 self.extract_services()?;
118 } else {
119 debug!("No services found in descriptor pool, keeping mock services");
120 }
121
122 info!("Successfully parsed {} services", self.services.len());
123 Ok(())
124 }
125
126 #[allow(clippy::only_used_in_recursion)]
128 fn discover_proto_files(
129 &self,
130 dir: &Path,
131 ) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
132 let mut proto_files = Vec::new();
133
134 if let Ok(entries) = fs::read_dir(dir) {
135 for entry in entries.flatten() {
136 let path = entry.path();
137
138 if path.is_dir() {
139 proto_files.extend(self.discover_proto_files(&path)?);
141 } else if path.extension().and_then(|s| s.to_str()) == Some("proto") {
142 proto_files.push(path.to_string_lossy().to_string());
144 }
145 }
146 }
147
148 Ok(proto_files)
149 }
150
151 async fn parse_proto_file(
153 &mut self,
154 proto_file: &str,
155 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
156 debug!("Parsing proto file: {}", proto_file);
157
158 if self.temp_dir.is_none() {
160 self.temp_dir = Some(TempDir::new()?);
161 }
162
163 let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
165 Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
166 })?;
167 let descriptor_path = temp_dir.path().join("descriptors.bin");
168
169 match self.compile_with_protoc(proto_file, &descriptor_path).await {
171 Ok(()) => {
172 let descriptor_bytes = fs::read(&descriptor_path)?;
174 match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
175 Ok(()) => {
176 info!("Successfully compiled and loaded proto file: {}", proto_file);
177 if self.pool.services().count() > 0 {
179 self.extract_services()?;
180 }
181 return Ok(());
182 }
183 Err(e) => {
184 warn!("Failed to decode descriptor set, falling back to mock: {}", e);
185 }
186 }
187 }
188 Err(e) => {
189 warn!(
192 "protoc not available or compilation failed (this is OK for basic usage, using fallback): {}",
193 e
194 );
195 }
196 }
197
198 if proto_file.contains("gretter.proto") || proto_file.contains("greeter.proto") {
200 debug!("Adding mock greeter service for {}", proto_file);
201 self.add_mock_greeter_service();
202 }
203
204 Ok(())
205 }
206
207 async fn compile_protos_batch(
210 &mut self,
211 proto_files: &[String],
212 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
213 if proto_files.is_empty() {
214 return Ok(());
215 }
216
217 info!("Batch compiling {} proto files", proto_files.len());
218
219 if self.temp_dir.is_none() {
221 self.temp_dir = Some(TempDir::new()?);
222 }
223
224 let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
225 Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
226 })?;
227 let descriptor_path = temp_dir.path().join("descriptors_batch.bin");
228
229 let mut cmd = Command::new("protoc");
231
232 let mut parent_dirs = std::collections::HashSet::new();
234 for proto_file in proto_files {
235 if let Some(parent_dir) = Path::new(proto_file).parent() {
236 parent_dirs.insert(parent_dir.to_path_buf());
237 }
238 }
239
240 for include_path in &self.include_paths {
242 cmd.arg("-I").arg(include_path);
243 }
244
245 for parent_dir in &parent_dirs {
247 cmd.arg("-I").arg(parent_dir);
248 }
249
250 let well_known_paths = [
252 "/usr/local/include",
253 "/usr/include",
254 "/opt/homebrew/include",
255 ];
256
257 for path in &well_known_paths {
258 if Path::new(path).exists() {
259 cmd.arg("-I").arg(path);
260 }
261 }
262
263 cmd.arg("--descriptor_set_out")
265 .arg(&descriptor_path)
266 .arg("--include_imports")
267 .arg("--include_source_info");
268
269 for proto_file in proto_files {
271 cmd.arg(proto_file);
272 }
273
274 debug!("Running batch protoc command for {} files", proto_files.len());
275
276 let output = cmd.output()?;
278
279 if !output.status.success() {
280 let stderr = String::from_utf8_lossy(&output.stderr);
281 return Err(format!("Batch protoc compilation failed: {}", stderr).into());
282 }
283
284 let descriptor_bytes = fs::read(&descriptor_path)?;
286 match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
287 Ok(()) => {
288 info!("Successfully batch compiled and loaded {} proto files", proto_files.len());
289 if self.pool.services().count() > 0 {
291 self.extract_services()?;
292 }
293 Ok(())
294 }
295 Err(e) => Err(format!("Failed to decode batch descriptor set: {}", e).into()),
296 }
297 }
298
299 async fn compile_with_protoc(
301 &self,
302 proto_file: &str,
303 output_path: &Path,
304 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
305 debug!("Compiling proto file with protoc: {}", proto_file);
306
307 let mut cmd = Command::new("protoc");
309
310 for include_path in &self.include_paths {
312 cmd.arg("-I").arg(include_path);
313 }
314
315 if let Some(parent_dir) = Path::new(proto_file).parent() {
317 cmd.arg("-I").arg(parent_dir);
318 }
319
320 let well_known_paths = [
322 "/usr/local/include",
323 "/usr/include",
324 "/opt/homebrew/include",
325 ];
326
327 for path in &well_known_paths {
328 if Path::new(path).exists() {
329 cmd.arg("-I").arg(path);
330 }
331 }
332
333 cmd.arg("--descriptor_set_out")
335 .arg(output_path)
336 .arg("--include_imports")
337 .arg("--include_source_info")
338 .arg(proto_file);
339
340 debug!("Running protoc command: {:?}", cmd);
341
342 let output = cmd.output()?;
344
345 if !output.status.success() {
346 let stderr = String::from_utf8_lossy(&output.stderr);
347 return Err(format!("protoc failed: {}", stderr).into());
348 }
349
350 info!("Successfully compiled proto file with protoc: {}", proto_file);
351 Ok(())
352 }
353
354 fn add_mock_greeter_service(&mut self) {
356 let service = ProtoService {
357 name: "mockforge.greeter.Greeter".to_string(),
358 package: "mockforge.greeter".to_string(),
359 short_name: "Greeter".to_string(),
360 methods: vec![
361 ProtoMethod {
362 name: "SayHello".to_string(),
363 input_type: "mockforge.greeter.HelloRequest".to_string(),
364 output_type: "mockforge.greeter.HelloReply".to_string(),
365 client_streaming: false,
366 server_streaming: false,
367 },
368 ProtoMethod {
369 name: "SayHelloStream".to_string(),
370 input_type: "mockforge.greeter.HelloRequest".to_string(),
371 output_type: "mockforge.greeter.HelloReply".to_string(),
372 client_streaming: false,
373 server_streaming: true,
374 },
375 ProtoMethod {
376 name: "SayHelloClientStream".to_string(),
377 input_type: "mockforge.greeter.HelloRequest".to_string(),
378 output_type: "mockforge.greeter.HelloReply".to_string(),
379 client_streaming: true,
380 server_streaming: false,
381 },
382 ProtoMethod {
383 name: "Chat".to_string(),
384 input_type: "mockforge.greeter.HelloRequest".to_string(),
385 output_type: "mockforge.greeter.HelloReply".to_string(),
386 client_streaming: true,
387 server_streaming: true,
388 },
389 ],
390 };
391
392 self.services.insert(service.name.clone(), service);
393 }
394
395 fn extract_services(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
397 debug!("Extracting services from descriptor pool");
398
399 let mock_services: HashMap<String, ProtoService> = self
401 .services
402 .drain()
403 .filter(|(name, _)| name.contains("mockforge.greeter"))
404 .collect();
405
406 self.services = mock_services;
407
408 for service_descriptor in self.pool.services() {
410 let service_name = service_descriptor.full_name().to_string();
411 let package_name = service_descriptor.parent_file().package_name().to_string();
412 let short_name = service_descriptor.name().to_string();
413
414 debug!("Found service: {} in package: {}", service_name, package_name);
415
416 let mut methods = Vec::new();
418 for method_descriptor in service_descriptor.methods() {
419 let method = ProtoMethod {
420 name: method_descriptor.name().to_string(),
421 input_type: method_descriptor.input().full_name().to_string(),
422 output_type: method_descriptor.output().full_name().to_string(),
423 client_streaming: method_descriptor.is_client_streaming(),
424 server_streaming: method_descriptor.is_server_streaming(),
425 };
426
427 debug!(
428 " Found method: {} ({} -> {})",
429 method.name, method.input_type, method.output_type
430 );
431
432 methods.push(method);
433 }
434
435 let service = ProtoService {
436 name: service_name.clone(),
437 package: package_name,
438 short_name,
439 methods,
440 };
441
442 self.services.insert(service_name, service);
443 }
444
445 info!("Extracted {} services from descriptor pool", self.services.len());
446 Ok(())
447 }
448
449 pub fn services(&self) -> &HashMap<String, ProtoService> {
451 &self.services
452 }
453
454 pub fn get_service(&self, name: &str) -> Option<&ProtoService> {
456 self.services.get(name)
457 }
458
459 pub fn pool(&self) -> &DescriptorPool {
461 &self.pool
462 }
463
464 pub fn into_pool(self) -> DescriptorPool {
466 self.pool
467 }
468}
469
470impl Default for ProtoParser {
471 fn default() -> Self {
472 Self::new()
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[tokio::test]
481 async fn test_parse_proto_file() {
482 let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
484 let proto_path = format!("{}/gretter.proto", proto_dir);
485
486 let mut parser = ProtoParser::new();
488 parser.parse_proto_file(&proto_path).await.unwrap();
489
490 let services = parser.services();
492 assert_eq!(services.len(), 1);
493
494 let service_name = "mockforge.greeter.Greeter";
495 assert!(services.contains_key(service_name));
496
497 let service = &services[service_name];
498 assert_eq!(service.name, service_name);
499 assert_eq!(service.methods.len(), 4); let say_hello = service.methods.iter().find(|m| m.name == "SayHello").unwrap();
503 assert_eq!(say_hello.input_type, "mockforge.greeter.HelloRequest");
504 assert_eq!(say_hello.output_type, "mockforge.greeter.HelloReply");
505 assert!(!say_hello.client_streaming);
506 assert!(!say_hello.server_streaming);
507
508 let say_hello_stream = service.methods.iter().find(|m| m.name == "SayHelloStream").unwrap();
510 assert!(!say_hello_stream.client_streaming);
511 assert!(say_hello_stream.server_streaming);
512 }
513
514 #[tokio::test]
515 async fn test_parse_directory() {
516 let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
518
519 let mut parser = ProtoParser::new();
521 parser.parse_directory(&proto_dir).await.unwrap();
522
523 let services = parser.services();
525 assert_eq!(services.len(), 1);
526
527 let service_name = "mockforge.greeter.Greeter";
528 assert!(services.contains_key(service_name));
529
530 let service = &services[service_name];
531 assert_eq!(service.methods.len(), 4);
532
533 let method_names: Vec<&str> = service.methods.iter().map(|m| m.name.as_str()).collect();
535 assert!(method_names.contains(&"SayHello"));
536 assert!(method_names.contains(&"SayHelloStream"));
537 assert!(method_names.contains(&"SayHelloClientStream"));
538 assert!(method_names.contains(&"Chat"));
539 }
540}