use anyhow::{Context, Result, anyhow};
use bytes::Bytes;
use heck::ToSnakeCase;
use prost::Message;
use prost_types::{
DescriptorProto, FileDescriptorProto, ServiceDescriptorProto,
compiler::{CodeGeneratorRequest, CodeGeneratorResponse, code_generator_response::File},
};
use serde::Serialize;
use std::collections::HashMap;
use std::io::{self, Read, Write};
use actr_framework_protoc_codegen::{GeneratorRole, ModernGenerator, RemoteServiceInfo};
use actr_protocol::{ActrType, PackageName, ServiceName};
#[derive(Debug, Clone, PartialEq)]
pub enum ProtoSource {
Local,
Remote,
}
impl ProtoSource {
pub fn from_proto_file(
file: &FileDescriptorProto,
params: &HashMap<String, String>,
) -> Result<Self> {
let file_name = file.name();
let file_path = std::path::Path::new(file_name);
let matches = |list_str: &str| {
list_str.split(':').filter(|p| !p.is_empty()).any(|p| {
p == file_name
|| file_path.ends_with(p)
|| std::path::Path::new(p).ends_with(file_name)
})
};
let in_remote = params.get("RemoteFiles").is_some_and(|s| matches(s));
let in_local = params.get("LocalFiles").is_some_and(|s| matches(s));
if in_remote && in_local {
return Err(anyhow!(
"{}: appears in both RemoteFiles and LocalFiles; a file must belong to exactly one side.",
file_name
));
}
if in_remote {
return Ok(Self::Remote);
}
if in_local {
return Ok(Self::Local);
}
let has_services = !file.service.is_empty();
Ok(if has_services {
Self::Local
} else {
Self::Remote
})
}
}
#[derive(Serialize)]
struct ActrGenMetadata {
plugin_version: String,
language: &'static str,
local_services: Vec<LocalServiceMetadata>,
remote_services: Vec<RemoteServiceMetadata>,
}
#[derive(Serialize)]
struct LocalServiceMetadata {
name: String,
package: String,
proto_file: String,
handler_interface: String,
workload_type: String,
dispatcher_type: String,
methods: Vec<MethodMetadata>,
}
#[derive(Serialize)]
struct RemoteServiceMetadata {
name: String,
package: String,
proto_file: String,
actr_type: String,
client_type: String,
methods: Vec<MethodMetadata>,
}
#[derive(Serialize)]
struct MethodMetadata {
name: String,
snake_name: String,
input_type: String,
output_type: String,
route_key: String,
}
fn main() -> Result<()> {
let args: Vec<String> = std::env::args().collect();
if args.len() > 1 {
match args[1].as_str() {
"--version" | "-V" => {
println!("protoc-gen-actrframework {}", env!("CARGO_PKG_VERSION"));
println!(
"actr-framework-protoc-codegen library version: {}",
env!("CARGO_PKG_VERSION")
);
return Ok(());
}
"--help" | "-h" => {
println!("protoc-gen-actrframework - Protobuf plugin for Actor-RTC framework");
println!();
println!("USAGE:");
println!(
" As protoc plugin: protoc --plugin=protoc-gen-actrframework=PATH --actrframework_out=OUT_DIR input.proto"
);
println!(" Version info: protoc-gen-actrframework --version");
println!();
println!("VERSION:");
println!(" {}", env!("CARGO_PKG_VERSION"));
return Ok(());
}
_ => {}
}
}
let mut stdin = io::stdin();
let mut buf = Vec::new();
stdin
.read_to_end(&mut buf)
.context("Failed to read from stdin")?;
let request = CodeGeneratorRequest::decode(Bytes::from(buf))
.context("Failed to decode CodeGeneratorRequest")?;
let response = generate_code(request)?;
let mut out_buf = Vec::new();
response
.encode(&mut out_buf)
.context("Failed to encode CodeGeneratorResponse")?;
io::stdout()
.write_all(&out_buf)
.context("Failed to write to stdout")?;
Ok(())
}
fn parse_parameters(param_str: &str) -> HashMap<String, String> {
let mut params = HashMap::new();
for pair in param_str.split(',') {
if let Some((key, value)) = pair.split_once('=') {
params.insert(key.trim().to_string(), value.trim().to_string());
}
}
params
}
fn generate_code(request: CodeGeneratorRequest) -> Result<CodeGeneratorResponse> {
let mut response = CodeGeneratorResponse {
supported_features: Some(1u64),
..Default::default()
};
let params = parse_parameters(request.parameter.as_deref().unwrap_or(""));
let mut remote_file_to_actr_type: HashMap<String, String> = HashMap::new();
if let Some(remote_file_actr_types) = params.get("RemoteFileActrTypes") {
for mapping in remote_file_actr_types.split(';') {
if let Some((file, actr_type)) = mapping.split_once('=') {
remote_file_to_actr_type
.insert(file.trim().to_string(), actr_type.trim().to_string());
}
}
}
let mut message_types = HashMap::new();
for file in &request.proto_file {
collect_message_types(file, &mut message_types, file.package());
}
let mut remote_services = Vec::new();
let mut metadata = ActrGenMetadata {
plugin_version: env!("CARGO_PKG_VERSION").to_string(),
language: "rust",
local_services: Vec::new(),
remote_services: Vec::new(),
};
for file in &request.proto_file {
let proto_source = ProtoSource::from_proto_file(file, ¶ms)?;
for service in &file.service {
let package_name = file.package().to_string();
let service_name = service.name().to_string();
let actr_type = remote_file_to_actr_type
.get(file.name())
.cloned()
.unwrap_or_else(|| {
let manufacturer = params
.get("manufacturer")
.map(|s| s.as_str())
.unwrap_or(&package_name);
ActrType {
manufacturer: manufacturer.to_string(),
name: service_name.clone(),
version: "1.0.0".to_string(),
}
.to_string_repr()
});
if proto_source == ProtoSource::Remote {
let methods: Vec<String> = service
.method
.iter()
.map(|m| m.name().to_string())
.collect();
remote_services.push(RemoteServiceInfo {
package_name: package_name.clone(),
service_name: service_name.clone(),
methods,
actr_type: actr_type.clone(),
});
metadata
.remote_services
.push(build_remote_service_metadata(file, service, actr_type));
} else {
metadata
.local_services
.push(build_local_service_metadata(file, service));
}
}
}
for file_name in &request.file_to_generate {
if let Some(file) = request.proto_file.iter().find(|f| f.name() == file_name) {
if file.service.len() > 1 {
return Err(anyhow!(
"{}: defines {} services, but only one service per .proto file is supported. \
Split each service into its own .proto file.",
file_name,
file.service.len()
));
}
for service in &file.service {
let generated_file = generate_service_code(
file,
service,
&message_types,
¶ms,
&remote_services,
)?;
response.file.push(generated_file);
}
}
}
response.file.push(File {
name: Some("actr-gen-meta.json".to_string()),
content: Some(serde_json::to_string_pretty(&metadata)?),
..Default::default()
});
Ok(response)
}
fn collect_message_types(
file: &FileDescriptorProto,
types: &mut HashMap<String, DescriptorProto>,
package_prefix: &str,
) {
for message in &file.message_type {
let full_name = if package_prefix.is_empty() {
message.name().to_string()
} else {
format!("{}.{}", package_prefix, message.name())
};
types.insert(full_name.clone(), message.clone());
for nested in &message.nested_type {
let nested_name = format!("{}.{}", full_name, nested.name());
types.insert(nested_name, nested.clone());
}
}
}
fn generate_service_code(
file: &FileDescriptorProto,
service: &ServiceDescriptorProto,
_message_types: &HashMap<String, DescriptorProto>,
params: &HashMap<String, String>,
remote_services: &[RemoteServiceInfo],
) -> Result<File> {
let service_name = service.name();
let package_name = file.package();
PackageName::new(package_name.to_string())
.map_err(|e| anyhow!("Invalid proto package name '{}': {}", package_name, e))?;
ServiceName::new(service_name.to_string())
.map_err(|e| anyhow!("Invalid proto service name '{}': {}", service_name, e))?;
let proto_source = ProtoSource::from_proto_file(file, params)?;
let file_stem = std::path::Path::new(file.name())
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or(service_name)
.to_snake_case();
let role = match proto_source {
ProtoSource::Local => GeneratorRole::ServerSide,
ProtoSource::Remote => GeneratorRole::ClientSide,
};
let generator = ModernGenerator::new(package_name, service_name, role);
let final_code = if role == GeneratorRole::ServerSide {
generator.generate_with_remotes(&service.method, remote_services)?
} else {
generator.generate(&service.method)?
};
let file_suffix = match role {
GeneratorRole::ServerSide => "_actor",
GeneratorRole::ClientSide => "_client",
};
Ok(File {
name: Some(format!("{}{}.rs", file_stem, file_suffix)),
content: Some(final_code),
insertion_point: None,
generated_code_info: None,
})
}
fn build_local_service_metadata(
file: &FileDescriptorProto,
service: &ServiceDescriptorProto,
) -> LocalServiceMetadata {
LocalServiceMetadata {
name: service.name().to_string(),
package: file.package().to_string(),
proto_file: file.name().to_string(),
handler_interface: format!("{}Handler", service.name()),
workload_type: format!("{}Workload", service.name()),
dispatcher_type: format!("{}Dispatcher", service.name()),
methods: service
.method
.iter()
.map(|method| build_method_metadata(file, service, method))
.collect(),
}
}
fn build_remote_service_metadata(
file: &FileDescriptorProto,
service: &ServiceDescriptorProto,
actr_type: String,
) -> RemoteServiceMetadata {
RemoteServiceMetadata {
name: service.name().to_string(),
package: file.package().to_string(),
proto_file: file.name().to_string(),
actr_type,
client_type: format!("{}Client", service.name()),
methods: service
.method
.iter()
.map(|method| build_method_metadata(file, service, method))
.collect(),
}
}
fn build_method_metadata(
file: &FileDescriptorProto,
service: &ServiceDescriptorProto,
method: &prost_types::MethodDescriptorProto,
) -> MethodMetadata {
let package = file.package();
let route_key = if package.is_empty() {
format!("{}.{}", service.name(), method.name())
} else {
format!("{}.{}.{}", package, service.name(), method.name())
};
MethodMetadata {
name: method.name().to_string(),
snake_name: method.name().to_snake_case(),
input_type: short_type_name(method.input_type()),
output_type: short_type_name(method.output_type()),
route_key,
}
}
fn short_type_name(raw: &str) -> String {
raw.trim_start_matches('.')
.split('.')
.next_back()
.unwrap_or(raw)
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use prost_types::MethodDescriptorProto;
#[test]
fn generate_code_uses_remote_file_actr_types_for_local_bridge() {
let request = CodeGeneratorRequest {
file_to_generate: vec!["local.proto".to_string(), "remote/echo.proto".to_string()],
parameter: Some(
"manufacturer=acme,LocalFiles=local.proto,RemoteFiles=remote/echo.proto,RemoteFileActrTypes=remote/echo.proto=custom:EchoAlias:1.0.0"
.to_string(),
),
proto_file: vec![
FileDescriptorProto {
name: Some("local.proto".to_string()),
package: Some("demo".to_string()),
service: vec![ServiceDescriptorProto {
name: Some("DemoClientApp".to_string()),
method: vec![],
..Default::default()
}],
..Default::default()
},
FileDescriptorProto {
name: Some("remote/echo.proto".to_string()),
package: Some("echo".to_string()),
service: vec![ServiceDescriptorProto {
name: Some("EchoService".to_string()),
method: vec![MethodDescriptorProto {
name: Some("Echo".to_string()),
input_type: Some(".echo.EchoRequest".to_string()),
output_type: Some(".echo.EchoResponse".to_string()),
..Default::default()
}],
..Default::default()
}],
..Default::default()
},
],
..Default::default()
};
let response = generate_code(request).unwrap();
let local_actor = response
.file
.iter()
.find(|file| file.name.as_deref() == Some("local_actor.rs"))
.and_then(|file| file.content.as_ref())
.expect("local_actor.rs should be generated");
assert!(local_actor.contains("pub trait DemoClientAppHandler"));
assert!(local_actor.contains("\"echo.EchoService.Echo\""));
assert!(local_actor.contains("manufacturer"));
assert!(local_actor.contains("\"custom\""));
assert!(local_actor.contains("name"));
assert!(local_actor.contains("\"EchoAlias\""));
}
}