#[cfg(feature = "validate")]
mod ver {
pub const AXUM_VERSION: &str = "0.8.3";
pub const TONIC_VERSION: &str = "0.13.0";
pub const HTTP_VERSION: &str = "1.3.1";
}
use heck::ToSnakeCase;
use prost_build::ServiceGenerator;
use quote::quote;
use prost::Message;
use prost_types::{
field_descriptor_proto::{Label, Type},
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
};
#[cfg(feature = "validate")]
pub(crate) mod vercheck;
pub struct BridgeGenerator {
inner: Box<dyn ServiceGenerator>,
enable_string_enums: bool,
file_descriptor_set: Option<FileDescriptorSet>,
descriptor_set_path: Option<std::path::PathBuf>,
}
impl BridgeGenerator {
pub fn new(inner: Box<dyn ServiceGenerator>) -> Self {
#[cfg(feature = "validate")]
{
let output =
vercheck::Deps::new(ver::AXUM_VERSION, ver::TONIC_VERSION, ver::HTTP_VERSION)
.and_then(vercheck::Deps::validate);
if let Err(err) = output {
eprintln!("g2h: {err}");
}
}
Self {
inner,
enable_string_enums: false,
file_descriptor_set: None,
descriptor_set_path: None,
}
}
pub fn build_prost_config(self) -> prost_build::Config {
let mut config = prost_build::Config::new();
config
.service_generator(Box::new(self))
.type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]");
config
}
pub fn compile_protos(
self,
protos: &[impl AsRef<std::path::Path>],
includes: &[impl AsRef<std::path::Path>],
) -> Result<(), Box<dyn std::error::Error>> {
let file_descriptor_set = if self.enable_string_enums || self.descriptor_set_path.is_some()
{
Some(prost_build::Config::new().load_fds(protos, includes)?)
} else {
None
};
if let (Some(ref path), Some(ref fds)) = (&self.descriptor_set_path, &file_descriptor_set) {
let bytes = fds.encode_to_vec();
std::fs::write(path, bytes)?;
}
if !self.enable_string_enums {
let descriptor_path = self.descriptor_set_path.clone();
let mut config = self.build_prost_config();
if let Some(path) = descriptor_path {
config.file_descriptor_set_path(path);
}
return Ok(config.compile_protos(protos, includes)?);
}
let file_descriptor_set = file_descriptor_set.unwrap(); let mut generator = self;
generator.file_descriptor_set = Some(file_descriptor_set.clone());
let mut final_config = generator
.build_enum_config()
.build_prost_config_with_descriptors(&file_descriptor_set);
final_config.compile_protos(protos, includes)?;
Ok(())
}
fn build_enum_config(self) -> EnumConfig {
EnumConfig::new(self)
}
pub fn with_tonic_build() -> Self {
Self::new(tonic_build::configure().service_generator())
}
pub fn with_string_enums(mut self) -> Self {
self.enable_string_enums = true;
self
}
pub fn file_descriptor_set_path(mut self, path: impl AsRef<std::path::Path>) -> Self {
self.descriptor_set_path = Some(path.as_ref().to_path_buf());
self
}
fn generate_package_specific_enum_deserializer_code(
file_descriptor_set: &FileDescriptorSet,
target_package: &str,
) -> String {
let package_enum_fields =
Self::extract_package_enum_fields_static(file_descriptor_set, target_package);
if package_enum_fields.is_empty() {
return String::new();
}
let field_specific_functions =
Self::generate_field_specific_enum_functions_static(&package_enum_fields);
let field_functions_tokens: proc_macro2::TokenStream = field_specific_functions
.parse()
.expect("Generated field-specific enum functions should be valid Rust syntax");
quote! {
pub mod enum_deserializer {
use super::*;
#field_functions_tokens
}
}
.to_string()
}
fn extract_package_enum_fields_static(
file_descriptor_set: &FileDescriptorSet,
target_package: &str,
) -> Vec<(String, String, String)> {
let mut enum_fields = Vec::new();
for file in &file_descriptor_set.file {
let package = file.package();
if package != target_package {
continue;
}
for message in &file.message_type {
Self::extract_enum_fields_from_message_static(message, &mut enum_fields);
}
}
enum_fields
}
fn extract_enum_fields_from_message_static(
message: &DescriptorProto,
enum_fields: &mut Vec<(String, String, String)>,
) {
Self::extract_enum_fields_from_message_with_path_static(message, enum_fields, "");
}
fn extract_enum_fields_from_message_with_path_static(
message: &DescriptorProto,
enum_fields: &mut Vec<(String, String, String)>,
message_path: &str,
) {
let message_name = message.name();
let current_path = if message_path.is_empty() {
message_name.to_snake_case()
} else {
format!("{}_{}", message_path, message_name.to_snake_case())
};
for field in &message.field {
if field.r#type() == Type::Enum {
let field_id = format!("{}_{}", current_path, field.name().to_snake_case());
let enum_type = field.type_name().trim_start_matches('.');
let enum_path = Self::resolve_enum_path(enum_type);
let field_label = match field.label() {
Label::Optional => {
if field.proto3_optional() {
"Option"
} else {
"Single"
}
}
Label::Required => "Single",
Label::Repeated => "Repeated",
};
enum_fields.push((field_id, enum_path, field_label.to_string()));
}
}
for nested_message in &message.nested_type {
Self::extract_enum_fields_from_message_with_path_static(
nested_message,
enum_fields,
¤t_path,
);
}
}
fn resolve_enum_path(enum_type: &str) -> String {
if !enum_type.contains('.') {
return enum_type.to_string();
}
let parts: Vec<&str> = enum_type.split('.').collect();
match parts.len() {
0 | 1 => parts.last().unwrap_or(&"UnknownEnum").to_string(),
2 => {
parts[1].to_string()
}
_ => {
let enum_name = parts[parts.len() - 1];
let mut message_parts = Vec::new();
let start_idx = 1;
for &part in &parts[start_idx..parts.len() - 1] {
if Self::is_message_name(part) {
message_parts.push(part.to_snake_case());
}
}
if message_parts.is_empty() {
enum_name.to_string()
} else {
format!("{}::{}", message_parts.join("::"), enum_name)
}
}
}
}
fn is_message_name(name: &str) -> bool {
name.chars().next().is_some_and(|c| c.is_uppercase())
}
fn generate_field_specific_enum_functions_static(
enum_fields: &[(String, String, String)],
) -> String {
let mut functions = String::new();
for (field_id, enum_name, field_label) in enum_fields {
let enum_ident: proc_macro2::TokenStream = enum_name
.parse()
.unwrap_or_else(|e| panic!("Invalid enum type path '{enum_name}': {e}"));
let function_code = match field_label.as_str() {
"Single" => Self::generate_single_enum_functions(field_id, &enum_ident),
"Option" => Self::generate_option_enum_functions(field_id, &enum_ident),
"Repeated" => Self::generate_repeated_enum_functions(field_id, &enum_ident),
_ => String::new(),
};
functions.push_str(&function_code);
}
functions
}
fn generate_single_enum_functions(
field_id: &str,
enum_ident: &proc_macro2::TokenStream,
) -> String {
let serialize_fn = quote::format_ident!("serialize_{}_as_string", field_id);
let deserialize_fn = quote::format_ident!("deserialize_{}_from_string", field_id);
quote! {
#[allow(dead_code)]
pub fn #serialize_fn<S>(value: &i32, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::Serialize;
if let Ok(enum_val) = #enum_ident::try_from(*value) {
enum_val.as_str_name().serialize(serializer)
} else {
value.serialize(serializer)
}
}
#[allow(dead_code)]
pub fn #deserialize_fn<'de, D>(deserializer: D) -> Result<i32, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
#[derive(Deserialize)]
#[serde(untagged)]
#[allow(dead_code)]
enum EnumOrString {
String(String),
Int(i32),
}
match EnumOrString::deserialize(deserializer)? {
EnumOrString::String(s) => {
if let Some(enum_val) = #enum_ident::from_str_name(&s) {
Ok(enum_val as i32)
} else {
Err(serde::de::Error::custom(format!("Unknown enum value for {}: {}", stringify!(#enum_ident), s)))
}
}
EnumOrString::Int(i) => Ok(i),
}
}
}.to_string()
}
fn generate_option_enum_functions(
field_id: &str,
enum_ident: &proc_macro2::TokenStream,
) -> String {
let serialize_fn = quote::format_ident!("serialize_option_{}_as_string", field_id);
let deserialize_fn = quote::format_ident!("deserialize_option_{}_from_string", field_id);
quote! {
#[allow(dead_code)]
pub fn #serialize_fn<S>(value: &Option<i32>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::Serialize;
match value {
Some(val) => {
if let Ok(enum_val) = #enum_ident::try_from(*val) {
Some(enum_val.as_str_name()).serialize(serializer)
} else {
Some(*val).serialize(serializer)
}
}
None => None::<&str>.serialize(serializer),
}
}
#[allow(dead_code)]
pub fn #deserialize_fn<'de, D>(deserializer: D) -> Result<Option<i32>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
#[derive(Deserialize)]
#[serde(untagged)]
#[allow(dead_code)]
enum OptionalEnumOrString {
String(String),
Int(i32),
None,
}
match Option::<OptionalEnumOrString>::deserialize(deserializer)? {
Some(OptionalEnumOrString::String(s)) => {
if let Some(enum_val) = #enum_ident::from_str_name(&s) {
Ok(Some(enum_val as i32))
} else {
Err(serde::de::Error::custom(format!("Unknown enum value for {}: {}", stringify!(#enum_ident), s)))
}
}
Some(OptionalEnumOrString::Int(i)) => Ok(Some(i)),
Some(OptionalEnumOrString::None) | None => Ok(None),
}
}
}.to_string()
}
fn generate_repeated_enum_functions(
field_id: &str,
enum_ident: &proc_macro2::TokenStream,
) -> String {
let serialize_fn = quote::format_ident!("serialize_repeated_{}_as_string", field_id);
let deserialize_fn = quote::format_ident!("deserialize_repeated_{}_from_string", field_id);
quote! {
#[allow(dead_code)]
pub fn #serialize_fn<S>(values: &[i32], serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::Serialize;
let string_values: Vec<_> = values.iter().map(|val| {
if let Ok(enum_val) = #enum_ident::try_from(*val) {
enum_val.as_str_name().to_string()
} else {
val.to_string()
}
}).collect();
string_values.serialize(serializer)
}
#[allow(dead_code)]
pub fn #deserialize_fn<'de, D>(deserializer: D) -> Result<Vec<i32>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
#[derive(Deserialize)]
#[serde(untagged)]
#[allow(dead_code)]
enum EnumOrStringItem {
String(String),
Int(i32),
}
let items: Vec<EnumOrStringItem> = Vec::deserialize(deserializer)?;
let mut result = Vec::with_capacity(items.len());
for item in items {
match item {
EnumOrStringItem::String(s) => {
if let Some(enum_val) = #enum_ident::from_str_name(&s) {
result.push(enum_val as i32);
} else {
return Err(serde::de::Error::custom(format!("Unknown enum value for {}: {}", stringify!(#enum_ident), s)));
}
}
EnumOrStringItem::Int(i) => {
result.push(i);
}
}
}
Ok(result)
}
}.to_string()
}
}
pub struct EnumConfig {
generator: BridgeGenerator,
}
impl EnumConfig {
pub fn new(generator: BridgeGenerator) -> Self {
Self { generator }
}
pub fn build_prost_config_with_descriptors(
self,
file_descriptor_set: &FileDescriptorSet,
) -> prost_build::Config {
let enable_string_enums = self.generator.enable_string_enums;
let mut config = self.generator.build_prost_config();
if enable_string_enums {
config = Self::add_enum_string_support_static(config, file_descriptor_set);
}
config = Self::add_skip_nulls_support_static(config, file_descriptor_set);
config
}
fn add_enum_string_support_static(
mut config: prost_build::Config,
file_descriptor_set: &FileDescriptorSet,
) -> prost_build::Config {
for file in &file_descriptor_set.file {
config = Self::process_file_descriptor_static(config, file);
}
config
}
fn process_file_descriptor_static(
mut config: prost_build::Config,
file: &FileDescriptorProto,
) -> prost_build::Config {
for message in &file.message_type {
let package = file.package();
config = Self::process_message_descriptor_static(config, message, package);
}
config
}
fn process_message_descriptor_static(
config: prost_build::Config,
message: &DescriptorProto,
package: &str,
) -> prost_build::Config {
Self::process_message_descriptor_with_path_static(config, message, package, "")
}
fn process_message_descriptor_with_path_static(
mut config: prost_build::Config,
message: &DescriptorProto,
package: &str,
message_path: &str,
) -> prost_build::Config {
let message_name = message.name();
let current_path = if message_path.is_empty() {
message_name.to_snake_case()
} else {
format!("{}_{}", message_path, message_name.to_snake_case())
};
let is_nested = !message_path.is_empty();
for field in &message.field {
if Self::is_enum_field_static(field) {
config = Self::add_enum_deserializer_with_path_static(
config,
¤t_path,
message_name,
field,
package,
is_nested,
);
}
}
for nested_message in &message.nested_type {
config = Self::process_message_descriptor_with_path_static(
config,
nested_message,
package,
¤t_path,
);
}
config
}
fn is_enum_field_static(field: &FieldDescriptorProto) -> bool {
field.r#type() == Type::Enum
}
fn add_enum_deserializer_with_path_static(
mut config: prost_build::Config,
message_path: &str,
message_name: &str,
field: &FieldDescriptorProto,
_package: &str,
is_nested: bool,
) -> prost_build::Config {
let field_path = format!("{}.{}", message_name, field.name());
let field_id = format!("{}_{}", message_path, field.name().to_snake_case());
let enum_deserializer_path = if is_nested {
"super::enum_deserializer"
} else {
"enum_deserializer"
};
let serde_attribute = match Self::get_field_label_static(field) {
FieldLabel::Optional => {
if field.proto3_optional() {
format!("#[serde(serialize_with = \"{enum_deserializer_path}::serialize_option_{field_id}_as_string\", deserialize_with = \"{enum_deserializer_path}::deserialize_option_{field_id}_from_string\", default)]")
} else {
format!("#[serde(serialize_with = \"{enum_deserializer_path}::serialize_{field_id}_as_string\", deserialize_with = \"{enum_deserializer_path}::deserialize_{field_id}_from_string\", default)]")
}
},
FieldLabel::Required => format!("#[serde(serialize_with = \"{enum_deserializer_path}::serialize_{field_id}_as_string\", deserialize_with = \"{enum_deserializer_path}::deserialize_{field_id}_from_string\")]"),
FieldLabel::Repeated => format!("#[serde(serialize_with = \"{enum_deserializer_path}::serialize_repeated_{field_id}_as_string\", deserialize_with = \"{enum_deserializer_path}::deserialize_repeated_{field_id}_from_string\", default)]"),
};
config.field_attribute(&field_path, &serde_attribute);
config
}
fn get_field_label_static(field: &FieldDescriptorProto) -> FieldLabel {
match field.label() {
Label::Optional => FieldLabel::Optional,
Label::Required => FieldLabel::Required,
Label::Repeated => FieldLabel::Repeated,
}
}
fn add_skip_nulls_support_static(
mut config: prost_build::Config,
file_descriptor_set: &FileDescriptorSet,
) -> prost_build::Config {
for file in &file_descriptor_set.file {
for message in &file.message_type {
config = Self::process_message_skip_nulls_recursive(config, message);
}
}
config
}
fn process_message_skip_nulls_recursive(
mut config: prost_build::Config,
message: &DescriptorProto,
) -> prost_build::Config {
let message_name = message.name();
for field in &message.field {
config = Self::add_skip_null_attribute_static(config, message_name, field);
}
for nested_message in &message.nested_type {
config = Self::process_message_skip_nulls_recursive(config, nested_message);
}
config
}
fn add_skip_null_attribute_static(
mut config: prost_build::Config,
message_name: &str,
field: &FieldDescriptorProto,
) -> prost_build::Config {
const SKIP_NONE: &str = "#[serde(skip_serializing_if = \"Option::is_none\")]";
const SKIP_EMPTY: &str = "#[serde(skip_serializing_if = \"String::is_empty\")]";
let field_path = format!("{}.{}", message_name, field.name());
let skip_attribute = if field.proto3_optional()
|| (field.label() == Label::Optional && field.r#type() == Type::Message)
{
Some(SKIP_NONE)
} else if field.r#type() == Type::String && field.label() != Label::Repeated {
Some(SKIP_EMPTY)
} else {
None
};
if let Some(attribute) = skip_attribute {
config.field_attribute(&field_path, attribute);
}
config
}
pub fn generate_enum_deserializer_code(
&self,
file_descriptor_set: &FileDescriptorSet,
) -> String {
Self::generate_enum_deserializer_code_static(file_descriptor_set)
}
fn generate_enum_deserializer_code_static(file_descriptor_set: &FileDescriptorSet) -> String {
let enum_types = Self::extract_all_enum_types_static(file_descriptor_set);
let enum_list_macro = Self::generate_enum_list_macro_static(&enum_types);
let enum_serializer_macro = Self::generate_enum_serializer_macro_static(&enum_types);
let single_deserializer = Self::generate_single_enum_deserializer_static();
let option_deserializer = Self::generate_option_enum_deserializer_static();
let repeated_deserializer = Self::generate_repeated_enum_deserializer_static();
let single_serializer = Self::generate_single_enum_serializer_static();
let option_serializer = Self::generate_option_enum_serializer_static();
let repeated_serializer = Self::generate_repeated_enum_serializer_static();
let enum_list_tokens: proc_macro2::TokenStream = enum_list_macro.parse().unwrap();
let enum_serializer_tokens: proc_macro2::TokenStream =
enum_serializer_macro.parse().unwrap();
let single_deserializer_tokens: proc_macro2::TokenStream =
single_deserializer.parse().unwrap();
let option_deserializer_tokens: proc_macro2::TokenStream =
option_deserializer.parse().unwrap();
let repeated_deserializer_tokens: proc_macro2::TokenStream =
repeated_deserializer.parse().unwrap();
let single_serializer_tokens: proc_macro2::TokenStream = single_serializer.parse().unwrap();
let option_serializer_tokens: proc_macro2::TokenStream = option_serializer.parse().unwrap();
let repeated_serializer_tokens: proc_macro2::TokenStream =
repeated_serializer.parse().unwrap();
quote! {
pub mod enum_deserializer {
use super::*;
#enum_list_tokens
#enum_serializer_tokens
#single_deserializer_tokens
#option_deserializer_tokens
#repeated_deserializer_tokens
#single_serializer_tokens
#option_serializer_tokens
#repeated_serializer_tokens
}
}
.to_string()
}
fn extract_all_enum_types_static(file_descriptor_set: &FileDescriptorSet) -> Vec<String> {
let mut enum_types = Vec::new();
for file in &file_descriptor_set.file {
for enum_desc in &file.enum_type {
let enum_name = enum_desc.name();
enum_types.push(enum_name.to_string());
}
for message in &file.message_type {
enum_types.extend(Self::extract_nested_enums_static(message, ""));
}
}
enum_types
}
fn extract_nested_enums_static(message: &DescriptorProto, module_path: &str) -> Vec<String> {
let mut enum_types = Vec::new();
let message_name = message.name();
let message_module = message_name.to_snake_case();
for enum_desc in &message.enum_type {
let enum_name = enum_desc.name();
enum_types.push(format!("{module_path}{message_module}::{enum_name}"));
}
for nested_message in &message.nested_type {
let nested_path = format!("{module_path}{message_module}::");
enum_types.extend(Self::extract_nested_enums_static(
nested_message,
&nested_path,
));
}
enum_types
}
fn generate_enum_list_macro_static(enum_types: &[String]) -> String {
let enum_idents: Vec<proc_macro2::TokenStream> = enum_types
.iter()
.map(|enum_type| {
enum_type
.parse()
.unwrap_or_else(|e| panic!("Invalid enum type path '{enum_type}': {e}"))
})
.collect();
quote! {
macro_rules! try_parse_all_enums {
($s:expr) => {
{
#(
if let Some(val) = #enum_idents::from_str_name($s) {
return Some(val as i32);
}
)*
None
}
};
}
}
.to_string()
}
fn generate_enum_serializer_macro_static(enum_types: &[String]) -> String {
let enum_idents: Vec<proc_macro2::TokenStream> = enum_types
.iter()
.map(|enum_type| {
enum_type
.parse()
.unwrap_or_else(|e| panic!("Invalid enum type path '{enum_type}': {e}"))
})
.collect();
quote! {
macro_rules! try_serialize_all_enums {
($value:expr) => {
{
#(
if let Ok(enum_val) = #enum_idents::try_from($value) {
return Some(enum_val.as_str_name());
}
)*
None
}
};
}
}
.to_string()
}
fn generate_single_enum_deserializer_static() -> String {
quote! {
#[allow(dead_code)]
pub fn deserialize_enum_from_string<'de, D>(deserializer: D) -> Result<i32, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
#[derive(Deserialize)]
#[serde(untagged)]
#[allow(dead_code)]
enum EnumOrString {
String(String),
Int(i32),
}
match EnumOrString::deserialize(deserializer)? {
EnumOrString::String(s) => {
fn try_parse_enum(s: &str) -> Option<i32> {
try_parse_all_enums!(s)
}
try_parse_enum(&s).ok_or_else(|| {
serde::de::Error::custom(format!("Unknown enum value: {}", s))
})
}
EnumOrString::Int(i) => Ok(i),
}
}
}
.to_string()
}
fn generate_option_enum_deserializer_static() -> String {
quote! {
#[allow(dead_code)]
pub fn deserialize_option_enum_from_string<'de, D>(deserializer: D) -> Result<Option<i32>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
#[derive(Deserialize)]
#[serde(untagged)]
#[allow(dead_code)]
enum OptionalEnumOrString {
String(String),
Int(i32),
None,
}
match Option::<OptionalEnumOrString>::deserialize(deserializer)? {
Some(OptionalEnumOrString::String(s)) => {
fn try_parse_enum(s: &str) -> Option<i32> {
try_parse_all_enums!(s)
}
try_parse_enum(&s)
.map(Some)
.ok_or_else(|| serde::de::Error::custom(format!("Unknown enum value: {}", s)))
}
Some(OptionalEnumOrString::Int(i)) => Ok(Some(i)),
Some(OptionalEnumOrString::None) | None => Ok(None),
}
}
}.to_string()
}
fn generate_repeated_enum_deserializer_static() -> String {
quote! {
#[allow(dead_code)]
pub fn deserialize_repeated_enum_from_string<'de, D>(deserializer: D) -> Result<Vec<i32>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
#[derive(Deserialize)]
#[serde(untagged)]
#[allow(dead_code)]
enum EnumOrStringItem {
String(String),
Int(i32),
}
let items: Vec<EnumOrStringItem> = Vec::deserialize(deserializer)?;
let mut result = Vec::with_capacity(items.len());
for item in items {
match item {
EnumOrStringItem::String(s) => {
fn try_parse_enum(s: &str) -> Option<i32> {
try_parse_all_enums!(s)
}
if let Some(enum_val) = try_parse_enum(&s) {
result.push(enum_val);
} else {
return Err(serde::de::Error::custom(format!("Unknown enum value: {}", s)));
}
}
EnumOrStringItem::Int(i) => {
result.push(i);
}
}
}
Ok(result)
}
}.to_string()
}
fn generate_single_enum_serializer_static() -> String {
quote! {
#[allow(dead_code)]
pub fn serialize_enum_as_string<S>(value: &i32, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::Serialize;
fn try_enum_to_string(value: i32) -> Option<&'static str> {
try_serialize_all_enums!(value)
}
if let Some(enum_str) = try_enum_to_string(*value) {
enum_str.serialize(serializer)
} else {
value.serialize(serializer)
}
}
}.to_string()
}
fn generate_option_enum_serializer_static() -> String {
quote! {
#[allow(dead_code)]
pub fn serialize_option_enum_as_string<S>(value: &Option<i32>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::Serialize;
fn try_enum_to_string(value: i32) -> Option<&'static str> {
try_serialize_all_enums!(value)
}
match value {
Some(val) => {
if let Some(enum_str) = try_enum_to_string(*val) {
Some(enum_str).serialize(serializer)
} else {
Some(*val).serialize(serializer)
}
}
None => None::<&str>.serialize(serializer),
}
}
}.to_string()
}
fn generate_repeated_enum_serializer_static() -> String {
quote! {
#[allow(dead_code)]
pub fn serialize_repeated_enum_as_string<S>(values: &[i32], serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::Serialize;
fn try_enum_to_string(value: i32) -> Option<&'static str> {
try_serialize_all_enums!(value)
}
let string_values: Vec<_> = values.iter().map(|val| {
if let Some(enum_str) = try_enum_to_string(*val) {
enum_str.to_string()
} else {
val.to_string()
}
}).collect();
string_values.serialize(serializer)
}
}.to_string()
}
}
#[derive(Debug)]
enum FieldLabel {
Optional,
Required,
Repeated,
}
impl prost_build::ServiceGenerator for BridgeGenerator {
fn generate(&mut self, service: prost_build::Service, buf: &mut String) {
self.inner.generate(service.clone(), buf);
let package = &service.package;
let name = &service.proto_name;
let func_name = service.name.to_string();
let ident_func_name = quote::format_ident!("{}", func_name);
let branch_names = service
.methods
.iter()
.map(|method| format!("/{package}.{name}/{}", method.proto_name))
.collect::<Vec<_>>();
let func_names = service
.methods
.iter()
.map(|method| quote::format_ident!("{}", method.name))
.collect::<Vec<_>>();
let branch_request = service
.methods
.iter()
.map(|method| quote::format_ident!("{}", method.input_type.trim_matches('"')))
.collect::<Vec<_>>();
#[cfg(feature = "doc")]
let branch_response = service
.methods
.iter()
.map(|method| quote::format_ident!("{}", method.output_type.trim_matches('"')))
.collect::<Vec<_>>();
let snake_case_name = func_name.to_snake_case();
let service_name = quote::format_ident!("{}_handler", snake_case_name);
let server_module = quote::format_ident!("{}_server", snake_case_name);
#[cfg(feature = "doc")]
let docs = quote! {
#[doc = "Axum Router for handling the gRPC service. This router is generated with the [`prost-build`] crate. This builds a web router on top of the gRPC service."]
#[doc = ""]
#[doc = ::std::concat!("Package: `", stringify!(#package), "`")]
#[doc = ""]
#[doc = ::std::concat!("Name: `", stringify!(#name), "`")]
#[doc = ""]
#[doc = "Routes:"]
#(
#[doc = ::std::concat!("- `", stringify!(#func_names), "` `::` [`", stringify!(#branch_request), "`]` -> `[`", stringify!(#branch_response), "`]")]
)*
};
#[cfg(not(feature = "doc"))]
let docs = quote! {};
let output = quote! {
#[allow(dead_code)]
#docs
pub fn #service_name<T: #server_module::#ident_func_name>(server: T) -> ::axum::Router {
use ::axum::extract::State;
use ::axum::response::IntoResponse;
use std::sync::Arc;
let router = ::axum::Router::new();
#(
let router = router.route(#branch_names, ::axum::routing::post(|State(state): State<Arc<T>>, extension: ::http::Extensions, headers: ::http::header::HeaderMap, ::axum::Json(body): ::axum::Json<#branch_request>| async move {
let metadata_map = ::tonic::metadata::MetadataMap::from_headers(headers);
let request = ::tonic::Request::from_parts(metadata_map, extension, body);
let output = <T as #server_module::#ident_func_name>::#func_names(&state, request).await;
match output {
Ok(response) => {
let (metadata_map, body, extension) = response.into_parts();
let headers = metadata_map.into_headers();
let body = ::axum::Json(body);
(headers, extension, body).into_response()
},
Err(status) => {
let code = match status.code() {
::tonic::Code::Ok => ::http::StatusCode::OK,
::tonic::Code::InvalidArgument => ::http::StatusCode::BAD_REQUEST,
::tonic::Code::NotFound => ::http::StatusCode::NOT_FOUND,
::tonic::Code::AlreadyExists | ::tonic::Code::Aborted => ::http::StatusCode::CONFLICT,
::tonic::Code::PermissionDenied => ::http::StatusCode::FORBIDDEN,
::tonic::Code::Unauthenticated => ::http::StatusCode::UNAUTHORIZED,
::tonic::Code::ResourceExhausted => ::http::StatusCode::TOO_MANY_REQUESTS,
::tonic::Code::FailedPrecondition => ::http::StatusCode::PRECONDITION_FAILED,
::tonic::Code::Unimplemented => ::http::StatusCode::NOT_IMPLEMENTED,
::tonic::Code::Unavailable => ::http::StatusCode::SERVICE_UNAVAILABLE,
::tonic::Code::DeadlineExceeded | ::tonic::Code::Cancelled => ::http::StatusCode::REQUEST_TIMEOUT,
::tonic::Code::OutOfRange => ::http::StatusCode::RANGE_NOT_SATISFIABLE,
_ => ::http::StatusCode::INTERNAL_SERVER_ERROR,
};
let error_body = ErrorResponse {
error: ErrorDetails {
code: status.code().to_string(),
message: status.message().to_string(),
}
};
let body = ::axum::Json(error_body);
(code, body).into_response()
}
}
}));
)*
router.with_state(Arc::new(server))
}
};
buf.push_str(&output.to_string());
}
fn finalize(&mut self, buf: &mut String) {
self.inner.finalize(buf);
}
fn finalize_package(&mut self, package: &str, buf: &mut String) {
self.inner.finalize_package(package, buf);
let error_structs = quote! {
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ErrorResponse {
pub error: ErrorDetails,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ErrorDetails {
pub code: String,
pub message: String,
}
};
buf.push('\n');
buf.push_str(&error_structs.to_string());
if self.enable_string_enums {
if let Some(ref file_descriptor_set) = self.file_descriptor_set {
let enum_deserializer_code = Self::generate_package_specific_enum_deserializer_code(
file_descriptor_set,
package,
);
if !enum_deserializer_code.trim().is_empty() {
buf.push('\n');
buf.push_str(&enum_deserializer_code);
}
}
}
}
}