1use anyhow::Result;
12use heck::ToSnakeCase;
13use heck::ToUpperCamelCase;
14use proc_macro2::TokenStream;
15use quote::format_ident;
16use quote::quote;
17
18use buffa_codegen::generated::descriptor::FileDescriptorProto;
19use buffa_codegen::generated::descriptor::MethodDescriptorProto;
20use buffa_codegen::generated::descriptor::ServiceDescriptorProto;
21use buffa_codegen::generated::descriptor::SourceCodeInfo;
22use buffa_codegen::generated::descriptor::method_options::IdempotencyLevel;
23
24pub use buffa_codegen::GeneratedFile;
25pub use buffa_codegen::generated::descriptor;
26
27use crate::plugin::CodeGeneratorRequest;
28use crate::plugin::CodeGeneratorResponse;
29use crate::plugin::CodeGeneratorResponseFile;
30
31#[derive(Debug, Clone)]
38#[non_exhaustive]
39pub struct Options {
40 pub strict_utf8_mapping: bool,
44 pub generate_json: bool,
49 pub extern_paths: Vec<(String, String)>,
60}
61
62impl Default for Options {
63 fn default() -> Self {
64 Self {
65 strict_utf8_mapping: false,
66 generate_json: true,
67 extern_paths: Vec::new(),
68 }
69 }
70}
71
72impl Options {
73 fn to_buffa_config(&self) -> buffa_codegen::CodeGenConfig {
74 let mut config = buffa_codegen::CodeGenConfig::default();
75 config.generate_views = true;
76 config.generate_json = self.generate_json;
77 config.strict_utf8_mapping = self.strict_utf8_mapping;
78 config.extern_paths.clone_from(&self.extern_paths);
79 config
80 }
81}
82
83fn emit_service_files(
86 proto_file: &[FileDescriptorProto],
87 file_to_generate: &[String],
88 resolver: &TypeResolver<'_>,
89) -> Result<Vec<GeneratedFile>> {
90 let mut out = Vec::new();
91 for file_name in file_to_generate {
92 let file_desc = proto_file
93 .iter()
94 .find(|f| f.name.as_deref() == Some(file_name.as_str()));
95
96 if let Some(file) = file_desc
97 && !file.service.is_empty()
98 {
99 let service_tokens = generate_connect_services(file, resolver)?;
100 let service_code = format_token_stream(&service_tokens)?;
101 out.push(GeneratedFile {
102 name: buffa_codegen::proto_path_to_rust_module(file_name),
103 content: service_code,
104 });
105 }
106 }
107 Ok(out)
108}
109
110pub fn generate_files(
127 proto_file: &[FileDescriptorProto],
128 file_to_generate: &[String],
129 options: &Options,
130) -> Result<Vec<GeneratedFile>> {
131 let config = options.to_buffa_config();
132
133 let mut files = buffa_codegen::generate(proto_file, file_to_generate, &config)
134 .map_err(|e| anyhow::anyhow!("buffa-codegen failed: {e}"))?;
135
136 let resolver = TypeResolver::new(proto_file, file_to_generate, &config, false);
137 let service_files = emit_service_files(proto_file, file_to_generate, &resolver)?;
138
139 for svc in service_files {
141 if let Some(out) = files.iter_mut().find(|g| g.name == svc.name) {
142 out.content.push('\n');
143 out.content.push_str(&svc.content);
144 }
145 }
146
147 Ok(files)
148}
149
150pub fn generate_services(
167 proto_file: &[FileDescriptorProto],
168 file_to_generate: &[String],
169 options: &Options,
170) -> Result<Vec<GeneratedFile>> {
171 let config = options.to_buffa_config();
172 let resolver = TypeResolver::new(proto_file, file_to_generate, &config, true);
173 emit_service_files(proto_file, file_to_generate, &resolver)
174}
175
176pub fn generate(request: &CodeGeneratorRequest) -> Result<CodeGeneratorResponse> {
197 let mut options = Options::default();
198
199 if let Some(ref param) = request.parameter {
200 for opt in param.split(',').map(str::trim).filter(|s| !s.is_empty()) {
201 if let Some(value) = opt.strip_prefix("buffa_module=") {
202 let rust = value.trim();
203 if rust.is_empty() {
204 anyhow::bail!(
205 "buffa_module requires a non-empty path, \
206 e.g. buffa_module=crate::proto"
207 );
208 }
209 options.extern_paths.push((".".into(), rust.to_string()));
210 } else if let Some(value) = opt.strip_prefix("extern_path=") {
211 let (proto, rust) = value.split_once('=').ok_or_else(|| {
213 anyhow::anyhow!(
214 "invalid extern_path format {value:?}, expected \
215 extern_path=.proto.pkg=::rust::path"
216 )
217 })?;
218 let proto = proto.trim();
219 let rust = rust.trim();
220 if proto.is_empty() || rust.is_empty() {
221 anyhow::bail!(
222 "invalid extern_path format {value:?}, expected \
223 extern_path=.proto.pkg=::rust::path (both sides non-empty)"
224 );
225 }
226 let mut proto = proto.to_string();
227 if !proto.starts_with('.') {
228 proto.insert(0, '.');
229 }
230 options.extern_paths.push((proto, rust.to_string()));
231 } else {
232 match opt {
233 "strict_utf8_mapping" => options.strict_utf8_mapping = true,
234 "no_json" => options.generate_json = false,
235 _ => {
236 return Err(anyhow::anyhow!(
237 "unknown plugin option: {opt:?}. Supported: \
238 buffa_module=<rust_path>, extern_path=<proto>=<rust>, \
239 strict_utf8_mapping, no_json"
240 ));
241 }
242 }
243 }
244 }
245 }
246
247 let generated = generate_services(&request.proto_file, &request.file_to_generate, &options)?;
248
249 let files: Vec<CodeGeneratorResponseFile> = generated
250 .into_iter()
251 .map(|g| CodeGeneratorResponseFile {
252 name: Some(g.name),
253 content: Some(g.content),
254 ..Default::default()
255 })
256 .collect();
257
258 Ok(CodeGeneratorResponse {
259 supported_features: Some(feature_flags()),
260 minimum_edition: Some(EDITION_2023),
261 maximum_edition: Some(EDITION_2023),
262 file: files,
263 ..Default::default()
264 })
265}
266
267fn feature_flags() -> u64 {
270 const FEATURE_PROTO3_OPTIONAL: u64 = 1;
271 const FEATURE_SUPPORTS_EDITIONS: u64 = 2;
272 FEATURE_PROTO3_OPTIONAL | FEATURE_SUPPORTS_EDITIONS
273}
274
275const EDITION_2023: i32 = 1000;
278
279fn format_token_stream(tokens: &TokenStream) -> Result<String> {
281 let file = syn::parse2::<syn::File>(tokens.clone())
282 .map_err(|e| anyhow::anyhow!("generated code failed to parse: {e}"))?;
283 Ok(prettyplease::unparse(&file))
284}
285
286fn doc_attrs(text: &str) -> TokenStream {
295 let lines: Vec<String> = text
296 .lines()
297 .map(|l| {
298 if l.is_empty() {
299 String::new()
300 } else {
301 format!(" {l}")
302 }
303 })
304 .collect();
305 quote! { #(#[doc = #lines])* }
306}
307
308struct TypeResolver<'a> {
321 ctx: buffa_codegen::context::CodeGenContext<'a>,
322 require_extern: bool,
328}
329
330impl<'a> TypeResolver<'a> {
331 fn new(
332 proto_file: &'a [FileDescriptorProto],
333 file_to_generate: &[String],
334 config: &'a buffa_codegen::CodeGenConfig,
335 require_extern: bool,
336 ) -> Self {
337 Self {
338 ctx: buffa_codegen::context::CodeGenContext::for_generate(
339 proto_file,
340 file_to_generate,
341 config,
342 ),
343 require_extern,
344 }
345 }
346
347 fn resolve_path(&self, proto_fqn: &str, current_package: &str) -> Result<String> {
354 match self.ctx.rust_type_relative(proto_fqn, current_package, 0) {
355 Some(path) => {
356 if self.require_extern && !path.starts_with("::") && !path.starts_with("crate::") {
357 anyhow::bail!(
358 "type {proto_fqn} is not covered by any extern_path mapping. \
359 Add extern_path=.=<your_buffa_module> (e.g. \
360 extern_path=.=crate::proto) to the plugin opts."
361 );
362 }
363 Ok(path)
364 }
365 None if self.require_extern => anyhow::bail!(
366 "type {proto_fqn} not found in descriptor set (missing proto import?)"
367 ),
368 None => Ok(bare_type_name(proto_fqn).to_string()),
369 }
370 }
371
372 fn rust_type(&self, proto_fqn: &str, current_package: &str) -> Result<TokenStream> {
374 let path = self.resolve_path(proto_fqn, current_package)?;
375 Ok(buffa_codegen::idents::rust_path_to_tokens(&path))
376 }
377
378 fn rust_view_type(&self, proto_fqn: &str, current_package: &str) -> Result<TokenStream> {
381 let path = self.resolve_path(proto_fqn, current_package)?;
382 Ok(buffa_codegen::idents::rust_path_to_tokens(&format!(
383 "{path}View"
384 )))
385 }
386}
387
388fn bare_type_name(proto_fqn: &str) -> &str {
391 proto_fqn
392 .strip_prefix('.')
393 .unwrap_or(proto_fqn)
394 .rsplit('.')
395 .next()
396 .unwrap_or(proto_fqn)
397}
398
399fn generate_connect_services(
405 file: &FileDescriptorProto,
406 resolver: &TypeResolver<'_>,
407) -> Result<TokenStream> {
408 let mut tokens = TokenStream::new();
409
410 let imports = quote! {
414 use std::future::Future;
415 use std::pin::Pin;
416 use std::sync::Arc;
417
418 use ::connectrpc::{Context, ConnectError, Router, Dispatcher, view_handler_fn, view_streaming_handler_fn, view_client_streaming_handler_fn, view_bidi_streaming_handler_fn};
419 use ::connectrpc::dispatcher::codegen as __crpc_codegen;
420 use ::connectrpc::CodecFormat as __CodecFormat;
421 use buffa::bytes::Bytes as __Bytes;
422 use ::connectrpc::client::{ClientConfig, ClientTransport, CallOptions, call_unary, call_server_stream, call_client_stream, call_bidi_stream};
423 use futures::Stream;
424 use buffa::Message;
425 use buffa::view::OwnedView;
426 };
427 tokens.extend(imports);
428
429 for service in &file.service {
430 tokens.extend(generate_service(file, service, resolver)?);
431 }
432
433 Ok(tokens)
434}
435
436fn generate_service(
438 file: &FileDescriptorProto,
439 service: &ServiceDescriptorProto,
440 resolver: &TypeResolver<'_>,
441) -> Result<TokenStream> {
442 let package = file.package.as_deref().unwrap_or("");
443 let service_name = service.name.as_deref().unwrap_or("");
444 let full_service_name = if package.is_empty() {
447 service_name.to_string()
448 } else {
449 format!("{package}.{service_name}")
450 };
451 let trait_name = format_ident!("{}", service_name.to_upper_camel_case());
452 let ext_trait_name = format_ident!("{}Ext", service_name.to_upper_camel_case());
453 let client_name = format_ident!("{}Client", service_name.to_upper_camel_case());
454 let service_name_const = format_ident!(
455 "{}_SERVICE_NAME",
456 service_name.to_snake_case().to_uppercase()
457 );
458
459 let service_doc = get_service_comment(file, service).unwrap_or_default();
461 let base_doc = if service_doc.is_empty() {
462 format!("Server trait for {service_name}.")
463 } else {
464 service_doc
465 };
466 let full_doc = format!(
467 "{base_doc}\n\n\
468 # Implementing handlers\n\n\
469 Handlers receive requests as `OwnedView<FooView<'static>>`, which gives\n\
470 zero-copy borrowed access to fields (e.g. `request.name` is a `&str`\n\
471 into the decoded buffer). The view can be held across `.await` points.\n\n\
472 Implement methods with plain `async fn`; the returned future satisfies\n\
473 the `Send` bound automatically. See the\n\
474 [buffa user guide](https://github.com/anthropics/buffa/blob/main/docs/guide.md#ownedview-in-async-trait-implementations)\n\
475 for zero-copy access patterns and when `to_owned_message()` is needed."
476 );
477 let service_doc_tokens = doc_attrs(&full_doc);
478
479 let trait_methods: Vec<TokenStream> = service
481 .method
482 .iter()
483 .map(|m| generate_trait_method(file, service, m, resolver, package))
484 .collect::<Result<Vec<_>>>()?;
485
486 let route_registrations: Vec<TokenStream> = service
488 .method
489 .iter()
490 .map(|m| {
491 let method_name = m.name.as_deref().unwrap_or("");
492 let method_snake = format_ident!("{}", method_name.to_snake_case());
493
494 let client_streaming = m.client_streaming.unwrap_or(false);
495 let server_streaming = m.server_streaming.unwrap_or(false);
496
497 if server_streaming && !client_streaming {
498 quote! {
500 .route_view_server_stream(
501 #service_name_const,
502 #method_name,
503 view_streaming_handler_fn({
504 let svc = Arc::clone(&self);
505 move |ctx, req| {
506 let svc = Arc::clone(&svc);
507 async move { svc.#method_snake(ctx, req).await }
508 }
509 }),
510 )
511 }
512 } else if client_streaming && !server_streaming {
513 quote! {
515 .route_view_client_stream(
516 #service_name_const,
517 #method_name,
518 view_client_streaming_handler_fn({
519 let svc = Arc::clone(&self);
520 move |ctx, req| {
521 let svc = Arc::clone(&svc);
522 async move { svc.#method_snake(ctx, req).await }
523 }
524 }),
525 )
526 }
527 } else if client_streaming && server_streaming {
528 quote! {
530 .route_view_bidi_stream(
531 #service_name_const,
532 #method_name,
533 view_bidi_streaming_handler_fn({
534 let svc = Arc::clone(&self);
535 move |ctx, req| {
536 let svc = Arc::clone(&svc);
537 async move { svc.#method_snake(ctx, req).await }
538 }
539 }),
540 )
541 }
542 } else {
543 let is_idempotent = m
545 .options
546 .idempotency_level
547 .map(|level| level == IdempotencyLevel::NO_SIDE_EFFECTS)
548 .unwrap_or(false);
549
550 let route_method = if is_idempotent {
551 quote! { route_view_idempotent }
552 } else {
553 quote! { route_view }
554 };
555
556 quote! {
557 .#route_method(
558 #service_name_const,
559 #method_name,
560 {
561 let svc = Arc::clone(&self);
562 view_handler_fn(move |ctx, req| {
563 let svc = Arc::clone(&svc);
564 async move { svc.#method_snake(ctx, req).await }
565 })
566 },
567 )
568 }
569 }
570 })
571 .collect();
572
573 let client_methods: Vec<TokenStream> = service
575 .method
576 .iter()
577 .map(|m| generate_client_method(&full_service_name, m, resolver, package))
578 .collect::<Result<Vec<_>>>()?;
579
580 let service_server =
582 generate_service_server(&full_service_name, &trait_name, service, resolver, package)?;
583
584 let example_method = service
586 .method
587 .first()
588 .and_then(|m| m.name.as_deref())
589 .map(|n| n.to_snake_case())
590 .unwrap_or_else(|| "method".to_string());
591
592 let client_name_str = client_name.to_string();
594 let client_doc = format!(
595 r#"Client for this service.
596
597Generic over `T: ClientTransport`. For **gRPC** (HTTP/2), use
598`Http2Connection` — it has honest `poll_ready` and composes with
599`tower::balance` for multi-connection load balancing. For **Connect
600over HTTP/1.1** (or unknown protocol), use `HttpClient`.
601
602# Example (gRPC / HTTP/2)
603
604```rust,ignore
605use connectrpc::client::{{Http2Connection, ClientConfig}};
606use connectrpc::Protocol;
607
608let uri: http::Uri = "http://localhost:8080".parse()?;
609let conn = Http2Connection::connect_plaintext(uri.clone()).await?.shared(1024);
610let config = ClientConfig::new(uri).protocol(Protocol::Grpc);
611
612let client = {client_name_str}::new(conn, config);
613let response = client.{example_method}(request).await?;
614```
615
616# Example (Connect / HTTP/1.1 or ALPN)
617
618```rust,ignore
619use connectrpc::client::{{HttpClient, ClientConfig}};
620
621let http = HttpClient::plaintext(); // cleartext http:// only
622let config = ClientConfig::new("http://localhost:8080".parse()?);
623
624let client = {client_name_str}::new(http, config);
625let response = client.{example_method}(request).await?;
626```
627
628# Working with the response
629
630Unary calls return [`UnaryResponse<OwnedView<FooView>>`](::connectrpc::client::UnaryResponse).
631The `OwnedView` derefs to the view, so field access is zero-copy:
632
633```rust,ignore
634let resp = client.{example_method}(request).await?.into_view();
635let name: &str = resp.name; // borrow into the response buffer
636```
637
638If you need the owned struct (e.g. to store or pass by value), use
639[`into_owned()`](::connectrpc::client::UnaryResponse::into_owned):
640
641```rust,ignore
642let owned = client.{example_method}(request).await?.into_owned();
643```"#
644 );
645 let client_doc_tokens = doc_attrs(&client_doc);
646
647 Ok(quote! {
648 pub const #service_name_const: &str = #full_service_name;
654
655 #service_doc_tokens
656 #[allow(clippy::type_complexity)]
657 pub trait #trait_name: Send + Sync + 'static {
658 #(#trait_methods)*
659 }
660
661 pub trait #ext_trait_name: #trait_name {
674 fn register(self: Arc<Self>, router: Router) -> Router;
679 }
680
681 impl<S: #trait_name> #ext_trait_name for S {
682 fn register(self: Arc<Self>, router: Router) -> Router {
683 router
684 #(#route_registrations)*
685 }
686 }
687
688 #service_server
689
690 #client_doc_tokens
691 #[derive(Clone)]
692 pub struct #client_name<T> {
693 transport: T,
694 config: ClientConfig,
695 }
696
697 impl<T> #client_name<T>
698 where
699 T: ClientTransport,
700 <T::ResponseBody as http_body::Body>::Error: std::fmt::Display,
701 {
702 pub fn new(transport: T, config: ClientConfig) -> Self {
704 Self { transport, config }
705 }
706
707 pub fn config(&self) -> &ClientConfig {
709 &self.config
710 }
711
712 pub fn config_mut(&mut self) -> &mut ClientConfig {
714 &mut self.config
715 }
716
717 #(#client_methods)*
718 }
719 })
720}
721
722fn generate_service_server(
729 full_service_name: &str,
730 trait_name: &proc_macro2::Ident,
731 service: &ServiceDescriptorProto,
732 resolver: &TypeResolver<'_>,
733 package: &str,
734) -> Result<TokenStream> {
735 let server_name = format_ident!("{}Server", trait_name);
736 let path_prefix = format!("{full_service_name}/");
738
739 let lookup_arms: Vec<TokenStream> = service
741 .method
742 .iter()
743 .map(|m| {
744 let method_name = m.name.as_deref().unwrap_or("");
745 let client_streaming = m.client_streaming.unwrap_or(false);
746 let server_streaming = m.server_streaming.unwrap_or(false);
747 let is_idempotent = m
748 .options
749 .idempotency_level
750 .map(|level| level == IdempotencyLevel::NO_SIDE_EFFECTS)
751 .unwrap_or(false);
752
753 let desc = if client_streaming && server_streaming {
754 quote! { __crpc_codegen::MethodDescriptor::bidi_streaming() }
755 } else if client_streaming {
756 quote! { __crpc_codegen::MethodDescriptor::client_streaming() }
757 } else if server_streaming {
758 quote! { __crpc_codegen::MethodDescriptor::server_streaming() }
759 } else {
760 quote! { __crpc_codegen::MethodDescriptor::unary(#is_idempotent) }
761 };
762 quote! { #method_name => Some(#desc), }
763 })
764 .collect();
765
766 let mut call_unary_arms: Vec<TokenStream> = Vec::new();
771 let mut call_ss_arms: Vec<TokenStream> = Vec::new();
772 let mut call_cs_arms: Vec<TokenStream> = Vec::new();
773 let mut call_bidi_arms: Vec<TokenStream> = Vec::new();
774
775 for m in &service.method {
776 let method_name = m.name.as_deref().unwrap_or("");
777 let method_snake = format_ident!("{}", method_name.to_snake_case());
778 let input_view = resolver.rust_view_type(m.input_type.as_deref().unwrap_or(""), package)?;
779 let cs = m.client_streaming.unwrap_or(false);
780 let ss = m.server_streaming.unwrap_or(false);
781
782 if cs && ss {
783 call_bidi_arms.push(quote! {
785 #method_name => {
786 let svc = Arc::clone(&self.inner);
787 Box::pin(async move {
788 let req_stream = __crpc_codegen::decode_view_request_stream::<#input_view>(requests, format);
789 let (resp_stream, ctx) = svc.#method_snake(ctx, req_stream).await?;
790 Ok((__crpc_codegen::encode_response_stream(resp_stream, format), ctx))
791 })
792 }
793 });
794 } else if cs {
795 call_cs_arms.push(quote! {
797 #method_name => {
798 let svc = Arc::clone(&self.inner);
799 Box::pin(async move {
800 let req_stream = __crpc_codegen::decode_view_request_stream::<#input_view>(requests, format);
801 let (res, ctx) = svc.#method_snake(ctx, req_stream).await?;
802 let bytes = __crpc_codegen::encode_response(&res, format)?;
803 Ok((bytes, ctx))
804 })
805 }
806 });
807 } else if ss {
808 call_ss_arms.push(quote! {
810 #method_name => {
811 let svc = Arc::clone(&self.inner);
812 Box::pin(async move {
813 let req = __crpc_codegen::decode_request_view::<#input_view>(request, format)?;
814 let (resp_stream, ctx) = svc.#method_snake(ctx, req).await?;
815 Ok((__crpc_codegen::encode_response_stream(resp_stream, format), ctx))
816 })
817 }
818 });
819 } else {
820 call_unary_arms.push(quote! {
822 #method_name => {
823 let svc = Arc::clone(&self.inner);
824 Box::pin(async move {
825 let req = __crpc_codegen::decode_request_view::<#input_view>(request, format)?;
826 let (res, ctx) = svc.#method_snake(ctx, req).await?;
827 let bytes = __crpc_codegen::encode_response(&res, format)?;
828 Ok((bytes, ctx))
829 })
830 }
831 });
832 }
833 }
834
835 let server_doc = format!(
836 "Monomorphic dispatcher for `{trait_name}`.\n\n\
837 Unlike `.register(Router)` which type-erases each method into an \
838 `Arc<dyn ErasedHandler>` stored in a `HashMap`, this struct dispatches \
839 via a compile-time `match` on method name: no vtable, no hash lookup.\n\n\
840 # Example\n\n\
841 ```rust,ignore\n\
842 use connectrpc::ConnectRpcService;\n\n\
843 let server = {server_name}::new(MyImpl);\n\
844 let service = ConnectRpcService::new(server);\n\
845 // hand `service` to axum/hyper as a fallback_service\n\
846 ```"
847 );
848 let server_doc_tokens = doc_attrs(&server_doc);
849
850 Ok(quote! {
851 #server_doc_tokens
852 pub struct #server_name<T> {
853 inner: Arc<T>,
854 }
855
856 impl<T: #trait_name> #server_name<T> {
857 pub fn new(service: T) -> Self {
859 Self { inner: Arc::new(service) }
860 }
861
862 pub fn from_arc(inner: Arc<T>) -> Self {
864 Self { inner }
865 }
866 }
867
868 impl<T> Clone for #server_name<T> {
869 fn clone(&self) -> Self {
870 Self { inner: Arc::clone(&self.inner) }
871 }
872 }
873
874 impl<T: #trait_name> Dispatcher for #server_name<T> {
875 #[inline]
876 fn lookup(&self, path: &str) -> Option<__crpc_codegen::MethodDescriptor> {
877 let method = path.strip_prefix(#path_prefix)?;
878 match method {
879 #(#lookup_arms)*
880 _ => None,
881 }
882 }
883
884 fn call_unary(
885 &self,
886 path: &str,
887 ctx: Context,
888 request: __Bytes,
889 format: __CodecFormat,
890 ) -> __crpc_codegen::UnaryResult {
891 let Some(method) = path.strip_prefix(#path_prefix) else {
892 return __crpc_codegen::unimplemented_unary(path);
893 };
894 let _ = (&ctx, &request, &format);
896 match method {
897 #(#call_unary_arms)*
898 _ => __crpc_codegen::unimplemented_unary(path),
899 }
900 }
901
902 fn call_server_streaming(
903 &self,
904 path: &str,
905 ctx: Context,
906 request: __Bytes,
907 format: __CodecFormat,
908 ) -> __crpc_codegen::StreamingResult {
909 let Some(method) = path.strip_prefix(#path_prefix) else {
910 return __crpc_codegen::unimplemented_streaming(path);
911 };
912 let _ = (&ctx, &request, &format);
913 match method {
914 #(#call_ss_arms)*
915 _ => __crpc_codegen::unimplemented_streaming(path),
916 }
917 }
918
919 fn call_client_streaming(
920 &self,
921 path: &str,
922 ctx: Context,
923 requests: __crpc_codegen::RequestStream,
924 format: __CodecFormat,
925 ) -> __crpc_codegen::UnaryResult {
926 let Some(method) = path.strip_prefix(#path_prefix) else {
927 return __crpc_codegen::unimplemented_unary(path);
928 };
929 let _ = (&ctx, &requests, &format);
930 match method {
931 #(#call_cs_arms)*
932 _ => __crpc_codegen::unimplemented_unary(path),
933 }
934 }
935
936 fn call_bidi_streaming(
937 &self,
938 path: &str,
939 ctx: Context,
940 requests: __crpc_codegen::RequestStream,
941 format: __CodecFormat,
942 ) -> __crpc_codegen::StreamingResult {
943 let Some(method) = path.strip_prefix(#path_prefix) else {
944 return __crpc_codegen::unimplemented_streaming(path);
945 };
946 let _ = (&ctx, &requests, &format);
947 match method {
948 #(#call_bidi_arms)*
949 _ => __crpc_codegen::unimplemented_streaming(path),
950 }
951 }
952 }
953 })
954}
955
956fn generate_doc_comment(doc: &str, default: &str) -> TokenStream {
958 let comment = if doc.is_empty() { default } else { doc };
959 doc_attrs(comment)
960}
961
962fn generate_trait_method(
964 file: &FileDescriptorProto,
965 service: &ServiceDescriptorProto,
966 method: &MethodDescriptorProto,
967 resolver: &TypeResolver<'_>,
968 package: &str,
969) -> Result<TokenStream> {
970 let method_name = method.name.as_deref().unwrap_or("");
971 let method_snake = format_ident!("{}", method_name.to_snake_case());
972 let input_view_type =
973 resolver.rust_view_type(method.input_type.as_deref().unwrap_or(""), package)?;
974 let output_type = resolver.rust_type(method.output_type.as_deref().unwrap_or(""), package)?;
975
976 let method_doc = get_method_comment(file, service, method).unwrap_or_default();
978 let method_doc_tokens =
979 generate_doc_comment(&method_doc, &format!("Handle the {method_name} RPC."));
980
981 let client_streaming = method.client_streaming.unwrap_or(false);
983 let server_streaming = method.server_streaming.unwrap_or(false);
984
985 if server_streaming && !client_streaming {
986 Ok(quote! {
988 #method_doc_tokens
989 fn #method_snake(
990 &self,
991 ctx: Context,
992 request: OwnedView<#input_view_type<'static>>,
993 ) -> impl Future<Output = Result<(Pin<Box<dyn Stream<Item = Result<#output_type, ConnectError>> + Send>>, Context), ConnectError>> + Send;
994 })
995 } else if client_streaming && !server_streaming {
996 Ok(quote! {
998 #method_doc_tokens
999 fn #method_snake(
1000 &self,
1001 ctx: Context,
1002 requests: Pin<Box<dyn Stream<Item = Result<OwnedView<#input_view_type<'static>>, ConnectError>> + Send>>,
1003 ) -> impl Future<Output = Result<(#output_type, Context), ConnectError>> + Send;
1004 })
1005 } else if client_streaming && server_streaming {
1006 Ok(quote! {
1008 #method_doc_tokens
1009 fn #method_snake(
1010 &self,
1011 ctx: Context,
1012 requests: Pin<Box<dyn Stream<Item = Result<OwnedView<#input_view_type<'static>>, ConnectError>> + Send>>,
1013 ) -> impl Future<Output = Result<(Pin<Box<dyn Stream<Item = Result<#output_type, ConnectError>> + Send>>, Context), ConnectError>> + Send;
1014 })
1015 } else {
1016 Ok(quote! {
1018 #method_doc_tokens
1019 fn #method_snake(
1020 &self,
1021 ctx: Context,
1022 request: OwnedView<#input_view_type<'static>>,
1023 ) -> impl Future<Output = Result<(#output_type, Context), ConnectError>> + Send;
1024 })
1025 }
1026}
1027
1028fn generate_client_method(
1039 full_service_name: &str,
1040 method: &MethodDescriptorProto,
1041 resolver: &TypeResolver<'_>,
1042 package: &str,
1043) -> Result<TokenStream> {
1044 let method_name = method.name.as_deref().unwrap_or("");
1045 let method_snake = format_ident!("{}", method_name.to_snake_case());
1046 let method_with_opts = format_ident!("{}_with_options", method_name.to_snake_case());
1047 let input_type = resolver.rust_type(method.input_type.as_deref().unwrap_or(""), package)?;
1048 let output_view_type =
1049 resolver.rust_view_type(method.output_type.as_deref().unwrap_or(""), package)?;
1050
1051 let client_streaming = method.client_streaming.unwrap_or(false);
1052 let server_streaming = method.server_streaming.unwrap_or(false);
1053
1054 let doc = format!(
1055 " Call the {method_name} RPC. Sends a request to /{full_service_name}/{method_name}."
1056 );
1057 let doc_opts = format!(
1058 " Call the {method_name} RPC with explicit per-call options. \
1059 Options override [`ClientConfig`] defaults."
1060 );
1061
1062 let ret_ty: TokenStream;
1064 let call_body: TokenStream;
1065 let short_args: TokenStream; let opts_args: TokenStream; let short_delegate_args: TokenStream; if client_streaming && !server_streaming {
1070 ret_ty = quote! {
1072 Result<
1073 ::connectrpc::client::UnaryResponse<OwnedView<#output_view_type<'static>>>,
1074 ConnectError,
1075 >
1076 };
1077 call_body = quote! {
1078 call_client_stream(
1079 &self.transport, &self.config,
1080 #full_service_name, #method_name,
1081 requests, options,
1082 ).await
1083 };
1084 short_args = quote! { requests: impl IntoIterator<Item = #input_type> };
1085 opts_args =
1086 quote! { requests: impl IntoIterator<Item = #input_type>, options: CallOptions };
1087 short_delegate_args = quote! { requests, CallOptions::default() };
1088 } else if client_streaming && server_streaming {
1089 ret_ty = quote! {
1091 Result<
1092 ::connectrpc::client::BidiStream<
1093 T::ResponseBody, #input_type, #output_view_type<'static>
1094 >,
1095 ConnectError,
1096 >
1097 };
1098 call_body = quote! {
1099 call_bidi_stream(
1100 &self.transport, &self.config,
1101 #full_service_name, #method_name, options,
1102 ).await
1103 };
1104 short_args = quote! {};
1105 opts_args = quote! { options: CallOptions };
1106 short_delegate_args = quote! { CallOptions::default() };
1107 } else if server_streaming {
1108 ret_ty = quote! {
1110 Result<
1111 ::connectrpc::client::ServerStream<T::ResponseBody, #output_view_type<'static>>,
1112 ConnectError,
1113 >
1114 };
1115 call_body = quote! {
1116 call_server_stream(
1117 &self.transport, &self.config,
1118 #full_service_name, #method_name,
1119 request, options,
1120 ).await
1121 };
1122 short_args = quote! { request: #input_type };
1123 opts_args = quote! { request: #input_type, options: CallOptions };
1124 short_delegate_args = quote! { request, CallOptions::default() };
1125 } else {
1126 ret_ty = quote! {
1128 Result<
1129 ::connectrpc::client::UnaryResponse<OwnedView<#output_view_type<'static>>>,
1130 ConnectError,
1131 >
1132 };
1133 call_body = quote! {
1134 call_unary(
1135 &self.transport, &self.config,
1136 #full_service_name, #method_name,
1137 request, options,
1138 ).await
1139 };
1140 short_args = quote! { request: #input_type };
1141 opts_args = quote! { request: #input_type, options: CallOptions };
1142 short_delegate_args = quote! { request, CallOptions::default() };
1143 }
1144
1145 Ok(quote! {
1146 #[doc = #doc]
1147 pub async fn #method_snake(&self, #short_args) -> #ret_ty {
1148 self.#method_with_opts(#short_delegate_args).await
1149 }
1150
1151 #[doc = #doc_opts]
1152 pub async fn #method_with_opts(&self, #opts_args) -> #ret_ty {
1153 #call_body
1154 }
1155 })
1156}
1157
1158fn get_service_comment(
1160 file: &FileDescriptorProto,
1161 service: &ServiceDescriptorProto,
1162) -> Option<String> {
1163 let source_info: &SourceCodeInfo = &file.source_code_info;
1165
1166 let service_index = file.service.iter().position(|s| s.name == service.name)?;
1168
1169 let target_path = vec![6, service_index as i32];
1172
1173 find_comment(source_info, &target_path)
1174}
1175
1176fn get_method_comment(
1178 file: &FileDescriptorProto,
1179 service: &ServiceDescriptorProto,
1180 method: &MethodDescriptorProto,
1181) -> Option<String> {
1182 let source_info: &SourceCodeInfo = &file.source_code_info;
1183
1184 let (service_index, method_index) = file.service.iter().enumerate().find_map(|(si, s)| {
1187 if s.name != service.name {
1188 return None;
1189 }
1190 s.method
1191 .iter()
1192 .position(|m| m.name == method.name)
1193 .map(|mi| (si, mi))
1194 })?;
1195
1196 let target_path = vec![6, service_index as i32, 2, method_index as i32];
1200
1201 find_comment(source_info, &target_path)
1202}
1203
1204fn find_comment(source_info: &SourceCodeInfo, target_path: &[i32]) -> Option<String> {
1206 for location in &source_info.location {
1207 if location.path == target_path {
1208 let comment = location
1209 .leading_comments
1210 .as_ref()
1211 .or(location.trailing_comments.as_ref())?;
1212
1213 let cleaned: String = comment
1217 .lines()
1218 .map(|line| line.trim())
1219 .filter(|line| !line.is_empty())
1220 .collect::<Vec<_>>()
1221 .join("\n");
1222
1223 if !cleaned.is_empty() {
1224 return Some(cleaned);
1225 }
1226 }
1227 }
1228 None
1229}
1230
1231#[cfg(test)]
1232mod tests {
1233 use super::*;
1234 use buffa_codegen::generated::descriptor::DescriptorProto;
1235
1236 #[test]
1237 fn doc_attrs_prefixes_space_for_prettyplease() {
1238 let ts = quote! {
1241 #[allow(dead_code)]
1242 mod m {}
1243 };
1244 let doc = doc_attrs("Hello.\n\nSecond paragraph.");
1245 let combined = quote! { #doc #ts };
1246 let file = syn::parse2::<syn::File>(combined).unwrap();
1247 let out = prettyplease::unparse(&file);
1248 assert!(out.contains("/// Hello."), "got: {out}");
1250 assert!(out.contains("/// Second paragraph."), "got: {out}");
1251 assert!(out.contains("///\n"), "got: {out}");
1253 assert!(!out.contains("///Hello"), "got: {out}");
1255 assert!(!out.contains("/// Hello"), "got: {out}");
1256 }
1257
1258 fn minimal_file(
1263 package: Option<&str>,
1264 input_type: &str,
1265 output_type: &str,
1266 local_messages: &[&str],
1267 ) -> FileDescriptorProto {
1268 let method = MethodDescriptorProto {
1269 name: Some("Ping".into()),
1270 input_type: Some(input_type.into()),
1271 output_type: Some(output_type.into()),
1272 ..Default::default()
1273 };
1274 let service = ServiceDescriptorProto {
1275 name: Some("PingService".into()),
1276 method: vec![method],
1277 ..Default::default()
1278 };
1279 FileDescriptorProto {
1280 name: Some("ping.proto".into()),
1281 package: package.map(|p| p.into()),
1282 service: vec![service],
1283 message_type: local_messages
1284 .iter()
1285 .map(|name| DescriptorProto {
1286 name: Some((*name).into()),
1287 ..Default::default()
1288 })
1289 .collect(),
1290 ..Default::default()
1291 }
1292 }
1293
1294 fn gen_service(
1303 files: &[FileDescriptorProto],
1304 target_idx: usize,
1305 extern_paths: &[(String, String)],
1306 require_extern: bool,
1307 ) -> Result<String> {
1308 let mut config = buffa_codegen::CodeGenConfig::default();
1309 config.extern_paths = extern_paths.to_vec();
1310 let target_name = files[target_idx]
1311 .name
1312 .clone()
1313 .into_iter()
1314 .collect::<Vec<_>>();
1315 let resolver = TypeResolver::new(files, &target_name, &config, require_extern);
1316 let file = &files[target_idx];
1317 let service = &file.service[0];
1318 Ok(generate_service(file, service, &resolver)?.to_string())
1319 }
1320
1321 #[test]
1322 fn service_name_with_package() {
1323 let file = minimal_file(
1324 Some("example.v1"),
1325 ".example.v1.PingReq",
1326 ".example.v1.PingResp",
1327 &["PingReq", "PingResp"],
1328 );
1329 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
1330 assert!(code.contains("\"example.v1.PingService\""), "got: {code}");
1331 }
1332
1333 #[test]
1334 fn service_name_without_package() {
1335 let file = minimal_file(None, ".PingReq", ".PingResp", &["PingReq", "PingResp"]);
1337 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
1338 assert!(code.contains("\"PingService\""), "got: {code}");
1339 assert!(
1340 !code.contains("\".PingService\""),
1341 "must not have leading dot: {code}"
1342 );
1343 }
1344
1345 #[test]
1346 fn same_package_types_use_bare_names() {
1347 let file = minimal_file(
1348 Some("example.v1"),
1349 ".example.v1.PingReq",
1350 ".example.v1.PingResp",
1351 &["PingReq", "PingResp"],
1352 );
1353 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
1354 assert!(code.contains("PingReq"), "input type missing: {code}");
1356 assert!(code.contains("PingResp"), "output type missing: {code}");
1357 assert!(
1359 !code.contains("super :: PingReq"),
1360 "unexpected super: {code}"
1361 );
1362 }
1363
1364 #[test]
1365 fn cross_package_types_use_relative_paths() {
1366 let common = FileDescriptorProto {
1370 name: Some("common.proto".into()),
1371 package: Some("common.v1".into()),
1372 message_type: vec![DescriptorProto {
1373 name: Some("Shared".into()),
1374 ..Default::default()
1375 }],
1376 ..Default::default()
1377 };
1378 let svc = minimal_file(
1379 Some("example.v1"),
1380 ".common.v1.Shared",
1381 ".example.v1.Out",
1382 &["Out"],
1383 );
1384 let code = gen_service(&[common, svc], 1, &[], false).unwrap();
1385
1386 assert!(
1389 code.contains("super :: super :: common :: v1 :: Shared"),
1390 "cross-package path not emitted: {code}"
1391 );
1392 assert!(
1393 code.contains("super :: super :: common :: v1 :: SharedView"),
1394 "cross-package view path not emitted: {code}"
1395 );
1396 }
1397
1398 #[test]
1399 fn wkt_types_use_buffa_types_extern_path() {
1400 let wkt = FileDescriptorProto {
1404 name: Some("google/protobuf/empty.proto".into()),
1405 package: Some("google.protobuf".into()),
1406 message_type: vec![DescriptorProto {
1407 name: Some("Empty".into()),
1408 ..Default::default()
1409 }],
1410 ..Default::default()
1411 };
1412 let svc = minimal_file(
1413 Some("example.v1"),
1414 ".google.protobuf.Empty",
1415 ".example.v1.Out",
1416 &["Out"],
1417 );
1418 let code = gen_service(&[wkt, svc], 1, &[], false).unwrap();
1419
1420 assert!(
1421 code.contains(":: buffa_types :: google :: protobuf :: Empty"),
1422 "WKT extern path not emitted: {code}"
1423 );
1424 }
1425
1426 #[test]
1427 fn extern_catchall_uses_absolute_paths() {
1428 let file = minimal_file(
1429 Some("example.v1"),
1430 ".example.v1.PingReq",
1431 ".example.v1.PingResp",
1432 &["PingReq", "PingResp"],
1433 );
1434 let extern_paths = [(".".into(), "crate::proto".into())];
1435 let code = gen_service(std::slice::from_ref(&file), 0, &extern_paths, true).unwrap();
1436 assert!(
1437 code.contains("crate :: proto :: example :: v1 :: PingReq"),
1438 "owned type path missing: {code}"
1439 );
1440 assert!(
1441 code.contains("crate :: proto :: example :: v1 :: PingReqView"),
1442 "view type path missing: {code}"
1443 );
1444 }
1445
1446 #[test]
1447 fn extern_catchall_with_wkt_longest_wins() {
1448 let wkt = FileDescriptorProto {
1451 name: Some("google/protobuf/empty.proto".into()),
1452 package: Some("google.protobuf".into()),
1453 message_type: vec![DescriptorProto {
1454 name: Some("Empty".into()),
1455 ..Default::default()
1456 }],
1457 ..Default::default()
1458 };
1459 let svc = minimal_file(
1460 Some("example.v1"),
1461 ".google.protobuf.Empty",
1462 ".example.v1.Out",
1463 &["Out"],
1464 );
1465 let extern_paths = [(".".into(), "crate::proto".into())];
1466 let code = gen_service(&[wkt, svc], 1, &extern_paths, true).unwrap();
1467 assert!(
1468 code.contains(":: buffa_types :: google :: protobuf :: Empty"),
1469 "WKT mapping lost to catch-all: {code}"
1470 );
1471 assert!(
1472 code.contains("crate :: proto :: example :: v1 :: Out"),
1473 "local type not routed through catch-all: {code}"
1474 );
1475 }
1476
1477 #[test]
1478 fn missing_extern_path_errors() {
1479 let file = minimal_file(
1480 Some("example.v1"),
1481 ".example.v1.PingReq",
1482 ".example.v1.PingResp",
1483 &["PingReq", "PingResp"],
1484 );
1485 let err = gen_service(std::slice::from_ref(&file), 0, &[], true).unwrap_err();
1486 let msg = err.to_string();
1487 assert!(
1488 msg.contains("extern_path"),
1489 "error message lacks hint: {msg}"
1490 );
1491 }
1492
1493 #[test]
1494 fn keyword_package_escaped() {
1495 let file = minimal_file(
1497 Some("google.type"),
1498 ".google.type.LatLng",
1499 ".google.type.LatLng",
1500 &["LatLng"],
1501 );
1502 let extern_paths = [(".".into(), "crate::proto".into())];
1503 let code = gen_service(std::slice::from_ref(&file), 0, &extern_paths, true).unwrap();
1504 assert!(
1505 code.contains("crate :: proto :: google :: r#type :: LatLng"),
1506 "keyword segment not escaped: {code}"
1507 );
1508 }
1509}