1use std::collections::HashMap;
12
13use anyhow::Result;
14use heck::ToSnakeCase;
15use heck::ToUpperCamelCase;
16use proc_macro2::{Ident, TokenStream};
17use quote::format_ident;
18use quote::quote;
19
20use buffa_codegen::generated::descriptor::FileDescriptorProto;
21use buffa_codegen::generated::descriptor::MethodDescriptorProto;
22use buffa_codegen::generated::descriptor::ServiceDescriptorProto;
23use buffa_codegen::generated::descriptor::SourceCodeInfo;
24use buffa_codegen::generated::descriptor::method_options::IdempotencyLevel;
25use buffa_codegen::idents::make_field_ident;
26use buffa_codegen::idents::rust_path_to_tokens;
27
28pub use buffa_codegen::GeneratedFile;
29pub use buffa_codegen::generated::descriptor;
30
31use crate::plugin::CodeGeneratorRequest;
32use crate::plugin::CodeGeneratorResponse;
33use crate::plugin::CodeGeneratorResponseFile;
34
35#[derive(Debug, Clone)]
42#[non_exhaustive]
43pub struct Options {
44 pub strict_utf8_mapping: bool,
48 pub generate_json: bool,
53 pub extern_paths: Vec<(String, String)>,
64 pub emit_register_fn: bool,
73}
74
75impl Default for Options {
76 fn default() -> Self {
77 Self {
78 strict_utf8_mapping: false,
79 generate_json: true,
80 extern_paths: Vec::new(),
81 emit_register_fn: true,
82 }
83 }
84}
85
86impl Options {
87 fn to_buffa_config(&self) -> buffa_codegen::CodeGenConfig {
88 let mut config = buffa_codegen::CodeGenConfig::default();
89 config.generate_views = true;
90 config.generate_json = self.generate_json;
91 config.strict_utf8_mapping = self.strict_utf8_mapping;
92 config.extern_paths.clone_from(&self.extern_paths);
93 config.emit_register_fn = self.emit_register_fn;
94 config
95 }
96}
97
98fn emit_service_files(
101 proto_file: &[FileDescriptorProto],
102 file_to_generate: &[String],
103 resolver: &TypeResolver<'_>,
104) -> Result<Vec<GeneratedFile>> {
105 let mut out = Vec::new();
106 for file_name in file_to_generate {
107 let file_desc = proto_file
108 .iter()
109 .find(|f| f.name.as_deref() == Some(file_name.as_str()));
110
111 if let Some(file) = file_desc
112 && !file.service.is_empty()
113 {
114 let service_tokens = generate_connect_services(file, resolver)?;
115 let service_code = format_token_stream(&service_tokens)?;
116 out.push(GeneratedFile {
117 name: buffa_codegen::proto_path_to_rust_module(file_name),
118 content: service_code,
119 });
120 }
121 }
122 Ok(out)
123}
124
125pub fn generate_files(
142 proto_file: &[FileDescriptorProto],
143 file_to_generate: &[String],
144 options: &Options,
145) -> Result<Vec<GeneratedFile>> {
146 let config = options.to_buffa_config();
147
148 let mut files = buffa_codegen::generate(proto_file, file_to_generate, &config)
149 .map_err(|e| anyhow::anyhow!("buffa-codegen failed: {e}"))?;
150
151 let resolver = TypeResolver::new(proto_file, file_to_generate, &config, false);
152 let service_files = emit_service_files(proto_file, file_to_generate, &resolver)?;
153
154 for svc in service_files {
156 if let Some(out) = files.iter_mut().find(|g| g.name == svc.name) {
157 out.content.push('\n');
158 out.content.push_str(&svc.content);
159 }
160 }
161
162 Ok(files)
163}
164
165pub fn generate_services(
182 proto_file: &[FileDescriptorProto],
183 file_to_generate: &[String],
184 options: &Options,
185) -> Result<Vec<GeneratedFile>> {
186 let config = options.to_buffa_config();
187 let resolver = TypeResolver::new(proto_file, file_to_generate, &config, true);
188 emit_service_files(proto_file, file_to_generate, &resolver)
189}
190
191pub fn generate(request: &CodeGeneratorRequest) -> Result<CodeGeneratorResponse> {
216 let mut options = Options::default();
217
218 if let Some(ref param) = request.parameter {
219 for opt in param.split(',').map(str::trim).filter(|s| !s.is_empty()) {
220 if let Some(value) = opt.strip_prefix("buffa_module=") {
221 let rust = value.trim();
222 if rust.is_empty() {
223 anyhow::bail!(
224 "buffa_module requires a non-empty path, \
225 e.g. buffa_module=crate::proto"
226 );
227 }
228 options.extern_paths.push((".".into(), rust.to_string()));
229 } else if let Some(value) = opt.strip_prefix("extern_path=") {
230 let (proto, rust) = value.split_once('=').ok_or_else(|| {
232 anyhow::anyhow!(
233 "invalid extern_path format {value:?}, expected \
234 extern_path=.proto.pkg=::rust::path"
235 )
236 })?;
237 let proto = proto.trim();
238 let rust = rust.trim();
239 if proto.is_empty() || rust.is_empty() {
240 anyhow::bail!(
241 "invalid extern_path format {value:?}, expected \
242 extern_path=.proto.pkg=::rust::path (both sides non-empty)"
243 );
244 }
245 let mut proto = proto.to_string();
246 if !proto.starts_with('.') {
247 proto.insert(0, '.');
248 }
249 options.extern_paths.push((proto, rust.to_string()));
250 } else {
251 match opt {
252 "strict_utf8_mapping" => options.strict_utf8_mapping = true,
253 "no_json" => options.generate_json = false,
254 "no_register_fn" => options.emit_register_fn = false,
255 _ => {
256 return Err(anyhow::anyhow!(
257 "unknown plugin option: {opt:?}. Supported: \
258 buffa_module=<rust_path>, extern_path=<proto>=<rust>, \
259 strict_utf8_mapping, no_json, no_register_fn"
260 ));
261 }
262 }
263 }
264 }
265 }
266
267 let generated = generate_services(&request.proto_file, &request.file_to_generate, &options)?;
268
269 let files: Vec<CodeGeneratorResponseFile> = generated
270 .into_iter()
271 .map(|g| CodeGeneratorResponseFile {
272 name: Some(g.name),
273 content: Some(g.content),
274 ..Default::default()
275 })
276 .collect();
277
278 Ok(CodeGeneratorResponse {
279 supported_features: Some(feature_flags()),
280 minimum_edition: Some(EDITION_2023),
281 maximum_edition: Some(EDITION_2023),
282 file: files,
283 ..Default::default()
284 })
285}
286
287fn feature_flags() -> u64 {
290 const FEATURE_PROTO3_OPTIONAL: u64 = 1;
291 const FEATURE_SUPPORTS_EDITIONS: u64 = 2;
292 FEATURE_PROTO3_OPTIONAL | FEATURE_SUPPORTS_EDITIONS
293}
294
295const EDITION_2023: i32 = 1000;
298
299fn format_token_stream(tokens: &TokenStream) -> Result<String> {
301 let file = syn::parse2::<syn::File>(tokens.clone())
302 .map_err(|e| anyhow::anyhow!("generated code failed to parse: {e}"))?;
303 Ok(prettyplease::unparse(&file))
304}
305
306fn doc_attrs(text: &str) -> TokenStream {
315 let lines: Vec<String> = text
316 .lines()
317 .map(|l| {
318 if l.is_empty() {
319 String::new()
320 } else {
321 format!(" {l}")
322 }
323 })
324 .collect();
325 quote! { #(#[doc = #lines])* }
326}
327
328struct TypeResolver<'a> {
341 ctx: buffa_codegen::context::CodeGenContext<'a>,
342 require_extern: bool,
348}
349
350impl<'a> TypeResolver<'a> {
351 fn new(
352 proto_file: &'a [FileDescriptorProto],
353 file_to_generate: &[String],
354 config: &'a buffa_codegen::CodeGenConfig,
355 require_extern: bool,
356 ) -> Self {
357 Self {
358 ctx: buffa_codegen::context::CodeGenContext::for_generate(
359 proto_file,
360 file_to_generate,
361 config,
362 ),
363 require_extern,
364 }
365 }
366
367 fn resolve_path(&self, proto_fqn: &str, current_package: &str) -> Result<String> {
374 match self.ctx.rust_type_relative(proto_fqn, current_package, 0) {
375 Some(path) => {
376 if self.require_extern && !path.starts_with("::") && !path.starts_with("crate::") {
377 anyhow::bail!(
378 "type {proto_fqn} is not covered by any extern_path mapping. \
379 Add extern_path=.=<your_buffa_module> (e.g. \
380 extern_path=.=crate::proto) to the plugin opts."
381 );
382 }
383 Ok(path)
384 }
385 None if self.require_extern => anyhow::bail!(
386 "type {proto_fqn} not found in descriptor set (missing proto import?)"
387 ),
388 None => Ok(bare_type_name(proto_fqn).to_string()),
389 }
390 }
391
392 fn rust_type(&self, proto_fqn: &str, current_package: &str) -> Result<TokenStream> {
394 let path = self.resolve_path(proto_fqn, current_package)?;
395 Ok(rust_path_to_tokens(&path))
396 }
397
398 fn rust_view_type(&self, proto_fqn: &str, current_package: &str) -> Result<TokenStream> {
401 let path = self.resolve_path(proto_fqn, current_package)?;
402 Ok(rust_path_to_tokens(&format!("{path}View")))
403 }
404}
405
406fn bare_type_name(proto_fqn: &str) -> &str {
409 proto_fqn
410 .strip_prefix('.')
411 .unwrap_or(proto_fqn)
412 .rsplit('.')
413 .next()
414 .unwrap_or(proto_fqn)
415}
416
417fn generate_connect_services(
423 file: &FileDescriptorProto,
424 resolver: &TypeResolver<'_>,
425) -> Result<TokenStream> {
426 let mut tokens = TokenStream::new();
427
428 for service in &file.service {
434 tokens.extend(generate_service(file, service, resolver)?);
435 }
436
437 Ok(tokens)
438}
439
440fn check_method_collisions(service_name: &str, service: &ServiceDescriptorProto) -> Result<()> {
449 let mut seen: HashMap<String, String> = HashMap::new();
450 for m in &service.method {
451 let proto_name = m.name.as_deref().unwrap_or("");
452 let snake = proto_name.to_snake_case();
453 let with_opts = format!("{snake}_with_options");
454 for ident in [snake.as_str(), with_opts.as_str()] {
455 if let Some(prev) = seen.get(ident) {
456 anyhow::bail!(
457 "service {service_name}: RPC methods {prev:?} and {proto_name:?} \
458 both generate Rust identifier `{ident}`; rename one in the proto"
459 );
460 }
461 }
462 seen.insert(snake, proto_name.to_string());
463 seen.insert(with_opts, proto_name.to_string());
464 }
465 Ok(())
466}
467
468fn generate_service(
469 file: &FileDescriptorProto,
470 service: &ServiceDescriptorProto,
471 resolver: &TypeResolver<'_>,
472) -> Result<TokenStream> {
473 let package = file.package.as_deref().unwrap_or("");
474 let service_name = service.name.as_deref().unwrap_or("");
475 check_method_collisions(service_name, service)?;
476 let full_service_name = if package.is_empty() {
479 service_name.to_string()
480 } else {
481 format!("{package}.{service_name}")
482 };
483 let service_upper = service_name.to_upper_camel_case();
484 let trait_name = if service_upper == "Self" {
488 format_ident!("Self_")
489 } else {
490 format_ident!("{}", service_upper)
491 };
492 let ext_trait_name = format_ident!("{}Ext", service_upper);
493 let client_name = format_ident!("{}Client", service_upper);
494 let server_name = format_ident!("{}Server", service_upper);
495 let service_name_const = format_ident!(
496 "{}_SERVICE_NAME",
497 service_name.to_snake_case().to_uppercase()
498 );
499
500 let service_doc = get_service_comment(file, service).unwrap_or_default();
502 let base_doc = if service_doc.is_empty() {
503 format!("Server trait for {service_name}.")
504 } else {
505 service_doc
506 };
507 let full_doc = format!(
508 "{base_doc}\n\n\
509 # Implementing handlers\n\n\
510 Handlers receive requests as `OwnedView<FooView<'static>>`, which gives\n\
511 zero-copy borrowed access to fields (e.g. `request.name` is a `&str`\n\
512 into the decoded buffer). The view can be held across `.await` points.\n\n\
513 Implement methods with plain `async fn`; the returned future satisfies\n\
514 the `Send` bound automatically. See the\n\
515 [buffa user guide](https://github.com/anthropics/buffa/blob/main/docs/guide.md#ownedview-in-async-trait-implementations)\n\
516 for zero-copy access patterns and when `to_owned_message()` is needed."
517 );
518 let service_doc_tokens = doc_attrs(&full_doc);
519
520 let trait_methods: Vec<TokenStream> = service
522 .method
523 .iter()
524 .map(|m| generate_trait_method(file, service, m, resolver, package))
525 .collect::<Result<Vec<_>>>()?;
526
527 let route_registrations: Vec<TokenStream> = service
529 .method
530 .iter()
531 .map(|m| {
532 let method_name = m.name.as_deref().unwrap_or("");
533 let method_snake = make_field_ident(&method_name.to_snake_case());
534
535 let client_streaming = m.client_streaming.unwrap_or(false);
536 let server_streaming = m.server_streaming.unwrap_or(false);
537
538 if server_streaming && !client_streaming {
539 quote! {
541 .route_view_server_stream(
542 #service_name_const,
543 #method_name,
544 ::connectrpc::view_streaming_handler_fn({
545 let svc = ::std::sync::Arc::clone(&self);
546 move |ctx, req| {
547 let svc = ::std::sync::Arc::clone(&svc);
548 async move { svc.#method_snake(ctx, req).await }
549 }
550 }),
551 )
552 }
553 } else if client_streaming && !server_streaming {
554 quote! {
556 .route_view_client_stream(
557 #service_name_const,
558 #method_name,
559 ::connectrpc::view_client_streaming_handler_fn({
560 let svc = ::std::sync::Arc::clone(&self);
561 move |ctx, req| {
562 let svc = ::std::sync::Arc::clone(&svc);
563 async move { svc.#method_snake(ctx, req).await }
564 }
565 }),
566 )
567 }
568 } else if client_streaming && server_streaming {
569 quote! {
571 .route_view_bidi_stream(
572 #service_name_const,
573 #method_name,
574 ::connectrpc::view_bidi_streaming_handler_fn({
575 let svc = ::std::sync::Arc::clone(&self);
576 move |ctx, req| {
577 let svc = ::std::sync::Arc::clone(&svc);
578 async move { svc.#method_snake(ctx, req).await }
579 }
580 }),
581 )
582 }
583 } else {
584 let is_idempotent = m
586 .options
587 .idempotency_level
588 .map(|level| level == IdempotencyLevel::NO_SIDE_EFFECTS)
589 .unwrap_or(false);
590
591 let route_method = if is_idempotent {
592 quote! { route_view_idempotent }
593 } else {
594 quote! { route_view }
595 };
596
597 quote! {
598 .#route_method(
599 #service_name_const,
600 #method_name,
601 {
602 let svc = ::std::sync::Arc::clone(&self);
603 ::connectrpc::view_handler_fn(move |ctx, req| {
604 let svc = ::std::sync::Arc::clone(&svc);
605 async move { svc.#method_snake(ctx, req).await }
606 })
607 },
608 )
609 }
610 }
611 })
612 .collect();
613
614 let client_methods: Vec<TokenStream> = service
616 .method
617 .iter()
618 .map(|m| {
619 generate_client_method(
620 &service_name_const,
621 &full_service_name,
622 m,
623 resolver,
624 package,
625 )
626 })
627 .collect::<Result<Vec<_>>>()?;
628
629 let service_server = generate_service_server(
631 &full_service_name,
632 &trait_name,
633 &server_name,
634 service,
635 resolver,
636 package,
637 )?;
638
639 let example_method = service
641 .method
642 .first()
643 .and_then(|m| m.name.as_deref())
644 .map(|n| make_field_ident(&n.to_snake_case()).to_string())
645 .unwrap_or_else(|| "method".to_string());
646
647 let client_name_str = client_name.to_string();
649 let client_doc = format!(
650 r#"Client for this service.
651
652Generic over `T: ClientTransport`. For **gRPC** (HTTP/2), use
653`Http2Connection` — it has honest `poll_ready` and composes with
654`tower::balance` for multi-connection load balancing. For **Connect
655over HTTP/1.1** (or unknown protocol), use `HttpClient`.
656
657# Example (gRPC / HTTP/2)
658
659```rust,ignore
660use connectrpc::client::{{Http2Connection, ClientConfig}};
661use connectrpc::Protocol;
662
663let uri: http::Uri = "http://localhost:8080".parse()?;
664let conn = Http2Connection::connect_plaintext(uri.clone()).await?.shared(1024);
665let config = ClientConfig::new(uri).protocol(Protocol::Grpc);
666
667let client = {client_name_str}::new(conn, config);
668let response = client.{example_method}(request).await?;
669```
670
671# Example (Connect / HTTP/1.1 or ALPN)
672
673```rust,ignore
674use connectrpc::client::{{HttpClient, ClientConfig}};
675
676let http = HttpClient::plaintext(); // cleartext http:// only
677let config = ClientConfig::new("http://localhost:8080".parse()?);
678
679let client = {client_name_str}::new(http, config);
680let response = client.{example_method}(request).await?;
681```
682
683# Working with the response
684
685Unary calls return [`UnaryResponse<OwnedView<FooView>>`](::connectrpc::client::UnaryResponse).
686The `OwnedView` derefs to the view, so field access is zero-copy:
687
688```rust,ignore
689let resp = client.{example_method}(request).await?.into_view();
690let name: &str = resp.name; // borrow into the response buffer
691```
692
693If you need the owned struct (e.g. to store or pass by value), use
694[`into_owned()`](::connectrpc::client::UnaryResponse::into_owned):
695
696```rust,ignore
697let owned = client.{example_method}(request).await?.into_owned();
698```"#
699 );
700 let client_doc_tokens = doc_attrs(&client_doc);
701
702 Ok(quote! {
703 pub const #service_name_const: &str = #full_service_name;
709
710 #service_doc_tokens
711 #[allow(clippy::type_complexity)]
712 pub trait #trait_name: Send + Sync + 'static {
713 #(#trait_methods)*
714 }
715
716 pub trait #ext_trait_name: #trait_name {
729 fn register(self: ::std::sync::Arc<Self>, router: ::connectrpc::Router) -> ::connectrpc::Router;
734 }
735
736 impl<S: #trait_name> #ext_trait_name for S {
737 fn register(self: ::std::sync::Arc<Self>, router: ::connectrpc::Router) -> ::connectrpc::Router {
738 router
739 #(#route_registrations)*
740 }
741 }
742
743 #service_server
744
745 #client_doc_tokens
746 #[derive(Clone)]
747 pub struct #client_name<T> {
748 transport: T,
749 config: ::connectrpc::client::ClientConfig,
750 }
751
752 impl<T> #client_name<T>
753 where
754 T: ::connectrpc::client::ClientTransport,
755 <T::ResponseBody as ::http_body::Body>::Error: ::std::fmt::Display,
756 {
757 pub fn new(transport: T, config: ::connectrpc::client::ClientConfig) -> Self {
759 Self { transport, config }
760 }
761
762 pub fn config(&self) -> &::connectrpc::client::ClientConfig {
764 &self.config
765 }
766
767 pub fn config_mut(&mut self) -> &mut ::connectrpc::client::ClientConfig {
769 &mut self.config
770 }
771
772 #(#client_methods)*
773 }
774 })
775}
776
777fn generate_service_server(
784 full_service_name: &str,
785 trait_name: &proc_macro2::Ident,
786 server_name: &proc_macro2::Ident,
787 service: &ServiceDescriptorProto,
788 resolver: &TypeResolver<'_>,
789 package: &str,
790) -> Result<TokenStream> {
791 let path_prefix = format!("{full_service_name}/");
793
794 let lookup_arms: Vec<TokenStream> = service
796 .method
797 .iter()
798 .map(|m| {
799 let method_name = m.name.as_deref().unwrap_or("");
800 let client_streaming = m.client_streaming.unwrap_or(false);
801 let server_streaming = m.server_streaming.unwrap_or(false);
802 let is_idempotent = m
803 .options
804 .idempotency_level
805 .map(|level| level == IdempotencyLevel::NO_SIDE_EFFECTS)
806 .unwrap_or(false);
807
808 let desc = if client_streaming && server_streaming {
809 quote! { ::connectrpc::dispatcher::codegen::MethodDescriptor::bidi_streaming() }
810 } else if client_streaming {
811 quote! { ::connectrpc::dispatcher::codegen::MethodDescriptor::client_streaming() }
812 } else if server_streaming {
813 quote! { ::connectrpc::dispatcher::codegen::MethodDescriptor::server_streaming() }
814 } else {
815 quote! { ::connectrpc::dispatcher::codegen::MethodDescriptor::unary(#is_idempotent) }
816 };
817 quote! { #method_name => Some(#desc), }
818 })
819 .collect();
820
821 let mut call_unary_arms: Vec<TokenStream> = Vec::new();
826 let mut call_ss_arms: Vec<TokenStream> = Vec::new();
827 let mut call_cs_arms: Vec<TokenStream> = Vec::new();
828 let mut call_bidi_arms: Vec<TokenStream> = Vec::new();
829
830 for m in &service.method {
831 let method_name = m.name.as_deref().unwrap_or("");
832 let method_snake = make_field_ident(&method_name.to_snake_case());
833 let input_view = resolver.rust_view_type(m.input_type.as_deref().unwrap_or(""), package)?;
834 let cs = m.client_streaming.unwrap_or(false);
835 let ss = m.server_streaming.unwrap_or(false);
836
837 if cs && ss {
838 call_bidi_arms.push(quote! {
840 #method_name => {
841 let svc = ::std::sync::Arc::clone(&self.inner);
842 Box::pin(async move {
843 let req_stream = ::connectrpc::dispatcher::codegen::decode_view_request_stream::<#input_view>(requests, format);
844 let (resp_stream, ctx) = svc.#method_snake(ctx, req_stream).await?;
845 Ok((::connectrpc::dispatcher::codegen::encode_response_stream(resp_stream, format), ctx))
846 })
847 }
848 });
849 } else if cs {
850 call_cs_arms.push(quote! {
852 #method_name => {
853 let svc = ::std::sync::Arc::clone(&self.inner);
854 Box::pin(async move {
855 let req_stream = ::connectrpc::dispatcher::codegen::decode_view_request_stream::<#input_view>(requests, format);
856 let (res, ctx) = svc.#method_snake(ctx, req_stream).await?;
857 let bytes = ::connectrpc::dispatcher::codegen::encode_response(&res, format)?;
858 Ok((bytes, ctx))
859 })
860 }
861 });
862 } else if ss {
863 call_ss_arms.push(quote! {
865 #method_name => {
866 let svc = ::std::sync::Arc::clone(&self.inner);
867 Box::pin(async move {
868 let req = ::connectrpc::dispatcher::codegen::decode_request_view::<#input_view>(request, format)?;
869 let (resp_stream, ctx) = svc.#method_snake(ctx, req).await?;
870 Ok((::connectrpc::dispatcher::codegen::encode_response_stream(resp_stream, format), ctx))
871 })
872 }
873 });
874 } else {
875 call_unary_arms.push(quote! {
877 #method_name => {
878 let svc = ::std::sync::Arc::clone(&self.inner);
879 Box::pin(async move {
880 let req = ::connectrpc::dispatcher::codegen::decode_request_view::<#input_view>(request, format)?;
881 let (res, ctx) = svc.#method_snake(ctx, req).await?;
882 let bytes = ::connectrpc::dispatcher::codegen::encode_response(&res, format)?;
883 Ok((bytes, ctx))
884 })
885 }
886 });
887 }
888 }
889
890 let server_doc = format!(
891 "Monomorphic dispatcher for `{trait_name}`.\n\n\
892 Unlike `.register(Router)` which type-erases each method into an \
893 `Arc<dyn ErasedHandler>` stored in a `HashMap`, this struct dispatches \
894 via a compile-time `match` on method name: no vtable, no hash lookup.\n\n\
895 # Example\n\n\
896 ```rust,ignore\n\
897 use connectrpc::ConnectRpcService;\n\n\
898 let server = {server_name}::new(MyImpl);\n\
899 let service = ConnectRpcService::new(server);\n\
900 // hand `service` to axum/hyper as a fallback_service\n\
901 ```"
902 );
903 let server_doc_tokens = doc_attrs(&server_doc);
904
905 Ok(quote! {
906 #server_doc_tokens
907 pub struct #server_name<T> {
908 inner: ::std::sync::Arc<T>,
909 }
910
911 impl<T: #trait_name> #server_name<T> {
912 pub fn new(service: T) -> Self {
914 Self { inner: ::std::sync::Arc::new(service) }
915 }
916
917 pub fn from_arc(inner: ::std::sync::Arc<T>) -> Self {
919 Self { inner }
920 }
921 }
922
923 impl<T> Clone for #server_name<T> {
924 fn clone(&self) -> Self {
925 Self { inner: ::std::sync::Arc::clone(&self.inner) }
926 }
927 }
928
929 impl<T: #trait_name> ::connectrpc::Dispatcher for #server_name<T> {
930 #[inline]
931 fn lookup(&self, path: &str) -> Option<::connectrpc::dispatcher::codegen::MethodDescriptor> {
932 let method = path.strip_prefix(#path_prefix)?;
933 match method {
934 #(#lookup_arms)*
935 _ => None,
936 }
937 }
938
939 fn call_unary(
940 &self,
941 path: &str,
942 ctx: ::connectrpc::Context,
943 request: ::buffa::bytes::Bytes,
944 format: ::connectrpc::CodecFormat,
945 ) -> ::connectrpc::dispatcher::codegen::UnaryResult {
946 let Some(method) = path.strip_prefix(#path_prefix) else {
947 return ::connectrpc::dispatcher::codegen::unimplemented_unary(path);
948 };
949 let _ = (&ctx, &request, &format);
951 match method {
952 #(#call_unary_arms)*
953 _ => ::connectrpc::dispatcher::codegen::unimplemented_unary(path),
954 }
955 }
956
957 fn call_server_streaming(
958 &self,
959 path: &str,
960 ctx: ::connectrpc::Context,
961 request: ::buffa::bytes::Bytes,
962 format: ::connectrpc::CodecFormat,
963 ) -> ::connectrpc::dispatcher::codegen::StreamingResult {
964 let Some(method) = path.strip_prefix(#path_prefix) else {
965 return ::connectrpc::dispatcher::codegen::unimplemented_streaming(path);
966 };
967 let _ = (&ctx, &request, &format);
968 match method {
969 #(#call_ss_arms)*
970 _ => ::connectrpc::dispatcher::codegen::unimplemented_streaming(path),
971 }
972 }
973
974 fn call_client_streaming(
975 &self,
976 path: &str,
977 ctx: ::connectrpc::Context,
978 requests: ::connectrpc::dispatcher::codegen::RequestStream,
979 format: ::connectrpc::CodecFormat,
980 ) -> ::connectrpc::dispatcher::codegen::UnaryResult {
981 let Some(method) = path.strip_prefix(#path_prefix) else {
982 return ::connectrpc::dispatcher::codegen::unimplemented_unary(path);
983 };
984 let _ = (&ctx, &requests, &format);
985 match method {
986 #(#call_cs_arms)*
987 _ => ::connectrpc::dispatcher::codegen::unimplemented_unary(path),
988 }
989 }
990
991 fn call_bidi_streaming(
992 &self,
993 path: &str,
994 ctx: ::connectrpc::Context,
995 requests: ::connectrpc::dispatcher::codegen::RequestStream,
996 format: ::connectrpc::CodecFormat,
997 ) -> ::connectrpc::dispatcher::codegen::StreamingResult {
998 let Some(method) = path.strip_prefix(#path_prefix) else {
999 return ::connectrpc::dispatcher::codegen::unimplemented_streaming(path);
1000 };
1001 let _ = (&ctx, &requests, &format);
1002 match method {
1003 #(#call_bidi_arms)*
1004 _ => ::connectrpc::dispatcher::codegen::unimplemented_streaming(path),
1005 }
1006 }
1007 }
1008 })
1009}
1010
1011fn generate_doc_comment(doc: &str, default: &str) -> TokenStream {
1013 let comment = if doc.is_empty() { default } else { doc };
1014 doc_attrs(comment)
1015}
1016
1017fn generate_trait_method(
1019 file: &FileDescriptorProto,
1020 service: &ServiceDescriptorProto,
1021 method: &MethodDescriptorProto,
1022 resolver: &TypeResolver<'_>,
1023 package: &str,
1024) -> Result<TokenStream> {
1025 let method_name = method.name.as_deref().unwrap_or("");
1026 let method_snake = make_field_ident(&method_name.to_snake_case());
1027 let input_view_type =
1028 resolver.rust_view_type(method.input_type.as_deref().unwrap_or(""), package)?;
1029 let output_type = resolver.rust_type(method.output_type.as_deref().unwrap_or(""), package)?;
1030
1031 let method_doc = get_method_comment(file, service, method).unwrap_or_default();
1033 let method_doc_tokens =
1034 generate_doc_comment(&method_doc, &format!("Handle the {method_name} RPC."));
1035
1036 let client_streaming = method.client_streaming.unwrap_or(false);
1038 let server_streaming = method.server_streaming.unwrap_or(false);
1039
1040 if server_streaming && !client_streaming {
1041 Ok(quote! {
1043 #method_doc_tokens
1044 fn #method_snake(
1045 &self,
1046 ctx: ::connectrpc::Context,
1047 request: ::buffa::view::OwnedView<#input_view_type<'static>>,
1048 ) -> impl ::std::future::Future<Output = Result<(::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<#output_type, ::connectrpc::ConnectError>> + Send>>, ::connectrpc::Context), ::connectrpc::ConnectError>> + Send;
1049 })
1050 } else if client_streaming && !server_streaming {
1051 Ok(quote! {
1053 #method_doc_tokens
1054 fn #method_snake(
1055 &self,
1056 ctx: ::connectrpc::Context,
1057 requests: ::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<::buffa::view::OwnedView<#input_view_type<'static>>, ::connectrpc::ConnectError>> + Send>>,
1058 ) -> impl ::std::future::Future<Output = Result<(#output_type, ::connectrpc::Context), ::connectrpc::ConnectError>> + Send;
1059 })
1060 } else if client_streaming && server_streaming {
1061 Ok(quote! {
1063 #method_doc_tokens
1064 fn #method_snake(
1065 &self,
1066 ctx: ::connectrpc::Context,
1067 requests: ::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<::buffa::view::OwnedView<#input_view_type<'static>>, ::connectrpc::ConnectError>> + Send>>,
1068 ) -> impl ::std::future::Future<Output = Result<(::std::pin::Pin<Box<dyn ::futures::Stream<Item = Result<#output_type, ::connectrpc::ConnectError>> + Send>>, ::connectrpc::Context), ::connectrpc::ConnectError>> + Send;
1069 })
1070 } else {
1071 Ok(quote! {
1073 #method_doc_tokens
1074 fn #method_snake(
1075 &self,
1076 ctx: ::connectrpc::Context,
1077 request: ::buffa::view::OwnedView<#input_view_type<'static>>,
1078 ) -> impl ::std::future::Future<Output = Result<(#output_type, ::connectrpc::Context), ::connectrpc::ConnectError>> + Send;
1079 })
1080 }
1081}
1082
1083fn generate_client_method(
1094 service_name_const: &Ident,
1095 full_service_name: &str,
1096 method: &MethodDescriptorProto,
1097 resolver: &TypeResolver<'_>,
1098 package: &str,
1099) -> Result<TokenStream> {
1100 let method_name = method.name.as_deref().unwrap_or("");
1101 let method_snake = make_field_ident(&method_name.to_snake_case());
1102 let method_with_opts = format_ident!("{}_with_options", method_name.to_snake_case());
1103 let input_type = resolver.rust_type(method.input_type.as_deref().unwrap_or(""), package)?;
1104 let output_view_type =
1105 resolver.rust_view_type(method.output_type.as_deref().unwrap_or(""), package)?;
1106
1107 let client_streaming = method.client_streaming.unwrap_or(false);
1108 let server_streaming = method.server_streaming.unwrap_or(false);
1109
1110 let doc = format!(
1111 " Call the {method_name} RPC. Sends a request to /{full_service_name}/{method_name}."
1112 );
1113 let doc_opts = format!(
1114 " Call the {method_name} RPC with explicit per-call options. \
1115 Options override [`connectrpc::client::ClientConfig`] defaults."
1116 );
1117
1118 let ret_ty: TokenStream;
1120 let call_body: TokenStream;
1121 let short_args: TokenStream; let opts_args: TokenStream; let short_delegate_args: TokenStream; if client_streaming && !server_streaming {
1126 ret_ty = quote! {
1128 Result<
1129 ::connectrpc::client::UnaryResponse<::buffa::view::OwnedView<#output_view_type<'static>>>,
1130 ::connectrpc::ConnectError,
1131 >
1132 };
1133 call_body = quote! {
1134 ::connectrpc::client::call_client_stream(
1135 &self.transport, &self.config,
1136 #service_name_const, #method_name,
1137 requests, options,
1138 ).await
1139 };
1140 short_args = quote! { requests: impl IntoIterator<Item = #input_type> };
1141 opts_args = quote! { requests: impl IntoIterator<Item = #input_type>, options: ::connectrpc::client::CallOptions };
1142 short_delegate_args = quote! { requests, ::connectrpc::client::CallOptions::default() };
1143 } else if client_streaming && server_streaming {
1144 ret_ty = quote! {
1146 Result<
1147 ::connectrpc::client::BidiStream<
1148 T::ResponseBody, #input_type, #output_view_type<'static>
1149 >,
1150 ::connectrpc::ConnectError,
1151 >
1152 };
1153 call_body = quote! {
1154 ::connectrpc::client::call_bidi_stream(
1155 &self.transport, &self.config,
1156 #service_name_const, #method_name, options,
1157 ).await
1158 };
1159 short_args = quote! {};
1160 opts_args = quote! { options: ::connectrpc::client::CallOptions };
1161 short_delegate_args = quote! { ::connectrpc::client::CallOptions::default() };
1162 } else if server_streaming {
1163 ret_ty = quote! {
1165 Result<
1166 ::connectrpc::client::ServerStream<T::ResponseBody, #output_view_type<'static>>,
1167 ::connectrpc::ConnectError,
1168 >
1169 };
1170 call_body = quote! {
1171 ::connectrpc::client::call_server_stream(
1172 &self.transport, &self.config,
1173 #service_name_const, #method_name,
1174 request, options,
1175 ).await
1176 };
1177 short_args = quote! { request: #input_type };
1178 opts_args = quote! { request: #input_type, options: ::connectrpc::client::CallOptions };
1179 short_delegate_args = quote! { request, ::connectrpc::client::CallOptions::default() };
1180 } else {
1181 ret_ty = quote! {
1183 Result<
1184 ::connectrpc::client::UnaryResponse<::buffa::view::OwnedView<#output_view_type<'static>>>,
1185 ::connectrpc::ConnectError,
1186 >
1187 };
1188 call_body = quote! {
1189 ::connectrpc::client::call_unary(
1190 &self.transport, &self.config,
1191 #service_name_const, #method_name,
1192 request, options,
1193 ).await
1194 };
1195 short_args = quote! { request: #input_type };
1196 opts_args = quote! { request: #input_type, options: ::connectrpc::client::CallOptions };
1197 short_delegate_args = quote! { request, ::connectrpc::client::CallOptions::default() };
1198 }
1199
1200 Ok(quote! {
1201 #[doc = #doc]
1202 pub async fn #method_snake(&self, #short_args) -> #ret_ty {
1203 self.#method_with_opts(#short_delegate_args).await
1204 }
1205
1206 #[doc = #doc_opts]
1207 pub async fn #method_with_opts(&self, #opts_args) -> #ret_ty {
1208 #call_body
1209 }
1210 })
1211}
1212
1213fn get_service_comment(
1215 file: &FileDescriptorProto,
1216 service: &ServiceDescriptorProto,
1217) -> Option<String> {
1218 let source_info: &SourceCodeInfo = &file.source_code_info;
1220
1221 let service_index = file.service.iter().position(|s| s.name == service.name)?;
1223
1224 let target_path = vec![6, service_index as i32];
1227
1228 find_comment(source_info, &target_path)
1229}
1230
1231fn get_method_comment(
1233 file: &FileDescriptorProto,
1234 service: &ServiceDescriptorProto,
1235 method: &MethodDescriptorProto,
1236) -> Option<String> {
1237 let source_info: &SourceCodeInfo = &file.source_code_info;
1238
1239 let (service_index, method_index) = file.service.iter().enumerate().find_map(|(si, s)| {
1242 if s.name != service.name {
1243 return None;
1244 }
1245 s.method
1246 .iter()
1247 .position(|m| m.name == method.name)
1248 .map(|mi| (si, mi))
1249 })?;
1250
1251 let target_path = vec![6, service_index as i32, 2, method_index as i32];
1255
1256 find_comment(source_info, &target_path)
1257}
1258
1259fn find_comment(source_info: &SourceCodeInfo, target_path: &[i32]) -> Option<String> {
1261 for location in &source_info.location {
1262 if location.path == target_path {
1263 let comment = location
1264 .leading_comments
1265 .as_ref()
1266 .or(location.trailing_comments.as_ref())?;
1267
1268 let cleaned: String = comment
1272 .lines()
1273 .map(|line| line.trim())
1274 .filter(|line| !line.is_empty())
1275 .collect::<Vec<_>>()
1276 .join("\n");
1277
1278 if !cleaned.is_empty() {
1279 return Some(cleaned);
1280 }
1281 }
1282 }
1283 None
1284}
1285
1286#[cfg(test)]
1287mod tests {
1288 use super::*;
1289 use buffa_codegen::generated::descriptor::DescriptorProto;
1290
1291 #[test]
1292 fn doc_attrs_prefixes_space_for_prettyplease() {
1293 let ts = quote! {
1296 #[allow(dead_code)]
1297 mod m {}
1298 };
1299 let doc = doc_attrs("Hello.\n\nSecond paragraph.");
1300 let combined = quote! { #doc #ts };
1301 let file = syn::parse2::<syn::File>(combined).unwrap();
1302 let out = prettyplease::unparse(&file);
1303 assert!(out.contains("/// Hello."), "got: {out}");
1305 assert!(out.contains("/// Second paragraph."), "got: {out}");
1306 assert!(out.contains("///\n"), "got: {out}");
1308 assert!(!out.contains("///Hello"), "got: {out}");
1310 assert!(!out.contains("/// Hello"), "got: {out}");
1311 }
1312
1313 fn minimal_file(
1318 package: Option<&str>,
1319 input_type: &str,
1320 output_type: &str,
1321 local_messages: &[&str],
1322 ) -> FileDescriptorProto {
1323 minimal_file_with_method(package, "Ping", input_type, output_type, local_messages)
1324 }
1325
1326 fn minimal_file_with_method(
1329 package: Option<&str>,
1330 method_name: &str,
1331 input_type: &str,
1332 output_type: &str,
1333 local_messages: &[&str],
1334 ) -> FileDescriptorProto {
1335 let method = MethodDescriptorProto {
1336 name: Some(method_name.into()),
1337 input_type: Some(input_type.into()),
1338 output_type: Some(output_type.into()),
1339 ..Default::default()
1340 };
1341 let service = ServiceDescriptorProto {
1342 name: Some("PingService".into()),
1343 method: vec![method],
1344 ..Default::default()
1345 };
1346 FileDescriptorProto {
1347 name: Some("ping.proto".into()),
1348 package: package.map(|p| p.into()),
1349 service: vec![service],
1350 message_type: local_messages
1351 .iter()
1352 .map(|name| DescriptorProto {
1353 name: Some((*name).into()),
1354 ..Default::default()
1355 })
1356 .collect(),
1357 ..Default::default()
1358 }
1359 }
1360
1361 fn minimal_file_with_methods(package: &str, method_names: &[&str]) -> FileDescriptorProto {
1365 let methods = method_names
1366 .iter()
1367 .map(|n| MethodDescriptorProto {
1368 name: Some((*n).into()),
1369 input_type: Some(format!(".{package}.Empty")),
1370 output_type: Some(format!(".{package}.Empty")),
1371 ..Default::default()
1372 })
1373 .collect();
1374 let service = ServiceDescriptorProto {
1375 name: Some("PingService".into()),
1376 method: methods,
1377 ..Default::default()
1378 };
1379 FileDescriptorProto {
1380 name: Some("ping.proto".into()),
1381 package: Some(package.into()),
1382 service: vec![service],
1383 message_type: vec![DescriptorProto {
1384 name: Some("Empty".into()),
1385 ..Default::default()
1386 }],
1387 ..Default::default()
1388 }
1389 }
1390
1391 fn gen_service(
1400 files: &[FileDescriptorProto],
1401 target_idx: usize,
1402 extern_paths: &[(String, String)],
1403 require_extern: bool,
1404 ) -> Result<String> {
1405 let mut config = buffa_codegen::CodeGenConfig::default();
1406 config.extern_paths = extern_paths.to_vec();
1407 let target_name = files[target_idx]
1408 .name
1409 .clone()
1410 .into_iter()
1411 .collect::<Vec<_>>();
1412 let resolver = TypeResolver::new(files, &target_name, &config, require_extern);
1413 let file = &files[target_idx];
1414 let service = &file.service[0];
1415 Ok(generate_service(file, service, &resolver)?.to_string())
1416 }
1417
1418 fn assert_no_top_level_use(formatted: &str, label: &str) {
1423 let parsed: syn::File = syn::parse_str(formatted).expect("formatted code parses");
1424 let offenders: Vec<String> = parsed
1425 .items
1426 .iter()
1427 .filter_map(|item| match item {
1428 syn::Item::Use(u) => Some(quote!(#u).to_string()),
1429 _ => None,
1430 })
1431 .collect();
1432 assert!(
1433 offenders.is_empty(),
1434 "{label} contains top-level use statement(s): {offenders:?}\nFull source:\n{formatted}"
1435 );
1436 }
1437
1438 #[test]
1439 fn service_name_with_package() {
1440 let file = minimal_file(
1441 Some("example.v1"),
1442 ".example.v1.PingReq",
1443 ".example.v1.PingResp",
1444 &["PingReq", "PingResp"],
1445 );
1446 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
1447 assert!(code.contains("\"example.v1.PingService\""), "got: {code}");
1448 }
1449
1450 #[test]
1451 fn service_name_without_package() {
1452 let file = minimal_file(None, ".PingReq", ".PingResp", &["PingReq", "PingResp"]);
1454 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
1455 assert!(code.contains("\"PingService\""), "got: {code}");
1456 assert!(
1457 !code.contains("\".PingService\""),
1458 "must not have leading dot: {code}"
1459 );
1460 }
1461
1462 #[test]
1463 fn same_package_types_use_bare_names() {
1464 let file = minimal_file(
1465 Some("example.v1"),
1466 ".example.v1.PingReq",
1467 ".example.v1.PingResp",
1468 &["PingReq", "PingResp"],
1469 );
1470 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
1471 assert!(code.contains("PingReq"), "input type missing: {code}");
1473 assert!(code.contains("PingResp"), "output type missing: {code}");
1474 assert!(
1476 !code.contains("super :: PingReq"),
1477 "unexpected super: {code}"
1478 );
1479 }
1480
1481 #[test]
1482 fn cross_package_types_use_relative_paths() {
1483 let common = FileDescriptorProto {
1487 name: Some("common.proto".into()),
1488 package: Some("common.v1".into()),
1489 message_type: vec![DescriptorProto {
1490 name: Some("Shared".into()),
1491 ..Default::default()
1492 }],
1493 ..Default::default()
1494 };
1495 let svc = minimal_file(
1496 Some("example.v1"),
1497 ".common.v1.Shared",
1498 ".example.v1.Out",
1499 &["Out"],
1500 );
1501 let code = gen_service(&[common, svc], 1, &[], false).unwrap();
1502
1503 assert!(
1506 code.contains("super :: super :: common :: v1 :: Shared"),
1507 "cross-package path not emitted: {code}"
1508 );
1509 assert!(
1510 code.contains("super :: super :: common :: v1 :: SharedView"),
1511 "cross-package view path not emitted: {code}"
1512 );
1513 }
1514
1515 #[test]
1516 fn wkt_types_use_buffa_types_extern_path() {
1517 let wkt = FileDescriptorProto {
1521 name: Some("google/protobuf/empty.proto".into()),
1522 package: Some("google.protobuf".into()),
1523 message_type: vec![DescriptorProto {
1524 name: Some("Empty".into()),
1525 ..Default::default()
1526 }],
1527 ..Default::default()
1528 };
1529 let svc = minimal_file(
1530 Some("example.v1"),
1531 ".google.protobuf.Empty",
1532 ".example.v1.Out",
1533 &["Out"],
1534 );
1535 let code = gen_service(&[wkt, svc], 1, &[], false).unwrap();
1536
1537 assert!(
1538 code.contains(":: buffa_types :: google :: protobuf :: Empty"),
1539 "WKT extern path not emitted: {code}"
1540 );
1541 }
1542
1543 #[test]
1544 fn extern_catchall_uses_absolute_paths() {
1545 let file = minimal_file(
1546 Some("example.v1"),
1547 ".example.v1.PingReq",
1548 ".example.v1.PingResp",
1549 &["PingReq", "PingResp"],
1550 );
1551 let extern_paths = [(".".into(), "crate::proto".into())];
1552 let code = gen_service(std::slice::from_ref(&file), 0, &extern_paths, true).unwrap();
1553 assert!(
1554 code.contains("crate :: proto :: example :: v1 :: PingReq"),
1555 "owned type path missing: {code}"
1556 );
1557 assert!(
1558 code.contains("crate :: proto :: example :: v1 :: PingReqView"),
1559 "view type path missing: {code}"
1560 );
1561 }
1562
1563 #[test]
1564 fn extern_catchall_with_wkt_longest_wins() {
1565 let wkt = FileDescriptorProto {
1568 name: Some("google/protobuf/empty.proto".into()),
1569 package: Some("google.protobuf".into()),
1570 message_type: vec![DescriptorProto {
1571 name: Some("Empty".into()),
1572 ..Default::default()
1573 }],
1574 ..Default::default()
1575 };
1576 let svc = minimal_file(
1577 Some("example.v1"),
1578 ".google.protobuf.Empty",
1579 ".example.v1.Out",
1580 &["Out"],
1581 );
1582 let extern_paths = [(".".into(), "crate::proto".into())];
1583 let code = gen_service(&[wkt, svc], 1, &extern_paths, true).unwrap();
1584 assert!(
1585 code.contains(":: buffa_types :: google :: protobuf :: Empty"),
1586 "WKT mapping lost to catch-all: {code}"
1587 );
1588 assert!(
1589 code.contains("crate :: proto :: example :: v1 :: Out"),
1590 "local type not routed through catch-all: {code}"
1591 );
1592 }
1593
1594 #[test]
1595 fn missing_extern_path_errors() {
1596 let file = minimal_file(
1597 Some("example.v1"),
1598 ".example.v1.PingReq",
1599 ".example.v1.PingResp",
1600 &["PingReq", "PingResp"],
1601 );
1602 let err = gen_service(std::slice::from_ref(&file), 0, &[], true).unwrap_err();
1603 let msg = err.to_string();
1604 assert!(
1605 msg.contains("extern_path"),
1606 "error message lacks hint: {msg}"
1607 );
1608 }
1609
1610 #[test]
1611 fn keyword_package_escaped() {
1612 let file = minimal_file(
1614 Some("google.type"),
1615 ".google.type.LatLng",
1616 ".google.type.LatLng",
1617 &["LatLng"],
1618 );
1619 let extern_paths = [(".".into(), "crate::proto".into())];
1620 let code = gen_service(std::slice::from_ref(&file), 0, &extern_paths, true).unwrap();
1621 assert!(
1622 code.contains("crate :: proto :: google :: r#type :: LatLng"),
1623 "keyword segment not escaped: {code}"
1624 );
1625 }
1626
1627 #[test]
1628 fn keyword_method_escaped() {
1629 let file = minimal_file_with_method(
1632 Some("example.v1"),
1633 "Move",
1634 ".example.v1.Empty",
1635 ".example.v1.Empty",
1636 &["Empty"],
1637 );
1638 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
1639 assert!(
1640 code.contains("fn r#move"),
1641 "keyword method not escaped: {code}"
1642 );
1643 assert!(
1644 code.contains("move_with_options"),
1645 "suffixed variant should not need escaping: {code}"
1646 );
1647 assert!(code.contains("client.r#move(request)"));
1649 syn::parse_str::<syn::File>(&code).expect("generated code parses");
1650 }
1651
1652 #[test]
1653 fn path_keyword_method_suffixed() {
1654 let file = minimal_file_with_method(
1657 Some("example.v1"),
1658 "Self",
1659 ".example.v1.Empty",
1660 ".example.v1.Empty",
1661 &["Empty"],
1662 );
1663 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
1664 assert!(
1665 code.contains("fn self_"),
1666 "path-keyword method not suffixed: {code}"
1667 );
1668 assert!(code.contains("self_with_options"));
1672 syn::parse_str::<syn::File>(&code).expect("generated code parses");
1673 }
1674
1675 #[test]
1676 fn service_name_keyword_suffixed() {
1677 let mut file = minimal_file(
1681 Some("example.v1"),
1682 ".example.v1.Empty",
1683 ".example.v1.Empty",
1684 &["Empty"],
1685 );
1686 file.service[0].name = Some("Self".into());
1687 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
1688 assert!(code.contains("trait Self_ "), "trait not suffixed: {code}");
1689 assert!(code.contains("trait SelfExt"));
1690 assert!(code.contains("struct SelfClient"));
1691 assert!(code.contains("struct SelfServer"));
1692 syn::parse_str::<syn::File>(&code).expect("generated code parses");
1693 }
1694
1695 #[test]
1696 fn method_snake_collision_errors() {
1697 let file = minimal_file_with_methods("example.v1", &["GetFoo", "get_foo"]);
1700 let err = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap_err();
1701 let msg = err.to_string();
1702 assert!(msg.contains("PingService"), "missing service name: {msg}");
1703 assert!(msg.contains("\"GetFoo\""), "missing first method: {msg}");
1704 assert!(msg.contains("\"get_foo\""), "missing second method: {msg}");
1705 assert!(msg.contains("`get_foo`"), "missing rust ident: {msg}");
1706 }
1707
1708 #[test]
1709 fn method_with_options_collision_errors() {
1710 let file = minimal_file_with_methods("example.v1", &["Ping", "PingWithOptions"]);
1713 let err = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap_err();
1714 let msg = err.to_string();
1715 assert!(msg.contains("\"Ping\""), "missing first method: {msg}");
1716 assert!(
1717 msg.contains("\"PingWithOptions\""),
1718 "missing second method: {msg}"
1719 );
1720 assert!(
1721 msg.contains("`ping_with_options`"),
1722 "missing rust ident: {msg}"
1723 );
1724 }
1725
1726 #[test]
1727 fn distinct_methods_do_not_collide() {
1728 let file = minimal_file_with_methods("example.v1", &["GetFoo", "GetBar"]);
1729 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
1730 syn::parse_str::<syn::File>(&code).expect("generated code parses");
1731 }
1732
1733 #[test]
1734 fn options_default_emits_register_fn() {
1735 let opts = Options::default();
1736 assert!(opts.emit_register_fn);
1737 let cfg = opts.to_buffa_config();
1738 assert!(cfg.emit_register_fn);
1739 }
1740
1741 #[test]
1742 fn options_emit_register_fn_false_disables_buffa_register_fn() {
1743 let opts = Options {
1744 emit_register_fn: false,
1745 ..Options::default()
1746 };
1747 let cfg = opts.to_buffa_config();
1748 assert!(!cfg.emit_register_fn);
1749 }
1750
1751 #[test]
1752 fn generate_files_emit_register_fn_false_suppresses_register_types() {
1753 let file = FileDescriptorProto {
1756 name: Some("ping.proto".into()),
1757 package: Some("example.v1".into()),
1758 message_type: vec![DescriptorProto {
1759 name: Some("PingReq".into()),
1760 ..Default::default()
1761 }],
1762 ..Default::default()
1763 };
1764
1765 let with_fn = generate_files(
1766 std::slice::from_ref(&file),
1767 &["ping.proto".into()],
1768 &Options::default(),
1769 )
1770 .unwrap();
1771 assert_eq!(with_fn.len(), 1);
1772 assert!(
1773 with_fn[0].content.contains("fn register_types"),
1774 "expected register_types in default output: {}",
1775 with_fn[0].content
1776 );
1777
1778 let without_fn = generate_files(
1779 std::slice::from_ref(&file),
1780 &["ping.proto".into()],
1781 &Options {
1782 emit_register_fn: false,
1783 ..Options::default()
1784 },
1785 )
1786 .unwrap();
1787 assert_eq!(without_fn.len(), 1);
1788 assert!(
1789 !without_fn[0].content.contains("fn register_types"),
1790 "register_types should be suppressed: {}",
1791 without_fn[0].content
1792 );
1793 }
1794
1795 #[test]
1796 fn plugin_no_register_fn_parses() {
1797 let request = CodeGeneratorRequest {
1798 parameter: Some("buffa_module=crate::proto,no_register_fn".into()),
1799 file_to_generate: vec![],
1800 proto_file: vec![],
1801 ..Default::default()
1802 };
1803 generate(&request).expect("no_register_fn should be a recognized plugin option");
1806 }
1807
1808 #[test]
1809 fn no_top_level_use_statements_in_generated_code() {
1810 let file = minimal_file(
1814 Some("example.v1"),
1815 ".example.v1.PingReq",
1816 ".example.v1.PingResp",
1817 &["PingReq", "PingResp"],
1818 );
1819 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
1820 let formatted = format_token_stream(&code.parse::<TokenStream>().unwrap()).unwrap();
1821 assert_no_top_level_use(&formatted, "generated code");
1822 }
1823
1824 #[test]
1825 fn multi_service_include_no_e0252() {
1826 let file_a = {
1829 let method = MethodDescriptorProto {
1830 name: Some("Ping".into()),
1831 input_type: Some(".svc.v1.PingReq".into()),
1832 output_type: Some(".svc.v1.PingResp".into()),
1833 ..Default::default()
1834 };
1835 let service = ServiceDescriptorProto {
1836 name: Some("Alpha".into()),
1837 method: vec![method],
1838 ..Default::default()
1839 };
1840 FileDescriptorProto {
1841 name: Some("alpha.proto".into()),
1842 package: Some("svc.v1".into()),
1843 service: vec![service],
1844 message_type: vec![
1845 DescriptorProto {
1846 name: Some("PingReq".into()),
1847 ..Default::default()
1848 },
1849 DescriptorProto {
1850 name: Some("PingResp".into()),
1851 ..Default::default()
1852 },
1853 ],
1854 ..Default::default()
1855 }
1856 };
1857 let file_b = {
1858 let method = MethodDescriptorProto {
1859 name: Some("Pong".into()),
1860 input_type: Some(".svc.v1.PongReq".into()),
1861 output_type: Some(".svc.v1.PongResp".into()),
1862 ..Default::default()
1863 };
1864 let service = ServiceDescriptorProto {
1865 name: Some("Beta".into()),
1866 method: vec![method],
1867 ..Default::default()
1868 };
1869 FileDescriptorProto {
1870 name: Some("beta.proto".into()),
1871 package: Some("svc.v1".into()),
1872 service: vec![service],
1873 message_type: vec![
1874 DescriptorProto {
1875 name: Some("PongReq".into()),
1876 ..Default::default()
1877 },
1878 DescriptorProto {
1879 name: Some("PongResp".into()),
1880 ..Default::default()
1881 },
1882 ],
1883 ..Default::default()
1884 }
1885 };
1886
1887 let files = vec![file_a, file_b];
1888 let config = buffa_codegen::CodeGenConfig::default();
1889 let targets = vec!["alpha.proto".to_string(), "beta.proto".to_string()];
1890 let resolver = TypeResolver::new(&files, &targets, &config, false);
1891
1892 let code_a = generate_connect_services(&files[0], &resolver).unwrap();
1893 let code_b = generate_connect_services(&files[1], &resolver).unwrap();
1894
1895 let formatted_a = format_token_stream(&code_a).unwrap();
1896 let formatted_b = format_token_stream(&code_b).unwrap();
1897
1898 syn::parse_str::<syn::File>(&formatted_a).expect("service A should parse independently");
1900 syn::parse_str::<syn::File>(&formatted_b).expect("service B should parse independently");
1901
1902 let combined = format!("{formatted_a}\n{formatted_b}");
1904 syn::parse_str::<syn::File>(&combined)
1905 .expect("combined services should parse without E0252");
1906
1907 assert_no_top_level_use(&formatted_a, "service A");
1909 assert_no_top_level_use(&formatted_b, "service B");
1910 }
1911}