use prost_reflect::DescriptorPool;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::process::Command;
use tempfile::TempDir;
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone)]
pub struct ProtoService {
pub name: String,
pub package: String,
pub short_name: String,
pub methods: Vec<ProtoMethod>,
}
#[derive(Debug, Clone)]
pub struct ProtoMethod {
pub name: String,
pub input_type: String,
pub output_type: String,
pub client_streaming: bool,
pub server_streaming: bool,
}
pub struct ProtoParser {
pool: DescriptorPool,
services: HashMap<String, ProtoService>,
include_paths: Vec<PathBuf>,
temp_dir: Option<TempDir>,
}
impl ProtoParser {
pub fn new() -> Self {
Self {
pool: DescriptorPool::new(),
services: HashMap::new(),
include_paths: vec![],
temp_dir: None,
}
}
pub fn with_include_paths(include_paths: Vec<PathBuf>) -> Self {
Self {
pool: DescriptorPool::new(),
services: HashMap::new(),
include_paths,
temp_dir: None,
}
}
pub async fn parse_directory(
&mut self,
proto_dir: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
info!("Parsing proto files from directory: {}", proto_dir);
let proto_path = Path::new(proto_dir);
if !proto_path.exists() {
info!(
"No proto directory found at {}. gRPC server will start with built-in services only.",
proto_dir
);
return Ok(());
}
let proto_files = self.discover_proto_files(proto_path)?;
if proto_files.is_empty() {
warn!("No proto files found in directory: {}", proto_dir);
return Ok(());
}
info!("Found {} proto files: {:?}", proto_files.len(), proto_files);
if proto_files.len() > 1 {
if let Err(e) = self.compile_protos_batch(&proto_files).await {
warn!("Batch compilation failed, falling back to individual compilation: {}", e);
for proto_file in proto_files {
if let Err(e) = self.parse_proto_file(&proto_file).await {
error!("Failed to parse proto file {}: {}", proto_file, e);
}
}
}
} else if !proto_files.is_empty() {
if let Err(e) = self.parse_proto_file(&proto_files[0]).await {
error!("Failed to parse proto file {}: {}", proto_files[0], e);
}
}
if self.pool.services().count() > 0 {
self.extract_services()?;
} else {
debug!("No services found in descriptor pool, keeping mock services");
}
info!("Successfully parsed {} services", self.services.len());
Ok(())
}
#[allow(clippy::only_used_in_recursion)]
fn discover_proto_files(
&self,
dir: &Path,
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
let mut proto_files = Vec::new();
if let Ok(entries) = fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.is_dir() {
proto_files.extend(self.discover_proto_files(&path)?);
} else if path.extension().and_then(|s| s.to_str()) == Some("proto") {
proto_files.push(path.to_string_lossy().to_string());
}
}
}
Ok(proto_files)
}
async fn parse_proto_file(
&mut self,
proto_file: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
debug!("Parsing proto file: {}", proto_file);
if self.temp_dir.is_none() {
self.temp_dir = Some(TempDir::new()?);
}
let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
})?;
let descriptor_path = temp_dir.path().join("descriptors.bin");
match self.compile_with_protoc(proto_file, &descriptor_path).await {
Ok(()) => {
let descriptor_bytes = fs::read(&descriptor_path)?;
match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
Ok(()) => {
info!("Successfully compiled and loaded proto file: {}", proto_file);
if self.pool.services().count() > 0 {
self.extract_services()?;
}
return Ok(());
}
Err(e) => {
warn!("Failed to decode descriptor set, falling back to mock: {}", e);
}
}
}
Err(e) => {
warn!(
"protoc not available or compilation failed (this is OK for basic usage, using fallback): {}",
e
);
}
}
if proto_file.contains("gretter.proto") || proto_file.contains("greeter.proto") {
debug!("Adding mock greeter service for {}", proto_file);
self.add_mock_greeter_service();
}
Ok(())
}
async fn compile_protos_batch(
&mut self,
proto_files: &[String],
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if proto_files.is_empty() {
return Ok(());
}
info!("Batch compiling {} proto files", proto_files.len());
if self.temp_dir.is_none() {
self.temp_dir = Some(TempDir::new()?);
}
let temp_dir = self.temp_dir.as_ref().ok_or_else(|| {
Box::<dyn std::error::Error + Send + Sync>::from("Temp directory not initialized")
})?;
let descriptor_path = temp_dir.path().join("descriptors_batch.bin");
let mut cmd = Command::new("protoc");
let mut parent_dirs = std::collections::HashSet::new();
for proto_file in proto_files {
if let Some(parent_dir) = Path::new(proto_file).parent() {
parent_dirs.insert(parent_dir.to_path_buf());
}
}
for include_path in &self.include_paths {
cmd.arg("-I").arg(include_path);
}
for parent_dir in &parent_dirs {
cmd.arg("-I").arg(parent_dir);
}
let well_known_paths = [
"/usr/local/include",
"/usr/include",
"/opt/homebrew/include",
];
for path in &well_known_paths {
if Path::new(path).exists() {
cmd.arg("-I").arg(path);
}
}
cmd.arg("--descriptor_set_out")
.arg(&descriptor_path)
.arg("--include_imports")
.arg("--include_source_info");
for proto_file in proto_files {
cmd.arg(proto_file);
}
debug!("Running batch protoc command for {} files", proto_files.len());
let output = cmd.output()?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(format!("Batch protoc compilation failed: {}", stderr).into());
}
let descriptor_bytes = fs::read(&descriptor_path)?;
match self.pool.decode_file_descriptor_set(&*descriptor_bytes) {
Ok(()) => {
info!("Successfully batch compiled and loaded {} proto files", proto_files.len());
if self.pool.services().count() > 0 {
self.extract_services()?;
}
Ok(())
}
Err(e) => Err(format!("Failed to decode batch descriptor set: {}", e).into()),
}
}
async fn compile_with_protoc(
&self,
proto_file: &str,
output_path: &Path,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
debug!("Compiling proto file with protoc: {}", proto_file);
let mut cmd = Command::new("protoc");
for include_path in &self.include_paths {
cmd.arg("-I").arg(include_path);
}
if let Some(parent_dir) = Path::new(proto_file).parent() {
cmd.arg("-I").arg(parent_dir);
}
let well_known_paths = [
"/usr/local/include",
"/usr/include",
"/opt/homebrew/include",
];
for path in &well_known_paths {
if Path::new(path).exists() {
cmd.arg("-I").arg(path);
}
}
cmd.arg("--descriptor_set_out")
.arg(output_path)
.arg("--include_imports")
.arg("--include_source_info")
.arg(proto_file);
debug!("Running protoc command: {:?}", cmd);
let output = cmd.output()?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(format!("protoc failed: {}", stderr).into());
}
info!("Successfully compiled proto file with protoc: {}", proto_file);
Ok(())
}
fn add_mock_greeter_service(&mut self) {
let service = ProtoService {
name: "mockforge.greeter.Greeter".to_string(),
package: "mockforge.greeter".to_string(),
short_name: "Greeter".to_string(),
methods: vec![
ProtoMethod {
name: "SayHello".to_string(),
input_type: "mockforge.greeter.HelloRequest".to_string(),
output_type: "mockforge.greeter.HelloReply".to_string(),
client_streaming: false,
server_streaming: false,
},
ProtoMethod {
name: "SayHelloStream".to_string(),
input_type: "mockforge.greeter.HelloRequest".to_string(),
output_type: "mockforge.greeter.HelloReply".to_string(),
client_streaming: false,
server_streaming: true,
},
ProtoMethod {
name: "SayHelloClientStream".to_string(),
input_type: "mockforge.greeter.HelloRequest".to_string(),
output_type: "mockforge.greeter.HelloReply".to_string(),
client_streaming: true,
server_streaming: false,
},
ProtoMethod {
name: "Chat".to_string(),
input_type: "mockforge.greeter.HelloRequest".to_string(),
output_type: "mockforge.greeter.HelloReply".to_string(),
client_streaming: true,
server_streaming: true,
},
],
};
self.services.insert(service.name.clone(), service);
}
fn extract_services(&mut self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
debug!("Extracting services from descriptor pool");
let mock_services: HashMap<String, ProtoService> = self
.services
.drain()
.filter(|(name, _)| name.contains("mockforge.greeter"))
.collect();
self.services = mock_services;
for service_descriptor in self.pool.services() {
let service_name = service_descriptor.full_name().to_string();
let package_name = service_descriptor.parent_file().package_name().to_string();
let short_name = service_descriptor.name().to_string();
debug!("Found service: {} in package: {}", service_name, package_name);
let mut methods = Vec::new();
for method_descriptor in service_descriptor.methods() {
let method = ProtoMethod {
name: method_descriptor.name().to_string(),
input_type: method_descriptor.input().full_name().to_string(),
output_type: method_descriptor.output().full_name().to_string(),
client_streaming: method_descriptor.is_client_streaming(),
server_streaming: method_descriptor.is_server_streaming(),
};
debug!(
" Found method: {} ({} -> {})",
method.name, method.input_type, method.output_type
);
methods.push(method);
}
let service = ProtoService {
name: service_name.clone(),
package: package_name,
short_name,
methods,
};
self.services.insert(service_name, service);
}
info!("Extracted {} services from descriptor pool", self.services.len());
Ok(())
}
pub fn services(&self) -> &HashMap<String, ProtoService> {
&self.services
}
pub fn get_service(&self, name: &str) -> Option<&ProtoService> {
self.services.get(name)
}
pub fn pool(&self) -> &DescriptorPool {
&self.pool
}
pub fn into_pool(self) -> DescriptorPool {
self.pool
}
}
impl Default for ProtoParser {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_proto_service_creation() {
let service = ProtoService {
name: "mypackage.MyService".to_string(),
package: "mypackage".to_string(),
short_name: "MyService".to_string(),
methods: vec![],
};
assert_eq!(service.name, "mypackage.MyService");
assert_eq!(service.package, "mypackage");
assert_eq!(service.short_name, "MyService");
assert!(service.methods.is_empty());
}
#[test]
fn test_proto_service_with_methods() {
let method = ProtoMethod {
name: "GetData".to_string(),
input_type: "mypackage.Request".to_string(),
output_type: "mypackage.Response".to_string(),
client_streaming: false,
server_streaming: false,
};
let service = ProtoService {
name: "mypackage.DataService".to_string(),
package: "mypackage".to_string(),
short_name: "DataService".to_string(),
methods: vec![method],
};
assert_eq!(service.methods.len(), 1);
assert_eq!(service.methods[0].name, "GetData");
}
#[test]
fn test_proto_service_clone() {
let service = ProtoService {
name: "test.Service".to_string(),
package: "test".to_string(),
short_name: "Service".to_string(),
methods: vec![ProtoMethod {
name: "Method".to_string(),
input_type: "Request".to_string(),
output_type: "Response".to_string(),
client_streaming: false,
server_streaming: false,
}],
};
let cloned = service.clone();
assert_eq!(cloned.name, service.name);
assert_eq!(cloned.methods.len(), service.methods.len());
}
#[test]
fn test_proto_method_unary() {
let method = ProtoMethod {
name: "UnaryMethod".to_string(),
input_type: "Request".to_string(),
output_type: "Response".to_string(),
client_streaming: false,
server_streaming: false,
};
assert_eq!(method.name, "UnaryMethod");
assert!(!method.client_streaming);
assert!(!method.server_streaming);
}
#[test]
fn test_proto_method_server_streaming() {
let method = ProtoMethod {
name: "StreamMethod".to_string(),
input_type: "Request".to_string(),
output_type: "Response".to_string(),
client_streaming: false,
server_streaming: true,
};
assert!(!method.client_streaming);
assert!(method.server_streaming);
}
#[test]
fn test_proto_method_client_streaming() {
let method = ProtoMethod {
name: "ClientStreamMethod".to_string(),
input_type: "Request".to_string(),
output_type: "Response".to_string(),
client_streaming: true,
server_streaming: false,
};
assert!(method.client_streaming);
assert!(!method.server_streaming);
}
#[test]
fn test_proto_method_bidi_streaming() {
let method = ProtoMethod {
name: "BidiStreamMethod".to_string(),
input_type: "Request".to_string(),
output_type: "Response".to_string(),
client_streaming: true,
server_streaming: true,
};
assert!(method.client_streaming);
assert!(method.server_streaming);
}
#[test]
fn test_proto_method_clone() {
let method = ProtoMethod {
name: "TestMethod".to_string(),
input_type: "Input".to_string(),
output_type: "Output".to_string(),
client_streaming: true,
server_streaming: true,
};
let cloned = method.clone();
assert_eq!(cloned.name, method.name);
assert_eq!(cloned.input_type, method.input_type);
assert_eq!(cloned.output_type, method.output_type);
assert_eq!(cloned.client_streaming, method.client_streaming);
assert_eq!(cloned.server_streaming, method.server_streaming);
}
#[test]
fn test_proto_parser_new() {
let parser = ProtoParser::new();
assert!(parser.services().is_empty());
}
#[test]
fn test_proto_parser_default() {
let parser = ProtoParser::default();
assert!(parser.services().is_empty());
}
#[test]
fn test_proto_parser_with_include_paths() {
let paths = vec![PathBuf::from("/usr/include"), PathBuf::from("/opt/proto")];
let parser = ProtoParser::with_include_paths(paths);
assert!(parser.services().is_empty());
}
#[test]
fn test_proto_parser_get_service_nonexistent() {
let parser = ProtoParser::new();
assert!(parser.get_service("nonexistent").is_none());
}
#[test]
fn test_proto_parser_pool() {
let parser = ProtoParser::new();
let _pool = parser.pool();
}
#[test]
fn test_proto_parser_into_pool() {
let parser = ProtoParser::new();
let _pool = parser.into_pool();
}
#[test]
fn test_proto_parser_add_mock_greeter_service() {
let mut parser = ProtoParser::new();
parser.add_mock_greeter_service();
let services = parser.services();
assert_eq!(services.len(), 1);
assert!(services.contains_key("mockforge.greeter.Greeter"));
let service = &services["mockforge.greeter.Greeter"];
assert_eq!(service.short_name, "Greeter");
assert_eq!(service.package, "mockforge.greeter");
assert_eq!(service.methods.len(), 4);
}
#[test]
fn test_proto_parser_discover_empty_dir() {
let temp_dir = TempDir::new().unwrap();
let parser = ProtoParser::new();
let result = parser.discover_proto_files(temp_dir.path()).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_proto_parser_discover_with_proto_files() {
let temp_dir = TempDir::new().unwrap();
let proto_path = temp_dir.path().join("test.proto");
fs::write(&proto_path, "syntax = \"proto3\";").unwrap();
let parser = ProtoParser::new();
let result = parser.discover_proto_files(temp_dir.path()).unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].ends_with("test.proto"));
}
#[test]
fn test_proto_parser_discover_recursive() {
let temp_dir = TempDir::new().unwrap();
let sub_dir = temp_dir.path().join("subdir");
fs::create_dir(&sub_dir).unwrap();
fs::write(temp_dir.path().join("root.proto"), "").unwrap();
fs::write(sub_dir.join("nested.proto"), "").unwrap();
let parser = ProtoParser::new();
let result = parser.discover_proto_files(temp_dir.path()).unwrap();
assert_eq!(result.len(), 2);
}
#[test]
fn test_proto_parser_discover_ignores_non_proto() {
let temp_dir = TempDir::new().unwrap();
fs::write(temp_dir.path().join("test.proto"), "").unwrap();
fs::write(temp_dir.path().join("test.txt"), "").unwrap();
fs::write(temp_dir.path().join("test.json"), "").unwrap();
let parser = ProtoParser::new();
let result = parser.discover_proto_files(temp_dir.path()).unwrap();
assert_eq!(result.len(), 1);
assert!(result[0].ends_with(".proto"));
}
#[tokio::test]
async fn test_parse_nonexistent_directory() {
let mut parser = ProtoParser::new();
let result = parser.parse_directory("/nonexistent/path").await;
assert!(result.is_ok());
assert!(parser.services().is_empty());
}
#[tokio::test]
async fn test_parse_empty_directory() {
let temp_dir = TempDir::new().unwrap();
let mut parser = ProtoParser::new();
let result = parser.parse_directory(temp_dir.path().to_str().unwrap()).await;
assert!(result.is_ok());
assert!(parser.services().is_empty());
}
#[tokio::test]
async fn test_parse_proto_file() {
let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
let proto_path = format!("{}/gretter.proto", proto_dir);
let mut parser = ProtoParser::new();
parser.parse_proto_file(&proto_path).await.unwrap();
let services = parser.services();
assert_eq!(services.len(), 1);
let service_name = "mockforge.greeter.Greeter";
assert!(services.contains_key(service_name));
let service = &services[service_name];
assert_eq!(service.name, service_name);
assert_eq!(service.methods.len(), 4);
let say_hello = service.methods.iter().find(|m| m.name == "SayHello").unwrap();
assert_eq!(say_hello.input_type, "mockforge.greeter.HelloRequest");
assert_eq!(say_hello.output_type, "mockforge.greeter.HelloReply");
assert!(!say_hello.client_streaming);
assert!(!say_hello.server_streaming);
let say_hello_stream = service.methods.iter().find(|m| m.name == "SayHelloStream").unwrap();
assert!(!say_hello_stream.client_streaming);
assert!(say_hello_stream.server_streaming);
}
#[tokio::test]
async fn test_parse_directory() {
let proto_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap() + "/proto";
let mut parser = ProtoParser::new();
parser.parse_directory(&proto_dir).await.unwrap();
let services = parser.services();
assert_eq!(services.len(), 1);
let service_name = "mockforge.greeter.Greeter";
assert!(services.contains_key(service_name));
let service = &services[service_name];
assert_eq!(service.methods.len(), 4);
let method_names: Vec<&str> = service.methods.iter().map(|m| m.name.as_str()).collect();
assert!(method_names.contains(&"SayHello"));
assert!(method_names.contains(&"SayHelloStream"));
assert!(method_names.contains(&"SayHelloClientStream"));
assert!(method_names.contains(&"Chat"));
}
}