1use prost_reflect::DescriptorPool;
7use std::collections::HashMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11use tempfile::TempDir;
12use tracing::{debug, error, info, warn};
13
14#[derive(Debug, Clone)]
16pub struct ProtoService {
17 pub name: String,
19 pub package: String,
21 pub short_name: String,
23 pub methods: Vec<ProtoMethod>,
25}
26
27#[derive(Debug, Clone)]
29pub struct ProtoMethod {
30 pub name: String,
32 pub input_type: String,
34 pub output_type: String,
36 pub client_streaming: bool,
38 pub server_streaming: bool,
40}
41
42pub struct ProtoParser {
44 pool: DescriptorPool,
46 services: HashMap<String, ProtoService>,
48 include_paths: Vec<PathBuf>,
50 temp_dir: Option<TempDir>,
52}
53
54impl ProtoParser {
55 pub fn new() -> Self {
57 Self {
58 pool: DescriptorPool::new(),
59 services: HashMap::new(),
60 include_paths: vec![],
61 temp_dir: None,
62 }
63 }
64
65 pub fn with_include_paths(include_paths: Vec<PathBuf>) -> Self {
67 Self {
68 pool: DescriptorPool::new(),
69 services: HashMap::new(),
70 include_paths,
71 temp_dir: None,
72 }
73 }
74
75 pub async fn parse_directory(
81 &mut self,
82 proto_dir: &str,
83 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
84 info!("Parsing proto files from directory: {}", proto_dir);
85
86 let proto_path = Path::new(proto_dir);
87 if !proto_path.exists() {
88 warn!(
89 "Proto directory does not exist: {}. gRPC server will start with no services. \
90 This is normal when using only OpenAPI/HTTP.",
91 proto_dir
92 );
93 return Ok(());
94 }
95
96 let proto_files = self.discover_proto_files(proto_path)?;
98 if proto_files.is_empty() {
99 warn!("No proto files found in directory: {}", proto_dir);
100 return Ok(());
101 }
102
103 info!("Found {} proto files: {:?}", proto_files.len(), proto_files);
104
105 if proto_files.len() > 1 {
107 if let Err(e) = self.compile_protos_batch(&proto_files).await {
108 warn!("Batch compilation failed, falling back to individual compilation: {}", e);
109 for proto_file in proto_files {
111 if let Err(e) = self.parse_proto_file(&proto_file).await {
112 error!("Failed to parse proto file {}: {}", proto_file, e);
113 }
115 }
116 }
117 } else if !proto_files.is_empty() {
118 if let Err(e) = self.parse_proto_file(&proto_files[0]).await {
120 error!("Failed to parse proto file {}: {}", proto_files[0], e);
121 }
122 }
123
124 if self.pool.services().count() > 0 {
126 self.extract_services()?;
127 } else {
128 debug!("No services found in descriptor pool, keeping mock services");
129 }
130
131 info!("Successfully parsed {} services", self.services.len());
132 Ok(())
133 }
134
135 #[allow(clippy::only_used_in_recursion)]
137 fn discover_proto_files(
138 &self,
139 dir: &Path,
140 ) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
141 let mut proto_files = Vec::new();
142
143 if let Ok(entries) = fs::read_dir(dir) {
144 for entry in entries.flatten() {
145 let path = entry.path();
146
147 if path.is_dir() {
148 proto_files.extend(self.discover_proto_files(&path)?);
150 } else if path.extension().and_then(|s| s.to_str()) == Some("proto") {
151 proto_files.push(path.to_string_lossy().to_string());
153 }
154 }
155 }
156
157 Ok(proto_files)
158 }
159
160 async fn parse_proto_file(
162 &mut self,
163 proto_file: &str,
164 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
165 debug!("Parsing proto file: {}", proto_file);
166
167 if self.temp_dir.is_none() {
169 self.temp_dir = Some(TempDir::new()?);
170 }
171
172 let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
174 Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
175 })?;
176 let descriptor_path = temp_dir.path().join("descriptors.bin");
177
178 match self.compile_with_protoc(proto_file, &descriptor_path).await {
180 Ok(()) => {
181 let descriptor_bytes = fs::read(&descriptor_path)?;
183 match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
184 Ok(()) => {
185 info!("Successfully compiled and loaded proto file: {}", proto_file);
186 if self.pool.services().count() > 0 {
188 self.extract_services()?;
189 }
190 return Ok(());
191 }
192 Err(e) => {
193 warn!("Failed to decode descriptor set, falling back to mock: {}", e);
194 }
195 }
196 }
197 Err(e) => {
198 warn!(
201 "protoc not available or compilation failed (this is OK for basic usage, using fallback): {}",
202 e
203 );
204 }
205 }
206
207 if proto_file.contains("gretter.proto") || proto_file.contains("greeter.proto") {
209 debug!("Adding mock greeter service for {}", proto_file);
210 self.add_mock_greeter_service();
211 }
212
213 Ok(())
214 }
215
216 async fn compile_protos_batch(
219 &mut self,
220 proto_files: &[String],
221 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
222 if proto_files.is_empty() {
223 return Ok(());
224 }
225
226 info!("Batch compiling {} proto files", proto_files.len());
227
228 if self.temp_dir.is_none() {
230 self.temp_dir = Some(TempDir::new()?);
231 }
232
233 let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
234 Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
235 })?;
236 let descriptor_path = temp_dir.path().join("descriptors_batch.bin");
237
238 let mut cmd = Command::new("protoc");
240
241 let mut parent_dirs = std::collections::HashSet::new();
243 for proto_file in proto_files {
244 if let Some(parent_dir) = Path::new(proto_file).parent() {
245 parent_dirs.insert(parent_dir.to_path_buf());
246 }
247 }
248
249 for include_path in &self.include_paths {
251 cmd.arg("-I").arg(include_path);
252 }
253
254 for parent_dir in &parent_dirs {
256 cmd.arg("-I").arg(parent_dir);
257 }
258
259 let well_known_paths = [
261 "/usr/local/include",
262 "/usr/include",
263 "/opt/homebrew/include",
264 ];
265
266 for path in &well_known_paths {
267 if Path::new(path).exists() {
268 cmd.arg("-I").arg(path);
269 }
270 }
271
272 cmd.arg("--descriptor_set_out")
274 .arg(&descriptor_path)
275 .arg("--include_imports")
276 .arg("--include_source_info");
277
278 for proto_file in proto_files {
280 cmd.arg(proto_file);
281 }
282
283 debug!("Running batch protoc command for {} files", proto_files.len());
284
285 let output = cmd.output()?;
287
288 if !output.status.success() {
289 let stderr = String::from_utf8_lossy(&output.stderr);
290 return Err(format!("Batch protoc compilation failed: {}", stderr).into());
291 }
292
293 let descriptor_bytes = fs::read(&descriptor_path)?;
295 match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
296 Ok(()) => {
297 info!("Successfully batch compiled and loaded {} proto files", proto_files.len());
298 if self.pool.services().count() > 0 {
300 self.extract_services()?;
301 }
302 Ok(())
303 }
304 Err(e) => Err(format!("Failed to decode batch descriptor set: {}", e).into()),
305 }
306 }
307
308 async fn compile_with_protoc(
310 &self,
311 proto_file: &str,
312 output_path: &Path,
313 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
314 debug!("Compiling proto file with protoc: {}", proto_file);
315
316 let mut cmd = Command::new("protoc");
318
319 for include_path in &self.include_paths {
321 cmd.arg("-I").arg(include_path);
322 }
323
324 if let Some(parent_dir) = Path::new(proto_file).parent() {
326 cmd.arg("-I").arg(parent_dir);
327 }
328
329 let well_known_paths = [
331 "/usr/local/include",
332 "/usr/include",
333 "/opt/homebrew/include",
334 ];
335
336 for path in &well_known_paths {
337 if Path::new(path).exists() {
338 cmd.arg("-I").arg(path);
339 }
340 }
341
342 cmd.arg("--descriptor_set_out")
344 .arg(output_path)
345 .arg("--include_imports")
346 .arg("--include_source_info")
347 .arg(proto_file);
348
349 debug!("Running protoc command: {:?}", cmd);
350
351 let output = cmd.output()?;
353
354 if !output.status.success() {
355 let stderr = String::from_utf8_lossy(&output.stderr);
356 return Err(format!("protoc failed: {}", stderr).into());
357 }
358
359 info!("Successfully compiled proto file with protoc: {}", proto_file);
360 Ok(())
361 }
362
363 fn add_mock_greeter_service(&mut self) {
365 let service = ProtoService {
366 name: "mockforge.greeter.Greeter".to_string(),
367 package: "mockforge.greeter".to_string(),
368 short_name: "Greeter".to_string(),
369 methods: vec![
370 ProtoMethod {
371 name: "SayHello".to_string(),
372 input_type: "mockforge.greeter.HelloRequest".to_string(),
373 output_type: "mockforge.greeter.HelloReply".to_string(),
374 client_streaming: false,
375 server_streaming: false,
376 },
377 ProtoMethod {
378 name: "SayHelloStream".to_string(),
379 input_type: "mockforge.greeter.HelloRequest".to_string(),
380 output_type: "mockforge.greeter.HelloReply".to_string(),
381 client_streaming: false,
382 server_streaming: true,
383 },
384 ProtoMethod {
385 name: "SayHelloClientStream".to_string(),
386 input_type: "mockforge.greeter.HelloRequest".to_string(),
387 output_type: "mockforge.greeter.HelloReply".to_string(),
388 client_streaming: true,
389 server_streaming: false,
390 },
391 ProtoMethod {
392 name: "Chat".to_string(),
393 input_type: "mockforge.greeter.HelloRequest".to_string(),
394 output_type: "mockforge.greeter.HelloReply".to_string(),
395 client_streaming: true,
396 server_streaming: true,
397 },
398 ],
399 };
400
401 self.services.insert(service.name.clone(), service);
402 }
403
404 fn extract_services(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
406 debug!("Extracting services from descriptor pool");
407
408 let mock_services: HashMap<String, ProtoService> = self
410 .services
411 .drain()
412 .filter(|(name, _)| name.contains("mockforge.greeter"))
413 .collect();
414
415 self.services = mock_services;
416
417 for service_descriptor in self.pool.services() {
419 let service_name = service_descriptor.full_name().to_string();
420 let package_name = service_descriptor.parent_file().package_name().to_string();
421 let short_name = service_descriptor.name().to_string();
422
423 debug!("Found service: {} in package: {}", service_name, package_name);
424
425 let mut methods = Vec::new();
427 for method_descriptor in service_descriptor.methods() {
428 let method = ProtoMethod {
429 name: method_descriptor.name().to_string(),
430 input_type: method_descriptor.input().full_name().to_string(),
431 output_type: method_descriptor.output().full_name().to_string(),
432 client_streaming: method_descriptor.is_client_streaming(),
433 server_streaming: method_descriptor.is_server_streaming(),
434 };
435
436 debug!(
437 " Found method: {} ({} -> {})",
438 method.name, method.input_type, method.output_type
439 );
440
441 methods.push(method);
442 }
443
444 let service = ProtoService {
445 name: service_name.clone(),
446 package: package_name,
447 short_name,
448 methods,
449 };
450
451 self.services.insert(service_name, service);
452 }
453
454 info!("Extracted {} services from descriptor pool", self.services.len());
455 Ok(())
456 }
457
458 pub fn services(&self) -> &HashMap<String, ProtoService> {
460 &self.services
461 }
462
463 pub fn get_service(&self, name: &str) -> Option<&ProtoService> {
465 self.services.get(name)
466 }
467
468 pub fn pool(&self) -> &DescriptorPool {
470 &self.pool
471 }
472
473 pub fn into_pool(self) -> DescriptorPool {
475 self.pool
476 }
477}
478
479impl Default for ProtoParser {
480 fn default() -> Self {
481 Self::new()
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
492 fn test_proto_service_creation() {
493 let service = ProtoService {
494 name: "mypackage.MyService".to_string(),
495 package: "mypackage".to_string(),
496 short_name: "MyService".to_string(),
497 methods: vec![],
498 };
499
500 assert_eq!(service.name, "mypackage.MyService");
501 assert_eq!(service.package, "mypackage");
502 assert_eq!(service.short_name, "MyService");
503 assert!(service.methods.is_empty());
504 }
505
506 #[test]
507 fn test_proto_service_with_methods() {
508 let method = ProtoMethod {
509 name: "GetData".to_string(),
510 input_type: "mypackage.Request".to_string(),
511 output_type: "mypackage.Response".to_string(),
512 client_streaming: false,
513 server_streaming: false,
514 };
515
516 let service = ProtoService {
517 name: "mypackage.DataService".to_string(),
518 package: "mypackage".to_string(),
519 short_name: "DataService".to_string(),
520 methods: vec![method],
521 };
522
523 assert_eq!(service.methods.len(), 1);
524 assert_eq!(service.methods[0].name, "GetData");
525 }
526
527 #[test]
528 fn test_proto_service_clone() {
529 let service = ProtoService {
530 name: "test.Service".to_string(),
531 package: "test".to_string(),
532 short_name: "Service".to_string(),
533 methods: vec![ProtoMethod {
534 name: "Method".to_string(),
535 input_type: "Request".to_string(),
536 output_type: "Response".to_string(),
537 client_streaming: false,
538 server_streaming: false,
539 }],
540 };
541
542 let cloned = service.clone();
543 assert_eq!(cloned.name, service.name);
544 assert_eq!(cloned.methods.len(), service.methods.len());
545 }
546
547 #[test]
550 fn test_proto_method_unary() {
551 let method = ProtoMethod {
552 name: "UnaryMethod".to_string(),
553 input_type: "Request".to_string(),
554 output_type: "Response".to_string(),
555 client_streaming: false,
556 server_streaming: false,
557 };
558
559 assert_eq!(method.name, "UnaryMethod");
560 assert!(!method.client_streaming);
561 assert!(!method.server_streaming);
562 }
563
564 #[test]
565 fn test_proto_method_server_streaming() {
566 let method = ProtoMethod {
567 name: "StreamMethod".to_string(),
568 input_type: "Request".to_string(),
569 output_type: "Response".to_string(),
570 client_streaming: false,
571 server_streaming: true,
572 };
573
574 assert!(!method.client_streaming);
575 assert!(method.server_streaming);
576 }
577
578 #[test]
579 fn test_proto_method_client_streaming() {
580 let method = ProtoMethod {
581 name: "ClientStreamMethod".to_string(),
582 input_type: "Request".to_string(),
583 output_type: "Response".to_string(),
584 client_streaming: true,
585 server_streaming: false,
586 };
587
588 assert!(method.client_streaming);
589 assert!(!method.server_streaming);
590 }
591
592 #[test]
593 fn test_proto_method_bidi_streaming() {
594 let method = ProtoMethod {
595 name: "BidiStreamMethod".to_string(),
596 input_type: "Request".to_string(),
597 output_type: "Response".to_string(),
598 client_streaming: true,
599 server_streaming: true,
600 };
601
602 assert!(method.client_streaming);
603 assert!(method.server_streaming);
604 }
605
606 #[test]
607 fn test_proto_method_clone() {
608 let method = ProtoMethod {
609 name: "TestMethod".to_string(),
610 input_type: "Input".to_string(),
611 output_type: "Output".to_string(),
612 client_streaming: true,
613 server_streaming: true,
614 };
615
616 let cloned = method.clone();
617 assert_eq!(cloned.name, method.name);
618 assert_eq!(cloned.input_type, method.input_type);
619 assert_eq!(cloned.output_type, method.output_type);
620 assert_eq!(cloned.client_streaming, method.client_streaming);
621 assert_eq!(cloned.server_streaming, method.server_streaming);
622 }
623
624 #[test]
627 fn test_proto_parser_new() {
628 let parser = ProtoParser::new();
629 assert!(parser.services().is_empty());
630 }
631
632 #[test]
633 fn test_proto_parser_default() {
634 let parser = ProtoParser::default();
635 assert!(parser.services().is_empty());
636 }
637
638 #[test]
639 fn test_proto_parser_with_include_paths() {
640 let paths = vec![PathBuf::from("/usr/include"), PathBuf::from("/opt/proto")];
641 let parser = ProtoParser::with_include_paths(paths);
642 assert!(parser.services().is_empty());
643 }
644
645 #[test]
646 fn test_proto_parser_get_service_nonexistent() {
647 let parser = ProtoParser::new();
648 assert!(parser.get_service("nonexistent").is_none());
649 }
650
651 #[test]
652 fn test_proto_parser_pool() {
653 let parser = ProtoParser::new();
654 let _pool = parser.pool();
655 }
657
658 #[test]
659 fn test_proto_parser_into_pool() {
660 let parser = ProtoParser::new();
661 let _pool = parser.into_pool();
662 }
664
665 #[test]
666 fn test_proto_parser_add_mock_greeter_service() {
667 let mut parser = ProtoParser::new();
668 parser.add_mock_greeter_service();
669
670 let services = parser.services();
671 assert_eq!(services.len(), 1);
672 assert!(services.contains_key("mockforge.greeter.Greeter"));
673
674 let service = &services["mockforge.greeter.Greeter"];
675 assert_eq!(service.short_name, "Greeter");
676 assert_eq!(service.package, "mockforge.greeter");
677 assert_eq!(service.methods.len(), 4);
678 }
679
680 #[test]
681 fn test_proto_parser_discover_empty_dir() {
682 let temp_dir = TempDir::new().unwrap();
683 let parser = ProtoParser::new();
684
685 let result = parser.discover_proto_files(temp_dir.path()).unwrap();
686 assert!(result.is_empty());
687 }
688
689 #[test]
690 fn test_proto_parser_discover_with_proto_files() {
691 let temp_dir = TempDir::new().unwrap();
692
693 let proto_path = temp_dir.path().join("test.proto");
695 fs::write(&proto_path, "syntax = \"proto3\";").unwrap();
696
697 let parser = ProtoParser::new();
698 let result = parser.discover_proto_files(temp_dir.path()).unwrap();
699
700 assert_eq!(result.len(), 1);
701 assert!(result[0].ends_with("test.proto"));
702 }
703
704 #[test]
705 fn test_proto_parser_discover_recursive() {
706 let temp_dir = TempDir::new().unwrap();
707
708 let sub_dir = temp_dir.path().join("subdir");
710 fs::create_dir(&sub_dir).unwrap();
711 fs::write(temp_dir.path().join("root.proto"), "").unwrap();
712 fs::write(sub_dir.join("nested.proto"), "").unwrap();
713
714 let parser = ProtoParser::new();
715 let result = parser.discover_proto_files(temp_dir.path()).unwrap();
716
717 assert_eq!(result.len(), 2);
718 }
719
720 #[test]
721 fn test_proto_parser_discover_ignores_non_proto() {
722 let temp_dir = TempDir::new().unwrap();
723
724 fs::write(temp_dir.path().join("test.proto"), "").unwrap();
726 fs::write(temp_dir.path().join("test.txt"), "").unwrap();
727 fs::write(temp_dir.path().join("test.json"), "").unwrap();
728
729 let parser = ProtoParser::new();
730 let result = parser.discover_proto_files(temp_dir.path()).unwrap();
731
732 assert_eq!(result.len(), 1);
733 assert!(result[0].ends_with(".proto"));
734 }
735
736 #[tokio::test]
739 async fn test_parse_nonexistent_directory() {
740 let mut parser = ProtoParser::new();
741 let result = parser.parse_directory("/nonexistent/path").await;
742
743 assert!(result.is_ok());
745 assert!(parser.services().is_empty());
746 }
747
748 #[tokio::test]
749 async fn test_parse_empty_directory() {
750 let temp_dir = TempDir::new().unwrap();
751 let mut parser = ProtoParser::new();
752
753 let result = parser.parse_directory(temp_dir.path().to_str().unwrap()).await;
754
755 assert!(result.is_ok());
757 assert!(parser.services().is_empty());
758 }
759
760 #[tokio::test]
761 async fn test_parse_proto_file() {
762 let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
764 let proto_path = format!("{}/gretter.proto", proto_dir);
765
766 let mut parser = ProtoParser::new();
768 parser.parse_proto_file(&proto_path).await.unwrap();
769
770 let services = parser.services();
772 assert_eq!(services.len(), 1);
773
774 let service_name = "mockforge.greeter.Greeter";
775 assert!(services.contains_key(service_name));
776
777 let service = &services[service_name];
778 assert_eq!(service.name, service_name);
779 assert_eq!(service.methods.len(), 4); let say_hello = service.methods.iter().find(|m| m.name == "SayHello").unwrap();
783 assert_eq!(say_hello.input_type, "mockforge.greeter.HelloRequest");
784 assert_eq!(say_hello.output_type, "mockforge.greeter.HelloReply");
785 assert!(!say_hello.client_streaming);
786 assert!(!say_hello.server_streaming);
787
788 let say_hello_stream = service.methods.iter().find(|m| m.name == "SayHelloStream").unwrap();
790 assert!(!say_hello_stream.client_streaming);
791 assert!(say_hello_stream.server_streaming);
792 }
793
794 #[tokio::test]
795 async fn test_parse_directory() {
796 let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
798
799 let mut parser = ProtoParser::new();
801 parser.parse_directory(&proto_dir).await.unwrap();
802
803 let services = parser.services();
805 assert_eq!(services.len(), 1);
806
807 let service_name = "mockforge.greeter.Greeter";
808 assert!(services.contains_key(service_name));
809
810 let service = &services[service_name];
811 assert_eq!(service.methods.len(), 4);
812
813 let method_names: Vec<&str> = service.methods.iter().map(|m| m.name.as_str()).collect();
815 assert!(method_names.contains(&"SayHello"));
816 assert!(method_names.contains(&"SayHelloStream"));
817 assert!(method_names.contains(&"SayHelloClientStream"));
818 assert!(method_names.contains(&"Chat"));
819 }
820}