synapse-codegen 0.0.2

Code generation from protobuf service definitions for Synapse
Documentation
//! Server trait generation

use crate::ServiceDef;
use anyhow::Result;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};

/// Generate server trait for a service
///
/// Example output:
/// ```text
/// #[async_trait::async_trait]
/// pub trait UserService: Send + Sync {
///     type Error: Into<synapse_rpc::ServiceError> + Send;
///     async fn get_user(&self, request: GetUserRequest) -> Result<GetUserResponse, Self::Error>;
///     async fn create_user(&self, request: CreateUserRequest) -> Result<CreateUserResponse, Self::Error>;
/// }
/// ```
pub fn generate_server_trait(service: &ServiceDef) -> Result<TokenStream> {
    let trait_name = format_ident!("{}", service.service_name);
    let package = &service.package;

    let methods: Vec<TokenStream> = service
        .methods
        .iter()
        .map(|method| {
            let method_name = format_ident!("{}", method.method_name_snake());
            let input_type = method.input_type_path(package);
            let output_type = method.output_type_path(package);

            let comment = method.comment.as_ref().map(|c| {
                let doc = format!(" {}", c);
                quote! { #[doc = #doc] }
            });

            quote! {
                #comment
                async fn #method_name(
                    &self,
                    state: &synapse_sdk::RequestState<()>,
                    request: #input_type
                ) -> Result<#output_type, Self::Error>;
            }
        })
        .collect();

    Ok(quote! {
        /// Generated server trait for #trait_name
        #[async_trait::async_trait]
        pub trait #trait_name: Send + Sync + 'static {
            /// Error type for this service. Must be convertible to ServiceError.
            type Error: Into<synapse_rpc::ServiceError> + Send;

            #(#methods)*
        }
    })
}

/// Generate router implementation that converts trait impl to RpcHandler
///
/// Example output:
/// ```text
/// pub struct UserServiceRouter<T: UserService> {
///     service: Arc<T>,
/// }
///
/// impl<T: UserService> UserServiceRouter<T> {
///     pub fn new(service: T) -> Arc<dyn synapse_rpc::RpcHandler> {
///         // Creates MethodRouter with all methods
///     }
/// }
/// ```
pub fn generate_router_impl(service: &ServiceDef) -> Result<TokenStream> {
    let trait_name = format_ident!("{}", service.service_name);
    let router_name = format_ident!("{}Router", service.service_name);
    let package = &service.package;

    let interface_id = service.interface_id_expr();

    // Generate handler for each method
    let method_handlers: Vec<TokenStream> = service
        .methods
        .iter()
        .map(|method| {
            let method_name_snake = format_ident!("{}", method.method_name_snake());
            let method_id = method.method_id_expr();
            let input_type = method.input_type_path(package);
            let _output_type = method.output_type_path(package);

            quote! {
                {
                    let service = Arc::clone(&service);
                    let handler = synapse_rpc::FunctionHandler::new(move |req: synapse_proto::RpcRequest| {
                        let service = Arc::clone(&service);
                        Box::pin(async move {
                            // Create request state
                            let state = synapse_sdk::RequestState::new(&req);

                            // Deserialize request using prost
                            let request: #input_type = match prost::Message::decode(&req.payload[..]) {
                                Ok(r) => r,
                                Err(e) => {
                                    return synapse_rpc::error_response(
                                        synapse_proto::RpcStatus::InvalidRequest,
                                        400,
                                        format!("Invalid request: {}", e),
                                    );
                                }
                            };

                            // Call service method with state
                            match service.#method_name_snake(&state, request).await {
                                Ok(response) => {
                                    // Serialize response using prost
                                    let payload = prost::Message::encode_to_vec(&response);
                                    synapse_rpc::ok_response(bytes::Bytes::from(payload))
                                }
                                Err(err) => {
                                    // Convert custom error to ServiceError
                                    let service_err: synapse_rpc::ServiceError = err.into();
                                    synapse_rpc::error_response(service_err.status, service_err.code, &service_err.message)
                                }
                            }
                        })
                    });
                    router.method(#method_id, std::sync::Arc::new(handler))
                }
            }
        })
        .collect();

    let method_ids: Vec<TokenStream> = service
        .methods
        .iter()
        .map(|method| {
            let method_id = method.method_id_expr();
            quote! { #method_id }
        })
        .collect();

    let method_names: Vec<&str> = service
        .methods
        .iter()
        .map(|method| method.name.as_str())
        .collect();

    Ok(quote! {
        /// Generated router for #trait_name
        pub struct #router_name<T: #trait_name> {
            _phantom: std::marker::PhantomData<T>,
        }

        impl<T: #trait_name> #router_name<T> {
            /// Create a router from a service implementation
            pub fn create(service: T) -> (synapse_rpc::InterfaceRegistration, std::sync::Arc<dyn synapse_rpc::RpcHandler>) {
                use std::sync::Arc;
                let service = Arc::new(service);

                // Create method router
                let mut router = synapse_rpc::MethodRouter::new();

                // Add method handlers
                #(
                    router = #method_handlers;
                )*

                let router = router.build();

                // Create registration
                let registration = synapse_rpc::InterfaceRegistration {
                    interface_id: #interface_id,
                    interface_version: 1_000_000, // v1.0.0
                    method_ids: [#(#method_ids),*].into_iter().collect(),
                    method_names: vec![#(#method_names.to_string()),*],
                    instance_id: synapse_primitives::InstanceId::new_random(),
                    service_name: stringify!(#trait_name).to_string(),
                    interface_name: concat!(#package, ".", stringify!(#trait_name)).to_string(),
                };

                (registration, router)
            }
        }
    })
}