1use prost_reflect::DescriptorPool;
7use std::collections::HashMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11use tempfile::TempDir;
12use tracing::{debug, error, info, warn};
13
14#[derive(Debug, Clone)]
16pub struct ProtoService {
17 pub name: String,
19 pub package: String,
21 pub short_name: String,
23 pub methods: Vec<ProtoMethod>,
25}
26
27#[derive(Debug, Clone)]
29pub struct ProtoMethod {
30 pub name: String,
32 pub input_type: String,
34 pub output_type: String,
36 pub client_streaming: bool,
38 pub server_streaming: bool,
40}
41
42pub struct ProtoParser {
44 pool: DescriptorPool,
46 services: HashMap<String, ProtoService>,
48 include_paths: Vec<PathBuf>,
50 temp_dir: Option<TempDir>,
52}
53
54impl ProtoParser {
55 pub fn new() -> Self {
57 Self {
58 pool: DescriptorPool::new(),
59 services: HashMap::new(),
60 include_paths: vec![],
61 temp_dir: None,
62 }
63 }
64
65 pub fn with_include_paths(include_paths: Vec<PathBuf>) -> Self {
67 Self {
68 pool: DescriptorPool::new(),
69 services: HashMap::new(),
70 include_paths,
71 temp_dir: None,
72 }
73 }
74
75 pub async fn parse_directory(
81 &mut self,
82 proto_dir: &str,
83 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
84 info!("Parsing proto files from directory: {}", proto_dir);
85
86 let proto_path = Path::new(proto_dir);
87 if !proto_path.exists() {
88 warn!(
89 "Proto directory does not exist: {}. gRPC server will start with no services. \
90 This is normal when using only OpenAPI/HTTP.",
91 proto_dir
92 );
93 return Ok(());
94 }
95
96 let proto_files = self.discover_proto_files(proto_path)?;
98 if proto_files.is_empty() {
99 warn!("No proto files found in directory: {}", proto_dir);
100 return Ok(());
101 }
102
103 info!("Found {} proto files: {:?}", proto_files.len(), proto_files);
104
105 if proto_files.len() > 1 {
107 if let Err(e) = self.compile_protos_batch(&proto_files).await {
108 warn!("Batch compilation failed, falling back to individual compilation: {}", e);
109 for proto_file in proto_files {
111 if let Err(e) = self.parse_proto_file(&proto_file).await {
112 error!("Failed to parse proto file {}: {}", proto_file, e);
113 }
115 }
116 }
117 } else if !proto_files.is_empty() {
118 if let Err(e) = self.parse_proto_file(&proto_files[0]).await {
120 error!("Failed to parse proto file {}: {}", proto_files[0], e);
121 }
122 }
123
124 if self.pool.services().count() > 0 {
126 self.extract_services()?;
127 } else {
128 debug!("No services found in descriptor pool, keeping mock services");
129 }
130
131 info!("Successfully parsed {} services", self.services.len());
132 Ok(())
133 }
134
135 #[allow(clippy::only_used_in_recursion)]
137 fn discover_proto_files(
138 &self,
139 dir: &Path,
140 ) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
141 let mut proto_files = Vec::new();
142
143 if let Ok(entries) = fs::read_dir(dir) {
144 for entry in entries.flatten() {
145 let path = entry.path();
146
147 if path.is_dir() {
148 proto_files.extend(self.discover_proto_files(&path)?);
150 } else if path.extension().and_then(|s| s.to_str()) == Some("proto") {
151 proto_files.push(path.to_string_lossy().to_string());
153 }
154 }
155 }
156
157 Ok(proto_files)
158 }
159
160 async fn parse_proto_file(
162 &mut self,
163 proto_file: &str,
164 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
165 debug!("Parsing proto file: {}", proto_file);
166
167 if self.temp_dir.is_none() {
169 self.temp_dir = Some(TempDir::new()?);
170 }
171
172 let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
174 Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
175 })?;
176 let descriptor_path = temp_dir.path().join("descriptors.bin");
177
178 match self.compile_with_protoc(proto_file, &descriptor_path).await {
180 Ok(()) => {
181 let descriptor_bytes = fs::read(&descriptor_path)?;
183 match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
184 Ok(()) => {
185 info!("Successfully compiled and loaded proto file: {}", proto_file);
186 if self.pool.services().count() > 0 {
188 self.extract_services()?;
189 }
190 return Ok(());
191 }
192 Err(e) => {
193 warn!("Failed to decode descriptor set, falling back to mock: {}", e);
194 }
195 }
196 }
197 Err(e) => {
198 warn!(
201 "protoc not available or compilation failed (this is OK for basic usage, using fallback): {}",
202 e
203 );
204 }
205 }
206
207 if proto_file.contains("gretter.proto") || proto_file.contains("greeter.proto") {
209 debug!("Adding mock greeter service for {}", proto_file);
210 self.add_mock_greeter_service();
211 }
212
213 Ok(())
214 }
215
216 async fn compile_protos_batch(
219 &mut self,
220 proto_files: &[String],
221 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
222 if proto_files.is_empty() {
223 return Ok(());
224 }
225
226 info!("Batch compiling {} proto files", proto_files.len());
227
228 if self.temp_dir.is_none() {
230 self.temp_dir = Some(TempDir::new()?);
231 }
232
233 let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
234 Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
235 })?;
236 let descriptor_path = temp_dir.path().join("descriptors_batch.bin");
237
238 let mut cmd = Command::new("protoc");
240
241 let mut parent_dirs = std::collections::HashSet::new();
243 for proto_file in proto_files {
244 if let Some(parent_dir) = Path::new(proto_file).parent() {
245 parent_dirs.insert(parent_dir.to_path_buf());
246 }
247 }
248
249 for include_path in &self.include_paths {
251 cmd.arg("-I").arg(include_path);
252 }
253
254 for parent_dir in &parent_dirs {
256 cmd.arg("-I").arg(parent_dir);
257 }
258
259 let well_known_paths = [
261 "/usr/local/include",
262 "/usr/include",
263 "/opt/homebrew/include",
264 ];
265
266 for path in &well_known_paths {
267 if Path::new(path).exists() {
268 cmd.arg("-I").arg(path);
269 }
270 }
271
272 cmd.arg("--descriptor_set_out")
274 .arg(&descriptor_path)
275 .arg("--include_imports")
276 .arg("--include_source_info");
277
278 for proto_file in proto_files {
280 cmd.arg(proto_file);
281 }
282
283 debug!("Running batch protoc command for {} files", proto_files.len());
284
285 let output = cmd.output()?;
287
288 if !output.status.success() {
289 let stderr = String::from_utf8_lossy(&output.stderr);
290 return Err(format!("Batch protoc compilation failed: {}", stderr).into());
291 }
292
293 let descriptor_bytes = fs::read(&descriptor_path)?;
295 match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
296 Ok(()) => {
297 info!("Successfully batch compiled and loaded {} proto files", proto_files.len());
298 if self.pool.services().count() > 0 {
300 self.extract_services()?;
301 }
302 Ok(())
303 }
304 Err(e) => Err(format!("Failed to decode batch descriptor set: {}", e).into()),
305 }
306 }
307
308 async fn compile_with_protoc(
310 &self,
311 proto_file: &str,
312 output_path: &Path,
313 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
314 debug!("Compiling proto file with protoc: {}", proto_file);
315
316 let mut cmd = Command::new("protoc");
318
319 for include_path in &self.include_paths {
321 cmd.arg("-I").arg(include_path);
322 }
323
324 if let Some(parent_dir) = Path::new(proto_file).parent() {
326 cmd.arg("-I").arg(parent_dir);
327 }
328
329 let well_known_paths = [
331 "/usr/local/include",
332 "/usr/include",
333 "/opt/homebrew/include",
334 ];
335
336 for path in &well_known_paths {
337 if Path::new(path).exists() {
338 cmd.arg("-I").arg(path);
339 }
340 }
341
342 cmd.arg("--descriptor_set_out")
344 .arg(output_path)
345 .arg("--include_imports")
346 .arg("--include_source_info")
347 .arg(proto_file);
348
349 debug!("Running protoc command: {:?}", cmd);
350
351 let output = cmd.output()?;
353
354 if !output.status.success() {
355 let stderr = String::from_utf8_lossy(&output.stderr);
356 return Err(format!("protoc failed: {}", stderr).into());
357 }
358
359 info!("Successfully compiled proto file with protoc: {}", proto_file);
360 Ok(())
361 }
362
363 fn add_mock_greeter_service(&mut self) {
365 let service = ProtoService {
366 name: "mockforge.greeter.Greeter".to_string(),
367 package: "mockforge.greeter".to_string(),
368 short_name: "Greeter".to_string(),
369 methods: vec![
370 ProtoMethod {
371 name: "SayHello".to_string(),
372 input_type: "mockforge.greeter.HelloRequest".to_string(),
373 output_type: "mockforge.greeter.HelloReply".to_string(),
374 client_streaming: false,
375 server_streaming: false,
376 },
377 ProtoMethod {
378 name: "SayHelloStream".to_string(),
379 input_type: "mockforge.greeter.HelloRequest".to_string(),
380 output_type: "mockforge.greeter.HelloReply".to_string(),
381 client_streaming: false,
382 server_streaming: true,
383 },
384 ProtoMethod {
385 name: "SayHelloClientStream".to_string(),
386 input_type: "mockforge.greeter.HelloRequest".to_string(),
387 output_type: "mockforge.greeter.HelloReply".to_string(),
388 client_streaming: true,
389 server_streaming: false,
390 },
391 ProtoMethod {
392 name: "Chat".to_string(),
393 input_type: "mockforge.greeter.HelloRequest".to_string(),
394 output_type: "mockforge.greeter.HelloReply".to_string(),
395 client_streaming: true,
396 server_streaming: true,
397 },
398 ],
399 };
400
401 self.services.insert(service.name.clone(), service);
402 }
403
404 fn extract_services(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
406 debug!("Extracting services from descriptor pool");
407
408 let mock_services: HashMap<String, ProtoService> = self
410 .services
411 .drain()
412 .filter(|(name, _)| name.contains("mockforge.greeter"))
413 .collect();
414
415 self.services = mock_services;
416
417 for service_descriptor in self.pool.services() {
419 let service_name = service_descriptor.full_name().to_string();
420 let package_name = service_descriptor.parent_file().package_name().to_string();
421 let short_name = service_descriptor.name().to_string();
422
423 debug!("Found service: {} in package: {}", service_name, package_name);
424
425 let mut methods = Vec::new();
427 for method_descriptor in service_descriptor.methods() {
428 let method = ProtoMethod {
429 name: method_descriptor.name().to_string(),
430 input_type: method_descriptor.input().full_name().to_string(),
431 output_type: method_descriptor.output().full_name().to_string(),
432 client_streaming: method_descriptor.is_client_streaming(),
433 server_streaming: method_descriptor.is_server_streaming(),
434 };
435
436 debug!(
437 " Found method: {} ({} -> {})",
438 method.name, method.input_type, method.output_type
439 );
440
441 methods.push(method);
442 }
443
444 let service = ProtoService {
445 name: service_name.clone(),
446 package: package_name,
447 short_name,
448 methods,
449 };
450
451 self.services.insert(service_name, service);
452 }
453
454 info!("Extracted {} services from descriptor pool", self.services.len());
455 Ok(())
456 }
457
458 pub fn services(&self) -> &HashMap<String, ProtoService> {
460 &self.services
461 }
462
463 pub fn get_service(&self, name: &str) -> Option<&ProtoService> {
465 self.services.get(name)
466 }
467
468 pub fn pool(&self) -> &DescriptorPool {
470 &self.pool
471 }
472
473 pub fn into_pool(self) -> DescriptorPool {
475 self.pool
476 }
477}
478
479impl Default for ProtoParser {
480 fn default() -> Self {
481 Self::new()
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[tokio::test]
490 async fn test_parse_proto_file() {
491 let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
493 let proto_path = format!("{}/gretter.proto", proto_dir);
494
495 let mut parser = ProtoParser::new();
497 parser.parse_proto_file(&proto_path).await.unwrap();
498
499 let services = parser.services();
501 assert_eq!(services.len(), 1);
502
503 let service_name = "mockforge.greeter.Greeter";
504 assert!(services.contains_key(service_name));
505
506 let service = &services[service_name];
507 assert_eq!(service.name, service_name);
508 assert_eq!(service.methods.len(), 4); let say_hello = service.methods.iter().find(|m| m.name == "SayHello").unwrap();
512 assert_eq!(say_hello.input_type, "mockforge.greeter.HelloRequest");
513 assert_eq!(say_hello.output_type, "mockforge.greeter.HelloReply");
514 assert!(!say_hello.client_streaming);
515 assert!(!say_hello.server_streaming);
516
517 let say_hello_stream = service.methods.iter().find(|m| m.name == "SayHelloStream").unwrap();
519 assert!(!say_hello_stream.client_streaming);
520 assert!(say_hello_stream.server_streaming);
521 }
522
523 #[tokio::test]
524 async fn test_parse_directory() {
525 let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
527
528 let mut parser = ProtoParser::new();
530 parser.parse_directory(&proto_dir).await.unwrap();
531
532 let services = parser.services();
534 assert_eq!(services.len(), 1);
535
536 let service_name = "mockforge.greeter.Greeter";
537 assert!(services.contains_key(service_name));
538
539 let service = &services[service_name];
540 assert_eq!(service.methods.len(), 4);
541
542 let method_names: Vec<&str> = service.methods.iter().map(|m| m.name.as_str()).collect();
544 assert!(method_names.contains(&"SayHello"));
545 assert!(method_names.contains(&"SayHelloStream"));
546 assert!(method_names.contains(&"SayHelloClientStream"));
547 assert!(method_names.contains(&"Chat"));
548 }
549}