1use prost_reflect::DescriptorPool;
7use std::collections::HashMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11use tempfile::TempDir;
12use tracing::{debug, error, info, warn};
13
14#[derive(Debug, Clone)]
16pub struct ProtoService {
17 pub name: String,
19 pub package: String,
21 pub short_name: String,
23 pub methods: Vec<ProtoMethod>,
25}
26
27#[derive(Debug, Clone)]
29pub struct ProtoMethod {
30 pub name: String,
32 pub input_type: String,
34 pub output_type: String,
36 pub client_streaming: bool,
38 pub server_streaming: bool,
40}
41
42pub struct ProtoParser {
44 pool: DescriptorPool,
46 services: HashMap<String, ProtoService>,
48 include_paths: Vec<PathBuf>,
50 temp_dir: Option<TempDir>,
52}
53
54impl ProtoParser {
55 pub fn new() -> Self {
57 Self {
58 pool: DescriptorPool::new(),
59 services: HashMap::new(),
60 include_paths: vec![],
61 temp_dir: None,
62 }
63 }
64
65 pub fn with_include_paths(include_paths: Vec<PathBuf>) -> Self {
67 Self {
68 pool: DescriptorPool::new(),
69 services: HashMap::new(),
70 include_paths,
71 temp_dir: None,
72 }
73 }
74
75 pub async fn parse_directory(
77 &mut self,
78 proto_dir: &str,
79 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
80 info!("Parsing proto files from directory: {}", proto_dir);
81
82 let proto_path = Path::new(proto_dir);
83 if !proto_path.exists() {
84 return Err(format!("Proto directory does not exist: {}", proto_dir).into());
85 }
86
87 let proto_files = self.discover_proto_files(proto_path)?;
89 if proto_files.is_empty() {
90 warn!("No proto files found in directory: {}", proto_dir);
91 return Ok(());
92 }
93
94 info!("Found {} proto files: {:?}", proto_files.len(), proto_files);
95
96 for proto_file in proto_files {
98 if let Err(e) = self.parse_proto_file(&proto_file).await {
99 error!("Failed to parse proto file {}: {}", proto_file, e);
100 }
102 }
103
104 if self.pool.services().count() > 0 {
106 self.extract_services()?;
107 } else {
108 debug!("No services found in descriptor pool, keeping mock services");
109 }
110
111 info!("Successfully parsed {} services", self.services.len());
112 Ok(())
113 }
114
115 #[allow(clippy::only_used_in_recursion)]
117 fn discover_proto_files(
118 &self,
119 dir: &Path,
120 ) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
121 let mut proto_files = Vec::new();
122
123 if let Ok(entries) = fs::read_dir(dir) {
124 for entry in entries.flatten() {
125 let path = entry.path();
126
127 if path.is_dir() {
128 proto_files.extend(self.discover_proto_files(&path)?);
130 } else if path.extension().and_then(|s| s.to_str()) == Some("proto") {
131 proto_files.push(path.to_string_lossy().to_string());
133 }
134 }
135 }
136
137 Ok(proto_files)
138 }
139
140 async fn parse_proto_file(
142 &mut self,
143 proto_file: &str,
144 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
145 debug!("Parsing proto file: {}", proto_file);
146
147 if self.temp_dir.is_none() {
149 self.temp_dir = Some(TempDir::new()?);
150 }
151
152 let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
154 Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
155 })?;
156 let descriptor_path = temp_dir.path().join("descriptors.bin");
157
158 match self.compile_with_protoc(proto_file, &descriptor_path).await {
160 Ok(()) => {
161 let descriptor_bytes = fs::read(&descriptor_path)?;
163 match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
164 Ok(()) => {
165 info!("Successfully compiled and loaded proto file: {}", proto_file);
166 if self.pool.services().count() > 0 {
168 self.extract_services()?;
169 }
170 return Ok(());
171 }
172 Err(e) => {
173 warn!("Failed to decode descriptor set, falling back to mock: {}", e);
174 }
175 }
176 }
177 Err(e) => {
178 warn!(
181 "protoc not available or compilation failed (this is OK for basic usage, using fallback): {}",
182 e
183 );
184 }
185 }
186
187 if proto_file.contains("gretter.proto") || proto_file.contains("greeter.proto") {
189 debug!("Adding mock greeter service for {}", proto_file);
190 self.add_mock_greeter_service();
191 }
192
193 Ok(())
194 }
195
196 async fn compile_with_protoc(
198 &self,
199 proto_file: &str,
200 output_path: &Path,
201 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
202 debug!("Compiling proto file with protoc: {}", proto_file);
203
204 let mut cmd = Command::new("protoc");
206
207 for include_path in &self.include_paths {
209 cmd.arg("-I").arg(include_path);
210 }
211
212 if let Some(parent_dir) = Path::new(proto_file).parent() {
214 cmd.arg("-I").arg(parent_dir);
215 }
216
217 let well_known_paths = [
219 "/usr/local/include",
220 "/usr/include",
221 "/opt/homebrew/include",
222 ];
223
224 for path in &well_known_paths {
225 if Path::new(path).exists() {
226 cmd.arg("-I").arg(path);
227 }
228 }
229
230 cmd.arg("--descriptor_set_out")
232 .arg(output_path)
233 .arg("--include_imports")
234 .arg("--include_source_info")
235 .arg(proto_file);
236
237 debug!("Running protoc command: {:?}", cmd);
238
239 let output = cmd.output()?;
241
242 if !output.status.success() {
243 let stderr = String::from_utf8_lossy(&output.stderr);
244 return Err(format!("protoc failed: {}", stderr).into());
245 }
246
247 info!("Successfully compiled proto file with protoc: {}", proto_file);
248 Ok(())
249 }
250
251 fn add_mock_greeter_service(&mut self) {
253 let service = ProtoService {
254 name: "mockforge.greeter.Greeter".to_string(),
255 package: "mockforge.greeter".to_string(),
256 short_name: "Greeter".to_string(),
257 methods: vec![
258 ProtoMethod {
259 name: "SayHello".to_string(),
260 input_type: "mockforge.greeter.HelloRequest".to_string(),
261 output_type: "mockforge.greeter.HelloReply".to_string(),
262 client_streaming: false,
263 server_streaming: false,
264 },
265 ProtoMethod {
266 name: "SayHelloStream".to_string(),
267 input_type: "mockforge.greeter.HelloRequest".to_string(),
268 output_type: "mockforge.greeter.HelloReply".to_string(),
269 client_streaming: false,
270 server_streaming: true,
271 },
272 ProtoMethod {
273 name: "SayHelloClientStream".to_string(),
274 input_type: "mockforge.greeter.HelloRequest".to_string(),
275 output_type: "mockforge.greeter.HelloReply".to_string(),
276 client_streaming: true,
277 server_streaming: false,
278 },
279 ProtoMethod {
280 name: "Chat".to_string(),
281 input_type: "mockforge.greeter.HelloRequest".to_string(),
282 output_type: "mockforge.greeter.HelloReply".to_string(),
283 client_streaming: true,
284 server_streaming: true,
285 },
286 ],
287 };
288
289 self.services.insert(service.name.clone(), service);
290 }
291
292 fn extract_services(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
294 debug!("Extracting services from descriptor pool");
295
296 let mock_services: HashMap<String, ProtoService> = self
298 .services
299 .drain()
300 .filter(|(name, _)| name.contains("mockforge.greeter"))
301 .collect();
302
303 self.services = mock_services;
304
305 for service_descriptor in self.pool.services() {
307 let service_name = service_descriptor.full_name().to_string();
308 let package_name = service_descriptor.parent_file().package_name().to_string();
309 let short_name = service_descriptor.name().to_string();
310
311 debug!("Found service: {} in package: {}", service_name, package_name);
312
313 let mut methods = Vec::new();
315 for method_descriptor in service_descriptor.methods() {
316 let method = ProtoMethod {
317 name: method_descriptor.name().to_string(),
318 input_type: method_descriptor.input().full_name().to_string(),
319 output_type: method_descriptor.output().full_name().to_string(),
320 client_streaming: method_descriptor.is_client_streaming(),
321 server_streaming: method_descriptor.is_server_streaming(),
322 };
323
324 debug!(
325 " Found method: {} ({} -> {})",
326 method.name, method.input_type, method.output_type
327 );
328
329 methods.push(method);
330 }
331
332 let service = ProtoService {
333 name: service_name.clone(),
334 package: package_name,
335 short_name,
336 methods,
337 };
338
339 self.services.insert(service_name, service);
340 }
341
342 info!("Extracted {} services from descriptor pool", self.services.len());
343 Ok(())
344 }
345
346 pub fn services(&self) -> &HashMap<String, ProtoService> {
348 &self.services
349 }
350
351 pub fn get_service(&self, name: &str) -> Option<&ProtoService> {
353 self.services.get(name)
354 }
355
356 pub fn pool(&self) -> &DescriptorPool {
358 &self.pool
359 }
360
361 pub fn into_pool(self) -> DescriptorPool {
363 self.pool
364 }
365}
366
367impl Default for ProtoParser {
368 fn default() -> Self {
369 Self::new()
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[tokio::test]
378 async fn test_parse_proto_file() {
379 let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
381 let proto_path = format!("{}/gretter.proto", proto_dir);
382
383 let mut parser = ProtoParser::new();
385 parser.parse_proto_file(&proto_path).await.unwrap();
386
387 let services = parser.services();
389 assert_eq!(services.len(), 1);
390
391 let service_name = "mockforge.greeter.Greeter";
392 assert!(services.contains_key(service_name));
393
394 let service = &services[service_name];
395 assert_eq!(service.name, service_name);
396 assert_eq!(service.methods.len(), 4); let say_hello = service.methods.iter().find(|m| m.name == "SayHello").unwrap();
400 assert_eq!(say_hello.input_type, "mockforge.greeter.HelloRequest");
401 assert_eq!(say_hello.output_type, "mockforge.greeter.HelloReply");
402 assert!(!say_hello.client_streaming);
403 assert!(!say_hello.server_streaming);
404
405 let say_hello_stream = service.methods.iter().find(|m| m.name == "SayHelloStream").unwrap();
407 assert!(!say_hello_stream.client_streaming);
408 assert!(say_hello_stream.server_streaming);
409 }
410
411 #[tokio::test]
412 async fn test_parse_directory() {
413 let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
415
416 let mut parser = ProtoParser::new();
418 parser.parse_directory(&proto_dir).await.unwrap();
419
420 let services = parser.services();
422 assert_eq!(services.len(), 1);
423
424 let service_name = "mockforge.greeter.Greeter";
425 assert!(services.contains_key(service_name));
426
427 let service = &services[service_name];
428 assert_eq!(service.methods.len(), 4);
429
430 let method_names: Vec<&str> = service.methods.iter().map(|m| m.name.as_str()).collect();
432 assert!(method_names.contains(&"SayHello"));
433 assert!(method_names.contains(&"SayHelloStream"));
434 assert!(method_names.contains(&"SayHelloClientStream"));
435 assert!(method_names.contains(&"Chat"));
436 }
437}