use crate::ServiceDef;
use anyhow::Result;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
pub fn generate_server_trait(service: &ServiceDef) -> Result<TokenStream> {
let trait_name = format_ident!("{}", service.service_name);
let package = &service.package;
let methods: Vec<TokenStream> = service
.methods
.iter()
.map(|method| {
let method_name = format_ident!("{}", method.method_name_snake());
let input_type = method.input_type_path(package);
let output_type = method.output_type_path(package);
let comment = method.comment.as_ref().map(|c| {
let doc = format!(" {}", c);
quote! { #[doc = #doc] }
});
quote! {
#comment
async fn #method_name(
&self,
state: &synapse_sdk::RequestState<()>,
request: #input_type
) -> Result<#output_type, Self::Error>;
}
})
.collect();
Ok(quote! {
#[async_trait::async_trait]
pub trait #trait_name: Send + Sync + 'static {
type Error: Into<synapse_rpc::ServiceError> + Send;
#(#methods)*
}
})
}
pub fn generate_router_impl(service: &ServiceDef) -> Result<TokenStream> {
let trait_name = format_ident!("{}", service.service_name);
let router_name = format_ident!("{}Router", service.service_name);
let package = &service.package;
let interface_id = service.interface_id_expr();
let method_handlers: Vec<TokenStream> = service
.methods
.iter()
.map(|method| {
let method_name_snake = format_ident!("{}", method.method_name_snake());
let method_id = method.method_id_expr();
let input_type = method.input_type_path(package);
let _output_type = method.output_type_path(package);
quote! {
{
let service = Arc::clone(&service);
let handler = synapse_rpc::FunctionHandler::new(move |req: synapse_proto::RpcRequest| {
let service = Arc::clone(&service);
Box::pin(async move {
let state = synapse_sdk::RequestState::new(&req);
let request: #input_type = match prost::Message::decode(&req.payload[..]) {
Ok(r) => r,
Err(e) => {
return synapse_rpc::error_response(
synapse_proto::RpcStatus::InvalidRequest,
400,
format!("Invalid request: {}", e),
);
}
};
match service.#method_name_snake(&state, request).await {
Ok(response) => {
let payload = prost::Message::encode_to_vec(&response);
synapse_rpc::ok_response(bytes::Bytes::from(payload))
}
Err(err) => {
let service_err: synapse_rpc::ServiceError = err.into();
synapse_rpc::error_response(service_err.status, service_err.code, &service_err.message)
}
}
})
});
router.method(#method_id, std::sync::Arc::new(handler))
}
}
})
.collect();
let method_ids: Vec<TokenStream> = service
.methods
.iter()
.map(|method| {
let method_id = method.method_id_expr();
quote! { #method_id }
})
.collect();
let method_names: Vec<&str> = service
.methods
.iter()
.map(|method| method.name.as_str())
.collect();
Ok(quote! {
pub struct #router_name<T: #trait_name> {
_phantom: std::marker::PhantomData<T>,
}
impl<T: #trait_name> #router_name<T> {
pub fn create(service: T) -> (synapse_rpc::InterfaceRegistration, std::sync::Arc<dyn synapse_rpc::RpcHandler>) {
use std::sync::Arc;
let service = Arc::new(service);
let mut router = synapse_rpc::MethodRouter::new();
#(
router = #method_handlers;
)*
let router = router.build();
let registration = synapse_rpc::InterfaceRegistration {
interface_id: #interface_id,
interface_version: 1_000_000, method_ids: [#(#method_ids),*].into_iter().collect(),
method_names: vec![#(#method_names.to_string()),*],
instance_id: synapse_primitives::InstanceId::new_random(),
service_name: stringify!(#trait_name).to_string(),
interface_name: concat!(#package, ".", stringify!(#trait_name)).to_string(),
};
(registration, router)
}
}
})
}