use actr_protocol::ActrType;
use anyhow::{Result, anyhow};
use heck::ToSnakeCase;
use prost_types::MethodDescriptorProto;
use quote::{format_ident, quote};
use crate::payload_type_extractor::extract_payload_type_or_default;
#[derive(Debug, Clone)]
pub struct RemoteServiceInfo {
pub package_name: String,
pub service_name: String,
pub methods: Vec<String>,
pub actr_type: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GeneratorRole {
ServerSide,
ClientSide,
}
pub struct ModernGenerator {
package_name: String,
service_name: String,
role: GeneratorRole,
}
impl ModernGenerator {
pub fn new(package_name: &str, service_name: &str, role: GeneratorRole) -> Self {
Self {
package_name: package_name.to_string(),
service_name: service_name.to_string(),
role,
}
}
pub fn generate(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
match self.role {
GeneratorRole::ServerSide => self.generate_server_code(methods, &[]),
GeneratorRole::ClientSide => self.generate_client_code(methods),
}
}
pub fn generate_with_remotes(
&self,
methods: &[MethodDescriptorProto],
remote_services: &[RemoteServiceInfo],
) -> Result<String> {
match self.role {
GeneratorRole::ServerSide => self.generate_server_code(methods, remote_services),
GeneratorRole::ClientSide => self.generate_client_code(methods),
}
}
fn generate_server_code(
&self,
methods: &[MethodDescriptorProto],
remote_services: &[RemoteServiceInfo],
) -> Result<String> {
let sections = [
self.generate_imports(),
self.generate_message_impls(methods)?,
self.generate_handler_trait(methods)?,
self.generate_router_impl(methods, remote_services)?,
self.generate_workload_blanket_impl(methods)?,
self.generate_usage_docs(methods)?,
];
Ok(sections.join("\n\n"))
}
fn generate_client_code(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
let sections = [
self.generate_imports(),
self.generate_message_impls(methods)?,
self.generate_context_extensions(methods)?,
self.generate_client_usage_docs(methods)?,
];
Ok(sections.join("\n\n"))
}
fn generate_imports(&self) -> String {
let proto_module = self.package_name.replace('.', "_");
format!(
r#"// Auto-generated code - DO NOT EDIT
// Generated by actr-cli's protoc-gen-actrframework plugin
#[allow(dead_code, unused_imports)]
use async_trait::async_trait;
use bytes::Bytes;
use prost::Message as ProstMessage;
use actr_framework::{{Context, Dest, MessageDispatcher, Workload}};
use actr_protocol::{{ActrId, ActorResult, RpcRequest, RpcEnvelope, PayloadType}};
// Import protobuf message types (generated by prost)
use super::{proto_module}::*;
"#
)
}
fn generate_message_impls(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
let mut impls = Vec::new();
for method in methods {
let input_type = self.extract_message_type(method.input_type())?;
let output_type = self.extract_message_type(method.output_type())?;
let route_key = format!(
"{}.{}.{}",
self.package_name,
self.service_name,
method.name()
);
let payload_type = extract_payload_type_or_default(method);
let payload_type_code = payload_type.as_rust_variant();
let impl_code = format!(
r#"/// RpcRequest trait implementation - associates Request and Response types
///
/// This enables type-safe RPC calls with automatic response type inference:
/// ```rust,ignore
/// let response: {output_type} = ctx.call(&target, request).await?;
/// ```
impl RpcRequest for {input_type} {{
type Response = {output_type};
fn route_key() -> &'static str {{
"{route_key}"
}}
fn payload_type() -> PayloadType {{
{payload_type_code}
}}
}}"#
);
impls.push(impl_code);
}
Ok(impls.join("\n\n"))
}
fn generate_handler_trait(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
let handler_trait_name = format!("{}Handler", self.service_name);
let handler_trait_ident = format_ident!("{}", handler_trait_name);
let mut method_sigs = Vec::new();
for method in methods {
let method_name = method.name().to_snake_case();
let method_ident = format_ident!("{}", method_name);
let input_type = self.extract_message_type(method.input_type())?;
let output_type = self.extract_message_type(method.output_type())?;
let input_ident = format_ident!("{}", input_type);
let output_ident = format_ident!("{}", output_type);
method_sigs.push(quote! {
async fn #method_ident<C: Context>(
&self,
req: #input_ident,
ctx: &C,
) -> ActorResult<#output_ident>;
});
}
let handler_trait_without_attr = quote! {
pub trait #handler_trait_ident: actr_framework::MaybeSendSync + 'static {
#(#method_sigs)*
}
};
Ok(format!(
"#[cfg_attr(not(target_arch = \"wasm32\"), async_trait)]\n\
#[cfg_attr(target_arch = \"wasm32\", async_trait(?Send))]\n\
{handler_trait_without_attr}"
))
}
fn generate_router_impl(
&self,
methods: &[MethodDescriptorProto],
remote_services: &[RemoteServiceInfo],
) -> Result<String> {
let router_name = format!("{}Dispatcher", self.service_name);
let router_ident = format_ident!("{}", router_name);
let workload_name = format!("{}Workload", self.service_name);
let workload_ident = format_ident!("{}", workload_name);
let handler_trait = format!("{}Handler", self.service_name);
let handler_trait_ident = format_ident!("{}", handler_trait);
let mut match_arms = Vec::new();
for method in methods {
let route_key = format!(
"{}.{}.{}",
self.package_name,
self.service_name,
method.name()
);
let method_name = method.name().to_snake_case();
let method_ident = format_ident!("{}", method_name);
let input_type = self.extract_message_type(method.input_type())?;
let input_ident = format_ident!("{}", input_type);
match_arms.push(quote! {
#route_key => {
let payload = envelope.payload.as_ref()
.ok_or_else(|| actr_protocol::ActrError::DecodeFailure(
"Missing payload in RpcEnvelope".to_string()
))?;
let req = #input_ident::decode(&**payload)
.map_err(|e| actr_protocol::ActrError::DecodeFailure(
format!("Failed to decode {}: {}", stringify!(#input_ident), e)
))?;
let resp = workload.0.#method_ident(req, ctx).await?;
Ok(resp.encode_to_vec().into())
}
});
}
use std::collections::HashMap;
let mut services_by_actr_type: HashMap<String, Vec<&RemoteServiceInfo>> = HashMap::new();
for remote_service in remote_services {
services_by_actr_type
.entry(remote_service.actr_type.clone())
.or_default()
.push(remote_service);
}
for (actr_type_str, services) in services_by_actr_type {
let parsed = ActrType::from_string_repr(&actr_type_str).map_err(|e| {
anyhow!(
"Invalid remote actr_type '{}': expected <manufacturer>:<name>[:<version>] ({})",
actr_type_str,
e
)
})?;
let manufacturer = parsed.manufacturer;
let name = parsed.name;
let mut route_keys = Vec::new();
for service in &services {
for method in &service.methods {
let route_key = format!(
"{}.{}.{}",
service.package_name, service.service_name, method
);
route_keys.push(route_key);
}
}
match_arms.push(quote! {
#(#route_keys)|* => {
let target_type = actr_protocol::ActrType {
manufacturer: #manufacturer.to_string(),
name: #name.to_string(),
version: "1.0.0".to_string(),
};
let target_id = ctx.discover_route_candidate(&target_type).await?;
ctx.call_raw(
&target_id,
envelope.route_key.as_str(),
envelope.payload.clone().unwrap_or_default(),
).await
}
});
}
let workload_struct = quote! {
pub struct #workload_ident<T: #handler_trait_ident>(pub T);
impl<T: #handler_trait_ident> #workload_ident<T> {
pub fn new(handler: T) -> Self {
Self(handler)
}
}
};
let router_struct = quote! {
pub struct #router_ident<T: #handler_trait_ident>(std::marker::PhantomData<T>);
};
let router_impl_without_attr = quote! {
impl<T: #handler_trait_ident> MessageDispatcher for #router_ident<T> {
type Workload = #workload_ident<T>;
async fn dispatch<C: Context>(
workload: &Self::Workload,
envelope: RpcEnvelope,
ctx: &C,
) -> ActorResult<Bytes> {
match envelope.route_key.as_str() {
#(#match_arms,)*
_ => Err(actr_protocol::ActrError::UnknownRoute(
envelope.route_key.to_string()
))
}
}
}
};
let router_impl = format!(
"#[cfg_attr(not(target_arch = \"wasm32\"), async_trait)]\n\
#[cfg_attr(target_arch = \"wasm32\", async_trait(?Send))]\n\
{router_impl_without_attr}"
);
Ok(format!("{workload_struct}\n{router_struct}\n{router_impl}"))
}
fn generate_workload_blanket_impl(&self, _methods: &[MethodDescriptorProto]) -> Result<String> {
let router_name = format!("{}Dispatcher", self.service_name);
let router_ident = format_ident!("{}", router_name);
let workload_name = format!("{}Workload", self.service_name);
let workload_ident = format_ident!("{}", workload_name);
let handler_trait = format!("{}Handler", self.service_name);
let handler_trait_ident = format_ident!("{}", handler_trait);
Ok(quote! {
impl<T: #handler_trait_ident> Workload for #workload_ident<T> {
type Dispatcher = #router_ident<T>;
}
impl<T: #handler_trait_ident> actr_framework::ServiceHandler for #workload_ident<T> {
type Workload = Self;
}
}
.to_string())
}
fn generate_context_extensions(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
let client_struct_name = format!("{}Client", self.service_name);
let client_ident = format_ident!("{}", client_struct_name);
let mut client_methods = Vec::new();
for method in methods {
let method_name = method.name().to_snake_case();
let method_ident = format_ident!("{}", method_name);
let input_type = self.extract_message_type(method.input_type())?;
let output_type = self.extract_message_type(method.output_type())?;
let input_ident = format_ident!("{}", input_type);
let output_ident = format_ident!("{}", output_type);
client_methods.push(quote! {
pub async fn #method_ident(
&self,
target: ActrId,
req: #input_ident,
) -> ActorResult<#output_ident> {
self.ctx.call(&Dest::from(target), req).await
}
});
}
let extension_method_name = self.service_name.to_snake_case();
let extension_method_ident = format_ident!("{}", extension_method_name);
Ok(quote! {
pub struct #client_ident<'a, C: Context> {
ctx: &'a C,
}
impl<'a, C: Context> #client_ident<'a, C> {
#(#client_methods)*
}
pub trait ContextExt {
fn #extension_method_ident(&self) -> #client_ident<'_, Self> where Self: Sized + Context;
}
impl<T: Context> ContextExt for T {
fn #extension_method_ident(&self) -> #client_ident<'_, Self> {
#client_ident { ctx: self }
}
}
}
.to_string())
}
fn generate_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
let handler_trait = format!("{}Handler", self.service_name);
let first_method = methods.first();
let example_method = if let Some(method) = first_method {
let method_name = method.name().to_snake_case();
let input_type = self.extract_message_type(method.input_type())?;
let output_type = self.extract_message_type(method.output_type())?;
format!(
r#"
async fn {method_name}(&self, req: {input_type}, ctx: &Context) -> ActorResult<{output_type}> {{
// Implement business logic
Ok({output_type}::default())
}}"#
)
} else {
" // Implement methods...".to_string()
};
Ok(format!(
r#"/*
## Usage Example
### 1. Implement Business Logic
```rust
use actr_framework::Context;
use actr_protocol::ActorResult;
use async_trait::async_trait;
pub struct MyService {{
// Business state
}}
#[async_trait]
impl {handler_trait} for MyService {{
{example_method}
}}
```
### 2. Register the Entry Point
```rust
actr_framework::entry!({}Workload<MyService>);
```
## Architecture
- **{handler_trait}**: user-implemented business logic interface
- **{}Dispatcher**: zero-sized type static dispatcher (auto-generated)
- **{}Workload<T>**: generated wrapper that satisfies orphan rules
Users only need to implement {handler_trait}; the framework auto-provides routing and workload capabilities.
*/
"#,
self.service_name, self.service_name, self.service_name
))
}
fn generate_client_usage_docs(&self, methods: &[MethodDescriptorProto]) -> Result<String> {
let service_name_snake = self.service_name.to_snake_case();
let method_name_snake = methods
.first()
.map(|m| m.name().to_snake_case())
.unwrap_or("unknown_method".to_string());
Ok(format!(
r#"/*
## Dependency Usage Example
```rust
use actr_framework::Context;
use actr_protocol::{{ActorResult, ActrId}};
async fn call_remote_service(ctx: &impl Context, target: ActrId) -> ActorResult<()> {{
use super::ContextExt;
// Type-safe remote call
let response = ctx.{service_name_snake}()
.{method_name_snake}(target, request)
.await?;
Ok(())
}}
```
## Compile-time Routing
All remote calls determine the target service and method at compile time — no runtime lookup needed.
*/
"#
))
}
fn extract_message_type(&self, type_name: &str) -> Result<String> {
let cleaned = type_name.trim_start_matches('.');
if let Some(last_part) = cleaned.split('.').next_back() {
Ok(last_part.to_string())
} else {
Ok(cleaned.to_string())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use prost_types::MethodDescriptorProto;
#[test]
fn test_extract_message_type() {
let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
assert_eq!(
generator
.extract_message_type(".test.v1.EchoRequest")
.unwrap(),
"EchoRequest"
);
assert_eq!(
generator
.extract_message_type("test.v1.EchoResponse")
.unwrap(),
"EchoResponse"
);
assert_eq!(
generator.extract_message_type("SimpleMessage").unwrap(),
"SimpleMessage"
);
}
#[test]
fn test_generate_message_impls_includes_payload_type() {
let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
let methods = vec![MethodDescriptorProto {
name: Some("Echo".to_string()),
input_type: Some(".test.v1.EchoRequest".to_string()),
output_type: Some(".test.v1.EchoResponse".to_string()),
options: None,
..Default::default()
}];
let result = generator.generate_message_impls(&methods).unwrap();
eprintln!("Generated code:\n{result}");
assert!(
result.contains("fn payload_type"),
"Should contain 'fn payload_type'"
);
assert!(
result.contains("PayloadType"),
"Should contain 'PayloadType'"
);
assert!(
result.contains("RpcReliable"),
"Should contain 'RpcReliable'"
);
}
#[test]
fn test_generate_imports_includes_payload_type() {
let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
let imports = generator.generate_imports();
assert!(imports.contains("PayloadType"));
assert!(imports.contains(
"use actr_protocol::{ActrId, ActorResult, RpcRequest, RpcEnvelope, PayloadType}"
));
assert!(
imports.contains("use actr_framework::{Context, Dest, MessageDispatcher, Workload}")
);
}
#[test]
fn test_generate_client_code() {
let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ClientSide);
let methods = vec![MethodDescriptorProto {
name: Some("Echo".to_string()),
input_type: Some(".test.v1.EchoRequest".to_string()),
output_type: Some(".test.v1.EchoResponse".to_string()),
options: None,
..Default::default()
}];
let result = generator.generate(&methods);
assert!(result.is_ok());
let code = result.unwrap();
assert!(code.contains("impl RpcRequest for EchoRequest"));
assert!(code.contains("fn payload_type() -> PayloadType"));
assert!(code.contains("use actr_framework::{Context, Dest, MessageDispatcher, Workload}"));
assert!(code.contains(
"use actr_protocol::{ActrId, ActorResult, RpcRequest, RpcEnvelope, PayloadType}"
));
}
#[test]
fn test_generate_server_code() {
let generator = ModernGenerator::new("test.v1", "TestService", GeneratorRole::ServerSide);
let methods = vec![MethodDescriptorProto {
name: Some("Echo".to_string()),
input_type: Some(".test.v1.EchoRequest".to_string()),
output_type: Some(".test.v1.EchoResponse".to_string()),
options: None,
..Default::default()
}];
let result = generator.generate(&methods);
assert!(result.is_ok());
let code = result.unwrap();
assert!(code.contains("pub trait TestServiceHandler"));
assert!(code.contains("pub struct TestServiceDispatcher"));
assert!(code.contains("fn payload_type() -> PayloadType"));
}
#[test]
fn test_generate_server_code_with_no_local_methods() {
let generator = ModernGenerator::new("test.v1", "BridgeService", GeneratorRole::ServerSide);
let code = generator.generate(&[]).unwrap();
assert!(code.contains("pub trait BridgeServiceHandler"));
assert!(code.contains("pub struct BridgeServiceWorkload"));
assert!(code.contains("pub struct BridgeServiceDispatcher"));
assert!(code.contains("UnknownRoute"));
}
#[test]
fn test_generate_server_code_with_remote_forwarding_and_no_local_methods() {
let generator =
ModernGenerator::new("demo.app", "DemoClientApp", GeneratorRole::ServerSide);
let remote_services = vec![RemoteServiceInfo {
package_name: "echo".to_string(),
service_name: "EchoService".to_string(),
methods: vec!["Echo".to_string()],
actr_type: "acme:EchoService:1.0.0".to_string(),
}];
let code = generator
.generate_with_remotes(&[], &remote_services)
.unwrap();
assert!(code.contains("pub trait DemoClientAppHandler"));
assert!(code.contains("pub struct DemoClientAppWorkload"));
assert!(code.contains("\"echo.EchoService.Echo\""));
assert!(code.contains("manufacturer"));
assert!(code.contains("\"acme\""));
assert!(code.contains("name"));
assert!(code.contains("\"EchoService\""));
assert!(code.contains("discover_route_candidate"));
assert!(code.contains("call_raw"));
}
}