1use prost_reflect::DescriptorPool;
7use std::collections::HashMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11use tempfile::TempDir;
12use tracing::{debug, error, info, warn};
13
14#[derive(Debug, Clone)]
16pub struct ProtoService {
17 pub name: String,
19 pub package: String,
21 pub short_name: String,
23 pub methods: Vec<ProtoMethod>,
25}
26
27#[derive(Debug, Clone)]
29pub struct ProtoMethod {
30 pub name: String,
32 pub input_type: String,
34 pub output_type: String,
36 pub client_streaming: bool,
38 pub server_streaming: bool,
40}
41
42pub struct ProtoParser {
44 pool: DescriptorPool,
46 services: HashMap<String, ProtoService>,
48 include_paths: Vec<PathBuf>,
50 temp_dir: Option<TempDir>,
52}
53
54impl ProtoParser {
55 pub fn new() -> Self {
57 Self {
58 pool: DescriptorPool::new(),
59 services: HashMap::new(),
60 include_paths: vec![],
61 temp_dir: None,
62 }
63 }
64
65 pub fn with_include_paths(include_paths: Vec<PathBuf>) -> Self {
67 Self {
68 pool: DescriptorPool::new(),
69 services: HashMap::new(),
70 include_paths,
71 temp_dir: None,
72 }
73 }
74
75 pub async fn parse_directory(
81 &mut self,
82 proto_dir: &str,
83 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
84 info!("Parsing proto files from directory: {}", proto_dir);
85
86 let proto_path = Path::new(proto_dir);
87 if !proto_path.exists() {
88 info!(
89 "No proto directory found at {}. gRPC server will start with built-in services only.",
90 proto_dir
91 );
92 return Ok(());
93 }
94
95 let proto_files = self.discover_proto_files(proto_path)?;
97 if proto_files.is_empty() {
98 warn!("No proto files found in directory: {}", proto_dir);
99 return Ok(());
100 }
101
102 info!("Found {} proto files: {:?}", proto_files.len(), proto_files);
103
104 if proto_files.len() > 1 {
106 if let Err(e) = self.compile_protos_batch(&proto_files).await {
107 warn!("Batch compilation failed, falling back to individual compilation: {}", e);
108 for proto_file in proto_files {
110 if let Err(e) = self.parse_proto_file(&proto_file).await {
111 error!("Failed to parse proto file {}: {}", proto_file, e);
112 }
114 }
115 }
116 } else if !proto_files.is_empty() {
117 if let Err(e) = self.parse_proto_file(&proto_files[0]).await {
119 error!("Failed to parse proto file {}: {}", proto_files[0], e);
120 }
121 }
122
123 if self.pool.services().count() > 0 {
125 self.extract_services()?;
126 } else {
127 debug!("No services found in descriptor pool, keeping mock services");
128 }
129
130 info!("Successfully parsed {} services", self.services.len());
131 Ok(())
132 }
133
134 #[allow(clippy::only_used_in_recursion)]
136 fn discover_proto_files(
137 &self,
138 dir: &Path,
139 ) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
140 let mut proto_files = Vec::new();
141
142 if let Ok(entries) = fs::read_dir(dir) {
143 for entry in entries.flatten() {
144 let path = entry.path();
145
146 if path.is_dir() {
147 proto_files.extend(self.discover_proto_files(&path)?);
149 } else if path.extension().and_then(|s| s.to_str()) == Some("proto") {
150 proto_files.push(path.to_string_lossy().to_string());
152 }
153 }
154 }
155
156 Ok(proto_files)
157 }
158
159 async fn parse_proto_file(
161 &mut self,
162 proto_file: &str,
163 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
164 debug!("Parsing proto file: {}", proto_file);
165
166 if self.temp_dir.is_none() {
168 self.temp_dir = Some(TempDir::new()?);
169 }
170
171 let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
173 Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
174 })?;
175 let descriptor_path = temp_dir.path().join("descriptors.bin");
176
177 match self.compile_with_protoc(proto_file, &descriptor_path).await {
179 Ok(()) => {
180 let descriptor_bytes = fs::read(&descriptor_path)?;
182 match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
183 Ok(()) => {
184 info!("Successfully compiled and loaded proto file: {}", proto_file);
185 if self.pool.services().count() > 0 {
187 self.extract_services()?;
188 }
189 return Ok(());
190 }
191 Err(e) => {
192 warn!("Failed to decode descriptor set, falling back to mock: {}", e);
193 }
194 }
195 }
196 Err(e) => {
197 warn!(
200 "protoc not available or compilation failed (this is OK for basic usage, using fallback): {}",
201 e
202 );
203 }
204 }
205
206 if proto_file.contains("gretter.proto") || proto_file.contains("greeter.proto") {
208 debug!("Adding mock greeter service for {}", proto_file);
209 self.add_mock_greeter_service();
210 }
211
212 Ok(())
213 }
214
215 async fn compile_protos_batch(
218 &mut self,
219 proto_files: &[String],
220 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
221 if proto_files.is_empty() {
222 return Ok(());
223 }
224
225 info!("Batch compiling {} proto files", proto_files.len());
226
227 if self.temp_dir.is_none() {
229 self.temp_dir = Some(TempDir::new()?);
230 }
231
232 let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
233 Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
234 })?;
235 let descriptor_path = temp_dir.path().join("descriptors_batch.bin");
236
237 let mut cmd = Command::new("protoc");
239
240 let mut parent_dirs = std::collections::HashSet::new();
242 for proto_file in proto_files {
243 if let Some(parent_dir) = Path::new(proto_file).parent() {
244 parent_dirs.insert(parent_dir.to_path_buf());
245 }
246 }
247
248 for include_path in &self.include_paths {
250 cmd.arg("-I").arg(include_path);
251 }
252
253 for parent_dir in &parent_dirs {
255 cmd.arg("-I").arg(parent_dir);
256 }
257
258 let well_known_paths = [
260 "/usr/local/include",
261 "/usr/include",
262 "/opt/homebrew/include",
263 ];
264
265 for path in &well_known_paths {
266 if Path::new(path).exists() {
267 cmd.arg("-I").arg(path);
268 }
269 }
270
271 cmd.arg("--descriptor_set_out")
273 .arg(&descriptor_path)
274 .arg("--include_imports")
275 .arg("--include_source_info");
276
277 for proto_file in proto_files {
279 cmd.arg(proto_file);
280 }
281
282 debug!("Running batch protoc command for {} files", proto_files.len());
283
284 let output = cmd.output()?;
286
287 if !output.status.success() {
288 let stderr = String::from_utf8_lossy(&output.stderr);
289 return Err(format!("Batch protoc compilation failed: {}", stderr).into());
290 }
291
292 let descriptor_bytes = fs::read(&descriptor_path)?;
294 match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
295 Ok(()) => {
296 info!("Successfully batch compiled and loaded {} proto files", proto_files.len());
297 if self.pool.services().count() > 0 {
299 self.extract_services()?;
300 }
301 Ok(())
302 }
303 Err(e) => Err(format!("Failed to decode batch descriptor set: {}", e).into()),
304 }
305 }
306
307 async fn compile_with_protoc(
309 &self,
310 proto_file: &str,
311 output_path: &Path,
312 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
313 debug!("Compiling proto file with protoc: {}", proto_file);
314
315 let mut cmd = Command::new("protoc");
317
318 for include_path in &self.include_paths {
320 cmd.arg("-I").arg(include_path);
321 }
322
323 if let Some(parent_dir) = Path::new(proto_file).parent() {
325 cmd.arg("-I").arg(parent_dir);
326 }
327
328 let well_known_paths = [
330 "/usr/local/include",
331 "/usr/include",
332 "/opt/homebrew/include",
333 ];
334
335 for path in &well_known_paths {
336 if Path::new(path).exists() {
337 cmd.arg("-I").arg(path);
338 }
339 }
340
341 cmd.arg("--descriptor_set_out")
343 .arg(output_path)
344 .arg("--include_imports")
345 .arg("--include_source_info")
346 .arg(proto_file);
347
348 debug!("Running protoc command: {:?}", cmd);
349
350 let output = cmd.output()?;
352
353 if !output.status.success() {
354 let stderr = String::from_utf8_lossy(&output.stderr);
355 return Err(format!("protoc failed: {}", stderr).into());
356 }
357
358 info!("Successfully compiled proto file with protoc: {}", proto_file);
359 Ok(())
360 }
361
362 fn add_mock_greeter_service(&mut self) {
364 let service = ProtoService {
365 name: "mockforge.greeter.Greeter".to_string(),
366 package: "mockforge.greeter".to_string(),
367 short_name: "Greeter".to_string(),
368 methods: vec![
369 ProtoMethod {
370 name: "SayHello".to_string(),
371 input_type: "mockforge.greeter.HelloRequest".to_string(),
372 output_type: "mockforge.greeter.HelloReply".to_string(),
373 client_streaming: false,
374 server_streaming: false,
375 },
376 ProtoMethod {
377 name: "SayHelloStream".to_string(),
378 input_type: "mockforge.greeter.HelloRequest".to_string(),
379 output_type: "mockforge.greeter.HelloReply".to_string(),
380 client_streaming: false,
381 server_streaming: true,
382 },
383 ProtoMethod {
384 name: "SayHelloClientStream".to_string(),
385 input_type: "mockforge.greeter.HelloRequest".to_string(),
386 output_type: "mockforge.greeter.HelloReply".to_string(),
387 client_streaming: true,
388 server_streaming: false,
389 },
390 ProtoMethod {
391 name: "Chat".to_string(),
392 input_type: "mockforge.greeter.HelloRequest".to_string(),
393 output_type: "mockforge.greeter.HelloReply".to_string(),
394 client_streaming: true,
395 server_streaming: true,
396 },
397 ],
398 };
399
400 self.services.insert(service.name.clone(), service);
401 }
402
403 fn extract_services(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
405 debug!("Extracting services from descriptor pool");
406
407 let mock_services: HashMap<String, ProtoService> = self
409 .services
410 .drain()
411 .filter(|(name, _)| name.contains("mockforge.greeter"))
412 .collect();
413
414 self.services = mock_services;
415
416 for service_descriptor in self.pool.services() {
418 let service_name = service_descriptor.full_name().to_string();
419 let package_name = service_descriptor.parent_file().package_name().to_string();
420 let short_name = service_descriptor.name().to_string();
421
422 debug!("Found service: {} in package: {}", service_name, package_name);
423
424 let mut methods = Vec::new();
426 for method_descriptor in service_descriptor.methods() {
427 let method = ProtoMethod {
428 name: method_descriptor.name().to_string(),
429 input_type: method_descriptor.input().full_name().to_string(),
430 output_type: method_descriptor.output().full_name().to_string(),
431 client_streaming: method_descriptor.is_client_streaming(),
432 server_streaming: method_descriptor.is_server_streaming(),
433 };
434
435 debug!(
436 " Found method: {} ({} -> {})",
437 method.name, method.input_type, method.output_type
438 );
439
440 methods.push(method);
441 }
442
443 let service = ProtoService {
444 name: service_name.clone(),
445 package: package_name,
446 short_name,
447 methods,
448 };
449
450 self.services.insert(service_name, service);
451 }
452
453 info!("Extracted {} services from descriptor pool", self.services.len());
454 Ok(())
455 }
456
457 pub fn services(&self) -> &HashMap<String, ProtoService> {
459 &self.services
460 }
461
462 pub fn get_service(&self, name: &str) -> Option<&ProtoService> {
464 self.services.get(name)
465 }
466
467 pub fn pool(&self) -> &DescriptorPool {
469 &self.pool
470 }
471
472 pub fn into_pool(self) -> DescriptorPool {
474 self.pool
475 }
476}
477
478impl Default for ProtoParser {
479 fn default() -> Self {
480 Self::new()
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
491 fn test_proto_service_creation() {
492 let service = ProtoService {
493 name: "mypackage.MyService".to_string(),
494 package: "mypackage".to_string(),
495 short_name: "MyService".to_string(),
496 methods: vec![],
497 };
498
499 assert_eq!(service.name, "mypackage.MyService");
500 assert_eq!(service.package, "mypackage");
501 assert_eq!(service.short_name, "MyService");
502 assert!(service.methods.is_empty());
503 }
504
505 #[test]
506 fn test_proto_service_with_methods() {
507 let method = ProtoMethod {
508 name: "GetData".to_string(),
509 input_type: "mypackage.Request".to_string(),
510 output_type: "mypackage.Response".to_string(),
511 client_streaming: false,
512 server_streaming: false,
513 };
514
515 let service = ProtoService {
516 name: "mypackage.DataService".to_string(),
517 package: "mypackage".to_string(),
518 short_name: "DataService".to_string(),
519 methods: vec![method],
520 };
521
522 assert_eq!(service.methods.len(), 1);
523 assert_eq!(service.methods[0].name, "GetData");
524 }
525
526 #[test]
527 fn test_proto_service_clone() {
528 let service = ProtoService {
529 name: "test.Service".to_string(),
530 package: "test".to_string(),
531 short_name: "Service".to_string(),
532 methods: vec![ProtoMethod {
533 name: "Method".to_string(),
534 input_type: "Request".to_string(),
535 output_type: "Response".to_string(),
536 client_streaming: false,
537 server_streaming: false,
538 }],
539 };
540
541 let cloned = service.clone();
542 assert_eq!(cloned.name, service.name);
543 assert_eq!(cloned.methods.len(), service.methods.len());
544 }
545
546 #[test]
549 fn test_proto_method_unary() {
550 let method = ProtoMethod {
551 name: "UnaryMethod".to_string(),
552 input_type: "Request".to_string(),
553 output_type: "Response".to_string(),
554 client_streaming: false,
555 server_streaming: false,
556 };
557
558 assert_eq!(method.name, "UnaryMethod");
559 assert!(!method.client_streaming);
560 assert!(!method.server_streaming);
561 }
562
563 #[test]
564 fn test_proto_method_server_streaming() {
565 let method = ProtoMethod {
566 name: "StreamMethod".to_string(),
567 input_type: "Request".to_string(),
568 output_type: "Response".to_string(),
569 client_streaming: false,
570 server_streaming: true,
571 };
572
573 assert!(!method.client_streaming);
574 assert!(method.server_streaming);
575 }
576
577 #[test]
578 fn test_proto_method_client_streaming() {
579 let method = ProtoMethod {
580 name: "ClientStreamMethod".to_string(),
581 input_type: "Request".to_string(),
582 output_type: "Response".to_string(),
583 client_streaming: true,
584 server_streaming: false,
585 };
586
587 assert!(method.client_streaming);
588 assert!(!method.server_streaming);
589 }
590
591 #[test]
592 fn test_proto_method_bidi_streaming() {
593 let method = ProtoMethod {
594 name: "BidiStreamMethod".to_string(),
595 input_type: "Request".to_string(),
596 output_type: "Response".to_string(),
597 client_streaming: true,
598 server_streaming: true,
599 };
600
601 assert!(method.client_streaming);
602 assert!(method.server_streaming);
603 }
604
605 #[test]
606 fn test_proto_method_clone() {
607 let method = ProtoMethod {
608 name: "TestMethod".to_string(),
609 input_type: "Input".to_string(),
610 output_type: "Output".to_string(),
611 client_streaming: true,
612 server_streaming: true,
613 };
614
615 let cloned = method.clone();
616 assert_eq!(cloned.name, method.name);
617 assert_eq!(cloned.input_type, method.input_type);
618 assert_eq!(cloned.output_type, method.output_type);
619 assert_eq!(cloned.client_streaming, method.client_streaming);
620 assert_eq!(cloned.server_streaming, method.server_streaming);
621 }
622
623 #[test]
626 fn test_proto_parser_new() {
627 let parser = ProtoParser::new();
628 assert!(parser.services().is_empty());
629 }
630
631 #[test]
632 fn test_proto_parser_default() {
633 let parser = ProtoParser::default();
634 assert!(parser.services().is_empty());
635 }
636
637 #[test]
638 fn test_proto_parser_with_include_paths() {
639 let paths = vec![PathBuf::from("/usr/include"), PathBuf::from("/opt/proto")];
640 let parser = ProtoParser::with_include_paths(paths);
641 assert!(parser.services().is_empty());
642 }
643
644 #[test]
645 fn test_proto_parser_get_service_nonexistent() {
646 let parser = ProtoParser::new();
647 assert!(parser.get_service("nonexistent").is_none());
648 }
649
650 #[test]
651 fn test_proto_parser_pool() {
652 let parser = ProtoParser::new();
653 let _pool = parser.pool();
654 }
656
657 #[test]
658 fn test_proto_parser_into_pool() {
659 let parser = ProtoParser::new();
660 let _pool = parser.into_pool();
661 }
663
664 #[test]
665 fn test_proto_parser_add_mock_greeter_service() {
666 let mut parser = ProtoParser::new();
667 parser.add_mock_greeter_service();
668
669 let services = parser.services();
670 assert_eq!(services.len(), 1);
671 assert!(services.contains_key("mockforge.greeter.Greeter"));
672
673 let service = &services["mockforge.greeter.Greeter"];
674 assert_eq!(service.short_name, "Greeter");
675 assert_eq!(service.package, "mockforge.greeter");
676 assert_eq!(service.methods.len(), 4);
677 }
678
679 #[test]
680 fn test_proto_parser_discover_empty_dir() {
681 let temp_dir = TempDir::new().unwrap();
682 let parser = ProtoParser::new();
683
684 let result = parser.discover_proto_files(temp_dir.path()).unwrap();
685 assert!(result.is_empty());
686 }
687
688 #[test]
689 fn test_proto_parser_discover_with_proto_files() {
690 let temp_dir = TempDir::new().unwrap();
691
692 let proto_path = temp_dir.path().join("test.proto");
694 fs::write(&proto_path, "syntax = \"proto3\";").unwrap();
695
696 let parser = ProtoParser::new();
697 let result = parser.discover_proto_files(temp_dir.path()).unwrap();
698
699 assert_eq!(result.len(), 1);
700 assert!(result[0].ends_with("test.proto"));
701 }
702
703 #[test]
704 fn test_proto_parser_discover_recursive() {
705 let temp_dir = TempDir::new().unwrap();
706
707 let sub_dir = temp_dir.path().join("subdir");
709 fs::create_dir(&sub_dir).unwrap();
710 fs::write(temp_dir.path().join("root.proto"), "").unwrap();
711 fs::write(sub_dir.join("nested.proto"), "").unwrap();
712
713 let parser = ProtoParser::new();
714 let result = parser.discover_proto_files(temp_dir.path()).unwrap();
715
716 assert_eq!(result.len(), 2);
717 }
718
719 #[test]
720 fn test_proto_parser_discover_ignores_non_proto() {
721 let temp_dir = TempDir::new().unwrap();
722
723 fs::write(temp_dir.path().join("test.proto"), "").unwrap();
725 fs::write(temp_dir.path().join("test.txt"), "").unwrap();
726 fs::write(temp_dir.path().join("test.json"), "").unwrap();
727
728 let parser = ProtoParser::new();
729 let result = parser.discover_proto_files(temp_dir.path()).unwrap();
730
731 assert_eq!(result.len(), 1);
732 assert!(result[0].ends_with(".proto"));
733 }
734
735 #[tokio::test]
738 async fn test_parse_nonexistent_directory() {
739 let mut parser = ProtoParser::new();
740 let result = parser.parse_directory("/nonexistent/path").await;
741
742 assert!(result.is_ok());
744 assert!(parser.services().is_empty());
745 }
746
747 #[tokio::test]
748 async fn test_parse_empty_directory() {
749 let temp_dir = TempDir::new().unwrap();
750 let mut parser = ProtoParser::new();
751
752 let result = parser.parse_directory(temp_dir.path().to_str().unwrap()).await;
753
754 assert!(result.is_ok());
756 assert!(parser.services().is_empty());
757 }
758
759 #[tokio::test]
760 async fn test_parse_proto_file() {
761 let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
763 let proto_path = format!("{}/gretter.proto", proto_dir);
764
765 let mut parser = ProtoParser::new();
767 parser.parse_proto_file(&proto_path).await.unwrap();
768
769 let services = parser.services();
771 assert_eq!(services.len(), 1);
772
773 let service_name = "mockforge.greeter.Greeter";
774 assert!(services.contains_key(service_name));
775
776 let service = &services[service_name];
777 assert_eq!(service.name, service_name);
778 assert_eq!(service.methods.len(), 4); let say_hello = service.methods.iter().find(|m| m.name == "SayHello").unwrap();
782 assert_eq!(say_hello.input_type, "mockforge.greeter.HelloRequest");
783 assert_eq!(say_hello.output_type, "mockforge.greeter.HelloReply");
784 assert!(!say_hello.client_streaming);
785 assert!(!say_hello.server_streaming);
786
787 let say_hello_stream = service.methods.iter().find(|m| m.name == "SayHelloStream").unwrap();
789 assert!(!say_hello_stream.client_streaming);
790 assert!(say_hello_stream.server_streaming);
791 }
792
793 #[tokio::test]
794 async fn test_parse_directory() {
795 let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
797
798 let mut parser = ProtoParser::new();
800 parser.parse_directory(&proto_dir).await.unwrap();
801
802 let services = parser.services();
804 assert_eq!(services.len(), 1);
805
806 let service_name = "mockforge.greeter.Greeter";
807 assert!(services.contains_key(service_name));
808
809 let service = &services[service_name];
810 assert_eq!(service.methods.len(), 4);
811
812 let method_names: Vec<&str> = service.methods.iter().map(|m| m.name.as_str()).collect();
814 assert!(method_names.contains(&"SayHello"));
815 assert!(method_names.contains(&"SayHelloStream"));
816 assert!(method_names.contains(&"SayHelloClientStream"));
817 assert!(method_names.contains(&"Chat"));
818 }
819}