1use std::collections::HashMap;
12
13use anyhow::Result;
14use heck::ToSnakeCase;
15use heck::ToUpperCamelCase;
16use proc_macro2::{Ident, TokenStream};
17use quote::format_ident;
18use quote::quote;
19
20use buffa_codegen::generated::descriptor::FileDescriptorProto;
21use buffa_codegen::generated::descriptor::MethodDescriptorProto;
22use buffa_codegen::generated::descriptor::ServiceDescriptorProto;
23use buffa_codegen::generated::descriptor::SourceCodeInfo;
24use buffa_codegen::generated::descriptor::method_options::IdempotencyLevel;
25use buffa_codegen::idents::make_field_ident;
26use buffa_codegen::idents::rust_path_to_tokens;
27
28pub use buffa_codegen::generated::descriptor;
29pub use buffa_codegen::{CodeGenConfig, GeneratedFile, GeneratedFileKind};
30
31use crate::plugin::CodeGeneratorRequest;
32use crate::plugin::CodeGeneratorResponse;
33use crate::plugin::CodeGeneratorResponseFile;
34
35#[derive(Debug, Clone)]
44#[non_exhaustive]
45pub struct Options {
46 pub buffa: CodeGenConfig,
60}
61
62impl Default for Options {
63 fn default() -> Self {
64 let mut buffa = CodeGenConfig::default();
65 buffa.generate_json = true;
66 Self { buffa }
67 }
68}
69
70impl Options {
71 fn to_buffa_config(&self) -> CodeGenConfig {
74 let mut config = self.buffa.clone();
75 config.generate_views = true;
76 config
77 }
78}
79
80fn emit_service_files(
83 proto_file: &[FileDescriptorProto],
84 file_to_generate: &[String],
85 resolver: &TypeResolver<'_>,
86) -> Result<Vec<GeneratedFile>> {
87 let mut out = Vec::new();
88 let mut batch = BatchState {
96 colliding_aliases: collect_alias_collisions(proto_file, file_to_generate),
97 ..BatchState::default()
98 };
99 for file_name in file_to_generate {
100 let file_desc = proto_file
101 .iter()
102 .find(|f| f.name.as_deref() == Some(file_name.as_str()));
103
104 if let Some(file) = file_desc
105 && !file.service.is_empty()
106 {
107 let service_tokens = generate_connect_services(file, resolver, &mut batch)?;
108 let service_code = format_token_stream(&service_tokens)?;
109 out.push(GeneratedFile {
117 name: format!(
118 "{}.__connect.rs",
119 buffa_codegen::proto_path_to_stem(file_name)
120 ),
121 package: file.package.clone().unwrap_or_default(),
122 kind: GeneratedFileKind::Companion,
123 content: service_code,
124 });
125 }
126 }
127 Ok(out)
128}
129
130pub fn generate_files(
156 proto_file: &[FileDescriptorProto],
157 file_to_generate: &[String],
158 options: &Options,
159) -> Result<Vec<GeneratedFile>> {
160 let config = options.to_buffa_config();
161
162 let mut files = buffa_codegen::generate(proto_file, file_to_generate, &config)
163 .map_err(|e| anyhow::anyhow!("buffa-codegen failed: {e}"))?;
164
165 let resolver = TypeResolver::new(proto_file, file_to_generate, &config, false);
166 let service_files = emit_service_files(proto_file, file_to_generate, &resolver)?;
167
168 if config.file_per_package {
169 inline_companions_into_package_mods(&mut files, service_files);
177 } else {
178 buffa_codegen::apply_companions(&mut files, service_files);
185
186 debug_assert!(
193 files.iter().all(|f| {
194 f.kind != GeneratedFileKind::Companion
195 || files.iter().any(|g| {
196 g.kind == GeneratedFileKind::PackageMod
197 && g.content.contains(&format!("include!(\"{}\")", f.name))
198 })
199 }),
200 "a companion service file was not wired into any package stitcher"
201 );
202 }
203
204 Ok(files)
205}
206
207fn inline_companions_into_package_mods(
226 files: &mut [GeneratedFile],
229 companions: Vec<GeneratedFile>,
230) {
231 debug_assert!(
235 companions.iter().all(|c| files
236 .iter()
237 .any(|f| f.kind == GeneratedFileKind::PackageMod && f.package == c.package)),
238 "a companion service file's package has no PackageMod to inline into"
239 );
240 for comp in companions {
241 if let Some(pkg_mod) = files
242 .iter_mut()
243 .find(|f| f.kind == GeneratedFileKind::PackageMod && f.package == comp.package)
244 {
245 pkg_mod.content.push('\n');
246 pkg_mod.content.push_str(&comp.content);
247 }
248 }
249}
250
251pub fn generate_services(
286 proto_file: &[FileDescriptorProto],
287 file_to_generate: &[String],
288 options: &Options,
289) -> Result<Vec<GeneratedFile>> {
290 use std::collections::BTreeMap;
291
292 let config = options.to_buffa_config();
293 let resolver = TypeResolver::new(proto_file, file_to_generate, &config, true);
294 let mut files = emit_service_files(proto_file, file_to_generate, &resolver)?;
295
296 if config.file_per_package {
297 let mut by_package: BTreeMap<String, String> = BTreeMap::new();
302 for f in files {
303 let entry = by_package.entry(f.package).or_insert_with(|| {
304 String::from("// @generated by connectrpc-codegen. DO NOT EDIT.\n")
305 });
306 entry.push('\n');
307 entry.push_str(&f.content);
308 }
309 return Ok(by_package
310 .into_iter()
311 .map(|(package, content)| GeneratedFile {
312 name: buffa_codegen::package_to_filename(&package),
313 package,
314 kind: GeneratedFileKind::PackageMod,
315 content,
316 })
317 .collect());
318 }
319
320 let mut by_package: BTreeMap<String, Vec<String>> = BTreeMap::new();
326 for f in &files {
327 by_package
328 .entry(f.package.clone())
329 .or_default()
330 .push(f.name.clone());
331 }
332 for (package, names) in by_package {
333 let mut content = String::from("// @generated by connectrpc-codegen. DO NOT EDIT.\n");
334 for n in &names {
335 content.push_str(&format!("include!({n:?});\n"));
337 }
338 files.push(GeneratedFile {
339 name: buffa_codegen::package_to_mod_filename(&package),
340 package,
341 kind: GeneratedFileKind::PackageMod,
342 content,
343 });
344 }
345
346 Ok(files)
347}
348
349pub fn generate(request: &CodeGeneratorRequest) -> Result<CodeGeneratorResponse> {
427 let mut options = Options::default();
428
429 if let Some(ref param) = request.parameter {
430 for opt in param.split(',').map(str::trim).filter(|s| !s.is_empty()) {
431 if let Some(value) = opt.strip_prefix("buffa_module=") {
432 let rust = value.trim();
433 if rust.is_empty() {
434 anyhow::bail!(
435 "buffa_module requires a non-empty path, \
436 e.g. buffa_module=crate::proto"
437 );
438 }
439 options
440 .buffa
441 .extern_paths
442 .push((".".into(), rust.to_string()));
443 } else if let Some(value) = opt.strip_prefix("extern_path=") {
444 let (proto, rust) = value.split_once('=').ok_or_else(|| {
446 anyhow::anyhow!(
447 "invalid extern_path format {value:?}, expected \
448 extern_path=.proto.pkg=::rust::path"
449 )
450 })?;
451 let proto = proto.trim();
452 let rust = rust.trim();
453 if proto.is_empty() || rust.is_empty() {
454 anyhow::bail!(
455 "invalid extern_path format {value:?}, expected \
456 extern_path=.proto.pkg=::rust::path (both sides non-empty)"
457 );
458 }
459 let mut proto = proto.to_string();
460 if !proto.starts_with('.') {
461 proto.insert(0, '.');
462 }
463 options.buffa.extern_paths.push((proto, rust.to_string()));
464 } else {
465 match opt {
466 "file_per_package" => options.buffa.file_per_package = true,
467 "strict_utf8_mapping" => options.buffa.strict_utf8_mapping = true,
468 "no_json" => options.buffa.generate_json = false,
469 "no_register_fn" => options.buffa.emit_register_fn = false,
470 _ => {
471 return Err(anyhow::anyhow!(
472 "unknown plugin option: {opt:?}. Supported: \
473 buffa_module=<rust_path>, extern_path=<proto>=<rust>, \
474 file_per_package, strict_utf8_mapping, no_json, \
475 no_register_fn"
476 ));
477 }
478 }
479 }
480 }
481 }
482
483 let generated = generate_services(&request.proto_file, &request.file_to_generate, &options)?;
484
485 let files: Vec<CodeGeneratorResponseFile> = generated
486 .into_iter()
487 .map(|g| CodeGeneratorResponseFile {
488 name: Some(g.name),
489 content: Some(g.content),
490 ..Default::default()
491 })
492 .collect();
493
494 Ok(CodeGeneratorResponse {
495 supported_features: Some(feature_flags()),
496 minimum_edition: Some(EDITION_2023),
497 maximum_edition: Some(EDITION_2023),
498 file: files,
499 ..Default::default()
500 })
501}
502
503fn feature_flags() -> u64 {
506 const FEATURE_PROTO3_OPTIONAL: u64 = 1;
507 const FEATURE_SUPPORTS_EDITIONS: u64 = 2;
508 FEATURE_PROTO3_OPTIONAL | FEATURE_SUPPORTS_EDITIONS
509}
510
511const EDITION_2023: i32 = 1000;
514
515fn format_token_stream(tokens: &TokenStream) -> Result<String> {
517 let file = syn::parse2::<syn::File>(tokens.clone())
518 .map_err(|e| anyhow::anyhow!("generated code failed to parse: {e}"))?;
519 Ok(prettyplease::unparse(&file))
520}
521
522fn doc_attrs(text: &str) -> TokenStream {
531 let lines: Vec<String> = text
532 .lines()
533 .map(|l| {
534 if l.is_empty() {
535 String::new()
536 } else {
537 format!(" {l}")
538 }
539 })
540 .collect();
541 quote! { #(#[doc = #lines])* }
542}
543
544struct TypeResolver<'a> {
557 ctx: buffa_codegen::context::CodeGenContext<'a>,
558 require_extern: bool,
564}
565
566impl<'a> TypeResolver<'a> {
567 fn new(
568 proto_file: &'a [FileDescriptorProto],
569 file_to_generate: &[String],
570 config: &'a buffa_codegen::CodeGenConfig,
571 require_extern: bool,
572 ) -> Self {
573 Self {
574 ctx: buffa_codegen::context::CodeGenContext::for_generate(
575 proto_file,
576 file_to_generate,
577 config,
578 ),
579 require_extern,
580 }
581 }
582
583 fn resolve_path(&self, proto_fqn: &str, current_package: &str) -> Result<String> {
590 match self.ctx.rust_type_relative(proto_fqn, current_package, 0) {
591 Some(path) => {
592 self.check_extern_coverage(proto_fqn, &path)?;
593 Ok(path)
594 }
595 None => self.fallback_unresolved(proto_fqn).map(str::to_string),
596 }
597 }
598
599 fn check_extern_coverage(&self, proto_fqn: &str, path_prefix: &str) -> Result<()> {
603 if self.require_extern
604 && !path_prefix.starts_with("::")
605 && !path_prefix.starts_with("crate::")
606 {
607 anyhow::bail!(
608 "type {proto_fqn} is not covered by any extern_path mapping. \
609 Add extern_path=.=<your_buffa_module> (e.g. \
610 extern_path=.=crate::proto) to the plugin opts."
611 );
612 }
613 Ok(())
614 }
615
616 fn fallback_unresolved<'f>(&self, proto_fqn: &'f str) -> Result<&'f str> {
620 if self.require_extern {
621 anyhow::bail!("type {proto_fqn} not found in descriptor set (missing proto import?)");
622 }
623 Ok(bare_type_name(proto_fqn))
624 }
625
626 fn rust_type(&self, proto_fqn: &str, current_package: &str) -> Result<TokenStream> {
628 let path = self.resolve_path(proto_fqn, current_package)?;
629 Ok(rust_path_to_tokens(&path))
630 }
631
632 fn rust_view_type(&self, proto_fqn: &str, current_package: &str) -> Result<TokenStream> {
639 use buffa_codegen::context::SENTINEL_MOD;
640 let (to_package, within) =
641 match self
642 .ctx
643 .rust_type_relative_split(proto_fqn, current_package, 0)
644 {
645 Some(s) => {
646 self.check_extern_coverage(proto_fqn, &s.to_package)?;
647 (s.to_package, s.within_package)
648 }
649 None => (
650 String::new(),
651 self.fallback_unresolved(proto_fqn)?.to_string(),
652 ),
653 };
654 let prefix = if to_package.is_empty() {
655 format!("{SENTINEL_MOD}::view")
656 } else {
657 format!("{to_package}::{SENTINEL_MOD}::view")
658 };
659 Ok(rust_path_to_tokens(&format!("{prefix}::{within}View")))
660 }
661}
662
663fn bare_type_name(proto_fqn: &str) -> &str {
666 proto_fqn
667 .strip_prefix('.')
668 .unwrap_or(proto_fqn)
669 .rsplit('.')
670 .next()
671 .unwrap_or(proto_fqn)
672}
673
674#[derive(Default)]
681struct BatchState {
682 encodable_seen: std::collections::BTreeSet<String>,
685 alias_seen: std::collections::BTreeSet<(String, String)>,
689 colliding_aliases: std::collections::BTreeSet<(String, String)>,
702}
703
704fn generate_connect_services(
705 file: &FileDescriptorProto,
706 resolver: &TypeResolver<'_>,
707 batch: &mut BatchState,
708) -> Result<TokenStream> {
709 let mut tokens = TokenStream::new();
710
711 tokens.extend(generate_owned_view_aliases(file, resolver, batch)?);
717 tokens.extend(generate_encodable_view_impls(file, resolver, batch)?);
718
719 for service in &file.service {
720 tokens.extend(generate_service(file, service, resolver, batch)?);
721 }
722
723 Ok(tokens)
724}
725
726fn owned_view_alias_ident(fqn: &str) -> Ident {
729 format_ident!("Owned{}View", bare_type_name(fqn).to_upper_camel_case())
730}
731
732fn alias_collides(batch: &BatchState, current_package: &str, proto_fqn: &str) -> bool {
740 let alias = owned_view_alias_ident(proto_fqn).to_string();
741 batch
742 .colliding_aliases
743 .contains(&(current_package.to_string(), alias))
744}
745
746fn owned_view_input_arg_type(
753 resolver: &TypeResolver<'_>,
754 batch: &BatchState,
755 proto_fqn: &str,
756 current_package: &str,
757) -> Result<TokenStream> {
758 if alias_collides(batch, current_package, proto_fqn) {
759 let view = resolver.rust_view_type(proto_fqn, current_package)?;
760 Ok(quote!(::buffa::view::OwnedView<#view<'static>>))
761 } else {
762 let alias = owned_view_alias_ident(proto_fqn);
763 Ok(quote!(#alias))
764 }
765}
766
767fn collect_alias_collisions(
779 proto_file: &[FileDescriptorProto],
780 file_to_generate: &[String],
781) -> std::collections::BTreeSet<(String, String)> {
782 use std::collections::BTreeMap;
783 let mut first_seen: BTreeMap<(String, String), String> = BTreeMap::new();
786 let mut colliding: std::collections::BTreeSet<(String, String)> =
787 std::collections::BTreeSet::new();
788
789 for file_name in file_to_generate {
790 let Some(file) = proto_file
791 .iter()
792 .find(|f| f.name.as_deref() == Some(file_name.as_str()))
793 else {
794 continue;
795 };
796 let package = file.package.clone().unwrap_or_default();
797 for service in &file.service {
798 for m in &service.method {
799 for fqn in [m.input_type.as_deref(), m.output_type.as_deref()]
800 .into_iter()
801 .flatten()
802 {
803 let alias = owned_view_alias_ident(fqn).to_string();
804 let key = (package.clone(), alias);
805 match first_seen.get(&key) {
806 Some(prev) if prev != fqn => {
807 colliding.insert(key);
808 }
809 Some(_) => {} None => {
811 first_seen.insert(key, fqn.to_string());
812 }
813 }
814 }
815 }
816 }
817 }
818 colliding
819}
820
821fn generate_owned_view_aliases(
839 file: &FileDescriptorProto,
840 resolver: &TypeResolver<'_>,
841 batch: &mut BatchState,
842) -> Result<TokenStream> {
843 let package = file.package.as_deref().unwrap_or("");
844 let mut out = TokenStream::new();
845 for service in &file.service {
846 for m in &service.method {
847 for fqn in [m.input_type.as_deref(), m.output_type.as_deref()]
848 .into_iter()
849 .flatten()
850 {
851 if alias_collides(batch, package, fqn) {
852 continue;
853 }
854 if !batch
855 .alias_seen
856 .insert((package.to_string(), fqn.to_string()))
857 {
858 continue;
859 }
860 let alias = owned_view_alias_ident(fqn);
861 let view = resolver.rust_view_type(fqn, package)?;
862 let doc = format!(
863 "Shorthand for `OwnedView<{}View<'static>>`.",
864 bare_type_name(fqn).to_upper_camel_case()
865 );
866 out.extend(quote! {
867 #[doc = #doc]
868 pub type #alias = ::buffa::view::OwnedView<#view<'static>>;
869 });
870 }
871 }
872 }
873 Ok(out)
874}
875
876fn generate_encodable_view_impls(
892 file: &FileDescriptorProto,
893 resolver: &TypeResolver<'_>,
894 batch: &mut BatchState,
895) -> Result<TokenStream> {
896 let package = file.package.as_deref().unwrap_or("");
897 let mut out = TokenStream::new();
898 for service in &file.service {
899 for m in &service.method {
900 let fqn = m.output_type.as_deref().unwrap_or("");
901 if !batch.encodable_seen.insert(fqn.to_string()) {
902 continue;
903 }
904 let path = resolver.resolve_path(fqn, package)?;
905 if path.starts_with("::") {
908 continue;
909 }
910 let owned = resolver.rust_type(fqn, package)?;
911 let view = resolver.rust_view_type(fqn, package)?;
912 out.extend(quote! {
913 impl ::connectrpc::Encodable<#owned> for #view<'_> {
914 fn encode(&self, codec: ::connectrpc::CodecFormat)
915 -> ::std::result::Result<::buffa::bytes::Bytes, ::connectrpc::ConnectError>
916 {
917 ::connectrpc::__codegen::encode_view_body(self, codec)
918 }
919 }
920 impl ::connectrpc::Encodable<#owned> for ::buffa::view::OwnedView<#view<'static>> {
921 fn encode(&self, codec: ::connectrpc::CodecFormat)
922 -> ::std::result::Result<::buffa::bytes::Bytes, ::connectrpc::ConnectError>
923 {
924 ::connectrpc::__codegen::encode_view_body(&**self, codec)
925 }
926 }
927 });
928 }
929 }
930 Ok(out)
931}
932
933fn check_method_collisions(service_name: &str, service: &ServiceDescriptorProto) -> Result<()> {
942 let mut seen: HashMap<String, String> = HashMap::new();
943 for m in &service.method {
944 let proto_name = m.name.as_deref().unwrap_or("");
945 let snake = proto_name.to_snake_case();
946 let with_opts = format!("{snake}_with_options");
947 for ident in [snake.as_str(), with_opts.as_str()] {
948 if let Some(prev) = seen.get(ident) {
949 anyhow::bail!(
950 "service {service_name}: RPC methods {prev:?} and {proto_name:?} \
951 both generate Rust identifier `{ident}`; rename one in the proto"
952 );
953 }
954 }
955 seen.insert(snake, proto_name.to_string());
956 seen.insert(with_opts, proto_name.to_string());
957 }
958 Ok(())
959}
960
961fn generate_service(
962 file: &FileDescriptorProto,
963 service: &ServiceDescriptorProto,
964 resolver: &TypeResolver<'_>,
965 batch: &BatchState,
966) -> Result<TokenStream> {
967 let package = file.package.as_deref().unwrap_or("");
968 let service_name = service.name.as_deref().unwrap_or("");
969 check_method_collisions(service_name, service)?;
970 let full_service_name = if package.is_empty() {
973 service_name.to_string()
974 } else {
975 format!("{package}.{service_name}")
976 };
977 let service_upper = service_name.to_upper_camel_case();
978 let trait_name = if service_upper == "Self" {
982 format_ident!("Self_")
983 } else {
984 format_ident!("{}", service_upper)
985 };
986 let ext_trait_name = format_ident!("{}Ext", service_upper);
987 let client_name = format_ident!("{}Client", service_upper);
988 let server_name = format_ident!("{}Server", service_upper);
989 let service_name_const = format_ident!(
990 "{}_SERVICE_NAME",
991 service_name.to_snake_case().to_uppercase()
992 );
993
994 let service_doc = get_service_comment(file, service).unwrap_or_default();
996 let base_doc = if service_doc.is_empty() {
997 format!("Server trait for {service_name}.")
998 } else {
999 service_doc
1000 };
1001 let full_doc = format!(
1002 "{base_doc}\n\n\
1003 # Implementing handlers\n\n\
1004 Handlers receive requests as `OwnedFooView` (an alias for\n\
1005 `OwnedView<FooView<'static>>`), which gives zero-copy borrowed access\n\
1006 to fields (e.g. `request.name` is a `&str` into the decoded buffer).\n\
1007 The view can be held across `.await` points. When two RPC types in\n\
1008 the same package would alias to the same `Owned<…>View` name (e.g.\n\
1009 a local message plus an imported one with the same short name), the\n\
1010 alias is suppressed for both and the request type is spelled as\n\
1011 `OwnedView<…View<'static>>` directly in the trait signature.\n\n\
1012 Implement methods with plain `async fn`; the returned future satisfies\n\
1013 the `Send` bound automatically. See the\n\
1014 [buffa user guide](https://github.com/anthropics/buffa/blob/main/docs/guide.md#ownedview-in-async-trait-implementations)\n\
1015 for zero-copy access patterns and when `to_owned_message()` is needed.\n\n\
1016 The `impl Encodable<Out>` return bound accepts the owned `Out`, the\n\
1017 generated `OutView<'_>` / `OwnedOutView`,\n\
1018 [`MaybeBorrowed`](::connectrpc::MaybeBorrowed), or\n\
1019 [`PreEncoded`](::connectrpc::PreEncoded) for handlers that encode a\n\
1020 non-`'static` view internally and pass the bytes across the handler\n\
1021 boundary. View bodies are not emitted for output types mapped via\n\
1022 `extern_path` (the impl would be an orphan); return owned for\n\
1023 WKT/extern outputs.\n\n\
1024 Server-streaming and bidi-streaming methods return\n\
1025 `ServiceStream<impl Encodable<Out> + Send + use<Self>>`. The\n\
1026 `use<Self>` precise-capturing clause excludes `&self`'s lifetime\n\
1027 (unary methods use `use<'a, Self>` and may borrow), so stream items\n\
1028 must be `'static`. To stream view-encoded data, encode each item\n\
1029 inside the stream body and yield\n\
1030 [`PreEncoded`](::connectrpc::PreEncoded) — see its `# Streaming\n\
1031 example` doc."
1032 );
1033 let service_doc_tokens = doc_attrs(&full_doc);
1034
1035 let trait_methods: Vec<TokenStream> = service
1037 .method
1038 .iter()
1039 .map(|m| generate_trait_method(file, service, m, resolver, batch, package))
1040 .collect::<Result<Vec<_>>>()?;
1041
1042 let route_registrations: Vec<TokenStream> = service
1044 .method
1045 .iter()
1046 .map(|m| {
1047 let method_name = m.name.as_deref().unwrap_or("");
1048 let method_snake = make_field_ident(&method_name.to_snake_case());
1049
1050 let client_streaming = m.client_streaming.unwrap_or(false);
1051 let server_streaming = m.server_streaming.unwrap_or(false);
1052
1053 if server_streaming && !client_streaming {
1054 let output_type = resolver
1059 .rust_type(m.output_type.as_deref().unwrap_or(""), package)
1060 .unwrap();
1061 quote! {
1062 .route_view_server_stream::<_, _, #output_type>(
1063 #service_name_const,
1064 #method_name,
1065 ::connectrpc::view_streaming_handler_fn({
1066 let svc = ::std::sync::Arc::clone(&self);
1067 move |ctx, req| {
1068 let svc = ::std::sync::Arc::clone(&svc);
1069 async move { svc.#method_snake(ctx, req).await }
1070 }
1071 }),
1072 )
1073 }
1074 } else if client_streaming && !server_streaming {
1075 let output_type = resolver
1077 .rust_type(m.output_type.as_deref().unwrap_or(""), package)
1078 .unwrap();
1079 quote! {
1080 .route_view_client_stream(
1081 #service_name_const,
1082 #method_name,
1083 ::connectrpc::view_client_streaming_handler_fn({
1084 let svc = ::std::sync::Arc::clone(&self);
1085 move |ctx, req, format| {
1086 let svc = ::std::sync::Arc::clone(&svc);
1087 async move {
1088 svc.#method_snake(ctx, req).await?.encode::<#output_type>(format)
1089 }
1090 }
1091 }),
1092 )
1093 }
1094 } else if client_streaming && server_streaming {
1095 let output_type = resolver
1098 .rust_type(m.output_type.as_deref().unwrap_or(""), package)
1099 .unwrap();
1100 quote! {
1101 .route_view_bidi_stream::<_, _, #output_type>(
1102 #service_name_const,
1103 #method_name,
1104 ::connectrpc::view_bidi_streaming_handler_fn({
1105 let svc = ::std::sync::Arc::clone(&self);
1106 move |ctx, req| {
1107 let svc = ::std::sync::Arc::clone(&svc);
1108 async move { svc.#method_snake(ctx, req).await }
1109 }
1110 }),
1111 )
1112 }
1113 } else {
1114 let is_idempotent = m
1116 .options
1117 .idempotency_level
1118 .map(|level| level == IdempotencyLevel::NO_SIDE_EFFECTS)
1119 .unwrap_or(false);
1120
1121 let route_method = if is_idempotent {
1122 quote! { route_view_idempotent }
1123 } else {
1124 quote! { route_view }
1125 };
1126 let output_type = resolver
1127 .rust_type(m.output_type.as_deref().unwrap_or(""), package)
1128 .unwrap();
1129
1130 quote! {
1131 .#route_method(
1132 #service_name_const,
1133 #method_name,
1134 {
1135 let svc = ::std::sync::Arc::clone(&self);
1136 ::connectrpc::view_handler_fn(move |ctx, req, format| {
1137 let svc = ::std::sync::Arc::clone(&svc);
1138 async move {
1139 svc.#method_snake(ctx, req).await?.encode::<#output_type>(format)
1140 }
1141 })
1142 },
1143 )
1144 }
1145 }
1146 })
1147 .collect();
1148
1149 let client_methods: Vec<TokenStream> = service
1151 .method
1152 .iter()
1153 .map(|m| {
1154 generate_client_method(
1155 &service_name_const,
1156 &full_service_name,
1157 m,
1158 resolver,
1159 package,
1160 )
1161 })
1162 .collect::<Result<Vec<_>>>()?;
1163
1164 let service_server = generate_service_server(
1166 &full_service_name,
1167 &trait_name,
1168 &server_name,
1169 service,
1170 resolver,
1171 package,
1172 )?;
1173
1174 let example_method = service
1176 .method
1177 .first()
1178 .and_then(|m| m.name.as_deref())
1179 .map(|n| make_field_ident(&n.to_snake_case()).to_string())
1180 .unwrap_or_else(|| "method".to_string());
1181
1182 let client_name_str = client_name.to_string();
1184 let client_doc = format!(
1185 r#"Client for this service.
1186
1187Generic over `T: ClientTransport`. For **gRPC** (HTTP/2), use
1188`Http2Connection` — it has honest `poll_ready` and composes with
1189`tower::balance` for multi-connection load balancing. For **Connect
1190over HTTP/1.1** (or unknown protocol), use `HttpClient`.
1191
1192# Example (gRPC / HTTP/2)
1193
1194```rust,ignore
1195use connectrpc::client::{{Http2Connection, ClientConfig}};
1196use connectrpc::Protocol;
1197
1198let uri: http::Uri = "http://localhost:8080".parse()?;
1199let conn = Http2Connection::connect_plaintext(uri.clone()).await?.shared(1024);
1200let config = ClientConfig::new(uri).with_protocol(Protocol::Grpc);
1201
1202let client = {client_name_str}::new(conn, config);
1203let response = client.{example_method}(request).await?;
1204```
1205
1206# Example (Connect / HTTP/1.1 or ALPN)
1207
1208```rust,ignore
1209use connectrpc::client::{{HttpClient, ClientConfig}};
1210
1211let http = HttpClient::plaintext(); // cleartext http:// only
1212let config = ClientConfig::new("http://localhost:8080".parse()?);
1213
1214let client = {client_name_str}::new(http, config);
1215let response = client.{example_method}(request).await?;
1216```
1217
1218# Working with the response
1219
1220Unary calls return [`UnaryResponse<OwnedView<FooView>>`](::connectrpc::client::UnaryResponse).
1221The `OwnedView` derefs to the view, so field access is zero-copy:
1222
1223```rust,ignore
1224let resp = client.{example_method}(request).await?.into_view();
1225let name: &str = resp.name; // borrow into the response buffer
1226```
1227
1228If you need the owned struct (e.g. to store or pass by value), use
1229[`into_owned()`](::connectrpc::client::UnaryResponse::into_owned):
1230
1231```rust,ignore
1232let owned = client.{example_method}(request).await?.into_owned();
1233```"#
1234 );
1235 let client_doc_tokens = doc_attrs(&client_doc);
1236
1237 Ok(quote! {
1238 pub const #service_name_const: &str = #full_service_name;
1244
1245 #service_doc_tokens
1246 #[allow(clippy::type_complexity)]
1247 pub trait #trait_name: Send + Sync + 'static {
1248 #(#trait_methods)*
1249 }
1250
1251 pub trait #ext_trait_name: #trait_name {
1264 fn register(self: ::std::sync::Arc<Self>, router: ::connectrpc::Router) -> ::connectrpc::Router;
1269 }
1270
1271 impl<S: #trait_name> #ext_trait_name for S {
1272 fn register(self: ::std::sync::Arc<Self>, router: ::connectrpc::Router) -> ::connectrpc::Router {
1273 router
1274 #(#route_registrations)*
1275 }
1276 }
1277
1278 #service_server
1279
1280 #client_doc_tokens
1281 #[derive(Clone)]
1282 pub struct #client_name<T> {
1283 transport: T,
1284 config: ::connectrpc::client::ClientConfig,
1285 }
1286
1287 impl<T> #client_name<T>
1288 where
1289 T: ::connectrpc::client::ClientTransport,
1290 <T::ResponseBody as ::http_body::Body>::Error: ::std::fmt::Display,
1291 {
1292 pub fn new(transport: T, config: ::connectrpc::client::ClientConfig) -> Self {
1294 Self { transport, config }
1295 }
1296
1297 pub fn config(&self) -> &::connectrpc::client::ClientConfig {
1299 &self.config
1300 }
1301
1302 pub fn config_mut(&mut self) -> &mut ::connectrpc::client::ClientConfig {
1304 &mut self.config
1305 }
1306
1307 #(#client_methods)*
1308 }
1309 })
1310}
1311
1312fn generate_service_server(
1319 full_service_name: &str,
1320 trait_name: &proc_macro2::Ident,
1321 server_name: &proc_macro2::Ident,
1322 service: &ServiceDescriptorProto,
1323 resolver: &TypeResolver<'_>,
1324 package: &str,
1325) -> Result<TokenStream> {
1326 let path_prefix = format!("{full_service_name}/");
1328
1329 let lookup_arms: Vec<TokenStream> = service
1331 .method
1332 .iter()
1333 .map(|m| {
1334 let method_name = m.name.as_deref().unwrap_or("");
1335 let client_streaming = m.client_streaming.unwrap_or(false);
1336 let server_streaming = m.server_streaming.unwrap_or(false);
1337 let is_idempotent = m
1338 .options
1339 .idempotency_level
1340 .map(|level| level == IdempotencyLevel::NO_SIDE_EFFECTS)
1341 .unwrap_or(false);
1342
1343 let desc = if client_streaming && server_streaming {
1344 quote! { ::connectrpc::dispatcher::codegen::MethodDescriptor::bidi_streaming() }
1345 } else if client_streaming {
1346 quote! { ::connectrpc::dispatcher::codegen::MethodDescriptor::client_streaming() }
1347 } else if server_streaming {
1348 quote! { ::connectrpc::dispatcher::codegen::MethodDescriptor::server_streaming() }
1349 } else {
1350 quote! { ::connectrpc::dispatcher::codegen::MethodDescriptor::unary(#is_idempotent) }
1351 };
1352 quote! { #method_name => Some(#desc), }
1353 })
1354 .collect();
1355
1356 let mut call_unary_arms: Vec<TokenStream> = Vec::new();
1361 let mut call_ss_arms: Vec<TokenStream> = Vec::new();
1362 let mut call_cs_arms: Vec<TokenStream> = Vec::new();
1363 let mut call_bidi_arms: Vec<TokenStream> = Vec::new();
1364
1365 for m in &service.method {
1366 let method_name = m.name.as_deref().unwrap_or("");
1367 let method_snake = make_field_ident(&method_name.to_snake_case());
1368 let input_view = resolver.rust_view_type(m.input_type.as_deref().unwrap_or(""), package)?;
1369 let output_type = resolver.rust_type(m.output_type.as_deref().unwrap_or(""), package)?;
1370 let cs = m.client_streaming.unwrap_or(false);
1371 let ss = m.server_streaming.unwrap_or(false);
1372
1373 if cs && ss {
1374 call_bidi_arms.push(quote! {
1376 #method_name => {
1377 let svc = ::std::sync::Arc::clone(&self.inner);
1378 Box::pin(async move {
1379 let req_stream = ::connectrpc::dispatcher::codegen::decode_view_request_stream::<#input_view>(requests, format);
1380 let resp = svc.#method_snake(ctx, req_stream).await?;
1381 Ok(resp.map_body(|s| ::connectrpc::dispatcher::codegen::encode_response_stream::<#output_type, _, _>(s, format)))
1382 })
1383 }
1384 });
1385 } else if cs {
1386 call_cs_arms.push(quote! {
1388 #method_name => {
1389 let svc = ::std::sync::Arc::clone(&self.inner);
1390 Box::pin(async move {
1391 let req_stream = ::connectrpc::dispatcher::codegen::decode_view_request_stream::<#input_view>(requests, format);
1392 svc.#method_snake(ctx, req_stream).await?.encode::<#output_type>(format)
1393 })
1394 }
1395 });
1396 } else if ss {
1397 call_ss_arms.push(quote! {
1399 #method_name => {
1400 let svc = ::std::sync::Arc::clone(&self.inner);
1401 Box::pin(async move {
1402 let req = ::connectrpc::dispatcher::codegen::decode_request_view::<#input_view>(request, format)?;
1403 let resp = svc.#method_snake(ctx, req).await?;
1404 Ok(resp.map_body(|s| ::connectrpc::dispatcher::codegen::encode_response_stream::<#output_type, _, _>(s, format)))
1405 })
1406 }
1407 });
1408 } else {
1409 call_unary_arms.push(quote! {
1411 #method_name => {
1412 let svc = ::std::sync::Arc::clone(&self.inner);
1413 Box::pin(async move {
1414 let req = ::connectrpc::dispatcher::codegen::decode_request_view::<#input_view>(request, format)?;
1415 svc.#method_snake(ctx, req).await?.encode::<#output_type>(format)
1416 })
1417 }
1418 });
1419 }
1420 }
1421
1422 let server_doc = format!(
1423 "Monomorphic dispatcher for `{trait_name}`.\n\n\
1424 Unlike `.register(Router)` which type-erases each method into an \
1425 `Arc<dyn ErasedHandler>` stored in a `HashMap`, this struct dispatches \
1426 via a compile-time `match` on method name: no vtable, no hash lookup.\n\n\
1427 # Example\n\n\
1428 ```rust,ignore\n\
1429 use connectrpc::ConnectRpcService;\n\n\
1430 let server = {server_name}::new(MyImpl);\n\
1431 let service = ConnectRpcService::new(server);\n\
1432 // hand `service` to axum/hyper as a fallback_service\n\
1433 ```"
1434 );
1435 let server_doc_tokens = doc_attrs(&server_doc);
1436
1437 Ok(quote! {
1438 #server_doc_tokens
1439 pub struct #server_name<T> {
1440 inner: ::std::sync::Arc<T>,
1441 }
1442
1443 impl<T: #trait_name> #server_name<T> {
1444 pub fn new(service: T) -> Self {
1446 Self { inner: ::std::sync::Arc::new(service) }
1447 }
1448
1449 pub fn from_arc(inner: ::std::sync::Arc<T>) -> Self {
1451 Self { inner }
1452 }
1453 }
1454
1455 impl<T> Clone for #server_name<T> {
1456 fn clone(&self) -> Self {
1457 Self { inner: ::std::sync::Arc::clone(&self.inner) }
1458 }
1459 }
1460
1461 impl<T: #trait_name> ::connectrpc::Dispatcher for #server_name<T> {
1462 #[inline]
1463 fn lookup(&self, path: &str) -> Option<::connectrpc::dispatcher::codegen::MethodDescriptor> {
1464 let method = path.strip_prefix(#path_prefix)?;
1465 match method {
1466 #(#lookup_arms)*
1467 _ => None,
1468 }
1469 }
1470
1471 fn call_unary(
1472 &self,
1473 path: &str,
1474 ctx: ::connectrpc::RequestContext,
1475 request: ::buffa::bytes::Bytes,
1476 format: ::connectrpc::CodecFormat,
1477 ) -> ::connectrpc::dispatcher::codegen::UnaryResult {
1478 let Some(method) = path.strip_prefix(#path_prefix) else {
1479 return ::connectrpc::dispatcher::codegen::unimplemented_unary(path);
1480 };
1481 let _ = (&ctx, &request, &format);
1483 match method {
1484 #(#call_unary_arms)*
1485 _ => ::connectrpc::dispatcher::codegen::unimplemented_unary(path),
1486 }
1487 }
1488
1489 fn call_server_streaming(
1490 &self,
1491 path: &str,
1492 ctx: ::connectrpc::RequestContext,
1493 request: ::buffa::bytes::Bytes,
1494 format: ::connectrpc::CodecFormat,
1495 ) -> ::connectrpc::dispatcher::codegen::StreamingResult {
1496 let Some(method) = path.strip_prefix(#path_prefix) else {
1497 return ::connectrpc::dispatcher::codegen::unimplemented_streaming(path);
1498 };
1499 let _ = (&ctx, &request, &format);
1500 match method {
1501 #(#call_ss_arms)*
1502 _ => ::connectrpc::dispatcher::codegen::unimplemented_streaming(path),
1503 }
1504 }
1505
1506 fn call_client_streaming(
1507 &self,
1508 path: &str,
1509 ctx: ::connectrpc::RequestContext,
1510 requests: ::connectrpc::dispatcher::codegen::RequestStream,
1511 format: ::connectrpc::CodecFormat,
1512 ) -> ::connectrpc::dispatcher::codegen::UnaryResult {
1513 let Some(method) = path.strip_prefix(#path_prefix) else {
1514 return ::connectrpc::dispatcher::codegen::unimplemented_unary(path);
1515 };
1516 let _ = (&ctx, &requests, &format);
1517 match method {
1518 #(#call_cs_arms)*
1519 _ => ::connectrpc::dispatcher::codegen::unimplemented_unary(path),
1520 }
1521 }
1522
1523 fn call_bidi_streaming(
1524 &self,
1525 path: &str,
1526 ctx: ::connectrpc::RequestContext,
1527 requests: ::connectrpc::dispatcher::codegen::RequestStream,
1528 format: ::connectrpc::CodecFormat,
1529 ) -> ::connectrpc::dispatcher::codegen::StreamingResult {
1530 let Some(method) = path.strip_prefix(#path_prefix) else {
1531 return ::connectrpc::dispatcher::codegen::unimplemented_streaming(path);
1532 };
1533 let _ = (&ctx, &requests, &format);
1534 match method {
1535 #(#call_bidi_arms)*
1536 _ => ::connectrpc::dispatcher::codegen::unimplemented_streaming(path),
1537 }
1538 }
1539 }
1540 })
1541}
1542
1543fn generate_doc_comment(doc: &str, default: &str) -> TokenStream {
1545 let comment = if doc.is_empty() { default } else { doc };
1546 doc_attrs(comment)
1547}
1548
1549fn generate_trait_method(
1551 file: &FileDescriptorProto,
1552 service: &ServiceDescriptorProto,
1553 method: &MethodDescriptorProto,
1554 resolver: &TypeResolver<'_>,
1555 batch: &BatchState,
1556 package: &str,
1557) -> Result<TokenStream> {
1558 let method_name = method.name.as_deref().unwrap_or("");
1559 let method_snake = make_field_ident(&method_name.to_snake_case());
1560 let input_arg = owned_view_input_arg_type(
1561 resolver,
1562 batch,
1563 method.input_type.as_deref().unwrap_or(""),
1564 package,
1565 )?;
1566 let output_type = resolver.rust_type(method.output_type.as_deref().unwrap_or(""), package)?;
1567
1568 let method_doc = get_method_comment(file, service, method).unwrap_or_default();
1570 let method_doc_tokens =
1571 generate_doc_comment(&method_doc, &format!("Handle the {method_name} RPC."));
1572
1573 let client_streaming = method.client_streaming.unwrap_or(false);
1575 let server_streaming = method.server_streaming.unwrap_or(false);
1576
1577 let borrow_doc = quote! {
1578 #[doc = ""]
1579 #[doc = " `'a` lets the response body borrow from `&self` (e.g. server-resident state)."]
1580 };
1581
1582 if server_streaming && !client_streaming {
1583 Ok(quote! {
1591 #method_doc_tokens
1592 fn #method_snake(
1593 &self,
1594 ctx: ::connectrpc::RequestContext,
1595 request: #input_arg,
1596 ) -> impl ::std::future::Future<Output = ::connectrpc::ServiceResult<::connectrpc::ServiceStream<impl ::connectrpc::Encodable<#output_type> + Send + use<Self>>>> + Send;
1597 })
1598 } else if client_streaming && !server_streaming {
1599 Ok(quote! {
1601 #method_doc_tokens
1602 #borrow_doc
1603 fn #method_snake<'a>(
1604 &'a self,
1605 ctx: ::connectrpc::RequestContext,
1606 requests: ::connectrpc::ServiceStream<#input_arg>,
1607 ) -> impl ::std::future::Future<Output = ::connectrpc::ServiceResult<impl ::connectrpc::Encodable<#output_type> + Send + use<'a, Self>>> + Send;
1608 })
1609 } else if client_streaming && server_streaming {
1610 Ok(quote! {
1613 #method_doc_tokens
1614 fn #method_snake(
1615 &self,
1616 ctx: ::connectrpc::RequestContext,
1617 requests: ::connectrpc::ServiceStream<#input_arg>,
1618 ) -> impl ::std::future::Future<Output = ::connectrpc::ServiceResult<::connectrpc::ServiceStream<impl ::connectrpc::Encodable<#output_type> + Send + use<Self>>>> + Send;
1619 })
1620 } else {
1621 Ok(quote! {
1623 #method_doc_tokens
1624 #borrow_doc
1625 fn #method_snake<'a>(
1626 &'a self,
1627 ctx: ::connectrpc::RequestContext,
1628 request: #input_arg,
1629 ) -> impl ::std::future::Future<Output = ::connectrpc::ServiceResult<impl ::connectrpc::Encodable<#output_type> + Send + use<'a, Self>>> + Send;
1630 })
1631 }
1632}
1633
1634fn generate_client_method(
1645 service_name_const: &Ident,
1646 full_service_name: &str,
1647 method: &MethodDescriptorProto,
1648 resolver: &TypeResolver<'_>,
1649 package: &str,
1650) -> Result<TokenStream> {
1651 let method_name = method.name.as_deref().unwrap_or("");
1652 let method_snake = make_field_ident(&method_name.to_snake_case());
1653 let method_with_opts = format_ident!("{}_with_options", method_name.to_snake_case());
1654 let input_type = resolver.rust_type(method.input_type.as_deref().unwrap_or(""), package)?;
1655 let output_view_type =
1656 resolver.rust_view_type(method.output_type.as_deref().unwrap_or(""), package)?;
1657
1658 let client_streaming = method.client_streaming.unwrap_or(false);
1659 let server_streaming = method.server_streaming.unwrap_or(false);
1660
1661 let doc = format!(
1662 " Call the {method_name} RPC. Sends a request to /{full_service_name}/{method_name}."
1663 );
1664 let doc_opts = format!(
1665 " Call the {method_name} RPC with explicit per-call options. \
1666 Options override [`ClientConfig`](::connectrpc::client::ClientConfig) defaults."
1667 );
1668
1669 let ret_ty: TokenStream;
1671 let call_body: TokenStream;
1672 let short_args: TokenStream; let opts_args: TokenStream; let short_delegate_args: TokenStream; if client_streaming && !server_streaming {
1677 ret_ty = quote! {
1679 Result<
1680 ::connectrpc::client::UnaryResponse<::buffa::view::OwnedView<#output_view_type<'static>>>,
1681 ::connectrpc::ConnectError,
1682 >
1683 };
1684 call_body = quote! {
1685 ::connectrpc::client::call_client_stream(
1686 &self.transport, &self.config,
1687 #service_name_const, #method_name,
1688 requests, options,
1689 ).await
1690 };
1691 short_args = quote! { requests: impl IntoIterator<Item = #input_type> };
1692 opts_args = quote! { requests: impl IntoIterator<Item = #input_type>, options: ::connectrpc::client::CallOptions };
1693 short_delegate_args = quote! { requests, ::connectrpc::client::CallOptions::default() };
1694 } else if client_streaming && server_streaming {
1695 ret_ty = quote! {
1697 Result<
1698 ::connectrpc::client::BidiStream<
1699 T::ResponseBody, #input_type, #output_view_type<'static>
1700 >,
1701 ::connectrpc::ConnectError,
1702 >
1703 };
1704 call_body = quote! {
1705 ::connectrpc::client::call_bidi_stream(
1706 &self.transport, &self.config,
1707 #service_name_const, #method_name, options,
1708 ).await
1709 };
1710 short_args = quote! {};
1711 opts_args = quote! { options: ::connectrpc::client::CallOptions };
1712 short_delegate_args = quote! { ::connectrpc::client::CallOptions::default() };
1713 } else if server_streaming {
1714 ret_ty = quote! {
1716 Result<
1717 ::connectrpc::client::ServerStream<T::ResponseBody, #output_view_type<'static>>,
1718 ::connectrpc::ConnectError,
1719 >
1720 };
1721 call_body = quote! {
1722 ::connectrpc::client::call_server_stream(
1723 &self.transport, &self.config,
1724 #service_name_const, #method_name,
1725 request, options,
1726 ).await
1727 };
1728 short_args = quote! { request: #input_type };
1729 opts_args = quote! { request: #input_type, options: ::connectrpc::client::CallOptions };
1730 short_delegate_args = quote! { request, ::connectrpc::client::CallOptions::default() };
1731 } else {
1732 ret_ty = quote! {
1734 Result<
1735 ::connectrpc::client::UnaryResponse<::buffa::view::OwnedView<#output_view_type<'static>>>,
1736 ::connectrpc::ConnectError,
1737 >
1738 };
1739 call_body = quote! {
1740 ::connectrpc::client::call_unary(
1741 &self.transport, &self.config,
1742 #service_name_const, #method_name,
1743 request, options,
1744 ).await
1745 };
1746 short_args = quote! { request: #input_type };
1747 opts_args = quote! { request: #input_type, options: ::connectrpc::client::CallOptions };
1748 short_delegate_args = quote! { request, ::connectrpc::client::CallOptions::default() };
1749 }
1750
1751 Ok(quote! {
1752 #[doc = #doc]
1753 pub async fn #method_snake(&self, #short_args) -> #ret_ty {
1754 self.#method_with_opts(#short_delegate_args).await
1755 }
1756
1757 #[doc = #doc_opts]
1758 pub async fn #method_with_opts(&self, #opts_args) -> #ret_ty {
1759 #call_body
1760 }
1761 })
1762}
1763
1764fn get_service_comment(
1766 file: &FileDescriptorProto,
1767 service: &ServiceDescriptorProto,
1768) -> Option<String> {
1769 let source_info: &SourceCodeInfo = &file.source_code_info;
1771
1772 let service_index = file.service.iter().position(|s| s.name == service.name)?;
1774
1775 let target_path = vec![6, service_index as i32];
1778
1779 find_comment(source_info, &target_path)
1780}
1781
1782fn get_method_comment(
1784 file: &FileDescriptorProto,
1785 service: &ServiceDescriptorProto,
1786 method: &MethodDescriptorProto,
1787) -> Option<String> {
1788 let source_info: &SourceCodeInfo = &file.source_code_info;
1789
1790 let (service_index, method_index) = file.service.iter().enumerate().find_map(|(si, s)| {
1793 if s.name != service.name {
1794 return None;
1795 }
1796 s.method
1797 .iter()
1798 .position(|m| m.name == method.name)
1799 .map(|mi| (si, mi))
1800 })?;
1801
1802 let target_path = vec![6, service_index as i32, 2, method_index as i32];
1806
1807 find_comment(source_info, &target_path)
1808}
1809
1810fn find_comment(source_info: &SourceCodeInfo, target_path: &[i32]) -> Option<String> {
1812 for location in &source_info.location {
1813 if location.path == target_path {
1814 let comment = location
1815 .leading_comments
1816 .as_ref()
1817 .or(location.trailing_comments.as_ref())?;
1818
1819 let cleaned: String = comment
1823 .lines()
1824 .map(|line| line.trim())
1825 .filter(|line| !line.is_empty())
1826 .collect::<Vec<_>>()
1827 .join("\n");
1828
1829 if !cleaned.is_empty() {
1830 return Some(cleaned);
1831 }
1832 }
1833 }
1834 None
1835}
1836
1837#[cfg(test)]
1838mod tests {
1839 use super::*;
1840 use buffa_codegen::generated::descriptor::DescriptorProto;
1841
1842 #[test]
1843 fn doc_attrs_prefixes_space_for_prettyplease() {
1844 let ts = quote! {
1847 #[allow(dead_code)]
1848 mod m {}
1849 };
1850 let doc = doc_attrs("Hello.\n\nSecond paragraph.");
1851 let combined = quote! { #doc #ts };
1852 let file = syn::parse2::<syn::File>(combined).unwrap();
1853 let out = prettyplease::unparse(&file);
1854 assert!(out.contains("/// Hello."), "got: {out}");
1856 assert!(out.contains("/// Second paragraph."), "got: {out}");
1857 assert!(out.contains("///\n"), "got: {out}");
1859 assert!(!out.contains("///Hello"), "got: {out}");
1861 assert!(!out.contains("/// Hello"), "got: {out}");
1862 }
1863
1864 fn minimal_file(
1869 package: Option<&str>,
1870 input_type: &str,
1871 output_type: &str,
1872 local_messages: &[&str],
1873 ) -> FileDescriptorProto {
1874 minimal_file_with_method(package, "Ping", input_type, output_type, local_messages)
1875 }
1876
1877 fn minimal_file_with_method(
1880 package: Option<&str>,
1881 method_name: &str,
1882 input_type: &str,
1883 output_type: &str,
1884 local_messages: &[&str],
1885 ) -> FileDescriptorProto {
1886 let method = MethodDescriptorProto {
1887 name: Some(method_name.into()),
1888 input_type: Some(input_type.into()),
1889 output_type: Some(output_type.into()),
1890 ..Default::default()
1891 };
1892 let service = ServiceDescriptorProto {
1893 name: Some("PingService".into()),
1894 method: vec![method],
1895 ..Default::default()
1896 };
1897 FileDescriptorProto {
1898 name: Some("ping.proto".into()),
1899 package: package.map(|p| p.into()),
1900 service: vec![service],
1901 message_type: local_messages
1902 .iter()
1903 .map(|name| DescriptorProto {
1904 name: Some((*name).into()),
1905 ..Default::default()
1906 })
1907 .collect(),
1908 ..Default::default()
1909 }
1910 }
1911
1912 fn minimal_file_with_methods(package: &str, method_names: &[&str]) -> FileDescriptorProto {
1916 let methods = method_names
1917 .iter()
1918 .map(|n| MethodDescriptorProto {
1919 name: Some((*n).into()),
1920 input_type: Some(format!(".{package}.Empty")),
1921 output_type: Some(format!(".{package}.Empty")),
1922 ..Default::default()
1923 })
1924 .collect();
1925 let service = ServiceDescriptorProto {
1926 name: Some("PingService".into()),
1927 method: methods,
1928 ..Default::default()
1929 };
1930 FileDescriptorProto {
1931 name: Some("ping.proto".into()),
1932 package: Some(package.into()),
1933 service: vec![service],
1934 message_type: vec![DescriptorProto {
1935 name: Some("Empty".into()),
1936 ..Default::default()
1937 }],
1938 ..Default::default()
1939 }
1940 }
1941
1942 fn gen_service(
1951 files: &[FileDescriptorProto],
1952 target_idx: usize,
1953 extern_paths: &[(String, String)],
1954 require_extern: bool,
1955 ) -> Result<String> {
1956 let mut config = buffa_codegen::CodeGenConfig::default();
1957 config.extern_paths = extern_paths.to_vec();
1958 let target_name = files[target_idx]
1959 .name
1960 .clone()
1961 .into_iter()
1962 .collect::<Vec<_>>();
1963 let resolver = TypeResolver::new(files, &target_name, &config, require_extern);
1964 let file = &files[target_idx];
1965 let service = &file.service[0];
1966 let batch = BatchState {
1967 colliding_aliases: collect_alias_collisions(files, &target_name),
1968 ..BatchState::default()
1969 };
1970 Ok(generate_service(file, service, &resolver, &batch)?.to_string())
1971 }
1972
1973 fn assert_no_top_level_use(formatted: &str, label: &str) {
1978 let parsed: syn::File = syn::parse_str(formatted).expect("formatted code parses");
1979 let offenders: Vec<String> = parsed
1980 .items
1981 .iter()
1982 .filter_map(|item| match item {
1983 syn::Item::Use(u) => Some(quote!(#u).to_string()),
1984 _ => None,
1985 })
1986 .collect();
1987 assert!(
1988 offenders.is_empty(),
1989 "{label} contains top-level use statement(s): {offenders:?}\nFull source:\n{formatted}"
1990 );
1991 }
1992
1993 fn gen_file(
1994 files: &[FileDescriptorProto],
1995 target_idx: usize,
1996 extern_paths: &[(String, String)],
1997 require_extern: bool,
1998 ) -> Result<String> {
1999 let mut config = buffa_codegen::CodeGenConfig::default();
2000 config.extern_paths = extern_paths.to_vec();
2001 let target_name = files[target_idx]
2002 .name
2003 .clone()
2004 .into_iter()
2005 .collect::<Vec<_>>();
2006 let resolver = TypeResolver::new(files, &target_name, &config, require_extern);
2007 let mut batch = BatchState {
2008 colliding_aliases: collect_alias_collisions(files, &target_name),
2009 ..BatchState::default()
2010 };
2011 Ok(generate_connect_services(&files[target_idx], &resolver, &mut batch)?.to_string())
2012 }
2013
2014 #[test]
2015 fn unary_response_body_captures_self_lifetime() {
2016 let file = minimal_file(
2017 Some("example.v1"),
2018 ".example.v1.PingReq",
2019 ".example.v1.PingResp",
2020 &["PingReq", "PingResp"],
2021 );
2022 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
2023 assert!(code.contains("< 'a >"), "trait method missing 'a: {code}");
2024 assert!(code.contains("& 'a self"), "missing &'a self: {code}");
2025 assert!(
2026 code.contains("use < 'a , Self >"),
2027 "missing use<'a, Self> capture: {code}"
2028 );
2029 assert!(
2030 !code.contains("'static + use"),
2031 "'static bound on body should be dropped: {code}"
2032 );
2033 }
2034
2035 #[test]
2036 fn owned_view_aliases_emitted_for_input_and_output() {
2037 let file = minimal_file(
2038 Some("example.v1"),
2039 ".example.v1.PingReq",
2040 ".example.v1.PingResp",
2041 &["PingReq", "PingResp"],
2042 );
2043 let code = gen_file(std::slice::from_ref(&file), 0, &[], false).unwrap();
2044 assert!(
2045 code.contains("pub type OwnedPingReqView = :: buffa :: view :: OwnedView"),
2046 "missing OwnedPingReqView alias: {code}"
2047 );
2048 assert!(
2049 code.contains("pub type OwnedPingRespView = :: buffa :: view :: OwnedView"),
2050 "missing OwnedPingRespView alias: {code}"
2051 );
2052 assert!(
2054 code.contains("request : OwnedPingReqView ,"),
2055 "trait method should take request: OwnedPingReqView: {code}"
2056 );
2057 }
2058
2059 #[test]
2060 fn cross_package_input_collision_suppresses_alias_for_both_sides() {
2061 let v1 = FileDescriptorProto {
2069 name: Some("api/v1/foo/bar/foobar.proto".into()),
2070 package: Some("api.v1.foo.bar".into()),
2071 message_type: vec![DescriptorProto {
2072 name: Some("MyMessage".into()),
2073 ..Default::default()
2074 }],
2075 ..Default::default()
2076 };
2077 let v2 = minimal_file(
2078 Some("api.v2.foo.bar"),
2079 ".api.v1.foo.bar.MyMessage",
2080 ".api.v2.foo.bar.MyMessage",
2081 &["MyMessage"],
2082 );
2083 let code = gen_file(&[v1, v2], 1, &[], false).unwrap();
2084
2085 let alias_count = code.matches("pub type OwnedMyMessageView").count();
2088 assert_eq!(
2089 alias_count, 0,
2090 "expected zero OwnedMyMessageView aliases when both sides collide; got {alias_count}: {code}"
2091 );
2092
2093 assert!(
2096 !code.contains("request : OwnedMyMessageView"),
2097 "colliding input must not reference the suppressed alias: {code}"
2098 );
2099 assert!(
2100 code.contains("request : :: buffa :: view :: OwnedView <"),
2101 "colliding input should be inlined as OwnedView<…<'static>>: {code}"
2102 );
2103 }
2104
2105 #[test]
2106 fn cross_package_input_without_collision_keeps_alias() {
2107 let wkt = FileDescriptorProto {
2114 name: Some("google/protobuf/empty.proto".into()),
2115 package: Some("google.protobuf".into()),
2116 message_type: vec![DescriptorProto {
2117 name: Some("Empty".into()),
2118 ..Default::default()
2119 }],
2120 ..Default::default()
2121 };
2122 let svc = minimal_file(
2123 Some("example.v1"),
2124 ".google.protobuf.Empty",
2125 ".example.v1.PingResp",
2126 &["PingResp"],
2127 );
2128 let code = gen_file(&[wkt, svc], 1, &[], false).unwrap();
2129 assert!(
2130 code.contains("pub type OwnedEmptyView = :: buffa :: view :: OwnedView"),
2131 "WKT cross-package input should keep its alias: {code}"
2132 );
2133 assert!(
2134 code.contains("request : OwnedEmptyView ,"),
2135 "trait method should still use OwnedEmptyView for non-colliding cross-package input: {code}"
2136 );
2137 }
2138
2139 #[test]
2140 fn collision_inlines_in_all_streaming_method_shapes() {
2141 let v1 = FileDescriptorProto {
2147 name: Some("api/v1/foo/bar/foobar.proto".into()),
2148 package: Some("api.v1.foo.bar".into()),
2149 message_type: vec![DescriptorProto {
2150 name: Some("MyMessage".into()),
2151 ..Default::default()
2152 }],
2153 ..Default::default()
2154 };
2155 let v2 = FileDescriptorProto {
2156 name: Some("api/v2/foo/bar/foobar.proto".into()),
2157 package: Some("api.v2.foo.bar".into()),
2158 message_type: vec![DescriptorProto {
2159 name: Some("MyMessage".into()),
2160 ..Default::default()
2161 }],
2162 service: vec![ServiceDescriptorProto {
2163 name: Some("FooBar".into()),
2164 method: vec![
2165 MethodDescriptorProto {
2166 name: Some("Unary".into()),
2167 input_type: Some(".api.v1.foo.bar.MyMessage".into()),
2168 output_type: Some(".api.v2.foo.bar.MyMessage".into()),
2169 ..Default::default()
2170 },
2171 MethodDescriptorProto {
2172 name: Some("ServerStream".into()),
2173 input_type: Some(".api.v1.foo.bar.MyMessage".into()),
2174 output_type: Some(".api.v2.foo.bar.MyMessage".into()),
2175 server_streaming: Some(true),
2176 ..Default::default()
2177 },
2178 MethodDescriptorProto {
2179 name: Some("ClientStream".into()),
2180 input_type: Some(".api.v1.foo.bar.MyMessage".into()),
2181 output_type: Some(".api.v2.foo.bar.MyMessage".into()),
2182 client_streaming: Some(true),
2183 ..Default::default()
2184 },
2185 MethodDescriptorProto {
2186 name: Some("Bidi".into()),
2187 input_type: Some(".api.v1.foo.bar.MyMessage".into()),
2188 output_type: Some(".api.v2.foo.bar.MyMessage".into()),
2189 client_streaming: Some(true),
2190 server_streaming: Some(true),
2191 ..Default::default()
2192 },
2193 ],
2194 ..Default::default()
2195 }],
2196 ..Default::default()
2197 };
2198 let code = gen_file(&[v1, v2], 1, &[], false).unwrap();
2199
2200 assert!(
2202 !code.contains("OwnedMyMessageView"),
2203 "no method shape should reference the suppressed alias: {code}"
2204 );
2205
2206 assert!(
2210 code.matches("request : :: buffa :: view :: OwnedView <")
2211 .count()
2212 >= 2,
2213 "unary and server-streaming should both inline the request type: {code}"
2214 );
2215 assert!(
2216 code.matches(
2217 "requests : :: connectrpc :: ServiceStream < :: buffa :: view :: OwnedView <"
2218 )
2219 .count()
2220 >= 2,
2221 "client-streaming and bidi should both inline the streamed request type: {code}"
2222 );
2223 }
2224
2225 #[test]
2226 fn streaming_methods_use_encodable_item_type() {
2227 let file = FileDescriptorProto {
2235 name: Some("ex/v1/svc.proto".into()),
2236 package: Some("ex.v1".into()),
2237 message_type: vec![
2238 DescriptorProto {
2239 name: Some("Req".into()),
2240 ..Default::default()
2241 },
2242 DescriptorProto {
2243 name: Some("Resp".into()),
2244 ..Default::default()
2245 },
2246 ],
2247 service: vec![ServiceDescriptorProto {
2248 name: Some("Svc".into()),
2249 method: vec![
2250 MethodDescriptorProto {
2251 name: Some("ServerStream".into()),
2252 input_type: Some(".ex.v1.Req".into()),
2253 output_type: Some(".ex.v1.Resp".into()),
2254 server_streaming: Some(true),
2255 ..Default::default()
2256 },
2257 MethodDescriptorProto {
2258 name: Some("Bidi".into()),
2259 input_type: Some(".ex.v1.Req".into()),
2260 output_type: Some(".ex.v1.Resp".into()),
2261 client_streaming: Some(true),
2262 server_streaming: Some(true),
2263 ..Default::default()
2264 },
2265 ],
2266 ..Default::default()
2267 }],
2268 ..Default::default()
2269 };
2270 let code = gen_file(std::slice::from_ref(&file), 0, &[], false).unwrap();
2271
2272 assert_eq!(
2274 code.matches(":: connectrpc :: ServiceStream < impl :: connectrpc :: Encodable < Resp > + Send + use < Self >>")
2275 .count(),
2276 2,
2277 "server-streaming and bidi should both use the Encodable item type: {code}"
2278 );
2279
2280 assert_eq!(
2282 code.matches("encode_response_stream :: < Resp , _ , _ >")
2283 .count(),
2284 2,
2285 "dispatcher arms must turbofish Res to encode_response_stream: {code}"
2286 );
2287
2288 assert!(
2290 code.contains("route_view_server_stream :: < _ , _ , Resp >"),
2291 "route_view_server_stream must turbofish Res: {code}"
2292 );
2293 assert!(
2294 code.contains("route_view_bidi_stream :: < _ , _ , Resp >"),
2295 "route_view_bidi_stream must turbofish Res: {code}"
2296 );
2297 }
2298
2299 #[test]
2300 fn encodable_view_impls_emitted_per_output_type() {
2301 let file = minimal_file(
2302 Some("example.v1"),
2303 ".example.v1.PingReq",
2304 ".example.v1.PingResp",
2305 &["PingReq", "PingResp"],
2306 );
2307 let code = gen_file(std::slice::from_ref(&file), 0, &[], false).unwrap();
2308 assert!(
2309 code.contains(
2310 ":: connectrpc :: Encodable < PingResp > for __buffa :: view :: PingRespView"
2311 ),
2312 "missing Encodable<PingResp> for PingRespView: {code}"
2313 );
2314 assert!(
2315 code.contains(
2316 ":: connectrpc :: Encodable < PingResp > for :: buffa :: view :: OwnedView"
2317 ),
2318 "missing Encodable<PingResp> for OwnedView<PingRespView>: {code}"
2319 );
2320 assert!(!code.contains("Encodable < PingReq >"), "got: {code}");
2322 }
2323
2324 #[test]
2325 fn encodable_view_impls_skipped_for_extern_output() {
2326 let wkt = FileDescriptorProto {
2329 name: Some("google/protobuf/empty.proto".into()),
2330 package: Some("google.protobuf".into()),
2331 message_type: vec![DescriptorProto {
2332 name: Some("Empty".into()),
2333 ..Default::default()
2334 }],
2335 ..Default::default()
2336 };
2337 let file = minimal_file(
2338 Some("example.v1"),
2339 ".example.v1.PingReq",
2340 ".google.protobuf.Empty",
2341 &["PingReq"],
2342 );
2343 let code = gen_file(&[wkt, file], 1, &[], false).unwrap();
2344 assert!(
2347 !code.contains("encode_view_body"),
2348 "extern output type must not get Encodable impl: {code}"
2349 );
2350 }
2351
2352 #[test]
2353 fn encodable_view_impls_deduped_across_files() {
2354 let common = FileDescriptorProto {
2359 name: Some("common.proto".into()),
2360 package: Some("common.v1".into()),
2361 message_type: vec![DescriptorProto {
2362 name: Some("Reply".into()),
2363 ..Default::default()
2364 }],
2365 ..Default::default()
2366 };
2367 let svc = |name: &str, pkg: &str| FileDescriptorProto {
2368 name: Some(name.into()),
2369 package: Some(pkg.into()),
2370 message_type: vec![DescriptorProto {
2371 name: Some("Req".into()),
2372 ..Default::default()
2373 }],
2374 service: vec![ServiceDescriptorProto {
2375 name: Some("S".into()),
2376 method: vec![MethodDescriptorProto {
2377 name: Some("Call".into()),
2378 input_type: Some(format!(".{pkg}.Req")),
2379 output_type: Some(".common.v1.Reply".into()),
2380 ..Default::default()
2381 }],
2382 ..Default::default()
2383 }],
2384 ..Default::default()
2385 };
2386 let files = vec![common, svc("a.proto", "a.v1"), svc("b.proto", "b.v1")];
2387
2388 let generated = generate_files(
2389 &files,
2390 &["a.proto".into(), "b.proto".into()],
2391 &Options::default(),
2392 )
2393 .unwrap();
2394
2395 let companions: Vec<_> = generated
2398 .iter()
2399 .filter(|f| f.kind == GeneratedFileKind::Companion)
2400 .collect();
2401 let mut companion_names: Vec<&str> = companions.iter().map(|f| f.name.as_str()).collect();
2402 companion_names.sort_unstable();
2403 assert_eq!(companion_names, ["a.__connect.rs", "b.__connect.rs"]);
2404 for c in &companions {
2405 let stitcher = generated
2406 .iter()
2407 .find(|g| g.kind == GeneratedFileKind::PackageMod && g.package == c.package)
2408 .expect("each companion's package must have a stitcher");
2409 assert!(
2410 stitcher
2411 .content
2412 .contains(&format!("include!(\"{}\")", c.name)),
2413 "stitcher for {} must include companion {}",
2414 c.package,
2415 c.name
2416 );
2417 }
2418
2419 let combined: String = companions.iter().map(|f| f.content.as_str()).collect();
2420
2421 let view_impl = "impl ::connectrpc::Encodable<super::super::common::v1::Reply>\nfor super::super::common::v1::__buffa::view::ReplyView<'_>";
2422 let owned_view_impl = "impl ::connectrpc::Encodable<super::super::common::v1::Reply>\nfor ::buffa::view::OwnedView<";
2423 assert_eq!(
2424 combined.matches(view_impl).count(),
2425 1,
2426 "Encodable<Reply> for ReplyView<'_> must appear once: {combined}"
2427 );
2428 assert_eq!(
2429 combined.matches(owned_view_impl).count(),
2430 1,
2431 "Encodable<Reply> for OwnedView<ReplyView> must appear once: {combined}"
2432 );
2433 }
2434
2435 fn file_per_package_fixture() -> Vec<FileDescriptorProto> {
2440 let common = FileDescriptorProto {
2441 name: Some("common.proto".into()),
2442 package: Some("common.v1".into()),
2443 message_type: vec![DescriptorProto {
2444 name: Some("Reply".into()),
2445 ..Default::default()
2446 }],
2447 ..Default::default()
2448 };
2449 let svc = |proto_name: &str, pkg: &str, svc_name: &str, req: &str| FileDescriptorProto {
2454 name: Some(proto_name.into()),
2455 package: Some(pkg.into()),
2456 message_type: vec![DescriptorProto {
2457 name: Some(req.into()),
2458 ..Default::default()
2459 }],
2460 service: vec![ServiceDescriptorProto {
2461 name: Some(svc_name.into()),
2462 method: vec![MethodDescriptorProto {
2463 name: Some("Call".into()),
2464 input_type: Some(format!(".{pkg}.{req}")),
2465 output_type: Some(".common.v1.Reply".into()),
2466 ..Default::default()
2467 }],
2468 ..Default::default()
2469 }],
2470 ..Default::default()
2471 };
2472 vec![
2473 common,
2474 svc("a/x.proto", "a.v1", "XService", "XReq"),
2475 svc("a/y.proto", "a.v1", "YService", "YReq"),
2476 svc("b/z.proto", "b.v1", "ZService", "ZReq"),
2477 ]
2478 }
2479
2480 #[test]
2481 fn generate_files_file_per_package_inlines_companions() {
2482 let files = file_per_package_fixture();
2483 let mut options = Options::default();
2484 options.buffa.file_per_package = true;
2485
2486 let generated = generate_files(
2487 &files,
2488 &["a/x.proto".into(), "a/y.proto".into(), "b/z.proto".into()],
2489 &options,
2490 )
2491 .unwrap();
2492
2493 assert!(
2495 !generated
2496 .iter()
2497 .any(|f| f.kind == GeneratedFileKind::Companion),
2498 "file_per_package must not emit sibling Companion files"
2499 );
2500 assert!(
2501 !generated.iter().any(|f| f.name.ends_with(".__connect.rs")),
2502 "file_per_package must not emit `<stem>.__connect.rs` files"
2503 );
2504
2505 let a = generated
2507 .iter()
2508 .find(|f| f.kind == GeneratedFileKind::PackageMod && f.package == "a.v1")
2509 .expect("a.v1 PackageMod must exist");
2510 assert!(
2511 a.content.contains("pub trait XService"),
2512 "a.v1 missing XService"
2513 );
2514 assert!(
2515 a.content.contains("pub trait YService"),
2516 "a.v1 missing YService"
2517 );
2518 assert!(
2519 !a.content.contains("pub trait ZService"),
2520 "a.v1 must not inline ZService"
2521 );
2522 assert!(
2523 !a.content.contains("__connect.rs"),
2524 "a.v1 PackageMod must not include! a connect file: {}",
2525 a.content
2526 );
2527
2528 let b = generated
2529 .iter()
2530 .find(|f| f.kind == GeneratedFileKind::PackageMod && f.package == "b.v1")
2531 .expect("b.v1 PackageMod must exist");
2532 assert!(
2533 b.content.contains("pub trait ZService"),
2534 "b.v1 missing ZService"
2535 );
2536 assert!(
2537 !b.content.contains("pub trait XService"),
2538 "b.v1 must not inline XService"
2539 );
2540
2541 let pkg_mods = generated
2544 .iter()
2545 .filter(|f| f.kind == GeneratedFileKind::PackageMod)
2546 .count();
2547 assert_eq!(
2548 pkg_mods, 2,
2549 "expected exactly two PackageMods: {generated:#?}"
2550 );
2551
2552 let combined: String = generated.iter().map(|f| f.content.as_str()).collect();
2557 assert_eq!(
2558 combined
2559 .matches("impl ::connectrpc::Encodable<super::super::common::v1::Reply>")
2560 .count(),
2561 2,
2562 "Encodable<Reply> impls must be deduplicated across packages \
2563 (1 for ReplyView, 1 for OwnedView<ReplyView>): {combined}"
2564 );
2565 }
2566
2567 #[test]
2568 fn generate_services_file_per_package_emits_one_file_per_package() {
2569 let files = file_per_package_fixture();
2570 let mut options = Options::default();
2571 options.buffa.file_per_package = true;
2572 options
2573 .buffa
2574 .extern_paths
2575 .push((".".into(), "crate::proto".into()));
2576
2577 let generated = generate_services(
2578 &files,
2579 &["a/x.proto".into(), "a/y.proto".into(), "b/z.proto".into()],
2580 &options,
2581 )
2582 .unwrap();
2583
2584 assert_eq!(
2587 generated.len(),
2588 2,
2589 "expected exactly two output files: {generated:#?}"
2590 );
2591 assert!(
2592 generated
2593 .iter()
2594 .all(|f| f.kind == GeneratedFileKind::PackageMod),
2595 "all output files must be PackageMod"
2596 );
2597 assert!(
2598 !generated.iter().any(|f| f.name.ends_with(".mod.rs")),
2599 "file_per_package must not emit a separate stitcher"
2600 );
2601 assert!(
2602 !generated.iter().any(|f| f.content.contains("include!")),
2603 "file_per_package output must not include! sibling files"
2604 );
2605
2606 let mut names: Vec<&str> = generated.iter().map(|f| f.name.as_str()).collect();
2607 names.sort_unstable();
2608 assert_eq!(
2609 names,
2610 ["a.v1.rs", "b.v1.rs"],
2611 "filenames must be `<dotted.pkg>.rs` to match buffa's file_per_package convention"
2612 );
2613
2614 let a = generated.iter().find(|f| f.package == "a.v1").unwrap();
2615 assert!(a.content.contains("pub trait XService"));
2616 assert!(a.content.contains("pub trait YService"));
2617 let b = generated.iter().find(|f| f.package == "b.v1").unwrap();
2618 assert!(b.content.contains("pub trait ZService"));
2619 assert!(!b.content.contains("pub trait XService"));
2620 }
2621
2622 #[test]
2623 fn generate_services_file_per_package_default_layout_unchanged() {
2624 let files = file_per_package_fixture();
2627 let mut options = Options::default();
2628 options
2629 .buffa
2630 .extern_paths
2631 .push((".".into(), "crate::proto".into()));
2632
2633 let generated = generate_services(
2634 &files,
2635 &["a/x.proto".into(), "a/y.proto".into(), "b/z.proto".into()],
2636 &options,
2637 )
2638 .unwrap();
2639
2640 let mut companions: Vec<&str> = generated
2641 .iter()
2642 .filter(|f| f.kind == GeneratedFileKind::Companion)
2643 .map(|f| f.name.as_str())
2644 .collect();
2645 companions.sort_unstable();
2646 assert_eq!(
2647 companions,
2648 ["a.x.__connect.rs", "a.y.__connect.rs", "b.z.__connect.rs"],
2649 "default layout emits one companion per proto"
2650 );
2651 let mut stitchers: Vec<&str> = generated
2652 .iter()
2653 .filter(|f| f.kind == GeneratedFileKind::PackageMod)
2654 .map(|f| f.name.as_str())
2655 .collect();
2656 stitchers.sort_unstable();
2657 assert_eq!(
2658 stitchers,
2659 ["a.v1.mod.rs", "b.v1.mod.rs"],
2660 "default layout emits one stitcher per package"
2661 );
2662 let a_stitcher = generated.iter().find(|f| f.name == "a.v1.mod.rs").unwrap();
2664 assert!(
2665 a_stitcher
2666 .content
2667 .contains(r#"include!("a.x.__connect.rs");"#)
2668 );
2669 assert!(
2670 a_stitcher
2671 .content
2672 .contains(r#"include!("a.y.__connect.rs");"#)
2673 );
2674 }
2675
2676 #[test]
2677 fn service_name_with_package() {
2678 let file = minimal_file(
2679 Some("example.v1"),
2680 ".example.v1.PingReq",
2681 ".example.v1.PingResp",
2682 &["PingReq", "PingResp"],
2683 );
2684 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
2685 assert!(code.contains("\"example.v1.PingService\""), "got: {code}");
2686 }
2687
2688 #[test]
2689 fn service_name_without_package() {
2690 let file = minimal_file(None, ".PingReq", ".PingResp", &["PingReq", "PingResp"]);
2692 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
2693 assert!(code.contains("\"PingService\""), "got: {code}");
2694 assert!(
2695 !code.contains("\".PingService\""),
2696 "must not have leading dot: {code}"
2697 );
2698 }
2699
2700 #[test]
2701 fn same_package_types_use_bare_names() {
2702 let file = minimal_file(
2703 Some("example.v1"),
2704 ".example.v1.PingReq",
2705 ".example.v1.PingResp",
2706 &["PingReq", "PingResp"],
2707 );
2708 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
2709 assert!(code.contains("PingReq"), "input type missing: {code}");
2711 assert!(code.contains("PingResp"), "output type missing: {code}");
2712 assert!(
2714 !code.contains("super :: PingReq"),
2715 "unexpected super: {code}"
2716 );
2717 }
2718
2719 #[test]
2720 fn cross_package_types_use_relative_paths() {
2721 let common = FileDescriptorProto {
2725 name: Some("common.proto".into()),
2726 package: Some("common.v1".into()),
2727 message_type: vec![DescriptorProto {
2728 name: Some("Shared".into()),
2729 ..Default::default()
2730 }],
2731 ..Default::default()
2732 };
2733 let svc = minimal_file(
2734 Some("example.v1"),
2735 ".common.v1.Shared",
2736 ".example.v1.Out",
2737 &["Out"],
2738 );
2739 let code = gen_service(&[common, svc], 1, &[], false).unwrap();
2740
2741 assert!(
2744 code.contains("super :: super :: common :: v1 :: Shared"),
2745 "cross-package path not emitted: {code}"
2746 );
2747 assert!(
2748 code.contains("super :: super :: common :: v1 :: __buffa :: view :: SharedView"),
2749 "cross-package view path not emitted: {code}"
2750 );
2751 }
2752
2753 #[test]
2754 fn nested_message_view_type_mirrors_owned_module_nesting() {
2755 let file = FileDescriptorProto {
2760 name: Some("nested.proto".into()),
2761 package: Some("example.v1".into()),
2762 message_type: vec![
2763 DescriptorProto {
2764 name: Some("Outer".into()),
2765 nested_type: vec![DescriptorProto {
2766 name: Some("Inner".into()),
2767 ..Default::default()
2768 }],
2769 ..Default::default()
2770 },
2771 DescriptorProto {
2772 name: Some("Out".into()),
2773 ..Default::default()
2774 },
2775 ],
2776 service: vec![ServiceDescriptorProto {
2777 name: Some("NestedService".into()),
2778 method: vec![MethodDescriptorProto {
2779 name: Some("Ping".into()),
2780 input_type: Some(".example.v1.Outer.Inner".into()),
2781 output_type: Some(".example.v1.Out".into()),
2782 ..Default::default()
2783 }],
2784 ..Default::default()
2785 }],
2786 ..Default::default()
2787 };
2788 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
2789
2790 assert!(
2791 code.contains("__buffa :: view :: outer :: InnerView"),
2792 "nested view path not emitted: {code}"
2793 );
2794 assert!(
2795 code.contains("outer :: Inner"),
2796 "nested owned path not emitted: {code}"
2797 );
2798 }
2799
2800 #[test]
2801 fn wkt_types_use_buffa_types_extern_path() {
2802 let wkt = FileDescriptorProto {
2806 name: Some("google/protobuf/empty.proto".into()),
2807 package: Some("google.protobuf".into()),
2808 message_type: vec![DescriptorProto {
2809 name: Some("Empty".into()),
2810 ..Default::default()
2811 }],
2812 ..Default::default()
2813 };
2814 let svc = minimal_file(
2815 Some("example.v1"),
2816 ".google.protobuf.Empty",
2817 ".example.v1.Out",
2818 &["Out"],
2819 );
2820 let code = gen_service(&[wkt, svc], 1, &[], false).unwrap();
2821
2822 assert!(
2823 code.contains(":: buffa_types :: google :: protobuf :: Empty"),
2824 "WKT extern path not emitted: {code}"
2825 );
2826 }
2827
2828 #[test]
2829 fn extern_catchall_uses_absolute_paths() {
2830 let file = minimal_file(
2831 Some("example.v1"),
2832 ".example.v1.PingReq",
2833 ".example.v1.PingResp",
2834 &["PingReq", "PingResp"],
2835 );
2836 let extern_paths = [(".".into(), "crate::proto".into())];
2837 let code = gen_service(std::slice::from_ref(&file), 0, &extern_paths, true).unwrap();
2838 assert!(
2839 code.contains("crate :: proto :: example :: v1 :: PingReq"),
2840 "owned type path missing: {code}"
2841 );
2842 assert!(
2843 code.contains("crate :: proto :: example :: v1 :: __buffa :: view :: PingReqView"),
2844 "view type path missing: {code}"
2845 );
2846 }
2847
2848 #[test]
2849 fn extern_catchall_with_wkt_longest_wins() {
2850 let wkt = FileDescriptorProto {
2853 name: Some("google/protobuf/empty.proto".into()),
2854 package: Some("google.protobuf".into()),
2855 message_type: vec![DescriptorProto {
2856 name: Some("Empty".into()),
2857 ..Default::default()
2858 }],
2859 ..Default::default()
2860 };
2861 let svc = minimal_file(
2862 Some("example.v1"),
2863 ".google.protobuf.Empty",
2864 ".example.v1.Out",
2865 &["Out"],
2866 );
2867 let extern_paths = [(".".into(), "crate::proto".into())];
2868 let code = gen_service(&[wkt, svc], 1, &extern_paths, true).unwrap();
2869 assert!(
2870 code.contains(":: buffa_types :: google :: protobuf :: Empty"),
2871 "WKT mapping lost to catch-all: {code}"
2872 );
2873 assert!(
2874 code.contains("crate :: proto :: example :: v1 :: Out"),
2875 "local type not routed through catch-all: {code}"
2876 );
2877 }
2878
2879 #[test]
2880 fn missing_extern_path_errors() {
2881 let file = minimal_file(
2882 Some("example.v1"),
2883 ".example.v1.PingReq",
2884 ".example.v1.PingResp",
2885 &["PingReq", "PingResp"],
2886 );
2887 let err = gen_service(std::slice::from_ref(&file), 0, &[], true).unwrap_err();
2888 let msg = err.to_string();
2889 assert!(
2890 msg.contains("extern_path"),
2891 "error message lacks hint: {msg}"
2892 );
2893 }
2894
2895 #[test]
2896 fn keyword_package_escaped() {
2897 let file = minimal_file(
2899 Some("google.type"),
2900 ".google.type.LatLng",
2901 ".google.type.LatLng",
2902 &["LatLng"],
2903 );
2904 let extern_paths = [(".".into(), "crate::proto".into())];
2905 let code = gen_service(std::slice::from_ref(&file), 0, &extern_paths, true).unwrap();
2906 assert!(
2907 code.contains("crate :: proto :: google :: r#type :: LatLng"),
2908 "keyword segment not escaped: {code}"
2909 );
2910 }
2911
2912 #[test]
2913 fn keyword_method_escaped() {
2914 let file = minimal_file_with_method(
2917 Some("example.v1"),
2918 "Move",
2919 ".example.v1.Empty",
2920 ".example.v1.Empty",
2921 &["Empty"],
2922 );
2923 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
2924 assert!(
2925 code.contains("fn r#move"),
2926 "keyword method not escaped: {code}"
2927 );
2928 assert!(
2929 code.contains("move_with_options"),
2930 "suffixed variant should not need escaping: {code}"
2931 );
2932 assert!(code.contains("client.r#move(request)"));
2934 syn::parse_str::<syn::File>(&code).expect("generated code parses");
2935 }
2936
2937 #[test]
2938 fn path_keyword_method_suffixed() {
2939 let file = minimal_file_with_method(
2942 Some("example.v1"),
2943 "Self",
2944 ".example.v1.Empty",
2945 ".example.v1.Empty",
2946 &["Empty"],
2947 );
2948 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
2949 assert!(
2950 code.contains("fn self_"),
2951 "path-keyword method not suffixed: {code}"
2952 );
2953 assert!(code.contains("self_with_options"));
2957 syn::parse_str::<syn::File>(&code).expect("generated code parses");
2958 }
2959
2960 #[test]
2961 fn service_name_keyword_suffixed() {
2962 let mut file = minimal_file(
2966 Some("example.v1"),
2967 ".example.v1.Empty",
2968 ".example.v1.Empty",
2969 &["Empty"],
2970 );
2971 file.service[0].name = Some("Self".into());
2972 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
2973 assert!(code.contains("trait Self_ "), "trait not suffixed: {code}");
2974 assert!(code.contains("trait SelfExt"));
2975 assert!(code.contains("struct SelfClient"));
2976 assert!(code.contains("struct SelfServer"));
2977 syn::parse_str::<syn::File>(&code).expect("generated code parses");
2978 }
2979
2980 #[test]
2981 fn method_snake_collision_errors() {
2982 let file = minimal_file_with_methods("example.v1", &["GetFoo", "get_foo"]);
2985 let err = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap_err();
2986 let msg = err.to_string();
2987 assert!(msg.contains("PingService"), "missing service name: {msg}");
2988 assert!(msg.contains("\"GetFoo\""), "missing first method: {msg}");
2989 assert!(msg.contains("\"get_foo\""), "missing second method: {msg}");
2990 assert!(msg.contains("`get_foo`"), "missing rust ident: {msg}");
2991 }
2992
2993 #[test]
2994 fn method_with_options_collision_errors() {
2995 let file = minimal_file_with_methods("example.v1", &["Ping", "PingWithOptions"]);
2998 let err = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap_err();
2999 let msg = err.to_string();
3000 assert!(msg.contains("\"Ping\""), "missing first method: {msg}");
3001 assert!(
3002 msg.contains("\"PingWithOptions\""),
3003 "missing second method: {msg}"
3004 );
3005 assert!(
3006 msg.contains("`ping_with_options`"),
3007 "missing rust ident: {msg}"
3008 );
3009 }
3010
3011 #[test]
3012 fn distinct_methods_do_not_collide() {
3013 let file = minimal_file_with_methods("example.v1", &["GetFoo", "GetBar"]);
3014 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
3015 syn::parse_str::<syn::File>(&code).expect("generated code parses");
3016 }
3017
3018 #[test]
3019 fn options_default_buffa_config() {
3020 let cfg = Options::default().to_buffa_config();
3021 assert!(cfg.generate_json, "connectrpc enables JSON by default");
3022 assert!(cfg.generate_views);
3023 assert!(cfg.emit_register_fn);
3024 assert!(!cfg.strict_utf8_mapping);
3025 }
3026
3027 #[test]
3028 fn options_buffa_passthrough_forces_views() {
3029 let mut opts = Options::default();
3030 opts.buffa.emit_register_fn = false;
3031 opts.buffa.generate_views = false;
3032 let cfg = opts.to_buffa_config();
3033 assert!(!cfg.emit_register_fn);
3034 assert!(cfg.generate_views, "generate_views must be forced on");
3035 }
3036
3037 #[test]
3038 fn generate_files_emit_register_fn_false_suppresses_register_types() {
3039 let file = FileDescriptorProto {
3042 name: Some("ping.proto".into()),
3043 package: Some("example.v1".into()),
3044 message_type: vec![DescriptorProto {
3045 name: Some("PingReq".into()),
3046 ..Default::default()
3047 }],
3048 ..Default::default()
3049 };
3050
3051 let stitcher = |files: &[GeneratedFile]| {
3054 files
3055 .iter()
3056 .find(|f| f.kind == GeneratedFileKind::PackageMod)
3057 .expect("PackageMod file emitted")
3058 .content
3059 .clone()
3060 };
3061
3062 let with_fn = generate_files(
3063 std::slice::from_ref(&file),
3064 &["ping.proto".into()],
3065 &Options::default(),
3066 )
3067 .unwrap();
3068 let mod_rs = stitcher(&with_fn);
3069 assert!(
3070 mod_rs.contains("fn register_types"),
3071 "expected register_types in default output: {mod_rs}"
3072 );
3073
3074 let mut opts = Options::default();
3075 opts.buffa.emit_register_fn = false;
3076 let without_fn =
3077 generate_files(std::slice::from_ref(&file), &["ping.proto".into()], &opts).unwrap();
3078 let mod_rs = stitcher(&without_fn);
3079 assert!(
3080 !mod_rs.contains("fn register_types"),
3081 "register_types should be suppressed: {mod_rs}"
3082 );
3083 }
3084
3085 #[test]
3086 fn plugin_no_register_fn_parses() {
3087 let request = CodeGeneratorRequest {
3088 parameter: Some("buffa_module=crate::proto,no_register_fn".into()),
3089 file_to_generate: vec![],
3090 proto_file: vec![],
3091 ..Default::default()
3092 };
3093 generate(&request).expect("no_register_fn should be a recognized plugin option");
3096 }
3097
3098 #[test]
3099 fn plugin_file_per_package_collapses_output() {
3100 let request = CodeGeneratorRequest {
3103 parameter: Some("buffa_module=crate::proto,file_per_package".into()),
3104 file_to_generate: vec!["a/x.proto".into(), "a/y.proto".into(), "b/z.proto".into()],
3105 proto_file: file_per_package_fixture(),
3106 ..Default::default()
3107 };
3108 let response = generate(&request).expect("file_per_package should parse and generate");
3109 let mut names: Vec<&str> = response
3110 .file
3111 .iter()
3112 .filter_map(|f| f.name.as_deref())
3113 .collect();
3114 names.sort_unstable();
3115 assert_eq!(
3116 names,
3117 ["a.v1.rs", "b.v1.rs"],
3118 "expected one file per package: {names:?}"
3119 );
3120 for f in &response.file {
3121 let content = f.content.as_deref().unwrap_or_default();
3122 assert!(
3123 !content.contains("include!"),
3124 "file_per_package output must be self-contained: {content}"
3125 );
3126 }
3127 }
3128
3129 #[test]
3130 fn no_top_level_use_statements_in_generated_code() {
3131 let file = minimal_file(
3135 Some("example.v1"),
3136 ".example.v1.PingReq",
3137 ".example.v1.PingResp",
3138 &["PingReq", "PingResp"],
3139 );
3140 let code = gen_service(std::slice::from_ref(&file), 0, &[], false).unwrap();
3141 let formatted = format_token_stream(&code.parse::<TokenStream>().unwrap()).unwrap();
3142 assert_no_top_level_use(&formatted, "generated code");
3143 }
3144
3145 #[test]
3146 fn multi_service_include_no_e0252() {
3147 let file_a = {
3150 let method = MethodDescriptorProto {
3151 name: Some("Ping".into()),
3152 input_type: Some(".svc.v1.PingReq".into()),
3153 output_type: Some(".svc.v1.PingResp".into()),
3154 ..Default::default()
3155 };
3156 let service = ServiceDescriptorProto {
3157 name: Some("Alpha".into()),
3158 method: vec![method],
3159 ..Default::default()
3160 };
3161 FileDescriptorProto {
3162 name: Some("alpha.proto".into()),
3163 package: Some("svc.v1".into()),
3164 service: vec![service],
3165 message_type: vec![
3166 DescriptorProto {
3167 name: Some("PingReq".into()),
3168 ..Default::default()
3169 },
3170 DescriptorProto {
3171 name: Some("PingResp".into()),
3172 ..Default::default()
3173 },
3174 ],
3175 ..Default::default()
3176 }
3177 };
3178 let file_b = {
3179 let method = MethodDescriptorProto {
3180 name: Some("Pong".into()),
3181 input_type: Some(".svc.v1.PongReq".into()),
3182 output_type: Some(".svc.v1.PongResp".into()),
3183 ..Default::default()
3184 };
3185 let service = ServiceDescriptorProto {
3186 name: Some("Beta".into()),
3187 method: vec![method],
3188 ..Default::default()
3189 };
3190 FileDescriptorProto {
3191 name: Some("beta.proto".into()),
3192 package: Some("svc.v1".into()),
3193 service: vec![service],
3194 message_type: vec![
3195 DescriptorProto {
3196 name: Some("PongReq".into()),
3197 ..Default::default()
3198 },
3199 DescriptorProto {
3200 name: Some("PongResp".into()),
3201 ..Default::default()
3202 },
3203 ],
3204 ..Default::default()
3205 }
3206 };
3207
3208 let files = vec![file_a, file_b];
3209 let config = buffa_codegen::CodeGenConfig::default();
3210 let targets = vec!["alpha.proto".to_string(), "beta.proto".to_string()];
3211 let resolver = TypeResolver::new(&files, &targets, &config, false);
3212
3213 let mut batch = BatchState {
3214 colliding_aliases: collect_alias_collisions(&files, &targets),
3215 ..BatchState::default()
3216 };
3217 let code_a = generate_connect_services(&files[0], &resolver, &mut batch).unwrap();
3218 let code_b = generate_connect_services(&files[1], &resolver, &mut batch).unwrap();
3219
3220 let formatted_a = format_token_stream(&code_a).unwrap();
3221 let formatted_b = format_token_stream(&code_b).unwrap();
3222
3223 syn::parse_str::<syn::File>(&formatted_a).expect("service A should parse independently");
3225 syn::parse_str::<syn::File>(&formatted_b).expect("service B should parse independently");
3226
3227 let combined = format!("{formatted_a}\n{formatted_b}");
3229 syn::parse_str::<syn::File>(&combined)
3230 .expect("combined services should parse without E0252");
3231
3232 assert_no_top_level_use(&formatted_a, "service A");
3234 assert_no_top_level_use(&formatted_b, "service B");
3235 }
3236}