use anyhow::Result;
use heck::ToSnakeCase;
use heck::ToUpperCamelCase;
use proc_macro2::TokenStream;
use quote::format_ident;
use quote::quote;
use buffa_codegen::generated::descriptor::FileDescriptorProto;
use buffa_codegen::generated::descriptor::MethodDescriptorProto;
use buffa_codegen::generated::descriptor::ServiceDescriptorProto;
use buffa_codegen::generated::descriptor::SourceCodeInfo;
use buffa_codegen::generated::descriptor::method_options::IdempotencyLevel;
pub use buffa_codegen::GeneratedFile;
pub use buffa_codegen::generated::descriptor;
use crate::plugin::CodeGeneratorRequest;
use crate::plugin::CodeGeneratorResponse;
use crate::plugin::CodeGeneratorResponseFile;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct Options {
pub strict_utf8_mapping: bool,
pub generate_json: bool,
pub extern_paths: Vec<(String, String)>,
}
impl Default for Options {
fn default() -> Self {
Self {
strict_utf8_mapping: false,
generate_json: true,
extern_paths: Vec::new(),
}
}
}
impl Options {
fn to_buffa_config(&self) -> buffa_codegen::CodeGenConfig {
let mut config = buffa_codegen::CodeGenConfig::default();
config.generate_views = true;
config.generate_json = self.generate_json;
config.strict_utf8_mapping = self.strict_utf8_mapping;
config.extern_paths.clone_from(&self.extern_paths);
config
}
}
fn emit_service_files(
proto_file: &[FileDescriptorProto],
file_to_generate: &[String],
resolver: &TypeResolver<'_>,
) -> Result<Vec<GeneratedFile>> {
let mut out = Vec::new();
for file_name in file_to_generate {
let file_desc = proto_file
.iter()
.find(|f| f.name.as_deref() == Some(file_name.as_str()));
if let Some(file) = file_desc
&& !file.service.is_empty()
{
let service_tokens = generate_connect_services(file, resolver)?;
let service_code = format_token_stream(&service_tokens)?;
out.push(GeneratedFile {
name: buffa_codegen::proto_path_to_rust_module(file_name),
content: service_code,
});
}
}
Ok(out)
}
pub fn generate_files(
proto_file: &[FileDescriptorProto],
file_to_generate: &[String],
options: &Options,
) -> Result<Vec<GeneratedFile>> {
let config = options.to_buffa_config();
let mut files = buffa_codegen::generate(proto_file, file_to_generate, &config)
.map_err(|e| anyhow::anyhow!("buffa-codegen failed: {e}"))?;
let resolver = TypeResolver::new(proto_file, file_to_generate, &config, false);
let service_files = emit_service_files(proto_file, file_to_generate, &resolver)?;
for svc in service_files {
if let Some(out) = files.iter_mut().find(|g| g.name == svc.name) {
out.content.push('\n');
out.content.push_str(&svc.content);
}
}
Ok(files)
}
pub fn generate_services(
proto_file: &[FileDescriptorProto],
file_to_generate: &[String],
options: &Options,
) -> Result<Vec<GeneratedFile>> {
let config = options.to_buffa_config();
let resolver = TypeResolver::new(proto_file, file_to_generate, &config, true);
emit_service_files(proto_file, file_to_generate, &resolver)
}
pub fn generate(request: &CodeGeneratorRequest) -> Result<CodeGeneratorResponse> {
let mut options = Options::default();
if let Some(ref param) = request.parameter {
for opt in param.split(',').map(str::trim).filter(|s| !s.is_empty()) {
if let Some(value) = opt.strip_prefix("buffa_module=") {
let rust = value.trim();
if rust.is_empty() {
anyhow::bail!(
"buffa_module requires a non-empty path, \
e.g. buffa_module=crate::proto"
);
}
options.extern_paths.push((".".into(), rust.to_string()));
} else if let Some(value) = opt.strip_prefix("extern_path=") {
let (proto, rust) = value.split_once('=').ok_or_else(|| {
anyhow::anyhow!(
"invalid extern_path format {value:?}, expected \
extern_path=.proto.pkg=::rust::path"
)
})?;
let proto = proto.trim();
let rust = rust.trim();
if proto.is_empty() || rust.is_empty() {
anyhow::bail!(
"invalid extern_path format {value:?}, expected \
extern_path=.proto.pkg=::rust::path (both sides non-empty)"
);
}
let mut proto = proto.to_string();
if !proto.starts_with('.') {
proto.insert(0, '.');
}
options.extern_paths.push((proto, rust.to_string()));
} else {
match opt {
"strict_utf8_mapping" => options.strict_utf8_mapping = true,
"no_json" => options.generate_json = false,
_ => {
return Err(anyhow::anyhow!(
"unknown plugin option: {opt:?}. Supported: \
buffa_module=<rust_path>, extern_path=<proto>=<rust>, \
strict_utf8_mapping, no_json"
));
}
}
}
}
}
let generated = generate_services(&request.proto_file, &request.file_to_generate, &options)?;
let files: Vec<CodeGeneratorResponseFile> = generated
.into_iter()
.map(|g| CodeGeneratorResponseFile {
name: Some(g.name),
content: Some(g.content),
..Default::default()
})
.collect();
Ok(CodeGeneratorResponse {
supported_features: Some(feature_flags()),
minimum_edition: Some(EDITION_2023),
maximum_edition: Some(EDITION_2023),
file: files,
..Default::default()
})
}
fn feature_flags() -> u64 {
const FEATURE_PROTO3_OPTIONAL: u64 = 1;
const FEATURE_SUPPORTS_EDITIONS: u64 = 2;
FEATURE_PROTO3_OPTIONAL | FEATURE_SUPPORTS_EDITIONS
}
const EDITION_2023: i32 = 1000;
fn format_token_stream(tokens: &TokenStream) -> Result<String> {
let file = syn::parse2::<syn::File>(tokens.clone())
.map_err(|e| anyhow::anyhow!("generated code failed to parse: {e}"))?;
Ok(prettyplease::unparse(&file))
}
fn doc_attrs(text: &str) -> TokenStream {
let lines: Vec<String> = text
.lines()
.map(|l| {
if l.is_empty() {
String::new()
} else {
format!(" {l}")
}
})
.collect();
quote! { #(#[doc = #lines])* }
}
struct TypeResolver<'a> {
ctx: buffa_codegen::context::CodeGenContext<'a>,
require_extern: bool,
}
impl<'a> TypeResolver<'a> {
fn new(
proto_file: &'a [FileDescriptorProto],
file_to_generate: &[String],
config: &'a buffa_codegen::CodeGenConfig,
require_extern: bool,
) -> Self {
Self {
ctx: buffa_codegen::context::CodeGenContext::for_generate(
proto_file,
file_to_generate,
config,
),
require_extern,
}
}
fn resolve_path(&self, proto_fqn: &str, current_package: &str) -> Result<String> {
match self.ctx.rust_type_relative(proto_fqn, current_package, 0) {
Some(path) => {
if self.require_extern && !path.starts_with("::") && !path.starts_with("crate::") {
anyhow::bail!(
"type {proto_fqn} is not covered by any extern_path mapping. \
Add extern_path=.=<your_buffa_module> (e.g. \
extern_path=.=crate::proto) to the plugin opts."
);
}
Ok(path)
}
None if self.require_extern => anyhow::bail!(
"type {proto_fqn} not found in descriptor set (missing proto import?)"
),
None => Ok(bare_type_name(proto_fqn).to_string()),
}
}
fn rust_type(&self, proto_fqn: &str, current_package: &str) -> Result<TokenStream> {
let path = self.resolve_path(proto_fqn, current_package)?;
Ok(buffa_codegen::idents::rust_path_to_tokens(&path))
}
fn rust_view_type(&self, proto_fqn: &str, current_package: &str) -> Result<TokenStream> {
let path = self.resolve_path(proto_fqn, current_package)?;
Ok(buffa_codegen::idents::rust_path_to_tokens(&format!(
"{path}View"
)))
}
}
fn bare_type_name(proto_fqn: &str) -> &str {
proto_fqn
.strip_prefix('.')
.unwrap_or(proto_fqn)
.rsplit('.')
.next()
.unwrap_or(proto_fqn)
}
fn generate_connect_services(
file: &FileDescriptorProto,
resolver: &TypeResolver<'_>,
) -> Result<TokenStream> {
let mut tokens = TokenStream::new();
let imports = quote! {
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use ::connectrpc::{Context, ConnectError, Router, Dispatcher, view_handler_fn, view_streaming_handler_fn, view_client_streaming_handler_fn, view_bidi_streaming_handler_fn};
use ::connectrpc::dispatcher::codegen as __crpc_codegen;
use ::connectrpc::CodecFormat as __CodecFormat;
use buffa::bytes::Bytes as __Bytes;
use ::connectrpc::client::{ClientConfig, ClientTransport, CallOptions, call_unary, call_server_stream, call_client_stream, call_bidi_stream};
use futures::Stream;
use buffa::Message;
use buffa::view::OwnedView;
};
tokens.extend(imports);
for service in &file.service {
tokens.extend(generate_service(file, service, resolver)?);
}
Ok(tokens)
}
fn generate_service(
file: &FileDescriptorProto,
service: &ServiceDescriptorProto,
resolver: &TypeResolver<'_>,
) -> Result<TokenStream> {
let package = file.package.as_deref().unwrap_or("");
let service_name = service.name.as_deref().unwrap_or("");
let full_service_name = if package.is_empty() {
service_name.to_string()
} else {
format!("{package}.{service_name}")
};
let trait_name = format_ident!("{}", service_name.to_upper_camel_case());
let ext_trait_name = format_ident!("{}Ext", service_name.to_upper_camel_case());
let client_name = format_ident!("{}Client", service_name.to_upper_camel_case());
let service_name_const = format_ident!(
"{}_SERVICE_NAME",
service_name.to_snake_case().to_uppercase()
);
let service_doc = get_service_comment(file, service).unwrap_or_default();
let base_doc = if service_doc.is_empty() {
format!("Server trait for {service_name}.")
} else {
service_doc
};
let full_doc = format!(
"{base_doc}\n\n\
# Implementing handlers\n\n\
Handlers receive requests as `OwnedView<FooView<'static>>`, which gives\n\
zero-copy borrowed access to fields (e.g. `request.name` is a `&str`\n\
into the decoded buffer). The view can be held across `.await` points.\n\n\
Implement methods with plain `async fn`; the returned future satisfies\n\
the `Send` bound automatically. See the\n\
[buffa user guide](https://github.com/anthropics/buffa/blob/main/docs/guide.md#ownedview-in-async-trait-implementations)\n\
for zero-copy access patterns and when `to_owned_message()` is needed."
);
let service_doc_tokens = doc_attrs(&full_doc);
let trait_methods: Vec<TokenStream> = service
.method
.iter()
.map(|m| generate_trait_method(file, service, m, resolver, package))
.collect::<Result<Vec<_>>>()?;
let route_registrations: Vec<TokenStream> = service
.method
.iter()
.map(|m| {
let method_name = m.name.as_deref().unwrap_or("");
let method_snake = format_ident!("{}", method_name.to_snake_case());
let client_streaming = m.client_streaming.unwrap_or(false);
let server_streaming = m.server_streaming.unwrap_or(false);
if server_streaming && !client_streaming {
quote! {
.route_view_server_stream(
#service_name_const,
#method_name,
view_streaming_handler_fn({
let svc = Arc::clone(&self);
move |ctx, req| {
let svc = Arc::clone(&svc);
async move { svc.#method_snake(ctx, req).await }
}
}),
)
}
} else if client_streaming && !server_streaming {
quote! {
.route_view_client_stream(
#service_name_const,
#method_name,
view_client_streaming_handler_fn({
let svc = Arc::clone(&self);
move |ctx, req| {
let svc = Arc::clone(&svc);
async move { svc.#method_snake(ctx, req).await }
}
}),
)
}
} else if client_streaming && server_streaming {
quote! {
.route_view_bidi_stream(
#service_name_const,
#method_name,
view_bidi_streaming_handler_fn({
let svc = Arc::clone(&self);
move |ctx, req| {
let svc = Arc::clone(&svc);
async move { svc.#method_snake(ctx, req).await }
}
}),
)
}
} else {
let is_idempotent = m
.options
.idempotency_level
.map(|level| level == IdempotencyLevel::NO_SIDE_EFFECTS)
.unwrap_or(false);
let route_method = if is_idempotent {
quote! { route_view_idempotent }
} else {
quote! { route_view }
};
quote! {
.#route_method(
#service_name_const,
#method_name,
{
let svc = Arc::clone(&self);
view_handler_fn(move |ctx, req| {
let svc = Arc::clone(&svc);
async move { svc.#method_snake(ctx, req).await }
})
},
)
}
}
})
.collect();
let client_methods: Vec<TokenStream> = service
.method
.iter()
.map(|m| generate_client_method(&full_service_name, m, resolver, package))
.collect::<Result<Vec<_>>>()?;
let service_server =
generate_service_server(&full_service_name, &trait_name, service, resolver, package)?;
let example_method = service
.method
.first()
.and_then(|m| m.name.as_deref())
.map(|n| n.to_snake_case())
.unwrap_or_else(|| "method".to_string());
let client_name_str = client_name.to_string();
let client_doc = format!(
r#"Client for this service.
Generic over `T: ClientTransport`. For **gRPC** (HTTP/2), use
`Http2Connection` — it has honest `poll_ready` and composes with
`tower::balance` for multi-connection load balancing. For **Connect
over HTTP/1.1** (or unknown protocol), use `HttpClient`.
# Example (gRPC / HTTP/2)
```rust,ignore
use connectrpc::client::{{Http2Connection, ClientConfig}};
use connectrpc::Protocol;
let uri: http::Uri = "http://localhost:8080".parse()?;
let conn = Http2Connection::connect_plaintext(uri.clone()).await?.shared(1024);
let config = ClientConfig::new(uri).protocol(Protocol::Grpc);
let client = {client_name_str}::new(conn, config);
let response = client.{example_method}(request).await?;
```
# Example (Connect / HTTP/1.1 or ALPN)
```rust,ignore
use connectrpc::client::{{HttpClient, ClientConfig}};
let http = HttpClient::plaintext(); // cleartext http:// only
let config = ClientConfig::new("http://localhost:8080".parse()?);
let client = {client_name_str}::new(http, config);
let response = client.{example_method}(request).await?;
```
# Working with the response
Unary calls return [`UnaryResponse<OwnedView<FooView>>`](::connectrpc::client::UnaryResponse).
The `OwnedView` derefs to the view, so field access is zero-copy:
```rust,ignore
let resp = client.{example_method}(request).await?.into_view();
let name: &str = resp.name; // borrow into the response buffer
```
If you need the owned struct (e.g. to store or pass by value), use
[`into_owned()`](::connectrpc::client::UnaryResponse::into_owned):
```rust,ignore
let owned = client.{example_method}(request).await?.into_owned();
```"#
);
let client_doc_tokens = doc_attrs(&client_doc);
Ok(quote! {
pub const #service_name_const: &str = #full_service_name;
#service_doc_tokens
#[allow(clippy::type_complexity)]
pub trait #trait_name: Send + Sync + 'static {
#(#trait_methods)*
}
pub trait #ext_trait_name: #trait_name {
fn register(self: Arc<Self>, router: Router) -> Router;
}
impl<S: #trait_name> #ext_trait_name for S {
fn register(self: Arc<Self>, router: Router) -> Router {
router
#(#route_registrations)*
}
}
#service_server
#client_doc_tokens
#[derive(Clone)]
pub struct #client_name<T> {
transport: T,
config: ClientConfig,
}
impl<T> #client_name<T>
where
T: ClientTransport,
<T::ResponseBody as http_body::Body>::Error: std::fmt::Display,
{
pub fn new(transport: T, config: ClientConfig) -> Self {
Self { transport, config }
}
pub fn config(&self) -> &ClientConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut ClientConfig {
&mut self.config
}
#(#client_methods)*
}
})
}
fn generate_service_server(
full_service_name: &str,
trait_name: &proc_macro2::Ident,
service: &ServiceDescriptorProto,
resolver: &TypeResolver<'_>,
package: &str,
) -> Result<TokenStream> {
let server_name = format_ident!("{}Server", trait_name);
let path_prefix = format!("{full_service_name}/");
let lookup_arms: Vec<TokenStream> = service
.method
.iter()
.map(|m| {
let method_name = m.name.as_deref().unwrap_or("");
let client_streaming = m.client_streaming.unwrap_or(false);
let server_streaming = m.server_streaming.unwrap_or(false);
let is_idempotent = m
.options
.idempotency_level
.map(|level| level == IdempotencyLevel::NO_SIDE_EFFECTS)
.unwrap_or(false);
let desc = if client_streaming && server_streaming {
quote! { __crpc_codegen::MethodDescriptor::bidi_streaming() }
} else if client_streaming {
quote! { __crpc_codegen::MethodDescriptor::client_streaming() }
} else if server_streaming {
quote! { __crpc_codegen::MethodDescriptor::server_streaming() }
} else {
quote! { __crpc_codegen::MethodDescriptor::unary(#is_idempotent) }
};
quote! { #method_name => Some(#desc), }
})
.collect();
let mut call_unary_arms: Vec<TokenStream> = Vec::new();
let mut call_ss_arms: Vec<TokenStream> = Vec::new();
let mut call_cs_arms: Vec<TokenStream> = Vec::new();
let mut call_bidi_arms: Vec<TokenStream> = Vec::new();
for m in &service.method {
let method_name = m.name.as_deref().unwrap_or("");
let method_snake = format_ident!("{}", method_name.to_snake_case());
let input_view = resolver.rust_view_type(m.input_type.as_deref().unwrap_or(""), package)?;
let cs = m.client_streaming.unwrap_or(false);
let ss = m.server_streaming.unwrap_or(false);
if cs && ss {
call_bidi_arms.push(quote! {
#method_name => {
let svc = Arc::clone(&self.inner);
Box::pin(async move {
let req_stream = __crpc_codegen::decode_view_request_stream::<#input_view>(requests, format);
let (resp_stream, ctx) = svc.#method_snake(ctx, req_stream).await?;
Ok((__crpc_codegen::encode_response_stream(resp_stream, format), ctx))
})
}
});
} else if cs {
call_cs_arms.push(quote! {
#method_name => {
let svc = Arc::clone(&self.inner);
Box::pin(async move {
let req_stream = __crpc_codegen::decode_view_request_stream::<#input_view>(requests, format);
let (res, ctx) = svc.#method_snake(ctx, req_stream).await?;
let bytes = __crpc_codegen::encode_response(&res, format)?;
Ok((bytes, ctx))
})
}
});
} else if ss {
call_ss_arms.push(quote! {
#method_name => {
let svc = Arc::clone(&self.inner);
Box::pin(async move {
let req = __crpc_codegen::decode_request_view::<#input_view>(request, format)?;
let (resp_stream, ctx) = svc.#method_snake(ctx, req).await?;
Ok((__crpc_codegen::encode_response_stream(resp_stream, format), ctx))
})
}
});
} else {
call_unary_arms.push(quote! {
#method_name => {
let svc = Arc::clone(&self.inner);
Box::pin(async move {
let req = __crpc_codegen::decode_request_view::<#input_view>(request, format)?;
let (res, ctx) = svc.#method_snake(ctx, req).await?;
let bytes = __crpc_codegen::encode_response(&res, format)?;
Ok((bytes, ctx))
})
}
});
}
}
let server_doc = format!(
"Monomorphic dispatcher for `{trait_name}`.\n\n\
Unlike `.register(Router)` which type-erases each method into an \
`Arc<dyn ErasedHandler>` stored in a `HashMap`, this struct dispatches \
via a compile-time `match` on method name: no vtable, no hash lookup.\n\n\
# Example\n\n\
```rust,ignore\n\
use connectrpc::ConnectRpcService;\n\n\
let server = {server_name}::new(MyImpl);\n\
let service = ConnectRpcService::new(server);\n\
// hand `service` to axum/hyper as a fallback_service\n\
```"
);
let server_doc_tokens = doc_attrs(&server_doc);
Ok(quote! {
#server_doc_tokens
pub struct #server_name<T> {
inner: Arc<T>,
}
impl<T: #trait_name> #server_name<T> {
pub fn new(service: T) -> Self {
Self { inner: Arc::new(service) }
}
pub fn from_arc(inner: Arc<T>) -> Self {
Self { inner }
}
}
impl<T> Clone for #server_name<T> {
fn clone(&self) -> Self {
Self { inner: Arc::clone(&self.inner) }
}
}
impl<T: #trait_name> Dispatcher for #server_name<T> {
#[inline]
fn lookup(&self, path: &str) -> Option<__crpc_codegen::MethodDescriptor> {
let method = path.strip_prefix(#path_prefix)?;
match method {
#(#lookup_arms)*
_ => None,
}
}
fn call_unary(
&self,
path: &str,
ctx: Context,
request: __Bytes,
format: __CodecFormat,
) -> __crpc_codegen::UnaryResult {
let Some(method) = path.strip_prefix(#path_prefix) else {
return __crpc_codegen::unimplemented_unary(path);
};
let _ = (&ctx, &request, &format);
match method {
#(#call_unary_arms)*
_ => __crpc_codegen::unimplemented_unary(path),
}
}
fn call_server_streaming(
&self,
path: &str,
ctx: Context,
request: __Bytes,
format: __CodecFormat,
) -> __crpc_codegen::StreamingResult {
let Some(method) = path.strip_prefix(#path_prefix) else {
return __crpc_codegen::unimplemented_streaming(path);
};
let _ = (&ctx, &request, &format);
match method {
#(#call_ss_arms)*
_ => __crpc_codegen::unimplemented_streaming(path),
}
}
fn call_client_streaming(
&self,
path: &str,
ctx: Context,
requests: __crpc_codegen::RequestStream,
format: __CodecFormat,
) -> __crpc_codegen::UnaryResult {
let Some(method) = path.strip_prefix(#path_prefix) else {
return __crpc_codegen::unimplemented_unary(path);
};
let _ = (&ctx, &requests, &format);
match method {
#(#call_cs_arms)*
_ => __crpc_codegen::unimplemented_unary(path),
}
}
fn call_bidi_streaming(
&self,
path: &str,
ctx: Context,
requests: __crpc_codegen::RequestStream,
format: __CodecFormat,
) -> __crpc_codegen::StreamingResult {
let Some(method) = path.strip_prefix(#path_prefix) else {
return __crpc_codegen::unimplemented_streaming(path);
};
let _ = (&ctx, &requests, &format);
match method {
#(#call_bidi_arms)*
_ => __crpc_codegen::unimplemented_streaming(path),
}
}
}
})
}
fn generate_doc_comment(doc: &str, default: &str) -> TokenStream {
let comment = if doc.is_empty() { default } else { doc };
doc_attrs(comment)
}
fn generate_trait_method(
file: &FileDescriptorProto,
service: &ServiceDescriptorProto,
method: &MethodDescriptorProto,
resolver: &TypeResolver<'_>,
package: &str,
) -> Result<TokenStream> {
let method_name = method.name.as_deref().unwrap_or("");
let method_snake = format_ident!("{}", method_name.to_snake_case());
let input_view_type =
resolver.rust_view_type(method.input_type.as_deref().unwrap_or(""), package)?;
let output_type = resolver.rust_type(method.output_type.as_deref().unwrap_or(""), package)?;
let method_doc = get_method_comment(file, service, method).unwrap_or_default();
let method_doc_tokens =
generate_doc_comment(&method_doc, &format!("Handle the {method_name} RPC."));
let client_streaming = method.client_streaming.unwrap_or(false);
let server_streaming = method.server_streaming.unwrap_or(false);
if server_streaming && !client_streaming {
Ok(quote! {
#method_doc_tokens
fn #method_snake(
&self,
ctx: Context,
request: OwnedView<#input_view_type<'static>>,
) -> impl Future<Output = Result<(Pin<Box<dyn Stream<Item = Result<#output_type, ConnectError>> + Send>>, Context), ConnectError>> + Send;
})
} else if client_streaming && !server_streaming {
Ok(quote! {
#method_doc_tokens
fn #method_snake(
&self,
ctx: Context,
requests: Pin<Box<dyn Stream<Item = Result<OwnedView<#input_view_type<'static>>, ConnectError>> + Send>>,
) -> impl Future<Output = Result<(#output_type, Context), ConnectError>> + Send;
})
} else if client_streaming && server_streaming {
Ok(quote! {
#method_doc_tokens
fn #method_snake(
&self,
ctx: Context,
requests: Pin<Box<dyn Stream<Item = Result<OwnedView<#input_view_type<'static>>, ConnectError>> + Send>>,
) -> impl Future<Output = Result<(Pin<Box<dyn Stream<Item = Result<#output_type, ConnectError>> + Send>>, Context), ConnectError>> + Send;
})
} else {
Ok(quote! {
#method_doc_tokens
fn #method_snake(
&self,
ctx: Context,
request: OwnedView<#input_view_type<'static>>,
) -> impl Future<Output = Result<(#output_type, Context), ConnectError>> + Send;
})
}
}
fn generate_client_method(
full_service_name: &str,
method: &MethodDescriptorProto,
resolver: &TypeResolver<'_>,
package: &str,
) -> Result<TokenStream> {
let method_name = method.name.as_deref().unwrap_or("");
let method_snake = format_ident!("{}", method_name.to_snake_case());
let method_with_opts = format_ident!("{}_with_options", method_name.to_snake_case());
let input_type = resolver.rust_type(method.input_type.as_deref().unwrap_or(""), package)?;
let output_view_type =
resolver.rust_view_type(method.output_type.as_deref().unwrap_or(""), package)?;
let client_streaming = method.client_streaming.unwrap_or(false);
let server_streaming = method.server_streaming.unwrap_or(false);
let doc = format!(
" Call the {method_name} RPC. Sends a request to /{full_service_name}/{method_name}."
);
let doc_opts = format!(
" Call the {method_name} RPC with explicit per-call options. \
Options override [`ClientConfig`] defaults."
);
let ret_ty: TokenStream;
let call_body: TokenStream;
let short_args: TokenStream; let opts_args: TokenStream; let short_delegate_args: TokenStream;
if client_streaming && !server_streaming {
ret_ty = quote! {
Result<
::connectrpc::client::UnaryResponse<OwnedView<#output_view_type<'static>>>,
ConnectError,
>
};
call_body = quote! {
call_client_stream(
&self.transport, &self.config,
#full_service_name, #method_name,
requests, options,
).await
};
short_args = quote! { requests: impl IntoIterator<Item = #input_type> };
opts_args =
quote! { requests: impl IntoIterator<Item = #input_type>, options: CallOptions };
short_delegate_args = quote! { requests, CallOptions::default() };
} else if client_streaming && server_streaming {
ret_ty = quote! {
Result<
::connectrpc::client::BidiStream<
T::ResponseBody, #input_type, #output_view_type<'static>
>,
ConnectError,
>
};
call_body = quote! {
call_bidi_stream(
&self.transport, &self.config,
#full_service_name, #method_name, options,
).await
};
short_args = quote! {};
opts_args = quote! { options: CallOptions };
short_delegate_args = quote! { CallOptions::default() };
} else if server_streaming {
ret_ty = quote! {
Result<
::connectrpc::client::ServerStream<T::ResponseBody, #output_view_type<'static>>,
ConnectError,
>
};
call_body = quote! {
call_server_stream(
&self.transport, &self.config,
#full_service_name, #method_name,
request, options,
).await
};
short_args = quote! { request: #input_type };
opts_args = quote! { request: #input_type, options: CallOptions };
short_delegate_args = quote! { request, CallOptions::default() };
} else {
ret_ty = quote! {
Result<
::connectrpc::client::UnaryResponse<OwnedView<#output_view_type<'static>>>,
ConnectError,
>
};
call_body = quote! {
call_unary(
&self.transport, &self.config,
#full_service_name, #method_name,
request, options,
).await
};
short_args = quote! { request: #input_type };
opts_args = quote! { request: #input_type, options: CallOptions };
short_delegate_args = quote! { request, CallOptions::default() };
}
Ok(quote! {
#[doc = #doc]
pub async fn #method_snake(&self, #short_args) -> #ret_ty {
self.#method_with_opts(#short_delegate_args).await
}
#[doc = #doc_opts]
pub async fn #method_with_opts(&self, #opts_args) -> #ret_ty {
#call_body
}
})
}
fn get_service_comment(
file: &FileDescriptorProto,
service: &ServiceDescriptorProto,
) -> Option<String> {
let source_info: &SourceCodeInfo = &file.source_code_info;
let service_index = file.service.iter().position(|s| s.name == service.name)?;
let target_path = vec![6, service_index as i32];
find_comment(source_info, &target_path)
}
fn get_method_comment(
file: &FileDescriptorProto,
service: &ServiceDescriptorProto,
method: &MethodDescriptorProto,
) -> Option<String> {
let source_info: &SourceCodeInfo = &file.source_code_info;
let (service_index, method_index) = file.service.iter().enumerate().find_map(|(si, s)| {
if s.name != service.name {
return None;
}
s.method
.iter()
.position(|m| m.name == method.name)
.map(|mi| (si, mi))
})?;
let target_path = vec![6, service_index as i32, 2, method_index as i32];
find_comment(source_info, &target_path)
}
fn find_comment(source_info: &SourceCodeInfo, target_path: &[i32]) -> Option<String> {
for location in &source_info.location {
if location.path == target_path {
let comment = location
.leading_comments
.as_ref()
.or(location.trailing_comments.as_ref())?;
let cleaned: String = comment
.lines()
.map(|line| line.trim())
.filter(|line| !line.is_empty())
.collect::<Vec<_>>()
.join("\n");
if !cleaned.is_empty() {
return Some(cleaned);
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use buffa_codegen::generated::descriptor::DescriptorProto;
#[test]
fn doc_attrs_prefixes_space_for_prettyplease() {
let ts = quote! {
#[allow(dead_code)]
mod m {}
};
let doc = doc_attrs("Hello.\n\nSecond paragraph.");
let combined = quote! { #doc #ts };
let file = syn::parse2::<syn::File>(combined).unwrap();
let out = prettyplease::unparse(&file);
assert!(out.contains("/// Hello."), "got: {out}");
assert!(out.contains("/// Second paragraph."), "got: {out}");
assert!(out.contains("///\n"), "got: {out}");
assert!(!out.contains("///Hello"), "got: {out}");
assert!(!out.contains("/// Hello"), "got: {out}");
}
fn minimal_file(
package: Option<&str>,
input_type: &str,
output_type: &str,
local_messages: &[&str],
) -> FileDescriptorProto {
let method = MethodDescriptorProto {
name: Some("Ping".into()),
input_type: Some(input_type.into()),
output_type: Some(output_type.into()),
..Default::default()
};
let service = ServiceDescriptorProto {
name: Some("PingService".into()),
method: vec![method],
..Default::default()
};
FileDescriptorProto {
name: Some("ping.proto".into()),
package: package.map(|p| p.into()),
service: vec![service],
message_type: local_messages
.iter()
.map(|name| DescriptorProto {
name: Some((*name).into()),
..Default::default()
})
.collect(),
..Default::default()
}
}
fn gen_service(
files: &[FileDescriptorProto],
target_idx: usize,
extern_paths: &[(String, String)],
require_extern: bool,
) -> Result<String> {
let mut config = buffa_codegen::CodeGenConfig::default();
config.extern_paths = extern_paths.to_vec();
let target_name = files[target_idx]
.name
.clone()
.into_iter()
.collect::<Vec<_>>();
let resolver = TypeResolver::new(files, &target_name, &config, require_extern);
let file = &files[target_idx];
let service = &file.service[0];
Ok(generate_service(file, service, &resolver)?.to_string())
}
#[test]
fn service_name_with_package() {
let file = minimal_file(
Some("example.v1"),
".example.v1.PingReq",
".example.v1.PingResp",
&["PingReq", "PingResp"],
);
let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
assert!(code.contains("\"example.v1.PingService\""), "got: {code}");
}
#[test]
fn service_name_without_package() {
let file = minimal_file(None, ".PingReq", ".PingResp", &["PingReq", "PingResp"]);
let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
assert!(code.contains("\"PingService\""), "got: {code}");
assert!(
!code.contains("\".PingService\""),
"must not have leading dot: {code}"
);
}
#[test]
fn same_package_types_use_bare_names() {
let file = minimal_file(
Some("example.v1"),
".example.v1.PingReq",
".example.v1.PingResp",
&["PingReq", "PingResp"],
);
let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
assert!(code.contains("PingReq"), "input type missing: {code}");
assert!(code.contains("PingResp"), "output type missing: {code}");
assert!(
!code.contains("super :: PingReq"),
"unexpected super: {code}"
);
}
#[test]
fn cross_package_types_use_relative_paths() {
let common = FileDescriptorProto {
name: Some("common.proto".into()),
package: Some("common.v1".into()),
message_type: vec![DescriptorProto {
name: Some("Shared".into()),
..Default::default()
}],
..Default::default()
};
let svc = minimal_file(
Some("example.v1"),
".common.v1.Shared",
".example.v1.Out",
&["Out"],
);
let code = gen_service(&[common, svc], 1, &[], false).unwrap();
assert!(
code.contains("super :: super :: common :: v1 :: Shared"),
"cross-package path not emitted: {code}"
);
assert!(
code.contains("super :: super :: common :: v1 :: SharedView"),
"cross-package view path not emitted: {code}"
);
}
#[test]
fn wkt_types_use_buffa_types_extern_path() {
let wkt = FileDescriptorProto {
name: Some("google/protobuf/empty.proto".into()),
package: Some("google.protobuf".into()),
message_type: vec![DescriptorProto {
name: Some("Empty".into()),
..Default::default()
}],
..Default::default()
};
let svc = minimal_file(
Some("example.v1"),
".google.protobuf.Empty",
".example.v1.Out",
&["Out"],
);
let code = gen_service(&[wkt, svc], 1, &[], false).unwrap();
assert!(
code.contains(":: buffa_types :: google :: protobuf :: Empty"),
"WKT extern path not emitted: {code}"
);
}
#[test]
fn extern_catchall_uses_absolute_paths() {
let file = minimal_file(
Some("example.v1"),
".example.v1.PingReq",
".example.v1.PingResp",
&["PingReq", "PingResp"],
);
let extern_paths = [(".".into(), "crate::proto".into())];
let code = gen_service(std::slice::from_ref(&file), 0, &extern_paths, true).unwrap();
assert!(
code.contains("crate :: proto :: example :: v1 :: PingReq"),
"owned type path missing: {code}"
);
assert!(
code.contains("crate :: proto :: example :: v1 :: PingReqView"),
"view type path missing: {code}"
);
}
#[test]
fn extern_catchall_with_wkt_longest_wins() {
let wkt = FileDescriptorProto {
name: Some("google/protobuf/empty.proto".into()),
package: Some("google.protobuf".into()),
message_type: vec![DescriptorProto {
name: Some("Empty".into()),
..Default::default()
}],
..Default::default()
};
let svc = minimal_file(
Some("example.v1"),
".google.protobuf.Empty",
".example.v1.Out",
&["Out"],
);
let extern_paths = [(".".into(), "crate::proto".into())];
let code = gen_service(&[wkt, svc], 1, &extern_paths, true).unwrap();
assert!(
code.contains(":: buffa_types :: google :: protobuf :: Empty"),
"WKT mapping lost to catch-all: {code}"
);
assert!(
code.contains("crate :: proto :: example :: v1 :: Out"),
"local type not routed through catch-all: {code}"
);
}
#[test]
fn missing_extern_path_errors() {
let file = minimal_file(
Some("example.v1"),
".example.v1.PingReq",
".example.v1.PingResp",
&["PingReq", "PingResp"],
);
let err = gen_service(std::slice::from_ref(&file), 0, &[], true).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("extern_path"),
"error message lacks hint: {msg}"
);
}
#[test]
fn keyword_package_escaped() {
let file = minimal_file(
Some("google.type"),
".google.type.LatLng",
".google.type.LatLng",
&["LatLng"],
);
let extern_paths = [(".".into(), "crate::proto".into())];
let code = gen_service(std::slice::from_ref(&file), 0, &extern_paths, true).unwrap();
assert!(
code.contains("crate :: proto :: google :: r#type :: LatLng"),
"keyword segment not escaped: {code}"
);
}
}