1use std::collections::HashMap;
12
13use anyhow::Result;
14use heck::ToSnakeCase;
15use heck::ToUpperCamelCase;
16use proc_macro2::{Ident, TokenStream};
17use quote::format_ident;
18use quote::quote;
19
20use buffa_codegen::generated::descriptor::FileDescriptorProto;
21use buffa_codegen::generated::descriptor::MethodDescriptorProto;
22use buffa_codegen::generated::descriptor::ServiceDescriptorProto;
23use buffa_codegen::generated::descriptor::SourceCodeInfo;
24use buffa_codegen::generated::descriptor::method_options::IdempotencyLevel;
25use buffa_codegen::idents::make_field_ident;
26use buffa_codegen::idents::rust_path_to_tokens;
27
28pub use buffa_codegen::generated::descriptor;
29pub use buffa_codegen::{CodeGenConfig, GeneratedFile, GeneratedFileKind};
30
31use crate::plugin::CodeGeneratorRequest;
32use crate::plugin::CodeGeneratorResponse;
33use crate::plugin::CodeGeneratorResponseFile;
34
35#[derive(Debug, Clone)]
44#[non_exhaustive]
45pub struct Options {
46 pub buffa: CodeGenConfig,
69
70 pub gate_client_feature: bool,
75}
76
77impl Default for Options {
78 fn default() -> Self {
79 let mut buffa = CodeGenConfig::default();
80 buffa.generate_json = true;
81 Self {
82 buffa,
83 gate_client_feature: false,
84 }
85 }
86}
87
88impl Options {
89 fn to_buffa_config(&self) -> CodeGenConfig {
92 let mut config = self.buffa.clone();
93 config.generate_views = true;
94 config
95 }
96}
97
98fn emit_service_files(
101 proto_file: &[FileDescriptorProto],
102 file_to_generate: &[String],
103 resolver: &TypeResolver<'_>,
104 gate_client_feature: bool,
105) -> Result<Vec<GeneratedFile>> {
106 let mut out = Vec::new();
107 let mut batch = BatchState {
115 colliding_aliases: collect_alias_collisions(proto_file, file_to_generate),
116 gate_client_feature,
117 ..BatchState::default()
118 };
119 for file_name in file_to_generate {
120 let file_desc = proto_file
121 .iter()
122 .find(|f| f.name.as_deref() == Some(file_name.as_str()));
123
124 if let Some(file) = file_desc
125 && !file.service.is_empty()
126 {
127 let service_tokens = generate_connect_services(file, resolver, &mut batch)?;
128 let service_code = format_token_stream(&service_tokens)?;
129 out.push(GeneratedFile {
137 name: format!(
138 "{}.__connect.rs",
139 buffa_codegen::proto_path_to_stem(file_name)
140 ),
141 package: file.package.clone().unwrap_or_default(),
142 kind: GeneratedFileKind::Companion,
143 content: service_code,
144 });
145 }
146 }
147 Ok(out)
148}
149
150pub fn generate_files(
176 proto_file: &[FileDescriptorProto],
177 file_to_generate: &[String],
178 options: &Options,
179) -> Result<Vec<GeneratedFile>> {
180 let config = options.to_buffa_config();
181
182 let mut files = buffa_codegen::generate(proto_file, file_to_generate, &config)
183 .map_err(|e| anyhow::anyhow!("buffa-codegen failed: {e}"))?;
184
185 let resolver = TypeResolver::new(proto_file, file_to_generate, &config, false);
186 let service_files = emit_service_files(
187 proto_file,
188 file_to_generate,
189 &resolver,
190 options.gate_client_feature,
191 )?;
192
193 if config.file_per_package {
194 inline_companions_into_package_mods(&mut files, service_files);
202 } else {
203 buffa_codegen::apply_companions(&mut files, service_files);
210
211 debug_assert!(
218 files.iter().all(|f| {
219 f.kind != GeneratedFileKind::Companion
220 || files.iter().any(|g| {
221 g.kind == GeneratedFileKind::PackageMod
222 && g.content.contains(&format!("include!(\"{}\")", f.name))
223 })
224 }),
225 "a companion service file was not wired into any package stitcher"
226 );
227 }
228
229 Ok(files)
230}
231
232fn inline_companions_into_package_mods(
251 files: &mut [GeneratedFile],
254 companions: Vec<GeneratedFile>,
255) {
256 debug_assert!(
260 companions.iter().all(|c| files
261 .iter()
262 .any(|f| f.kind == GeneratedFileKind::PackageMod && f.package == c.package)),
263 "a companion service file's package has no PackageMod to inline into"
264 );
265 for comp in companions {
266 if let Some(pkg_mod) = files
267 .iter_mut()
268 .find(|f| f.kind == GeneratedFileKind::PackageMod && f.package == comp.package)
269 {
270 pkg_mod.content.push('\n');
271 pkg_mod.content.push_str(&comp.content);
272 }
273 }
274}
275
276pub fn generate_services(
311 proto_file: &[FileDescriptorProto],
312 file_to_generate: &[String],
313 options: &Options,
314) -> Result<Vec<GeneratedFile>> {
315 use std::collections::BTreeMap;
316
317 let config = options.to_buffa_config();
318 let resolver = TypeResolver::new(proto_file, file_to_generate, &config, true);
319 let mut files = emit_service_files(
320 proto_file,
321 file_to_generate,
322 &resolver,
323 options.gate_client_feature,
324 )?;
325
326 if config.file_per_package {
327 let mut by_package: BTreeMap<String, String> = BTreeMap::new();
332 for f in files {
333 let entry = by_package.entry(f.package).or_insert_with(|| {
334 String::from("// @generated by connectrpc-codegen. DO NOT EDIT.\n")
335 });
336 entry.push('\n');
337 entry.push_str(&f.content);
338 }
339 return Ok(by_package
340 .into_iter()
341 .map(|(package, content)| GeneratedFile {
342 name: buffa_codegen::package_to_filename(&package),
343 package,
344 kind: GeneratedFileKind::PackageMod,
345 content,
346 })
347 .collect());
348 }
349
350 let mut by_package: BTreeMap<String, Vec<String>> = BTreeMap::new();
356 for f in &files {
357 by_package
358 .entry(f.package.clone())
359 .or_default()
360 .push(f.name.clone());
361 }
362 for (package, names) in by_package {
363 let mut content = String::from("// @generated by connectrpc-codegen. DO NOT EDIT.\n");
364 for n in &names {
365 content.push_str(&format!("include!({n:?});\n"));
367 }
368 files.push(GeneratedFile {
369 name: buffa_codegen::package_to_mod_filename(&package),
370 package,
371 kind: GeneratedFileKind::PackageMod,
372 content,
373 });
374 }
375
376 Ok(files)
377}
378
379pub fn generate(request: &CodeGeneratorRequest) -> Result<CodeGeneratorResponse> {
484 let mut options = Options::default();
485
486 if let Some(ref param) = request.parameter {
487 for opt in param.split(',').map(str::trim).filter(|s| !s.is_empty()) {
488 if let Some(value) = opt.strip_prefix("buffa_module=") {
489 let rust = value.trim();
490 if rust.is_empty() {
491 anyhow::bail!(
492 "buffa_module requires a non-empty path, \
493 e.g. buffa_module=crate::proto"
494 );
495 }
496 options
497 .buffa
498 .extern_paths
499 .push((".".into(), rust.to_string()));
500 } else if let Some(value) = opt.strip_prefix("extern_path=") {
501 let (proto, rust) = value.split_once('=').ok_or_else(|| {
503 anyhow::anyhow!(
504 "invalid extern_path format {value:?}, expected \
505 extern_path=.proto.pkg=::rust::path"
506 )
507 })?;
508 let proto = proto.trim();
509 let rust = rust.trim();
510 if proto.is_empty() || rust.is_empty() {
511 anyhow::bail!(
512 "invalid extern_path format {value:?}, expected \
513 extern_path=.proto.pkg=::rust::path (both sides non-empty)"
514 );
515 }
516 let mut proto = proto.to_string();
517 if !proto.starts_with('.') {
518 proto.insert(0, '.');
519 }
520 options.buffa.extern_paths.push((proto, rust.to_string()));
521 } else {
522 match opt {
523 "file_per_package" => options.buffa.file_per_package = true,
524 "strict_utf8_mapping" => options.buffa.strict_utf8_mapping = true,
525 "no_json" => options.buffa.generate_json = false,
526 "no_register_fn" => options.buffa.emit_register_fn = false,
527 "gate_client_feature" => options.gate_client_feature = true,
528 _ => {
529 return Err(anyhow::anyhow!(
530 "unknown plugin option: {opt:?}. Supported: \
531 buffa_module=<rust_path>, extern_path=<proto>=<rust>, \
532 file_per_package, strict_utf8_mapping, no_json, \
533 no_register_fn, gate_client_feature"
534 ));
535 }
536 }
537 }
538 }
539 }
540
541 let generated = generate_services(&request.proto_file, &request.file_to_generate, &options)?;
542
543 let files: Vec<CodeGeneratorResponseFile> = generated
544 .into_iter()
545 .map(|g| CodeGeneratorResponseFile {
546 name: Some(g.name),
547 content: Some(g.content),
548 ..Default::default()
549 })
550 .collect();
551
552 Ok(CodeGeneratorResponse {
553 supported_features: Some(feature_flags()),
554 minimum_edition: Some(EDITION_2023),
555 maximum_edition: Some(EDITION_2023),
556 file: files,
557 ..Default::default()
558 })
559}
560
561fn feature_flags() -> u64 {
564 const FEATURE_PROTO3_OPTIONAL: u64 = 1;
565 const FEATURE_SUPPORTS_EDITIONS: u64 = 2;
566 FEATURE_PROTO3_OPTIONAL | FEATURE_SUPPORTS_EDITIONS
567}
568
569const EDITION_2023: i32 = 1000;
572
573fn format_token_stream(tokens: &TokenStream) -> Result<String> {
575 let file = syn::parse2::<syn::File>(tokens.clone())
576 .map_err(|e| anyhow::anyhow!("generated code failed to parse: {e}"))?;
577 Ok(prettyplease::unparse(&file))
578}
579
580fn doc_attrs(text: &str) -> TokenStream {
589 let lines: Vec<String> = text
590 .lines()
591 .map(|l| {
592 if l.is_empty() {
593 String::new()
594 } else {
595 format!(" {l}")
596 }
597 })
598 .collect();
599 quote! { #(#[doc = #lines])* }
600}
601
602struct TypeResolver<'a> {
615 ctx: buffa_codegen::context::CodeGenContext<'a>,
616 require_extern: bool,
622}
623
624impl<'a> TypeResolver<'a> {
625 fn new(
626 proto_file: &'a [FileDescriptorProto],
627 file_to_generate: &[String],
628 config: &'a buffa_codegen::CodeGenConfig,
629 require_extern: bool,
630 ) -> Self {
631 Self {
632 ctx: buffa_codegen::context::CodeGenContext::for_generate(
633 proto_file,
634 file_to_generate,
635 config,
636 ),
637 require_extern,
638 }
639 }
640
641 fn resolve_path(&self, proto_fqn: &str, current_package: &str) -> Result<String> {
648 match self.ctx.rust_type_relative(proto_fqn, current_package, 0) {
649 Some(path) => {
650 self.check_extern_coverage(proto_fqn, &path)?;
651 Ok(path)
652 }
653 None => self.fallback_unresolved(proto_fqn).map(str::to_string),
654 }
655 }
656
657 fn check_extern_coverage(&self, proto_fqn: &str, path_prefix: &str) -> Result<()> {
661 if self.require_extern
662 && !path_prefix.starts_with("::")
663 && !path_prefix.starts_with("crate::")
664 {
665 anyhow::bail!(
666 "type {proto_fqn} is not covered by any extern_path mapping. \
667 Add extern_path=.=<your_buffa_module> (e.g. \
668 extern_path=.=crate::proto) to the plugin opts."
669 );
670 }
671 Ok(())
672 }
673
674 fn fallback_unresolved<'f>(&self, proto_fqn: &'f str) -> Result<&'f str> {
678 if self.require_extern {
679 anyhow::bail!("type {proto_fqn} not found in descriptor set (missing proto import?)");
680 }
681 Ok(bare_type_name(proto_fqn))
682 }
683
684 fn rust_type(&self, proto_fqn: &str, current_package: &str) -> Result<TokenStream> {
686 let path = self.resolve_path(proto_fqn, current_package)?;
687 Ok(rust_path_to_tokens(&path))
688 }
689
690 fn rust_view_type(&self, proto_fqn: &str, current_package: &str) -> Result<TokenStream> {
697 use buffa_codegen::context::SENTINEL_MOD;
698 let (to_package, within) =
699 match self
700 .ctx
701 .rust_type_relative_split(proto_fqn, current_package, 0)
702 {
703 Some(s) => {
704 self.check_extern_coverage(proto_fqn, &s.to_package)?;
705 (s.to_package, s.within_package)
706 }
707 None => (
708 String::new(),
709 self.fallback_unresolved(proto_fqn)?.to_string(),
710 ),
711 };
712 let prefix = if to_package.is_empty() {
713 format!("{SENTINEL_MOD}::view")
714 } else {
715 format!("{to_package}::{SENTINEL_MOD}::view")
716 };
717 Ok(rust_path_to_tokens(&format!("{prefix}::{within}View")))
718 }
719}
720
721fn bare_type_name(proto_fqn: &str) -> &str {
724 proto_fqn
725 .strip_prefix('.')
726 .unwrap_or(proto_fqn)
727 .rsplit('.')
728 .next()
729 .unwrap_or(proto_fqn)
730}
731
732#[derive(Default)]
739struct BatchState {
740 encodable_seen: std::collections::BTreeSet<String>,
743 alias_seen: std::collections::BTreeSet<(String, String)>,
747 colliding_aliases: std::collections::BTreeSet<(String, String)>,
760 gate_client_feature: bool,
766}
767
768fn generate_connect_services(
769 file: &FileDescriptorProto,
770 resolver: &TypeResolver<'_>,
771 batch: &mut BatchState,
772) -> Result<TokenStream> {
773 let mut tokens = TokenStream::new();
774
775 tokens.extend(generate_owned_view_aliases(file, resolver, batch)?);
785 tokens.extend(generate_encodable_view_impls(file, resolver, batch)?);
786
787 for service in &file.service {
788 tokens.extend(generate_service(file, service, resolver, batch)?);
789 }
790
791 Ok(tokens)
792}
793
794fn owned_view_alias_ident(fqn: &str) -> Ident {
797 format_ident!("Owned{}View", bare_type_name(fqn).to_upper_camel_case())
798}
799
800fn alias_collides(batch: &BatchState, current_package: &str, proto_fqn: &str) -> bool {
808 let alias = owned_view_alias_ident(proto_fqn).to_string();
809 batch
810 .colliding_aliases
811 .contains(&(current_package.to_string(), alias))
812}
813
814fn router_stream_items_tokens(
821 resolver: &TypeResolver<'_>,
822 method: &MethodDescriptorProto,
823 package: &str,
824) -> TokenStream {
825 let input_fqn = method.input_type.as_deref().unwrap_or("");
826 let input_owned = resolver
830 .rust_type(input_fqn, package)
831 .expect("rust_type failed for streaming input type");
832 quote! {
833 let req = ::connectrpc::dispatcher::codegen::into_stream_messages::<#input_owned>(req);
834 }
835}
836
837fn stream_items_doc(method: &MethodDescriptorProto) -> TokenStream {
844 let mut doc = quote! {
845 #[doc = ""]
846 #[doc = " Each `requests` item is a [`StreamMessage`](::connectrpc::StreamMessage):"]
847 #[doc = " it owns its buffer, is `Send + 'static`, and exposes zero-copy"]
848 #[doc = " accessor methods (`item.name()`), `.view()`, and"]
849 #[doc = " `.to_owned_message()`."]
850 };
851 if method.input_type == method.output_type {
852 doc.extend(quote! {
853 #[doc = " Items can be yielded back unchanged"]
854 #[doc = " (`StreamMessage<M>` implements `Encodable<M>`)."]
855 });
856 }
857 doc
858}
859
860fn stream_item_arg(
863 resolver: &TypeResolver<'_>,
864 method: &MethodDescriptorProto,
865 package: &str,
866) -> Result<TokenStream> {
867 let input_fqn = method.input_type.as_deref().unwrap_or("");
868 let input_owned = resolver.rust_type(input_fqn, package)?;
869 Ok(quote! { ::connectrpc::StreamMessage<#input_owned> })
870}
871
872fn collect_alias_collisions(
884 proto_file: &[FileDescriptorProto],
885 file_to_generate: &[String],
886) -> std::collections::BTreeSet<(String, String)> {
887 use std::collections::BTreeMap;
888 let mut first_seen: BTreeMap<(String, String), String> = BTreeMap::new();
891 let mut colliding: std::collections::BTreeSet<(String, String)> =
892 std::collections::BTreeSet::new();
893
894 for file_name in file_to_generate {
895 let Some(file) = proto_file
896 .iter()
897 .find(|f| f.name.as_deref() == Some(file_name.as_str()))
898 else {
899 continue;
900 };
901 let package = file.package.clone().unwrap_or_default();
902 for service in &file.service {
903 for m in &service.method {
904 for fqn in [m.input_type.as_deref(), m.output_type.as_deref()]
905 .into_iter()
906 .flatten()
907 {
908 let alias = owned_view_alias_ident(fqn).to_string();
909 let key = (package.clone(), alias);
910 match first_seen.get(&key) {
911 Some(prev) if prev != fqn => {
912 colliding.insert(key);
913 }
914 Some(_) => {} None => {
916 first_seen.insert(key, fqn.to_string());
917 }
918 }
919 }
920 }
921 }
922 }
923 colliding
924}
925
926fn generate_owned_view_aliases(
943 file: &FileDescriptorProto,
944 resolver: &TypeResolver<'_>,
945 batch: &mut BatchState,
946) -> Result<TokenStream> {
947 let package = file.package.as_deref().unwrap_or("");
948 let mut out = TokenStream::new();
949 for service in &file.service {
950 for m in &service.method {
951 for fqn in [m.input_type.as_deref(), m.output_type.as_deref()]
952 .into_iter()
953 .flatten()
954 {
955 if alias_collides(batch, package, fqn) {
956 continue;
957 }
958 if !batch
959 .alias_seen
960 .insert((package.to_string(), fqn.to_string()))
961 {
962 continue;
963 }
964 let alias = owned_view_alias_ident(fqn);
965 let view = resolver.rust_view_type(fqn, package)?;
966 let doc = format!(
967 "Shorthand for `OwnedView<{}View<'static>>`.",
968 bare_type_name(fqn).to_upper_camel_case()
969 );
970 out.extend(quote! {
971 #[doc = #doc]
972 pub type #alias = ::buffa::view::OwnedView<#view<'static>>;
973 });
974 }
975 }
976 }
977 Ok(out)
978}
979
980fn generate_encodable_view_impls(
996 file: &FileDescriptorProto,
997 resolver: &TypeResolver<'_>,
998 batch: &mut BatchState,
999) -> Result<TokenStream> {
1000 let package = file.package.as_deref().unwrap_or("");
1001 let mut out = TokenStream::new();
1002 for service in &file.service {
1003 for m in &service.method {
1004 let fqn = m.output_type.as_deref().unwrap_or("");
1005 if !batch.encodable_seen.insert(fqn.to_string()) {
1006 continue;
1007 }
1008 let path = resolver.resolve_path(fqn, package)?;
1009 if path.starts_with("::") {
1012 continue;
1013 }
1014 let owned = resolver.rust_type(fqn, package)?;
1015 let view = resolver.rust_view_type(fqn, package)?;
1016 out.extend(quote! {
1017 impl ::connectrpc::Encodable<#owned> for #view<'_> {
1018 fn encode(&self, codec: ::connectrpc::CodecFormat)
1019 -> ::std::result::Result<::buffa::bytes::Bytes, ::connectrpc::ConnectError>
1020 {
1021 ::connectrpc::__codegen::encode_view_body(self, codec)
1022 }
1023 }
1024 impl ::connectrpc::Encodable<#owned> for ::buffa::view::OwnedView<#view<'static>> {
1025 fn encode(&self, codec: ::connectrpc::CodecFormat)
1026 -> ::std::result::Result<::buffa::bytes::Bytes, ::connectrpc::ConnectError>
1027 {
1028 ::connectrpc::__codegen::encode_view_body(self.reborrow(), codec)
1029 }
1030 }
1031 });
1032 }
1033 }
1034 Ok(out)
1035}
1036
1037fn check_method_collisions(service_name: &str, service: &ServiceDescriptorProto) -> Result<()> {
1046 let mut seen: HashMap<String, String> = HashMap::new();
1047 for m in &service.method {
1048 let proto_name = m.name.as_deref().unwrap_or("");
1049 let snake = proto_name.to_snake_case();
1050 let with_opts = format!("{snake}_with_options");
1051 for ident in [snake.as_str(), with_opts.as_str()] {
1052 if let Some(prev) = seen.get(ident) {
1053 anyhow::bail!(
1054 "service {service_name}: RPC methods {prev:?} and {proto_name:?} \
1055 both generate Rust identifier `{ident}`; rename one in the proto"
1056 );
1057 }
1058 }
1059 seen.insert(snake, proto_name.to_string());
1060 seen.insert(with_opts, proto_name.to_string());
1061 }
1062 Ok(())
1063}
1064
1065fn generate_service(
1066 file: &FileDescriptorProto,
1067 service: &ServiceDescriptorProto,
1068 resolver: &TypeResolver<'_>,
1069 batch: &BatchState,
1070) -> Result<TokenStream> {
1071 let package = file.package.as_deref().unwrap_or("");
1072 let service_name = service.name.as_deref().unwrap_or("");
1073 check_method_collisions(service_name, service)?;
1074 let full_service_name = if package.is_empty() {
1077 service_name.to_string()
1078 } else {
1079 format!("{package}.{service_name}")
1080 };
1081 let service_upper = service_name.to_upper_camel_case();
1082 let trait_name = if service_upper == "Self" {
1086 format_ident!("Self_")
1087 } else {
1088 format_ident!("{}", service_upper)
1089 };
1090 let ext_trait_name = format_ident!("{}Ext", service_upper);
1091 let client_name = format_ident!("{}Client", service_upper);
1092 let server_name = format_ident!("{}Server", service_upper);
1093 let service_name_const = format_ident!(
1094 "{}_SERVICE_NAME",
1095 service_name.to_snake_case().to_uppercase()
1096 );
1097
1098 let service_doc = get_service_comment(file, service).unwrap_or_default();
1100 let base_doc = if service_doc.is_empty() {
1101 format!("Server trait for {service_name}.")
1102 } else {
1103 service_doc
1104 };
1105 let full_doc = format!(
1106 "{base_doc}\n\n\
1107 # Implementing handlers\n\n\
1108 Implement methods with plain `async fn`; the returned future satisfies\n\
1109 the `Send` bound automatically.\n\n\
1110 **Unary and server-streaming requests** arrive as\n\
1111 [`ServiceRequest<'_, Req>`](::connectrpc::ServiceRequest): a zero-copy\n\
1112 view of the request plus its body, valid for the duration of the call.\n\
1113 Fields are read directly (`request.name` is a `&str` into the decoded\n\
1114 buffer) and the borrow may be held across `.await` points. Anything\n\
1115 that must outlive the call — `tokio::spawn`, channels, server state,\n\
1116 or data captured by a returned response stream — takes owned data:\n\
1117 call `request.to_owned_message()` (or copy the specific fields)\n\
1118 first.\n\n\
1119 **Client-streaming and bidi requests** arrive as\n\
1120 `ServiceStream<`[`StreamMessage<Req>`](::connectrpc::StreamMessage)`>`.\n\
1121 Each item owns its decoded buffer and is `Send + 'static`, so items\n\
1122 can be buffered or moved into spawned tasks; read fields zero-copy\n\
1123 through the generated accessor methods (`item.name()`) or `.view()`,\n\
1124 convert with `.to_owned_message()`, or yield an item back unchanged —\n\
1125 `StreamMessage<M>` implements `Encodable<M>`.\n\n\
1126 Request types resolved through `extern_path` (e.g. well-known types\n\
1127 from another crate) use the same wrappers; the crate that owns the\n\
1128 type must be generated with buffa ≥ 0.7.0 and views enabled so the\n\
1129 backing `HasMessageView` impl exists.\n\n\
1130 The `impl Encodable<Out>` return bound accepts the owned `Out`, the\n\
1131 generated `OutView<'_>` / `OwnedOutView`,\n\
1132 [`MaybeBorrowed`](::connectrpc::MaybeBorrowed), or\n\
1133 [`PreEncoded`](::connectrpc::PreEncoded) for handlers that encode a\n\
1134 non-`'static` view internally and pass the bytes across the handler\n\
1135 boundary. View bodies are not emitted for output types mapped via\n\
1136 `extern_path` (the impl would be an orphan); return owned for\n\
1137 WKT/extern outputs.\n\n\
1138 Server-streaming and bidi-streaming methods return\n\
1139 `ServiceStream<impl Encodable<Out> + Send + use<Self>>`. The\n\
1140 `use<Self>` precise-capturing clause excludes `&self`'s lifetime and\n\
1141 the request's lifetime (unary methods use `use<'a, Self>` and may\n\
1142 borrow from `&self`), so stream items must be `'static` and cannot\n\
1143 borrow from the request. To stream view-encoded data, encode each\n\
1144 item inside the stream body and yield\n\
1145 [`PreEncoded`](::connectrpc::PreEncoded) — see its `# Streaming\n\
1146 example` doc."
1147 );
1148 let service_doc_tokens = doc_attrs(&full_doc);
1149
1150 let trait_methods: Vec<TokenStream> = service
1152 .method
1153 .iter()
1154 .map(|m| generate_trait_method(file, service, m, resolver, package))
1155 .collect::<Result<Vec<_>>>()?;
1156
1157 let route_registrations: Vec<TokenStream> = service
1159 .method
1160 .iter()
1161 .map(|m| {
1162 let method_name = m.name.as_deref().unwrap_or("");
1163 let method_snake = make_field_ident(&method_name.to_snake_case());
1164 let spec_const = method_spec_const_ident(service, method_name);
1168
1169 let client_streaming = m.client_streaming.unwrap_or(false);
1170 let server_streaming = m.server_streaming.unwrap_or(false);
1171
1172 let route_call = if server_streaming && !client_streaming {
1173 let output_type = resolver
1178 .rust_type(m.output_type.as_deref().unwrap_or(""), package)
1179 .unwrap();
1180 let input_fqn = m.input_type.as_deref().unwrap_or("");
1181 let input_view = resolver.rust_view_type(input_fqn, package).unwrap();
1182 let input_owned = resolver.rust_type(input_fqn, package).unwrap();
1183 let call_handler = quote! {
1184 let sreq = ::connectrpc::ServiceRequest::<#input_owned>::from_parts(req.reborrow(), req.bytes());
1185 svc.#method_snake(ctx, sreq).await
1186 };
1187 quote! {
1188 .route_view_server_stream::<_, _, #output_type>(
1189 #service_name_const,
1190 #method_name,
1191 ::connectrpc::view_streaming_handler_fn({
1192 let svc = ::std::sync::Arc::clone(&self);
1193 move |ctx, req: ::buffa::view::OwnedView<#input_view<'static>>| {
1194 let svc = ::std::sync::Arc::clone(&svc);
1195 async move {
1196 #call_handler
1199 }
1200 }
1201 }),
1202 )
1203 }
1204 } else if client_streaming && !server_streaming {
1205 let output_type = resolver
1207 .rust_type(m.output_type.as_deref().unwrap_or(""), package)
1208 .unwrap();
1209 let into_items = router_stream_items_tokens(resolver, m, package);
1210 quote! {
1211 .route_view_client_stream(
1212 #service_name_const,
1213 #method_name,
1214 ::connectrpc::view_client_streaming_handler_fn({
1215 let svc = ::std::sync::Arc::clone(&self);
1216 move |ctx, req, format| {
1217 let svc = ::std::sync::Arc::clone(&svc);
1218 async move {
1219 #into_items
1220 svc.#method_snake(ctx, req).await?.encode::<#output_type>(format)
1221 }
1222 }
1223 }),
1224 )
1225 }
1226 } else if client_streaming && server_streaming {
1227 let output_type = resolver
1230 .rust_type(m.output_type.as_deref().unwrap_or(""), package)
1231 .unwrap();
1232 let into_items = router_stream_items_tokens(resolver, m, package);
1233 quote! {
1234 .route_view_bidi_stream::<_, _, #output_type>(
1235 #service_name_const,
1236 #method_name,
1237 ::connectrpc::view_bidi_streaming_handler_fn({
1238 let svc = ::std::sync::Arc::clone(&self);
1239 move |ctx, req| {
1240 let svc = ::std::sync::Arc::clone(&svc);
1241 async move {
1242 #into_items
1243 svc.#method_snake(ctx, req).await
1244 }
1245 }
1246 }),
1247 )
1248 }
1249 } else {
1250 let is_idempotent = m
1252 .options
1253 .idempotency_level
1254 .map(|level| level == IdempotencyLevel::NO_SIDE_EFFECTS)
1255 .unwrap_or(false);
1256
1257 let route_method = if is_idempotent {
1258 quote! { route_view_idempotent }
1259 } else {
1260 quote! { route_view }
1261 };
1262 let output_type = resolver
1263 .rust_type(m.output_type.as_deref().unwrap_or(""), package)
1264 .unwrap();
1265 let input_fqn = m.input_type.as_deref().unwrap_or("");
1269 let input_view = resolver.rust_view_type(input_fqn, package).unwrap();
1270 let input_owned = resolver.rust_type(input_fqn, package).unwrap();
1271 let call_handler = quote! {
1272 let sreq = ::connectrpc::ServiceRequest::<#input_owned>::from_parts(req.reborrow(), req.bytes());
1273 svc.#method_snake(ctx, sreq).await?.encode::<#output_type>(format)
1274 };
1275
1276 quote! {
1277 .#route_method(
1278 #service_name_const,
1279 #method_name,
1280 {
1281 let svc = ::std::sync::Arc::clone(&self);
1282 ::connectrpc::view_handler_fn(move |ctx, req: ::buffa::view::OwnedView<#input_view<'static>>, format| {
1283 let svc = ::std::sync::Arc::clone(&svc);
1284 async move {
1285 #call_handler
1288 }
1289 })
1290 },
1291 )
1292 }
1293 };
1294
1295 quote! {
1296 #route_call
1297 .with_spec(#spec_const)
1298 }
1299 })
1300 .collect();
1301
1302 let client_methods: Vec<TokenStream> = service
1304 .method
1305 .iter()
1306 .map(|m| {
1307 generate_client_method(
1308 &service_name_const,
1309 &full_service_name,
1310 m,
1311 resolver,
1312 package,
1313 )
1314 })
1315 .collect::<Result<Vec<_>>>()?;
1316
1317 let service_server = generate_service_server(
1319 &full_service_name,
1320 &trait_name,
1321 &server_name,
1322 service,
1323 resolver,
1324 package,
1325 )?;
1326
1327 let example_method = service
1329 .method
1330 .first()
1331 .and_then(|m| m.name.as_deref())
1332 .map(|n| make_field_ident(&n.to_snake_case()).to_string())
1333 .unwrap_or_else(|| "method".to_string());
1334
1335 let client_name_str = client_name.to_string();
1337 let client_doc = format!(
1338 r#"Client for this service.
1339
1340Generic over `T: ClientTransport`. For **gRPC** (HTTP/2), use
1341`Http2Connection` — it has honest `poll_ready` and composes with
1342`tower::balance` for multi-connection load balancing. For **Connect
1343over HTTP/1.1** (or unknown protocol), use `HttpClient`.
1344
1345# Example (gRPC / HTTP/2)
1346
1347```rust,ignore
1348use connectrpc::client::{{Http2Connection, ClientConfig}};
1349use connectrpc::Protocol;
1350
1351let uri: http::Uri = "http://localhost:8080".parse()?;
1352let conn = Http2Connection::connect_plaintext(uri.clone()).await?.shared(1024);
1353let config = ClientConfig::new(uri).with_protocol(Protocol::Grpc);
1354
1355let client = {client_name_str}::new(conn, config);
1356let response = client.{example_method}(request).await?;
1357```
1358
1359# Example (Connect / HTTP/1.1 or ALPN)
1360
1361```rust,ignore
1362use connectrpc::client::{{HttpClient, ClientConfig}};
1363
1364let http = HttpClient::plaintext(); // cleartext http:// only
1365let config = ClientConfig::new("http://localhost:8080".parse()?);
1366
1367let client = {client_name_str}::new(http, config);
1368let response = client.{example_method}(request).await?;
1369```
1370
1371# Working with the response
1372
1373Unary calls return [`UnaryResponse<OwnedView<FooView>>`](::connectrpc::client::UnaryResponse).
1374[`view()`](::connectrpc::client::UnaryResponse::view) borrows the response
1375message, so field access is zero-copy:
1376
1377```rust,ignore
1378let resp = client.{example_method}(request).await?;
1379let name: &str = resp.view().name; // borrow into the response buffer
1380```
1381
1382If you need the owned struct (e.g. to store or pass by value), use
1383[`into_owned()`](::connectrpc::client::UnaryResponse::into_owned):
1384
1385```rust,ignore
1386let owned = client.{example_method}(request).await?.into_owned();
1387```
1388
1389[`into_view()`](::connectrpc::client::UnaryResponse::into_view) keeps the
1390zero-copy decoded body (an `OwnedView`) without copying; field access on it
1391goes through `.reborrow()`. Streaming responses yield one `OwnedView` per
1392received message from `.message().await` — bind `msg.reborrow()` for field
1393access, or convert with `.to_owned_message()`."#
1394 );
1395 let client_doc_tokens = doc_attrs(&client_doc);
1396 let client_cfg_attr: TokenStream = if batch.gate_client_feature {
1404 quote! { #[cfg(feature = "client")] }
1405 } else {
1406 TokenStream::new()
1407 };
1408
1409 let spec_consts = generate_spec_consts(&full_service_name, service);
1413
1414 Ok(quote! {
1415 pub const #service_name_const: &str = #full_service_name;
1421
1422 #(#spec_consts)*
1423
1424 #service_doc_tokens
1425 #[allow(clippy::type_complexity)]
1426 pub trait #trait_name: Send + Sync + 'static {
1427 #(#trait_methods)*
1428 }
1429
1430 pub trait #ext_trait_name: #trait_name {
1443 fn register(self: ::std::sync::Arc<Self>, router: ::connectrpc::Router) -> ::connectrpc::Router;
1448 }
1449
1450 impl<S: #trait_name> #ext_trait_name for S {
1451 fn register(self: ::std::sync::Arc<Self>, router: ::connectrpc::Router) -> ::connectrpc::Router {
1452 router
1453 #(#route_registrations)*
1454 }
1455 }
1456
1457 #service_server
1458
1459 #client_doc_tokens
1460 #client_cfg_attr
1461 #[derive(Clone)]
1462 pub struct #client_name<T> {
1463 transport: T,
1464 config: ::connectrpc::client::ClientConfig,
1465 }
1466
1467 #client_cfg_attr
1468 impl<T> #client_name<T>
1469 where
1470 T: ::connectrpc::client::ClientTransport,
1471 <T::ResponseBody as ::http_body::Body>::Error: ::std::fmt::Display,
1472 {
1473 pub fn new(transport: T, config: ::connectrpc::client::ClientConfig) -> Self {
1475 Self { transport, config }
1476 }
1477
1478 pub fn config(&self) -> &::connectrpc::client::ClientConfig {
1480 &self.config
1481 }
1482
1483 pub fn config_mut(&mut self) -> &mut ::connectrpc::client::ClientConfig {
1485 &mut self.config
1486 }
1487
1488 #(#client_methods)*
1489 }
1490 })
1491}
1492
1493fn method_spec_const_ident(service: &ServiceDescriptorProto, method_name: &str) -> Ident {
1500 let service_name = service.name.as_deref().unwrap_or("");
1501 format_ident!(
1502 "{}_{}_SPEC",
1503 service_name.to_snake_case().to_uppercase(),
1504 method_name.to_snake_case().to_uppercase()
1505 )
1506}
1507
1508fn generate_spec_consts(
1517 full_service_name: &str,
1518 service: &ServiceDescriptorProto,
1519) -> Vec<TokenStream> {
1520 service
1521 .method
1522 .iter()
1523 .map(|m| {
1524 let method_name = m.name.as_deref().unwrap_or("");
1525 let spec_const = method_spec_const_ident(service, method_name);
1526 let procedure = format!("/{full_service_name}/{method_name}");
1527 let cs = m.client_streaming.unwrap_or(false);
1528 let ss = m.server_streaming.unwrap_or(false);
1529 let stream_type = match (cs, ss) {
1530 (true, true) => quote! { ::connectrpc::StreamType::BidiStream },
1531 (true, false) => quote! { ::connectrpc::StreamType::ClientStream },
1532 (false, true) => quote! { ::connectrpc::StreamType::ServerStream },
1533 (false, false) => quote! { ::connectrpc::StreamType::Unary },
1534 };
1535 let idempotency_level = match m.options.idempotency_level {
1536 Some(IdempotencyLevel::NO_SIDE_EFFECTS) => {
1537 quote! { ::connectrpc::IdempotencyLevel::NoSideEffects }
1538 }
1539 Some(IdempotencyLevel::IDEMPOTENT) => {
1540 quote! { ::connectrpc::IdempotencyLevel::Idempotent }
1541 }
1542 _ => quote! { ::connectrpc::IdempotencyLevel::Unknown },
1543 };
1544 let doc = format!(
1545 "Static [`Spec`](::connectrpc::Spec) for the server-side `{method_name}` RPC.\n\n\
1546 The dispatcher surfaces this on\n\
1547 [`RequestContext::spec`](::connectrpc::RequestContext::spec)."
1548 );
1549 let doc_tokens = doc_attrs(&doc);
1550 quote! {
1551 #doc_tokens
1552 pub const #spec_const: ::connectrpc::Spec =
1553 ::connectrpc::Spec::server(#procedure, #stream_type)
1554 .with_idempotency_level(#idempotency_level);
1555 }
1556 })
1557 .collect()
1558}
1559
1560fn generate_service_server(
1567 full_service_name: &str,
1568 trait_name: &proc_macro2::Ident,
1569 server_name: &proc_macro2::Ident,
1570 service: &ServiceDescriptorProto,
1571 resolver: &TypeResolver<'_>,
1572 package: &str,
1573) -> Result<TokenStream> {
1574 let path_prefix = format!("{full_service_name}/");
1576
1577 let lookup_arms: Vec<TokenStream> = service
1579 .method
1580 .iter()
1581 .map(|m| {
1582 let method_name = m.name.as_deref().unwrap_or("");
1583 let client_streaming = m.client_streaming.unwrap_or(false);
1584 let server_streaming = m.server_streaming.unwrap_or(false);
1585 let is_idempotent = m
1586 .options
1587 .idempotency_level
1588 .map(|level| level == IdempotencyLevel::NO_SIDE_EFFECTS)
1589 .unwrap_or(false);
1590 let spec_const = method_spec_const_ident(service, method_name);
1591
1592 let desc = if client_streaming && server_streaming {
1593 quote! { ::connectrpc::dispatcher::codegen::MethodDescriptor::bidi_streaming() }
1594 } else if client_streaming {
1595 quote! { ::connectrpc::dispatcher::codegen::MethodDescriptor::client_streaming() }
1596 } else if server_streaming {
1597 quote! { ::connectrpc::dispatcher::codegen::MethodDescriptor::server_streaming() }
1598 } else {
1599 quote! { ::connectrpc::dispatcher::codegen::MethodDescriptor::unary(#is_idempotent) }
1600 };
1601 quote! { #method_name => Some(#desc.with_spec(#spec_const)), }
1602 })
1603 .collect();
1604
1605 let mut call_unary_arms: Vec<TokenStream> = Vec::new();
1610 let mut call_ss_arms: Vec<TokenStream> = Vec::new();
1611 let mut call_cs_arms: Vec<TokenStream> = Vec::new();
1612 let mut call_bidi_arms: Vec<TokenStream> = Vec::new();
1613
1614 for m in &service.method {
1615 let method_name = m.name.as_deref().unwrap_or("");
1616 let method_snake = make_field_ident(&method_name.to_snake_case());
1617 let input_view = resolver.rust_view_type(m.input_type.as_deref().unwrap_or(""), package)?;
1618 let output_type = resolver.rust_type(m.output_type.as_deref().unwrap_or(""), package)?;
1619 let cs = m.client_streaming.unwrap_or(false);
1620 let ss = m.server_streaming.unwrap_or(false);
1621
1622 let stream_decode = {
1625 let input_fqn = m.input_type.as_deref().unwrap_or("");
1626 let input_owned = resolver.rust_type(input_fqn, package)?;
1627 quote! { ::connectrpc::dispatcher::codegen::decode_message_request_stream::<#input_owned>(requests, format) }
1628 };
1629
1630 if cs && ss {
1631 call_bidi_arms.push(quote! {
1633 #method_name => {
1634 let svc = ::std::sync::Arc::clone(&self.inner);
1635 Box::pin(async move {
1636 let req_stream = #stream_decode;
1637 let resp = svc.#method_snake(ctx, req_stream).await?;
1638 Ok(resp.map_body(|s| ::connectrpc::dispatcher::codegen::encode_response_stream::<#output_type, _, _>(s, format)))
1639 })
1640 }
1641 });
1642 } else if cs {
1643 call_cs_arms.push(quote! {
1645 #method_name => {
1646 let svc = ::std::sync::Arc::clone(&self.inner);
1647 Box::pin(async move {
1648 let req_stream = #stream_decode;
1649 svc.#method_snake(ctx, req_stream).await?.encode::<#output_type>(format)
1650 })
1651 }
1652 });
1653 } else if ss {
1654 let input_fqn = m.input_type.as_deref().unwrap_or("");
1656 let input_owned = resolver.rust_type(input_fqn, package)?;
1657 let call_handler = quote! {
1658 let req = ::connectrpc::ServiceRequest::<#input_owned>::from_parts(&req, &body);
1659 let resp = svc.#method_snake(ctx, req).await?;
1660 };
1661 call_ss_arms.push(quote! {
1662 #method_name => {
1663 let svc = ::std::sync::Arc::clone(&self.inner);
1664 Box::pin(async move {
1665 let body = ::connectrpc::dispatcher::codegen::request_proto_bytes::<#input_owned>(request, format)?;
1668 let req: #input_view<'_> = ::connectrpc::dispatcher::codegen::decode_borrowed_request_view(&body)?;
1669 #call_handler
1670 Ok(resp.map_body(|s| ::connectrpc::dispatcher::codegen::encode_response_stream::<#output_type, _, _>(s, format)))
1671 })
1672 }
1673 });
1674 } else {
1675 let input_fqn = m.input_type.as_deref().unwrap_or("");
1677 let input_owned = resolver.rust_type(input_fqn, package)?;
1678 let call_handler = quote! {
1679 let req = ::connectrpc::ServiceRequest::<#input_owned>::from_parts(&req, &body);
1680 svc.#method_snake(ctx, req).await?.encode::<#output_type>(format)
1681 };
1682 call_unary_arms.push(quote! {
1683 #method_name => {
1684 let svc = ::std::sync::Arc::clone(&self.inner);
1685 Box::pin(async move {
1686 let body = ::connectrpc::dispatcher::codegen::request_proto_bytes::<#input_owned>(request.encoded()?, format)?;
1693 let req: #input_view<'_> = ::connectrpc::dispatcher::codegen::decode_borrowed_request_view(&body)?;
1694 #call_handler
1695 })
1696 }
1697 });
1698 }
1699 }
1700
1701 let server_doc = format!(
1702 "Monomorphic dispatcher for `{trait_name}`.\n\n\
1703 Unlike `.register(Router)` which type-erases each method into an \
1704 `Arc<dyn ErasedHandler>` stored in a `HashMap`, this struct dispatches \
1705 via a compile-time `match` on method name: no vtable, no hash lookup.\n\n\
1706 # Example\n\n\
1707 ```rust,ignore\n\
1708 use connectrpc::ConnectRpcService;\n\n\
1709 let server = {server_name}::new(MyImpl);\n\
1710 let service = ConnectRpcService::new(server);\n\
1711 // hand `service` to axum/hyper as a fallback_service\n\
1712 ```"
1713 );
1714 let server_doc_tokens = doc_attrs(&server_doc);
1715
1716 Ok(quote! {
1717 #server_doc_tokens
1718 pub struct #server_name<T> {
1719 inner: ::std::sync::Arc<T>,
1720 }
1721
1722 impl<T: #trait_name> #server_name<T> {
1723 pub fn new(service: T) -> Self {
1725 Self { inner: ::std::sync::Arc::new(service) }
1726 }
1727
1728 pub fn from_arc(inner: ::std::sync::Arc<T>) -> Self {
1730 Self { inner }
1731 }
1732 }
1733
1734 impl<T> Clone for #server_name<T> {
1735 fn clone(&self) -> Self {
1736 Self { inner: ::std::sync::Arc::clone(&self.inner) }
1737 }
1738 }
1739
1740 impl<T: #trait_name> ::connectrpc::Dispatcher for #server_name<T> {
1741 #[inline]
1742 fn lookup(&self, path: &str) -> Option<::connectrpc::dispatcher::codegen::MethodDescriptor> {
1743 let method = path.strip_prefix(#path_prefix)?;
1744 match method {
1745 #(#lookup_arms)*
1746 _ => None,
1747 }
1748 }
1749
1750 fn call_unary(
1751 &self,
1752 path: &str,
1753 ctx: ::connectrpc::RequestContext,
1754 request: ::connectrpc::Payload,
1755 format: ::connectrpc::CodecFormat,
1756 ) -> ::connectrpc::dispatcher::codegen::UnaryResult {
1757 let Some(method) = path.strip_prefix(#path_prefix) else {
1758 return ::connectrpc::dispatcher::codegen::unimplemented_unary(path);
1759 };
1760 let _ = (&ctx, &request, &format);
1762 match method {
1763 #(#call_unary_arms)*
1764 _ => ::connectrpc::dispatcher::codegen::unimplemented_unary(path),
1765 }
1766 }
1767
1768 fn call_server_streaming(
1769 &self,
1770 path: &str,
1771 ctx: ::connectrpc::RequestContext,
1772 request: ::buffa::bytes::Bytes,
1773 format: ::connectrpc::CodecFormat,
1774 ) -> ::connectrpc::dispatcher::codegen::StreamingResult {
1775 let Some(method) = path.strip_prefix(#path_prefix) else {
1776 return ::connectrpc::dispatcher::codegen::unimplemented_streaming(path);
1777 };
1778 let _ = (&ctx, &request, &format);
1779 match method {
1780 #(#call_ss_arms)*
1781 _ => ::connectrpc::dispatcher::codegen::unimplemented_streaming(path),
1782 }
1783 }
1784
1785 fn call_client_streaming(
1786 &self,
1787 path: &str,
1788 ctx: ::connectrpc::RequestContext,
1789 requests: ::connectrpc::dispatcher::codegen::RequestStream,
1790 format: ::connectrpc::CodecFormat,
1791 ) -> ::connectrpc::dispatcher::codegen::UnaryResult {
1792 let Some(method) = path.strip_prefix(#path_prefix) else {
1793 return ::connectrpc::dispatcher::codegen::unimplemented_unary(path);
1794 };
1795 let _ = (&ctx, &requests, &format);
1796 match method {
1797 #(#call_cs_arms)*
1798 _ => ::connectrpc::dispatcher::codegen::unimplemented_unary(path),
1799 }
1800 }
1801
1802 fn call_bidi_streaming(
1803 &self,
1804 path: &str,
1805 ctx: ::connectrpc::RequestContext,
1806 requests: ::connectrpc::dispatcher::codegen::RequestStream,
1807 format: ::connectrpc::CodecFormat,
1808 ) -> ::connectrpc::dispatcher::codegen::StreamingResult {
1809 let Some(method) = path.strip_prefix(#path_prefix) else {
1810 return ::connectrpc::dispatcher::codegen::unimplemented_streaming(path);
1811 };
1812 let _ = (&ctx, &requests, &format);
1813 match method {
1814 #(#call_bidi_arms)*
1815 _ => ::connectrpc::dispatcher::codegen::unimplemented_streaming(path),
1816 }
1817 }
1818 }
1819 })
1820}
1821
1822fn generate_doc_comment(doc: &str, default: &str) -> TokenStream {
1824 let comment = if doc.is_empty() { default } else { doc };
1825 doc_attrs(comment)
1826}
1827
1828fn generate_trait_method(
1830 file: &FileDescriptorProto,
1831 service: &ServiceDescriptorProto,
1832 method: &MethodDescriptorProto,
1833 resolver: &TypeResolver<'_>,
1834 package: &str,
1835) -> Result<TokenStream> {
1836 let method_name = method.name.as_deref().unwrap_or("");
1837 let method_snake = make_field_ident(&method_name.to_snake_case());
1838 let output_type = resolver.rust_type(method.output_type.as_deref().unwrap_or(""), package)?;
1839
1840 let method_doc = get_method_comment(file, service, method).unwrap_or_default();
1842 let method_doc_tokens =
1843 generate_doc_comment(&method_doc, &format!("Handle the {method_name} RPC."));
1844
1845 let client_streaming = method.client_streaming.unwrap_or(false);
1847 let server_streaming = method.server_streaming.unwrap_or(false);
1848
1849 let borrow_doc = quote! {
1850 #[doc = ""]
1851 #[doc = " `'a` lets the response body borrow from `&self` (e.g. server-resident state)."]
1852 };
1853
1854 if server_streaming && !client_streaming {
1855 let input_fqn = method.input_type.as_deref().unwrap_or("");
1866 let input_owned = resolver.rust_type(input_fqn, package)?;
1867 let request_param = quote! { ::connectrpc::ServiceRequest<'_, #input_owned> };
1868 let request_doc = quote! {
1869 #[doc = ""]
1870 #[doc = " `request` is borrowed from the request body and is valid for the"]
1871 #[doc = " duration of the call (until the response stream is returned);"]
1872 #[doc = " message fields are read directly on it (zero-copy). Data the"]
1873 #[doc = " returned stream needs must be copied out or converted via"]
1874 #[doc = " `.to_owned_message()`."]
1875 };
1876 Ok(quote! {
1877 #method_doc_tokens
1878 #request_doc
1879 fn #method_snake(
1880 &self,
1881 ctx: ::connectrpc::RequestContext,
1882 request: #request_param,
1883 ) -> impl ::std::future::Future<Output = ::connectrpc::ServiceResult<::connectrpc::ServiceStream<impl ::connectrpc::Encodable<#output_type> + Send + use<Self>>>> + Send;
1884 })
1885 } else if client_streaming && !server_streaming {
1886 let stream_item_arg = stream_item_arg(resolver, method, package)?;
1892 let items_doc = stream_items_doc(method);
1893 Ok(quote! {
1894 #method_doc_tokens
1895 #borrow_doc
1896 #items_doc
1897 fn #method_snake<'a>(
1898 &'a self,
1899 ctx: ::connectrpc::RequestContext,
1900 requests: ::connectrpc::ServiceStream<#stream_item_arg>,
1901 ) -> impl ::std::future::Future<Output = ::connectrpc::ServiceResult<impl ::connectrpc::Encodable<#output_type> + Send + use<'a, Self>>> + Send;
1902 })
1903 } else if client_streaming && server_streaming {
1904 let stream_item_arg = stream_item_arg(resolver, method, package)?;
1908 let items_doc = stream_items_doc(method);
1909 Ok(quote! {
1910 #method_doc_tokens
1911 #items_doc
1912 fn #method_snake(
1913 &self,
1914 ctx: ::connectrpc::RequestContext,
1915 requests: ::connectrpc::ServiceStream<#stream_item_arg>,
1916 ) -> impl ::std::future::Future<Output = ::connectrpc::ServiceResult<::connectrpc::ServiceStream<impl ::connectrpc::Encodable<#output_type> + Send + use<Self>>>> + Send;
1917 })
1918 } else {
1919 let input_fqn = method.input_type.as_deref().unwrap_or("");
1930 let input_owned = resolver.rust_type(input_fqn, package)?;
1931 let request_param = quote! { ::connectrpc::ServiceRequest<'_, #input_owned> };
1932 let request_doc = quote! {
1933 #[doc = ""]
1934 #[doc = " `request` is borrowed from the request body and is valid for the"]
1935 #[doc = " duration of the call; message fields are read directly on it"]
1936 #[doc = " (zero-copy). The response cannot borrow from `request` — use"]
1937 #[doc = " `.to_owned_message()` (or copy the specific fields) for anything"]
1938 #[doc = " returned, stored, or moved into `tokio::spawn`."]
1939 };
1940 Ok(quote! {
1941 #method_doc_tokens
1942 #borrow_doc
1943 #request_doc
1944 fn #method_snake<'a>(
1945 &'a self,
1946 ctx: ::connectrpc::RequestContext,
1947 request: #request_param,
1948 ) -> impl ::std::future::Future<Output = ::connectrpc::ServiceResult<impl ::connectrpc::Encodable<#output_type> + Send + use<'a, Self>>> + Send;
1949 })
1950 }
1951}
1952
1953fn generate_client_method(
1964 service_name_const: &Ident,
1965 full_service_name: &str,
1966 method: &MethodDescriptorProto,
1967 resolver: &TypeResolver<'_>,
1968 package: &str,
1969) -> Result<TokenStream> {
1970 let method_name = method.name.as_deref().unwrap_or("");
1971 let method_snake = make_field_ident(&method_name.to_snake_case());
1972 let method_with_opts = format_ident!("{}_with_options", method_name.to_snake_case());
1973 let input_type = resolver.rust_type(method.input_type.as_deref().unwrap_or(""), package)?;
1974 let output_view_type =
1975 resolver.rust_view_type(method.output_type.as_deref().unwrap_or(""), package)?;
1976
1977 let client_streaming = method.client_streaming.unwrap_or(false);
1978 let server_streaming = method.server_streaming.unwrap_or(false);
1979
1980 let doc = format!(
1981 " Call the {method_name} RPC. Sends a request to /{full_service_name}/{method_name}."
1982 );
1983 let doc_opts = format!(
1984 " Call the {method_name} RPC with explicit per-call options. \
1985 Options override [`ClientConfig`](::connectrpc::client::ClientConfig) defaults."
1986 );
1987
1988 let ret_ty: TokenStream;
1990 let call_body: TokenStream;
1991 let short_args: TokenStream; let opts_args: TokenStream; let short_delegate_args: TokenStream; if client_streaming && !server_streaming {
1996 ret_ty = quote! {
1998 Result<
1999 ::connectrpc::client::UnaryResponse<::buffa::view::OwnedView<#output_view_type<'static>>>,
2000 ::connectrpc::ConnectError,
2001 >
2002 };
2003 call_body = quote! {
2004 ::connectrpc::client::call_client_stream(
2005 &self.transport, &self.config,
2006 #service_name_const, #method_name,
2007 requests, options,
2008 ).await
2009 };
2010 short_args = quote! { requests: impl IntoIterator<Item = #input_type> };
2011 opts_args = quote! { requests: impl IntoIterator<Item = #input_type>, options: ::connectrpc::client::CallOptions };
2012 short_delegate_args = quote! { requests, ::connectrpc::client::CallOptions::default() };
2013 } else if client_streaming && server_streaming {
2014 ret_ty = quote! {
2016 Result<
2017 ::connectrpc::client::BidiStream<
2018 T::ResponseBody, #input_type, #output_view_type<'static>
2019 >,
2020 ::connectrpc::ConnectError,
2021 >
2022 };
2023 call_body = quote! {
2024 ::connectrpc::client::call_bidi_stream(
2025 &self.transport, &self.config,
2026 #service_name_const, #method_name, options,
2027 ).await
2028 };
2029 short_args = quote! {};
2030 opts_args = quote! { options: ::connectrpc::client::CallOptions };
2031 short_delegate_args = quote! { ::connectrpc::client::CallOptions::default() };
2032 } else if server_streaming {
2033 ret_ty = quote! {
2035 Result<
2036 ::connectrpc::client::ServerStream<T::ResponseBody, #output_view_type<'static>>,
2037 ::connectrpc::ConnectError,
2038 >
2039 };
2040 call_body = quote! {
2041 ::connectrpc::client::call_server_stream(
2042 &self.transport, &self.config,
2043 #service_name_const, #method_name,
2044 request, options,
2045 ).await
2046 };
2047 short_args = quote! { request: #input_type };
2048 opts_args = quote! { request: #input_type, options: ::connectrpc::client::CallOptions };
2049 short_delegate_args = quote! { request, ::connectrpc::client::CallOptions::default() };
2050 } else {
2051 ret_ty = quote! {
2053 Result<
2054 ::connectrpc::client::UnaryResponse<::buffa::view::OwnedView<#output_view_type<'static>>>,
2055 ::connectrpc::ConnectError,
2056 >
2057 };
2058 call_body = quote! {
2059 ::connectrpc::client::call_unary(
2060 &self.transport, &self.config,
2061 #service_name_const, #method_name,
2062 request, options,
2063 ).await
2064 };
2065 short_args = quote! { request: #input_type };
2066 opts_args = quote! { request: #input_type, options: ::connectrpc::client::CallOptions };
2067 short_delegate_args = quote! { request, ::connectrpc::client::CallOptions::default() };
2068 }
2069
2070 Ok(quote! {
2071 #[doc = #doc]
2072 pub async fn #method_snake(&self, #short_args) -> #ret_ty {
2073 self.#method_with_opts(#short_delegate_args).await
2074 }
2075
2076 #[doc = #doc_opts]
2077 pub async fn #method_with_opts(&self, #opts_args) -> #ret_ty {
2078 #call_body
2079 }
2080 })
2081}
2082
2083fn get_service_comment(
2085 file: &FileDescriptorProto,
2086 service: &ServiceDescriptorProto,
2087) -> Option<String> {
2088 let source_info: &SourceCodeInfo = &file.source_code_info;
2090
2091 let service_index = file.service.iter().position(|s| s.name == service.name)?;
2093
2094 let target_path = vec![6, service_index as i32];
2097
2098 find_comment(source_info, &target_path)
2099}
2100
2101fn get_method_comment(
2103 file: &FileDescriptorProto,
2104 service: &ServiceDescriptorProto,
2105 method: &MethodDescriptorProto,
2106) -> Option<String> {
2107 let source_info: &SourceCodeInfo = &file.source_code_info;
2108
2109 let (service_index, method_index) = file.service.iter().enumerate().find_map(|(si, s)| {
2112 if s.name != service.name {
2113 return None;
2114 }
2115 s.method
2116 .iter()
2117 .position(|m| m.name == method.name)
2118 .map(|mi| (si, mi))
2119 })?;
2120
2121 let target_path = vec![6, service_index as i32, 2, method_index as i32];
2125
2126 find_comment(source_info, &target_path)
2127}
2128
2129fn find_comment(source_info: &SourceCodeInfo, target_path: &[i32]) -> Option<String> {
2131 for location in &source_info.location {
2132 if location.path == target_path {
2133 let comment = location
2134 .leading_comments
2135 .as_ref()
2136 .or(location.trailing_comments.as_ref())?;
2137
2138 let cleaned: String = comment
2142 .lines()
2143 .map(|line| line.trim())
2144 .filter(|line| !line.is_empty())
2145 .collect::<Vec<_>>()
2146 .join("\n");
2147
2148 if !cleaned.is_empty() {
2149 return Some(cleaned);
2150 }
2151 }
2152 }
2153 None
2154}
2155
2156#[cfg(test)]
2157mod tests {
2158 use super::*;
2159 use buffa_codegen::generated::descriptor::DescriptorProto;
2160 use quote::ToTokens;
2161
2162 #[test]
2163 fn doc_attrs_prefixes_space_for_prettyplease() {
2164 let ts = quote! {
2167 #[allow(dead_code)]
2168 mod m {}
2169 };
2170 let doc = doc_attrs("Hello.\n\nSecond paragraph.");
2171 let combined = quote! { #doc #ts };
2172 let file = syn::parse2::<syn::File>(combined).unwrap();
2173 let out = prettyplease::unparse(&file);
2174 assert!(out.contains("/// Hello."), "got: {out}");
2176 assert!(out.contains("/// Second paragraph."), "got: {out}");
2177 assert!(out.contains("///\n"), "got: {out}");
2179 assert!(!out.contains("///Hello"), "got: {out}");
2181 assert!(!out.contains("/// Hello"), "got: {out}");
2182 }
2183
2184 fn minimal_file(
2189 package: Option<&str>,
2190 input_type: &str,
2191 output_type: &str,
2192 local_messages: &[&str],
2193 ) -> FileDescriptorProto {
2194 minimal_file_with_method(package, "Ping", input_type, output_type, local_messages)
2195 }
2196
2197 fn minimal_file_with_method(
2200 package: Option<&str>,
2201 method_name: &str,
2202 input_type: &str,
2203 output_type: &str,
2204 local_messages: &[&str],
2205 ) -> FileDescriptorProto {
2206 let method = MethodDescriptorProto {
2207 name: Some(method_name.into()),
2208 input_type: Some(input_type.into()),
2209 output_type: Some(output_type.into()),
2210 ..Default::default()
2211 };
2212 let service = ServiceDescriptorProto {
2213 name: Some("PingService".into()),
2214 method: vec![method],
2215 ..Default::default()
2216 };
2217 FileDescriptorProto {
2218 name: Some("ping.proto".into()),
2219 package: package.map(|p| p.into()),
2220 service: vec![service],
2221 message_type: local_messages
2222 .iter()
2223 .map(|name| DescriptorProto {
2224 name: Some((*name).into()),
2225 ..Default::default()
2226 })
2227 .collect(),
2228 ..Default::default()
2229 }
2230 }
2231
2232 fn minimal_file_with_methods(package: &str, method_names: &[&str]) -> FileDescriptorProto {
2236 let methods = method_names
2237 .iter()
2238 .map(|n| MethodDescriptorProto {
2239 name: Some((*n).into()),
2240 input_type: Some(format!(".{package}.Empty")),
2241 output_type: Some(format!(".{package}.Empty")),
2242 ..Default::default()
2243 })
2244 .collect();
2245 let service = ServiceDescriptorProto {
2246 name: Some("PingService".into()),
2247 method: methods,
2248 ..Default::default()
2249 };
2250 FileDescriptorProto {
2251 name: Some("ping.proto".into()),
2252 package: Some(package.into()),
2253 service: vec![service],
2254 message_type: vec![DescriptorProto {
2255 name: Some("Empty".into()),
2256 ..Default::default()
2257 }],
2258 ..Default::default()
2259 }
2260 }
2261
2262 fn gen_service(
2271 files: &[FileDescriptorProto],
2272 target_idx: usize,
2273 extern_paths: &[(String, String)],
2274 require_extern: bool,
2275 ) -> Result<String> {
2276 let mut config = buffa_codegen::CodeGenConfig::default();
2277 config.extern_paths = extern_paths.to_vec();
2278 let target_name = files[target_idx]
2279 .name
2280 .clone()
2281 .into_iter()
2282 .collect::<Vec<_>>();
2283 let resolver = TypeResolver::new(files, &target_name, &config, require_extern);
2284 let file = &files[target_idx];
2285 let service = &file.service[0];
2286 let batch = BatchState {
2287 colliding_aliases: collect_alias_collisions(files, &target_name),
2288 ..BatchState::default()
2289 };
2290 Ok(generate_service(file, service, &resolver, &batch)?.to_string())
2291 }
2292
2293 fn assert_no_top_level_use(formatted: &str, label: &str) {
2298 let parsed: syn::File = syn::parse_str(formatted).expect("formatted code parses");
2299 let offenders: Vec<String> = parsed
2300 .items
2301 .iter()
2302 .filter_map(|item| match item {
2303 syn::Item::Use(u) => Some(quote!(#u).to_string()),
2304 _ => None,
2305 })
2306 .collect();
2307 assert!(
2308 offenders.is_empty(),
2309 "{label} contains top-level use statement(s): {offenders:?}\nFull source:\n{formatted}"
2310 );
2311 }
2312
2313 fn gen_file(
2314 files: &[FileDescriptorProto],
2315 target_idx: usize,
2316 extern_paths: &[(String, String)],
2317 require_extern: bool,
2318 ) -> Result<String> {
2319 let mut config = buffa_codegen::CodeGenConfig::default();
2320 config.extern_paths = extern_paths.to_vec();
2321 let target_name = files[target_idx]
2322 .name
2323 .clone()
2324 .into_iter()
2325 .collect::<Vec<_>>();
2326 let resolver = TypeResolver::new(files, &target_name, &config, require_extern);
2327 let mut batch = BatchState {
2328 colliding_aliases: collect_alias_collisions(files, &target_name),
2329 ..BatchState::default()
2330 };
2331 Ok(generate_connect_services(&files[target_idx], &resolver, &mut batch)?.to_string())
2332 }
2333
2334 #[test]
2335 fn unary_response_body_captures_self_lifetime() {
2336 let file = minimal_file(
2337 Some("example.v1"),
2338 ".example.v1.PingReq",
2339 ".example.v1.PingResp",
2340 &["PingReq", "PingResp"],
2341 );
2342 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
2343 assert!(code.contains("< 'a >"), "trait method missing 'a: {code}");
2344 assert!(code.contains("& 'a self"), "missing &'a self: {code}");
2345 assert!(
2346 code.contains("use < 'a , Self >"),
2347 "missing use<'a, Self> capture: {code}"
2348 );
2349 assert!(
2350 !code.contains("'static + use"),
2351 "'static bound on body should be dropped: {code}"
2352 );
2353 }
2354
2355 #[test]
2356 fn owned_view_aliases_emitted_for_input_and_output() {
2357 let file = minimal_file(
2358 Some("example.v1"),
2359 ".example.v1.PingReq",
2360 ".example.v1.PingResp",
2361 &["PingReq", "PingResp"],
2362 );
2363 let code = gen_file(std::slice::from_ref(&file), 0, &[], false).unwrap();
2364 assert!(
2365 code.contains("pub type OwnedPingReqView = :: buffa :: view :: OwnedView"),
2366 "missing OwnedPingReqView alias: {code}"
2367 );
2368 assert!(
2369 code.contains("pub type OwnedPingRespView = :: buffa :: view :: OwnedView"),
2370 "missing OwnedPingRespView alias: {code}"
2371 );
2372 assert!(
2377 code.contains("request : :: connectrpc :: ServiceRequest < '_"),
2378 "unary trait method should take request: ServiceRequest<'_, PingReq>: {code}"
2379 );
2380 assert!(
2384 !code.contains("impl :: connectrpc :: HasMessageView for"),
2385 "connect-codegen must not emit view-family impls (buffa does): {code}"
2386 );
2387 }
2388
2389 #[test]
2390 fn cross_package_input_collision_suppresses_alias_for_both_sides() {
2391 let v1 = FileDescriptorProto {
2399 name: Some("api/v1/foo/bar/foobar.proto".into()),
2400 package: Some("api.v1.foo.bar".into()),
2401 message_type: vec![DescriptorProto {
2402 name: Some("MyMessage".into()),
2403 ..Default::default()
2404 }],
2405 ..Default::default()
2406 };
2407 let v2 = minimal_file(
2408 Some("api.v2.foo.bar"),
2409 ".api.v1.foo.bar.MyMessage",
2410 ".api.v2.foo.bar.MyMessage",
2411 &["MyMessage"],
2412 );
2413 let code = gen_file(&[v1, v2], 1, &[], false).unwrap();
2414
2415 let alias_count = code.matches("pub type OwnedMyMessageView").count();
2418 assert_eq!(
2419 alias_count, 0,
2420 "expected zero OwnedMyMessageView aliases when both sides collide; got {alias_count}: {code}"
2421 );
2422
2423 assert!(
2426 !code.contains("request : OwnedMyMessageView"),
2427 "colliding input must not reference the suppressed alias: {code}"
2428 );
2429 assert!(
2432 code.contains("request : :: connectrpc :: ServiceRequest < '_"),
2433 "colliding unary input should still use ServiceRequest: {code}"
2434 );
2435 }
2436
2437 #[test]
2438 fn cross_package_input_without_collision_keeps_alias() {
2439 let wkt = FileDescriptorProto {
2446 name: Some("google/protobuf/empty.proto".into()),
2447 package: Some("google.protobuf".into()),
2448 message_type: vec![DescriptorProto {
2449 name: Some("Empty".into()),
2450 ..Default::default()
2451 }],
2452 ..Default::default()
2453 };
2454 let svc = minimal_file(
2455 Some("example.v1"),
2456 ".google.protobuf.Empty",
2457 ".example.v1.PingResp",
2458 &["PingResp"],
2459 );
2460 let code = gen_file(&[wkt, svc], 1, &[], false).unwrap();
2461 assert!(
2462 code.contains("pub type OwnedEmptyView = :: buffa :: view :: OwnedView"),
2463 "WKT cross-package input should keep its alias: {code}"
2464 );
2465 assert!(
2471 code.contains(
2472 "request : :: connectrpc :: ServiceRequest < '_ , :: buffa_types :: google :: protobuf :: Empty >"
2473 ),
2474 "extern unary input should use ServiceRequest over the extern owned type: {code}"
2475 );
2476 }
2477
2478 #[test]
2479 fn collision_inlines_in_all_streaming_method_shapes() {
2480 let v1 = FileDescriptorProto {
2486 name: Some("api/v1/foo/bar/foobar.proto".into()),
2487 package: Some("api.v1.foo.bar".into()),
2488 message_type: vec![DescriptorProto {
2489 name: Some("MyMessage".into()),
2490 ..Default::default()
2491 }],
2492 ..Default::default()
2493 };
2494 let v2 = FileDescriptorProto {
2495 name: Some("api/v2/foo/bar/foobar.proto".into()),
2496 package: Some("api.v2.foo.bar".into()),
2497 message_type: vec![DescriptorProto {
2498 name: Some("MyMessage".into()),
2499 ..Default::default()
2500 }],
2501 service: vec![ServiceDescriptorProto {
2502 name: Some("FooBar".into()),
2503 method: vec![
2504 MethodDescriptorProto {
2505 name: Some("Unary".into()),
2506 input_type: Some(".api.v1.foo.bar.MyMessage".into()),
2507 output_type: Some(".api.v2.foo.bar.MyMessage".into()),
2508 ..Default::default()
2509 },
2510 MethodDescriptorProto {
2511 name: Some("ServerStream".into()),
2512 input_type: Some(".api.v1.foo.bar.MyMessage".into()),
2513 output_type: Some(".api.v2.foo.bar.MyMessage".into()),
2514 server_streaming: Some(true),
2515 ..Default::default()
2516 },
2517 MethodDescriptorProto {
2518 name: Some("ClientStream".into()),
2519 input_type: Some(".api.v1.foo.bar.MyMessage".into()),
2520 output_type: Some(".api.v2.foo.bar.MyMessage".into()),
2521 client_streaming: Some(true),
2522 ..Default::default()
2523 },
2524 MethodDescriptorProto {
2525 name: Some("Bidi".into()),
2526 input_type: Some(".api.v1.foo.bar.MyMessage".into()),
2527 output_type: Some(".api.v2.foo.bar.MyMessage".into()),
2528 client_streaming: Some(true),
2529 server_streaming: Some(true),
2530 ..Default::default()
2531 },
2532 ],
2533 ..Default::default()
2534 }],
2535 ..Default::default()
2536 };
2537 let code = gen_file(&[v1, v2], 1, &[], false).unwrap();
2538
2539 assert!(
2541 !code.contains("OwnedMyMessageView"),
2542 "no method shape should reference the suppressed alias: {code}"
2543 );
2544
2545 assert!(
2548 code.matches("request : :: connectrpc :: ServiceRequest < '_")
2549 .count()
2550 >= 2,
2551 "unary and server-streaming should take the borrowed ServiceRequest form: {code}"
2552 );
2553 assert!(
2556 code.matches(
2557 "requests : :: connectrpc :: ServiceStream < :: connectrpc :: StreamMessage <"
2558 )
2559 .count()
2560 >= 2,
2561 "client-streaming and bidi should both take StreamMessage items: {code}"
2562 );
2563 }
2564
2565 #[test]
2566 fn streaming_methods_use_encodable_item_type() {
2567 let file = FileDescriptorProto {
2575 name: Some("ex/v1/svc.proto".into()),
2576 package: Some("ex.v1".into()),
2577 message_type: vec![
2578 DescriptorProto {
2579 name: Some("Req".into()),
2580 ..Default::default()
2581 },
2582 DescriptorProto {
2583 name: Some("Resp".into()),
2584 ..Default::default()
2585 },
2586 ],
2587 service: vec![ServiceDescriptorProto {
2588 name: Some("Svc".into()),
2589 method: vec![
2590 MethodDescriptorProto {
2591 name: Some("ServerStream".into()),
2592 input_type: Some(".ex.v1.Req".into()),
2593 output_type: Some(".ex.v1.Resp".into()),
2594 server_streaming: Some(true),
2595 ..Default::default()
2596 },
2597 MethodDescriptorProto {
2598 name: Some("Bidi".into()),
2599 input_type: Some(".ex.v1.Req".into()),
2600 output_type: Some(".ex.v1.Resp".into()),
2601 client_streaming: Some(true),
2602 server_streaming: Some(true),
2603 ..Default::default()
2604 },
2605 ],
2606 ..Default::default()
2607 }],
2608 ..Default::default()
2609 };
2610 let code = gen_file(std::slice::from_ref(&file), 0, &[], false).unwrap();
2611
2612 assert_eq!(
2614 code.matches(":: connectrpc :: ServiceStream < impl :: connectrpc :: Encodable < Resp > + Send + use < Self >>")
2615 .count(),
2616 2,
2617 "server-streaming and bidi should both use the Encodable item type: {code}"
2618 );
2619
2620 assert_eq!(
2622 code.matches("encode_response_stream :: < Resp , _ , _ >")
2623 .count(),
2624 2,
2625 "dispatcher arms must turbofish Res to encode_response_stream: {code}"
2626 );
2627
2628 assert!(
2630 code.contains("route_view_server_stream :: < _ , _ , Resp >"),
2631 "route_view_server_stream must turbofish Res: {code}"
2632 );
2633 assert!(
2634 code.contains("route_view_bidi_stream :: < _ , _ , Resp >"),
2635 "route_view_bidi_stream must turbofish Res: {code}"
2636 );
2637 }
2638
2639 #[test]
2640 fn encodable_view_impls_emitted_per_output_type() {
2641 let file = minimal_file(
2642 Some("example.v1"),
2643 ".example.v1.PingReq",
2644 ".example.v1.PingResp",
2645 &["PingReq", "PingResp"],
2646 );
2647 let code = gen_file(std::slice::from_ref(&file), 0, &[], false).unwrap();
2648 assert!(
2649 code.contains(
2650 ":: connectrpc :: Encodable < PingResp > for __buffa :: view :: PingRespView"
2651 ),
2652 "missing Encodable<PingResp> for PingRespView: {code}"
2653 );
2654 assert!(
2655 code.contains(
2656 ":: connectrpc :: Encodable < PingResp > for :: buffa :: view :: OwnedView"
2657 ),
2658 "missing Encodable<PingResp> for OwnedView<PingRespView>: {code}"
2659 );
2660 assert!(!code.contains("Encodable < PingReq >"), "got: {code}");
2662 }
2663
2664 #[test]
2665 fn encodable_view_impls_skipped_for_extern_output() {
2666 let wkt = FileDescriptorProto {
2669 name: Some("google/protobuf/empty.proto".into()),
2670 package: Some("google.protobuf".into()),
2671 message_type: vec![DescriptorProto {
2672 name: Some("Empty".into()),
2673 ..Default::default()
2674 }],
2675 ..Default::default()
2676 };
2677 let file = minimal_file(
2678 Some("example.v1"),
2679 ".example.v1.PingReq",
2680 ".google.protobuf.Empty",
2681 &["PingReq"],
2682 );
2683 let code = gen_file(&[wkt, file], 1, &[], false).unwrap();
2684 assert!(
2687 !code.contains("encode_view_body"),
2688 "extern output type must not get Encodable impl: {code}"
2689 );
2690 }
2691
2692 #[test]
2693 fn encodable_view_impls_deduped_across_files() {
2694 let common = FileDescriptorProto {
2699 name: Some("common.proto".into()),
2700 package: Some("common.v1".into()),
2701 message_type: vec![DescriptorProto {
2702 name: Some("Reply".into()),
2703 ..Default::default()
2704 }],
2705 ..Default::default()
2706 };
2707 let svc = |name: &str, pkg: &str| FileDescriptorProto {
2708 name: Some(name.into()),
2709 package: Some(pkg.into()),
2710 message_type: vec![DescriptorProto {
2711 name: Some("Req".into()),
2712 ..Default::default()
2713 }],
2714 service: vec![ServiceDescriptorProto {
2715 name: Some("S".into()),
2716 method: vec![MethodDescriptorProto {
2717 name: Some("Call".into()),
2718 input_type: Some(format!(".{pkg}.Req")),
2719 output_type: Some(".common.v1.Reply".into()),
2720 ..Default::default()
2721 }],
2722 ..Default::default()
2723 }],
2724 ..Default::default()
2725 };
2726 let files = vec![common, svc("a.proto", "a.v1"), svc("b.proto", "b.v1")];
2727
2728 let generated = generate_files(
2729 &files,
2730 &["a.proto".into(), "b.proto".into()],
2731 &Options::default(),
2732 )
2733 .unwrap();
2734
2735 let companions: Vec<_> = generated
2738 .iter()
2739 .filter(|f| f.kind == GeneratedFileKind::Companion)
2740 .collect();
2741 let mut companion_names: Vec<&str> = companions.iter().map(|f| f.name.as_str()).collect();
2742 companion_names.sort_unstable();
2743 assert_eq!(companion_names, ["a.__connect.rs", "b.__connect.rs"]);
2744 for c in &companions {
2745 let stitcher = generated
2746 .iter()
2747 .find(|g| g.kind == GeneratedFileKind::PackageMod && g.package == c.package)
2748 .expect("each companion's package must have a stitcher");
2749 assert!(
2750 stitcher
2751 .content
2752 .contains(&format!("include!(\"{}\")", c.name)),
2753 "stitcher for {} must include companion {}",
2754 c.package,
2755 c.name
2756 );
2757 }
2758
2759 let combined: String = companions.iter().map(|f| f.content.as_str()).collect();
2760
2761 let view_impl = "impl ::connectrpc::Encodable<super::super::common::v1::Reply>\nfor super::super::common::v1::__buffa::view::ReplyView<'_>";
2762 let owned_view_impl = "impl ::connectrpc::Encodable<super::super::common::v1::Reply>\nfor ::buffa::view::OwnedView<";
2763 assert_eq!(
2764 combined.matches(view_impl).count(),
2765 1,
2766 "Encodable<Reply> for ReplyView<'_> must appear once: {combined}"
2767 );
2768 assert_eq!(
2769 combined.matches(owned_view_impl).count(),
2770 1,
2771 "Encodable<Reply> for OwnedView<ReplyView> must appear once: {combined}"
2772 );
2773 }
2774
2775 fn file_per_package_fixture() -> Vec<FileDescriptorProto> {
2780 let common = FileDescriptorProto {
2781 name: Some("common.proto".into()),
2782 package: Some("common.v1".into()),
2783 message_type: vec![DescriptorProto {
2784 name: Some("Reply".into()),
2785 ..Default::default()
2786 }],
2787 ..Default::default()
2788 };
2789 let svc = |proto_name: &str, pkg: &str, svc_name: &str, req: &str| FileDescriptorProto {
2794 name: Some(proto_name.into()),
2795 package: Some(pkg.into()),
2796 message_type: vec![DescriptorProto {
2797 name: Some(req.into()),
2798 ..Default::default()
2799 }],
2800 service: vec![ServiceDescriptorProto {
2801 name: Some(svc_name.into()),
2802 method: vec![MethodDescriptorProto {
2803 name: Some("Call".into()),
2804 input_type: Some(format!(".{pkg}.{req}")),
2805 output_type: Some(".common.v1.Reply".into()),
2806 ..Default::default()
2807 }],
2808 ..Default::default()
2809 }],
2810 ..Default::default()
2811 };
2812 vec![
2813 common,
2814 svc("a/x.proto", "a.v1", "XService", "XReq"),
2815 svc("a/y.proto", "a.v1", "YService", "YReq"),
2816 svc("b/z.proto", "b.v1", "ZService", "ZReq"),
2817 ]
2818 }
2819
2820 #[test]
2821 fn generate_files_file_per_package_inlines_companions() {
2822 let files = file_per_package_fixture();
2823 let mut options = Options::default();
2824 options.buffa.file_per_package = true;
2825
2826 let generated = generate_files(
2827 &files,
2828 &["a/x.proto".into(), "a/y.proto".into(), "b/z.proto".into()],
2829 &options,
2830 )
2831 .unwrap();
2832
2833 assert!(
2835 !generated
2836 .iter()
2837 .any(|f| f.kind == GeneratedFileKind::Companion),
2838 "file_per_package must not emit sibling Companion files"
2839 );
2840 assert!(
2841 !generated.iter().any(|f| f.name.ends_with(".__connect.rs")),
2842 "file_per_package must not emit `<stem>.__connect.rs` files"
2843 );
2844
2845 let a = generated
2847 .iter()
2848 .find(|f| f.kind == GeneratedFileKind::PackageMod && f.package == "a.v1")
2849 .expect("a.v1 PackageMod must exist");
2850 assert!(
2851 a.content.contains("pub trait XService"),
2852 "a.v1 missing XService"
2853 );
2854 assert!(
2855 a.content.contains("pub trait YService"),
2856 "a.v1 missing YService"
2857 );
2858 assert!(
2859 !a.content.contains("pub trait ZService"),
2860 "a.v1 must not inline ZService"
2861 );
2862 assert!(
2863 !a.content.contains("__connect.rs"),
2864 "a.v1 PackageMod must not include! a connect file: {}",
2865 a.content
2866 );
2867
2868 let b = generated
2869 .iter()
2870 .find(|f| f.kind == GeneratedFileKind::PackageMod && f.package == "b.v1")
2871 .expect("b.v1 PackageMod must exist");
2872 assert!(
2873 b.content.contains("pub trait ZService"),
2874 "b.v1 missing ZService"
2875 );
2876 assert!(
2877 !b.content.contains("pub trait XService"),
2878 "b.v1 must not inline XService"
2879 );
2880
2881 let pkg_mods = generated
2884 .iter()
2885 .filter(|f| f.kind == GeneratedFileKind::PackageMod)
2886 .count();
2887 assert_eq!(
2888 pkg_mods, 2,
2889 "expected exactly two PackageMods: {generated:#?}"
2890 );
2891
2892 let combined: String = generated.iter().map(|f| f.content.as_str()).collect();
2897 assert_eq!(
2898 combined
2899 .matches("impl ::connectrpc::Encodable<super::super::common::v1::Reply>")
2900 .count(),
2901 2,
2902 "Encodable<Reply> impls must be deduplicated across packages \
2903 (1 for ReplyView, 1 for OwnedView<ReplyView>): {combined}"
2904 );
2905 }
2906
2907 #[test]
2908 fn generate_services_file_per_package_emits_one_file_per_package() {
2909 let files = file_per_package_fixture();
2910 let mut options = Options::default();
2911 options.buffa.file_per_package = true;
2912 options
2913 .buffa
2914 .extern_paths
2915 .push((".".into(), "crate::proto".into()));
2916
2917 let generated = generate_services(
2918 &files,
2919 &["a/x.proto".into(), "a/y.proto".into(), "b/z.proto".into()],
2920 &options,
2921 )
2922 .unwrap();
2923
2924 assert_eq!(
2927 generated.len(),
2928 2,
2929 "expected exactly two output files: {generated:#?}"
2930 );
2931 assert!(
2932 generated
2933 .iter()
2934 .all(|f| f.kind == GeneratedFileKind::PackageMod),
2935 "all output files must be PackageMod"
2936 );
2937 assert!(
2938 !generated.iter().any(|f| f.name.ends_with(".mod.rs")),
2939 "file_per_package must not emit a separate stitcher"
2940 );
2941 assert!(
2942 !generated.iter().any(|f| f.content.contains("include!")),
2943 "file_per_package output must not include! sibling files"
2944 );
2945
2946 let mut names: Vec<&str> = generated.iter().map(|f| f.name.as_str()).collect();
2947 names.sort_unstable();
2948 assert_eq!(
2949 names,
2950 ["a.v1.rs", "b.v1.rs"],
2951 "filenames must be `<dotted.pkg>.rs` to match buffa's file_per_package convention"
2952 );
2953
2954 let a = generated.iter().find(|f| f.package == "a.v1").unwrap();
2955 assert!(a.content.contains("pub trait XService"));
2956 assert!(a.content.contains("pub trait YService"));
2957 let b = generated.iter().find(|f| f.package == "b.v1").unwrap();
2958 assert!(b.content.contains("pub trait ZService"));
2959 assert!(!b.content.contains("pub trait XService"));
2960 }
2961
2962 #[test]
2963 fn generate_services_file_per_package_default_layout_unchanged() {
2964 let files = file_per_package_fixture();
2967 let mut options = Options::default();
2968 options
2969 .buffa
2970 .extern_paths
2971 .push((".".into(), "crate::proto".into()));
2972
2973 let generated = generate_services(
2974 &files,
2975 &["a/x.proto".into(), "a/y.proto".into(), "b/z.proto".into()],
2976 &options,
2977 )
2978 .unwrap();
2979
2980 let mut companions: Vec<&str> = generated
2981 .iter()
2982 .filter(|f| f.kind == GeneratedFileKind::Companion)
2983 .map(|f| f.name.as_str())
2984 .collect();
2985 companions.sort_unstable();
2986 assert_eq!(
2987 companions,
2988 ["a.x.__connect.rs", "a.y.__connect.rs", "b.z.__connect.rs"],
2989 "default layout emits one companion per proto"
2990 );
2991 let mut stitchers: Vec<&str> = generated
2992 .iter()
2993 .filter(|f| f.kind == GeneratedFileKind::PackageMod)
2994 .map(|f| f.name.as_str())
2995 .collect();
2996 stitchers.sort_unstable();
2997 assert_eq!(
2998 stitchers,
2999 ["a.v1.mod.rs", "b.v1.mod.rs"],
3000 "default layout emits one stitcher per package"
3001 );
3002 let a_stitcher = generated.iter().find(|f| f.name == "a.v1.mod.rs").unwrap();
3004 assert!(
3005 a_stitcher
3006 .content
3007 .contains(r#"include!("a.x.__connect.rs");"#)
3008 );
3009 assert!(
3010 a_stitcher
3011 .content
3012 .contains(r#"include!("a.y.__connect.rs");"#)
3013 );
3014 }
3015
3016 #[test]
3017 fn service_name_with_package() {
3018 let file = minimal_file(
3019 Some("example.v1"),
3020 ".example.v1.PingReq",
3021 ".example.v1.PingResp",
3022 &["PingReq", "PingResp"],
3023 );
3024 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
3025 assert!(code.contains("\"example.v1.PingService\""), "got: {code}");
3026 }
3027
3028 #[test]
3029 fn service_name_without_package() {
3030 let file = minimal_file(None, ".PingReq", ".PingResp", &["PingReq", "PingResp"]);
3032 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
3033 assert!(code.contains("\"PingService\""), "got: {code}");
3034 assert!(
3035 !code.contains("\".PingService\""),
3036 "must not have leading dot: {code}"
3037 );
3038 }
3039
3040 #[test]
3041 fn same_package_types_use_bare_names() {
3042 let file = minimal_file(
3043 Some("example.v1"),
3044 ".example.v1.PingReq",
3045 ".example.v1.PingResp",
3046 &["PingReq", "PingResp"],
3047 );
3048 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
3049 assert!(code.contains("PingReq"), "input type missing: {code}");
3051 assert!(code.contains("PingResp"), "output type missing: {code}");
3052 assert!(
3054 !code.contains("super :: PingReq"),
3055 "unexpected super: {code}"
3056 );
3057 }
3058
3059 #[test]
3060 fn cross_package_types_use_relative_paths() {
3061 let common = FileDescriptorProto {
3065 name: Some("common.proto".into()),
3066 package: Some("common.v1".into()),
3067 message_type: vec![DescriptorProto {
3068 name: Some("Shared".into()),
3069 ..Default::default()
3070 }],
3071 ..Default::default()
3072 };
3073 let svc = minimal_file(
3074 Some("example.v1"),
3075 ".common.v1.Shared",
3076 ".example.v1.Out",
3077 &["Out"],
3078 );
3079 let code = gen_service(&[common, svc], 1, &[], false).unwrap();
3080
3081 assert!(
3084 code.contains("super :: super :: common :: v1 :: Shared"),
3085 "cross-package path not emitted: {code}"
3086 );
3087 assert!(
3088 code.contains("super :: super :: common :: v1 :: __buffa :: view :: SharedView"),
3089 "cross-package view path not emitted: {code}"
3090 );
3091 }
3092
3093 #[test]
3094 fn nested_message_view_type_mirrors_owned_module_nesting() {
3095 let file = FileDescriptorProto {
3100 name: Some("nested.proto".into()),
3101 package: Some("example.v1".into()),
3102 message_type: vec![
3103 DescriptorProto {
3104 name: Some("Outer".into()),
3105 nested_type: vec![DescriptorProto {
3106 name: Some("Inner".into()),
3107 ..Default::default()
3108 }],
3109 ..Default::default()
3110 },
3111 DescriptorProto {
3112 name: Some("Out".into()),
3113 ..Default::default()
3114 },
3115 ],
3116 service: vec![ServiceDescriptorProto {
3117 name: Some("NestedService".into()),
3118 method: vec![MethodDescriptorProto {
3119 name: Some("Ping".into()),
3120 input_type: Some(".example.v1.Outer.Inner".into()),
3121 output_type: Some(".example.v1.Out".into()),
3122 ..Default::default()
3123 }],
3124 ..Default::default()
3125 }],
3126 ..Default::default()
3127 };
3128 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
3129
3130 assert!(
3131 code.contains("__buffa :: view :: outer :: InnerView"),
3132 "nested view path not emitted: {code}"
3133 );
3134 assert!(
3135 code.contains("outer :: Inner"),
3136 "nested owned path not emitted: {code}"
3137 );
3138 }
3139
3140 #[test]
3141 fn wkt_types_use_buffa_types_extern_path() {
3142 let wkt = FileDescriptorProto {
3146 name: Some("google/protobuf/empty.proto".into()),
3147 package: Some("google.protobuf".into()),
3148 message_type: vec![DescriptorProto {
3149 name: Some("Empty".into()),
3150 ..Default::default()
3151 }],
3152 ..Default::default()
3153 };
3154 let svc = minimal_file(
3155 Some("example.v1"),
3156 ".google.protobuf.Empty",
3157 ".example.v1.Out",
3158 &["Out"],
3159 );
3160 let code = gen_service(&[wkt, svc], 1, &[], false).unwrap();
3161
3162 assert!(
3163 code.contains(":: buffa_types :: google :: protobuf :: Empty"),
3164 "WKT extern path not emitted: {code}"
3165 );
3166 }
3167
3168 #[test]
3169 fn extern_catchall_uses_absolute_paths() {
3170 let file = minimal_file(
3171 Some("example.v1"),
3172 ".example.v1.PingReq",
3173 ".example.v1.PingResp",
3174 &["PingReq", "PingResp"],
3175 );
3176 let extern_paths = [(".".into(), "crate::proto".into())];
3177 let code = gen_service(std::slice::from_ref(&file), 0, &extern_paths, true).unwrap();
3178 assert!(
3179 code.contains("crate :: proto :: example :: v1 :: PingReq"),
3180 "owned type path missing: {code}"
3181 );
3182 assert!(
3183 code.contains("crate :: proto :: example :: v1 :: __buffa :: view :: PingReqView"),
3184 "view type path missing: {code}"
3185 );
3186 }
3187
3188 #[test]
3189 fn extern_catchall_with_wkt_longest_wins() {
3190 let wkt = FileDescriptorProto {
3193 name: Some("google/protobuf/empty.proto".into()),
3194 package: Some("google.protobuf".into()),
3195 message_type: vec![DescriptorProto {
3196 name: Some("Empty".into()),
3197 ..Default::default()
3198 }],
3199 ..Default::default()
3200 };
3201 let svc = minimal_file(
3202 Some("example.v1"),
3203 ".google.protobuf.Empty",
3204 ".example.v1.Out",
3205 &["Out"],
3206 );
3207 let extern_paths = [(".".into(), "crate::proto".into())];
3208 let code = gen_service(&[wkt, svc], 1, &extern_paths, true).unwrap();
3209 assert!(
3210 code.contains(":: buffa_types :: google :: protobuf :: Empty"),
3211 "WKT mapping lost to catch-all: {code}"
3212 );
3213 assert!(
3214 code.contains("crate :: proto :: example :: v1 :: Out"),
3215 "local type not routed through catch-all: {code}"
3216 );
3217 }
3218
3219 #[test]
3220 fn missing_extern_path_errors() {
3221 let file = minimal_file(
3222 Some("example.v1"),
3223 ".example.v1.PingReq",
3224 ".example.v1.PingResp",
3225 &["PingReq", "PingResp"],
3226 );
3227 let err = gen_service(std::slice::from_ref(&file), 0, &[], true).unwrap_err();
3228 let msg = err.to_string();
3229 assert!(
3230 msg.contains("extern_path"),
3231 "error message lacks hint: {msg}"
3232 );
3233 }
3234
3235 #[test]
3236 fn keyword_package_escaped() {
3237 let file = minimal_file(
3239 Some("google.type"),
3240 ".google.type.LatLng",
3241 ".google.type.LatLng",
3242 &["LatLng"],
3243 );
3244 let extern_paths = [(".".into(), "crate::proto".into())];
3245 let code = gen_service(std::slice::from_ref(&file), 0, &extern_paths, true).unwrap();
3246 assert!(
3247 code.contains("crate :: proto :: google :: r#type :: LatLng"),
3248 "keyword segment not escaped: {code}"
3249 );
3250 }
3251
3252 #[test]
3253 fn keyword_method_escaped() {
3254 let file = minimal_file_with_method(
3257 Some("example.v1"),
3258 "Move",
3259 ".example.v1.Empty",
3260 ".example.v1.Empty",
3261 &["Empty"],
3262 );
3263 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
3264 assert!(
3265 code.contains("fn r#move"),
3266 "keyword method not escaped: {code}"
3267 );
3268 assert!(
3269 code.contains("move_with_options"),
3270 "suffixed variant should not need escaping: {code}"
3271 );
3272 assert!(code.contains("client.r#move(request)"));
3274 syn::parse_str::<syn::File>(&code).expect("generated code parses");
3275 }
3276
3277 #[test]
3278 fn path_keyword_method_suffixed() {
3279 let file = minimal_file_with_method(
3282 Some("example.v1"),
3283 "Self",
3284 ".example.v1.Empty",
3285 ".example.v1.Empty",
3286 &["Empty"],
3287 );
3288 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
3289 assert!(
3290 code.contains("fn self_"),
3291 "path-keyword method not suffixed: {code}"
3292 );
3293 assert!(code.contains("self_with_options"));
3297 syn::parse_str::<syn::File>(&code).expect("generated code parses");
3298 }
3299
3300 #[test]
3301 fn service_name_keyword_suffixed() {
3302 let mut file = minimal_file(
3306 Some("example.v1"),
3307 ".example.v1.Empty",
3308 ".example.v1.Empty",
3309 &["Empty"],
3310 );
3311 file.service[0].name = Some("Self".into());
3312 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
3313 assert!(code.contains("trait Self_ "), "trait not suffixed: {code}");
3314 assert!(code.contains("trait SelfExt"));
3315 assert!(code.contains("struct SelfClient"));
3316 assert!(code.contains("struct SelfServer"));
3317 syn::parse_str::<syn::File>(&code).expect("generated code parses");
3318 }
3319
3320 #[test]
3321 fn method_snake_collision_errors() {
3322 let file = minimal_file_with_methods("example.v1", &["GetFoo", "get_foo"]);
3325 let err = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap_err();
3326 let msg = err.to_string();
3327 assert!(msg.contains("PingService"), "missing service name: {msg}");
3328 assert!(msg.contains("\"GetFoo\""), "missing first method: {msg}");
3329 assert!(msg.contains("\"get_foo\""), "missing second method: {msg}");
3330 assert!(msg.contains("`get_foo`"), "missing rust ident: {msg}");
3331 }
3332
3333 #[test]
3334 fn method_with_options_collision_errors() {
3335 let file = minimal_file_with_methods("example.v1", &["Ping", "PingWithOptions"]);
3338 let err = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap_err();
3339 let msg = err.to_string();
3340 assert!(msg.contains("\"Ping\""), "missing first method: {msg}");
3341 assert!(
3342 msg.contains("\"PingWithOptions\""),
3343 "missing second method: {msg}"
3344 );
3345 assert!(
3346 msg.contains("`ping_with_options`"),
3347 "missing rust ident: {msg}"
3348 );
3349 }
3350
3351 #[test]
3352 fn distinct_methods_do_not_collide() {
3353 let file = minimal_file_with_methods("example.v1", &["GetFoo", "GetBar"]);
3354 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
3355 syn::parse_str::<syn::File>(&code).expect("generated code parses");
3356 }
3357
3358 #[test]
3359 fn options_default_buffa_config() {
3360 let cfg = Options::default().to_buffa_config();
3361 assert!(cfg.generate_json, "connectrpc enables JSON by default");
3362 assert!(cfg.generate_views);
3363 assert!(cfg.emit_register_fn);
3364 assert!(!cfg.strict_utf8_mapping);
3365 }
3366
3367 #[test]
3368 fn options_buffa_passthrough_forces_views() {
3369 let mut opts = Options::default();
3370 opts.buffa.emit_register_fn = false;
3371 opts.buffa.generate_views = false;
3372 let cfg = opts.to_buffa_config();
3373 assert!(!cfg.emit_register_fn);
3374 assert!(cfg.generate_views, "generate_views must be forced on");
3375 }
3376
3377 #[test]
3378 fn generate_files_emit_register_fn_false_suppresses_register_types() {
3379 let file = FileDescriptorProto {
3382 name: Some("ping.proto".into()),
3383 package: Some("example.v1".into()),
3384 message_type: vec![DescriptorProto {
3385 name: Some("PingReq".into()),
3386 ..Default::default()
3387 }],
3388 ..Default::default()
3389 };
3390
3391 let stitcher = |files: &[GeneratedFile]| {
3394 files
3395 .iter()
3396 .find(|f| f.kind == GeneratedFileKind::PackageMod)
3397 .expect("PackageMod file emitted")
3398 .content
3399 .clone()
3400 };
3401
3402 let with_fn = generate_files(
3403 std::slice::from_ref(&file),
3404 &["ping.proto".into()],
3405 &Options::default(),
3406 )
3407 .unwrap();
3408 let mod_rs = stitcher(&with_fn);
3409 assert!(
3410 mod_rs.contains("fn register_types"),
3411 "expected register_types in default output: {mod_rs}"
3412 );
3413
3414 let mut opts = Options::default();
3415 opts.buffa.emit_register_fn = false;
3416 let without_fn =
3417 generate_files(std::slice::from_ref(&file), &["ping.proto".into()], &opts).unwrap();
3418 let mod_rs = stitcher(&without_fn);
3419 assert!(
3420 !mod_rs.contains("fn register_types"),
3421 "register_types should be suppressed: {mod_rs}"
3422 );
3423 }
3424
3425 #[test]
3426 fn plugin_no_register_fn_parses() {
3427 let request = CodeGeneratorRequest {
3428 parameter: Some("buffa_module=crate::proto,no_register_fn".into()),
3429 file_to_generate: vec![],
3430 proto_file: vec![],
3431 ..Default::default()
3432 };
3433 generate(&request).expect("no_register_fn should be a recognized plugin option");
3436 }
3437
3438 fn format_minimal_service(gate_client_feature: bool) -> String {
3443 let file = minimal_file(
3444 Some("example.v1"),
3445 ".example.v1.PingReq",
3446 ".example.v1.PingResp",
3447 &["PingReq", "PingResp"],
3448 );
3449 let config = buffa_codegen::CodeGenConfig::default();
3450 let target = file.name.clone().into_iter().collect::<Vec<_>>();
3451 let resolver = TypeResolver::new(std::slice::from_ref(&file), &target, &config, false);
3452 let service = &file.service[0];
3453 let batch = BatchState {
3454 colliding_aliases: collect_alias_collisions(std::slice::from_ref(&file), &target),
3455 gate_client_feature,
3456 ..BatchState::default()
3457 };
3458 format_token_stream(&generate_service(&file, service, &resolver, &batch).unwrap()).unwrap()
3459 }
3460
3461 #[test]
3462 fn default_emission_has_no_client_cfg() {
3463 let out = format_minimal_service(false);
3467 assert!(
3468 !out.contains("#[cfg(feature ="),
3469 "default emission must not emit any cfg attr — external \
3470 consumers should not need to declare a `client` Cargo \
3471 feature unless they explicitly opt in via the \
3472 `gate_client_feature` plugin option:\n{out}"
3473 );
3474 }
3475
3476 #[test]
3477 fn client_items_gated_when_opt_in() {
3478 let out = format_minimal_service(true);
3483 let cfg_count = out.matches("#[cfg(feature = \"client\")]").count();
3484 assert_eq!(
3485 cfg_count, 2,
3486 "expected exactly two #[cfg(feature = \"client\")] attrs (one on \
3487 `pub struct PingServiceClient`, one on its `impl<T>` block); got \
3488 {cfg_count}:\n{out}"
3489 );
3490 }
3491
3492 #[test]
3493 fn server_items_never_carry_client_cfg() {
3494 let out = format_minimal_service(true);
3498 for marker in [
3499 "pub trait PingService",
3500 "pub trait PingServiceExt",
3501 "pub struct PingServiceServer",
3502 "pub const PING_SERVICE_SERVICE_NAME",
3503 ] {
3504 let idx = out
3505 .find(marker)
3506 .unwrap_or_else(|| panic!("expected `{marker}` in output:\n{out}"));
3507 let prefix = &out[..idx];
3508 assert!(
3509 !prefix.trim_end().ends_with("#[cfg(feature = \"client\")]"),
3510 "`{marker}` must not be preceded by a client cfg attr — \
3511 server-side items are always compiled in:\n{out}"
3512 );
3513 }
3514 }
3515
3516 #[test]
3530 fn no_ungated_client_references() {
3531 let out = format_minimal_service(true);
3535 let parsed: syn::File = syn::parse_str(&out).expect("output parses");
3536
3537 let mut offenders: Vec<String> = Vec::new();
3538 scan_items_for_ungated_client_refs(&parsed.items, false, &mut offenders);
3539 assert!(
3540 offenders.is_empty(),
3541 "every item that mentions `::connectrpc::client::*` must be \
3542 prefixed with `#[cfg(feature = \"client\")]`. Offenders:\n{}\n\nFull output:\n{out}",
3543 offenders.join("\n")
3544 );
3545 }
3546
3547 fn is_client_feature_cfg(attr: &syn::Attribute) -> bool {
3551 attr.path().is_ident("cfg")
3552 && attr
3553 .to_token_stream()
3554 .to_string()
3555 .contains("feature = \"client\"")
3556 }
3557
3558 fn mentions_connectrpc_client(ts: TokenStream) -> bool {
3563 let rendered = format_token_stream(&ts).unwrap_or_default();
3564 rendered.contains("::connectrpc::client::") || rendered.contains("connectrpc :: client ::")
3565 }
3566
3567 fn scan_items_for_ungated_client_refs(
3584 items: &[syn::Item],
3585 ancestor_gated: bool,
3586 offenders: &mut Vec<String>,
3587 ) {
3588 for item in items {
3589 let (attrs, ident): (&[syn::Attribute], String) = match item {
3595 syn::Item::Struct(s) => (&s.attrs, s.ident.to_string()),
3596 syn::Item::Impl(i) => (
3597 &i.attrs,
3598 format!("impl-block for {}", ToTokens::to_token_stream(&i.self_ty)),
3599 ),
3600 syn::Item::Fn(f) => (&f.attrs, f.sig.ident.to_string()),
3601 syn::Item::Trait(t) => (&t.attrs, t.ident.to_string()),
3602 syn::Item::Const(c) => (&c.attrs, c.ident.to_string()),
3603 syn::Item::Type(t) => (&t.attrs, t.ident.to_string()),
3604 syn::Item::Static(s) => (&s.attrs, s.ident.to_string()),
3605 syn::Item::Use(u) => (&u.attrs, "use-item".to_string()),
3606 syn::Item::ExternCrate(e) => (&e.attrs, e.ident.to_string()),
3607 syn::Item::Macro(m) => (
3608 &m.attrs,
3609 m.ident
3610 .as_ref()
3611 .map(syn::Ident::to_string)
3612 .unwrap_or_else(|| "macro-item".to_string()),
3613 ),
3614 syn::Item::ForeignMod(f) => (&f.attrs, "extern-block".to_string()),
3615 syn::Item::Union(u) => (&u.attrs, u.ident.to_string()),
3616 syn::Item::TraitAlias(t) => (&t.attrs, t.ident.to_string()),
3617 syn::Item::Enum(e) => (&e.attrs, e.ident.to_string()),
3618 syn::Item::Mod(m) => {
3619 let self_gated = m.attrs.iter().any(is_client_feature_cfg);
3620 let gated = ancestor_gated || self_gated;
3621 if let Some((_brace, children)) = &m.content {
3622 scan_items_for_ungated_client_refs(children, gated, offenders);
3623 }
3624 continue;
3627 }
3628 _ => (&[][..], "<unrecognized item>".to_string()),
3632 };
3633 let self_gated = attrs.iter().any(is_client_feature_cfg);
3634 let gated = ancestor_gated || self_gated;
3635 if gated {
3636 continue;
3637 }
3638 if mentions_connectrpc_client(ToTokens::to_token_stream(item)) {
3639 offenders.push(format!(
3640 "ungated reference to ::connectrpc::client in `{ident}`"
3641 ));
3642 }
3643 }
3644 }
3645
3646 #[test]
3651 fn ungated_scanner_handles_nested_modules() {
3652 let parsed: syn::File = syn::parse_str(
3654 r#"
3655 #[cfg(feature = "client")]
3656 pub mod gated_parent {
3657 pub struct WithClientRef {
3658 field: ::connectrpc::client::ClientConfig,
3659 }
3660 }
3661 "#,
3662 )
3663 .unwrap();
3664 let mut offenders = Vec::new();
3665 scan_items_for_ungated_client_refs(&parsed.items, false, &mut offenders);
3666 assert!(
3667 offenders.is_empty(),
3668 "parent-level cfg must cover children: {offenders:?}"
3669 );
3670
3671 let parsed: syn::File = syn::parse_str(
3674 r#"
3675 pub mod ungated_parent {
3676 pub struct WithClientRef {
3677 field: ::connectrpc::client::ClientConfig,
3678 }
3679 }
3680 "#,
3681 )
3682 .unwrap();
3683 let mut offenders = Vec::new();
3684 scan_items_for_ungated_client_refs(&parsed.items, false, &mut offenders);
3685 assert_eq!(
3686 offenders.len(),
3687 1,
3688 "exactly one offender expected (the inner struct), not the wrapping \
3689 module: {offenders:?}"
3690 );
3691 assert!(
3692 offenders[0].contains("WithClientRef"),
3693 "offender should name the inner struct: {:?}",
3694 offenders[0]
3695 );
3696
3697 let parsed: syn::File = syn::parse_str(
3699 r#"
3700 pub mod outer {
3701 #[cfg(feature = "client")]
3702 pub struct GatedClient {
3703 field: ::connectrpc::client::ClientConfig,
3704 }
3705 }
3706 "#,
3707 )
3708 .unwrap();
3709 let mut offenders = Vec::new();
3710 scan_items_for_ungated_client_refs(&parsed.items, false, &mut offenders);
3711 assert!(
3712 offenders.is_empty(),
3713 "self-gating child inside ungated module must be OK: {offenders:?}"
3714 );
3715 }
3716
3717 #[test]
3724 fn ungated_scanner_catches_use_and_static_items() {
3725 let parsed: syn::File = syn::parse_str("use ::connectrpc::client::ClientConfig;").unwrap();
3727 let mut offenders = Vec::new();
3728 scan_items_for_ungated_client_refs(&parsed.items, false, &mut offenders);
3729 assert_eq!(
3730 offenders.len(),
3731 1,
3732 "ungated `use ::connectrpc::client::*` must be flagged: {offenders:?}"
3733 );
3734
3735 let parsed: syn::File =
3737 syn::parse_str("#[cfg(feature = \"client\")] use ::connectrpc::client::ClientConfig;")
3738 .unwrap();
3739 let mut offenders = Vec::new();
3740 scan_items_for_ungated_client_refs(&parsed.items, false, &mut offenders);
3741 assert!(
3742 offenders.is_empty(),
3743 "gated `use ::connectrpc::client::*` must NOT be flagged: {offenders:?}"
3744 );
3745
3746 let parsed: syn::File =
3748 syn::parse_str("static FOO: &str = stringify!(::connectrpc::client::ClientConfig);")
3749 .unwrap();
3750 let mut offenders = Vec::new();
3751 scan_items_for_ungated_client_refs(&parsed.items, false, &mut offenders);
3752 assert_eq!(
3753 offenders.len(),
3754 1,
3755 "ungated `static FOO` mentioning ::connectrpc::client must be flagged: \
3756 {offenders:?}"
3757 );
3758 }
3759
3760 #[test]
3761 fn client_cfg_round_trips_through_prettyplease() {
3762 let out = format_minimal_service(true);
3768 assert!(
3771 out.contains("#[cfg(feature = \"client\")]"),
3772 "prettyplease no longer renders the cfg attr as expected; \
3773 update the grep pattern in client_items_always_gated:\n{out}"
3774 );
3775 }
3776
3777 #[test]
3778 fn multi_service_in_one_file_each_client_is_gated() {
3779 let make_service = |name: &str| ServiceDescriptorProto {
3783 name: Some(name.into()),
3784 method: vec![MethodDescriptorProto {
3785 name: Some("Ping".into()),
3786 input_type: Some(".example.v1.PingReq".into()),
3787 output_type: Some(".example.v1.PingResp".into()),
3788 ..Default::default()
3789 }],
3790 ..Default::default()
3791 };
3792 let file = FileDescriptorProto {
3793 name: Some("two.proto".into()),
3794 package: Some("example.v1".into()),
3795 service: vec![make_service("Alpha"), make_service("Beta")],
3796 message_type: vec![
3797 DescriptorProto {
3798 name: Some("PingReq".into()),
3799 ..Default::default()
3800 },
3801 DescriptorProto {
3802 name: Some("PingResp".into()),
3803 ..Default::default()
3804 },
3805 ],
3806 ..Default::default()
3807 };
3808 let config = buffa_codegen::CodeGenConfig::default();
3809 let target = vec!["two.proto".to_string()];
3810 let resolver = TypeResolver::new(std::slice::from_ref(&file), &target, &config, false);
3811 let mut batch = BatchState {
3812 colliding_aliases: collect_alias_collisions(std::slice::from_ref(&file), &target),
3813 gate_client_feature: true,
3814 ..BatchState::default()
3815 };
3816 let ts = generate_connect_services(&file, &resolver, &mut batch).unwrap();
3817 let out = format_token_stream(&ts).unwrap();
3818 let cfg_count = out.matches("#[cfg(feature = \"client\")]").count();
3819 assert_eq!(
3820 cfg_count, 4,
3821 "expected 4 client cfg attrs (2 per service * 2 services); got \
3822 {cfg_count}:\n{out}"
3823 );
3824 for client_struct in ["pub struct AlphaClient", "pub struct BetaClient"] {
3826 let idx = out
3827 .find(client_struct)
3828 .unwrap_or_else(|| panic!("expected `{client_struct}` in output:\n{out}"));
3829 let prefix = &out[..idx];
3830 assert!(
3831 prefix.trim_end().ends_with("#[derive(Clone)]")
3832 || prefix.contains("#[cfg(feature = \"client\")]"),
3833 "`{client_struct}` must have a client cfg attr in its \
3834 attribute cluster:\n{out}"
3835 );
3836 }
3837 }
3838
3839 #[test]
3840 fn plugin_accepts_gate_client_feature_flag() {
3841 let request = CodeGeneratorRequest {
3843 parameter: Some("buffa_module=crate::proto,gate_client_feature".into()),
3844 file_to_generate: vec![],
3845 proto_file: vec![],
3846 ..Default::default()
3847 };
3848 generate(&request).expect("gate_client_feature should be a recognized plugin option");
3849 }
3850
3851 #[test]
3852 fn plugin_rejects_old_client_feature_value_form() {
3853 let request = CodeGeneratorRequest {
3859 parameter: Some("buffa_module=crate::proto,client_feature=client".into()),
3860 file_to_generate: vec![],
3861 proto_file: vec![],
3862 ..Default::default()
3863 };
3864 let err = generate(&request)
3865 .expect_err("legacy `client_feature=…` option must now fail as unknown");
3866 let msg = err.to_string();
3867 assert!(
3868 msg.contains("client_feature"),
3869 "error should name the offending option: {msg}"
3870 );
3871 assert!(
3872 msg.contains("unknown plugin option"),
3873 "error should say the option is unknown: {msg}"
3874 );
3875 }
3876
3877 #[test]
3878 fn plugin_file_per_package_collapses_output() {
3879 let request = CodeGeneratorRequest {
3882 parameter: Some("buffa_module=crate::proto,file_per_package".into()),
3883 file_to_generate: vec!["a/x.proto".into(), "a/y.proto".into(), "b/z.proto".into()],
3884 proto_file: file_per_package_fixture(),
3885 ..Default::default()
3886 };
3887 let response = generate(&request).expect("file_per_package should parse and generate");
3888 let mut names: Vec<&str> = response
3889 .file
3890 .iter()
3891 .filter_map(|f| f.name.as_deref())
3892 .collect();
3893 names.sort_unstable();
3894 assert_eq!(
3895 names,
3896 ["a.v1.rs", "b.v1.rs"],
3897 "expected one file per package: {names:?}"
3898 );
3899 for f in &response.file {
3900 let content = f.content.as_deref().unwrap_or_default();
3901 assert!(
3902 !content.contains("include!"),
3903 "file_per_package output must be self-contained: {content}"
3904 );
3905 }
3906 }
3907
3908 #[test]
3909 fn no_top_level_use_statements_in_generated_code() {
3910 let file = minimal_file(
3914 Some("example.v1"),
3915 ".example.v1.PingReq",
3916 ".example.v1.PingResp",
3917 &["PingReq", "PingResp"],
3918 );
3919 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
3920 let formatted = format_token_stream(&code.parse::<TokenStream>().unwrap()).unwrap();
3921 assert_no_top_level_use(&formatted, "generated code");
3922 }
3923
3924 #[test]
3925 fn multi_service_include_no_e0252() {
3926 let file_a = {
3929 let method = MethodDescriptorProto {
3930 name: Some("Ping".into()),
3931 input_type: Some(".svc.v1.PingReq".into()),
3932 output_type: Some(".svc.v1.PingResp".into()),
3933 ..Default::default()
3934 };
3935 let service = ServiceDescriptorProto {
3936 name: Some("Alpha".into()),
3937 method: vec![method],
3938 ..Default::default()
3939 };
3940 FileDescriptorProto {
3941 name: Some("alpha.proto".into()),
3942 package: Some("svc.v1".into()),
3943 service: vec![service],
3944 message_type: vec![
3945 DescriptorProto {
3946 name: Some("PingReq".into()),
3947 ..Default::default()
3948 },
3949 DescriptorProto {
3950 name: Some("PingResp".into()),
3951 ..Default::default()
3952 },
3953 ],
3954 ..Default::default()
3955 }
3956 };
3957 let file_b = {
3958 let method = MethodDescriptorProto {
3959 name: Some("Pong".into()),
3960 input_type: Some(".svc.v1.PongReq".into()),
3961 output_type: Some(".svc.v1.PongResp".into()),
3962 ..Default::default()
3963 };
3964 let service = ServiceDescriptorProto {
3965 name: Some("Beta".into()),
3966 method: vec![method],
3967 ..Default::default()
3968 };
3969 FileDescriptorProto {
3970 name: Some("beta.proto".into()),
3971 package: Some("svc.v1".into()),
3972 service: vec![service],
3973 message_type: vec![
3974 DescriptorProto {
3975 name: Some("PongReq".into()),
3976 ..Default::default()
3977 },
3978 DescriptorProto {
3979 name: Some("PongResp".into()),
3980 ..Default::default()
3981 },
3982 ],
3983 ..Default::default()
3984 }
3985 };
3986
3987 let files = vec![file_a, file_b];
3988 let config = buffa_codegen::CodeGenConfig::default();
3989 let targets = vec!["alpha.proto".to_string(), "beta.proto".to_string()];
3990 let resolver = TypeResolver::new(&files, &targets, &config, false);
3991
3992 let mut batch = BatchState {
3993 colliding_aliases: collect_alias_collisions(&files, &targets),
3994 ..BatchState::default()
3995 };
3996 let code_a = generate_connect_services(&files[0], &resolver, &mut batch).unwrap();
3997 let code_b = generate_connect_services(&files[1], &resolver, &mut batch).unwrap();
3998
3999 let formatted_a = format_token_stream(&code_a).unwrap();
4000 let formatted_b = format_token_stream(&code_b).unwrap();
4001
4002 syn::parse_str::<syn::File>(&formatted_a).expect("service A should parse independently");
4004 syn::parse_str::<syn::File>(&formatted_b).expect("service B should parse independently");
4005
4006 let combined = format!("{formatted_a}\n{formatted_b}");
4008 syn::parse_str::<syn::File>(&combined)
4009 .expect("combined services should parse without E0252");
4010
4011 assert_no_top_level_use(&formatted_a, "service A");
4013 assert_no_top_level_use(&formatted_b, "service B");
4014 }
4015
4016 #[test]
4020 fn generate_spec_consts_per_method() {
4021 use buffa_codegen::generated::descriptor::MethodOptions;
4022
4023 let m = |name: &str, cs: bool, ss: bool, idem: Option<IdempotencyLevel>| {
4024 MethodDescriptorProto {
4025 name: Some(name.into()),
4026 input_type: Some(".pkg.Req".into()),
4027 output_type: Some(".pkg.Resp".into()),
4028 client_streaming: Some(cs),
4029 server_streaming: Some(ss),
4030 options: MethodOptions {
4031 idempotency_level: idem,
4032 ..Default::default()
4033 }
4034 .into(),
4035 ..Default::default()
4036 }
4037 };
4038 let service = ServiceDescriptorProto {
4039 name: Some("EchoService".into()),
4040 method: vec![
4041 m("Say", false, false, Some(IdempotencyLevel::NO_SIDE_EFFECTS)),
4042 m("Subscribe", false, true, Some(IdempotencyLevel::IDEMPOTENT)),
4043 m("Upload", true, false, None),
4044 m("Chat", true, true, None),
4045 ],
4046 ..Default::default()
4047 };
4048
4049 assert_eq!(
4051 method_spec_const_ident(&service, "Say").to_string(),
4052 "ECHO_SERVICE_SAY_SPEC"
4053 );
4054
4055 let consts = generate_spec_consts("pkg.EchoService", &service);
4056 assert_eq!(consts.len(), 4, "one const per method");
4057
4058 let render = |ts: &TokenStream| {
4059 let file = syn::parse2::<syn::File>(ts.clone()).expect("const should parse");
4060 prettyplease::unparse(&file)
4061 };
4062 let say = render(&consts[0]);
4063 assert!(say.contains("pub const ECHO_SERVICE_SAY_SPEC"), "{say}");
4064 assert!(say.contains(r#""/pkg.EchoService/Say""#), "{say}");
4065 assert!(say.contains("StreamType::Unary"), "{say}");
4066 assert!(say.contains("IdempotencyLevel::NoSideEffects"), "{say}");
4067
4068 let subscribe = render(&consts[1]);
4069 assert!(
4070 subscribe.contains("StreamType::ServerStream"),
4071 "{subscribe}"
4072 );
4073 assert!(
4074 subscribe.contains("IdempotencyLevel::Idempotent"),
4075 "{subscribe}"
4076 );
4077
4078 let upload = render(&consts[2]);
4079 assert!(upload.contains("StreamType::ClientStream"), "{upload}");
4080 assert!(upload.contains("IdempotencyLevel::Unknown"), "{upload}");
4081
4082 let chat = render(&consts[3]);
4083 assert!(chat.contains("StreamType::BidiStream"), "{chat}");
4084 }
4085}