use convert_case::{Case, Casing};
use proc_macro2::TokenStream;
use quote::quote;
use crate::utils::extract_simple_type_name;
#[derive(Debug, Clone, Copy)]
pub enum RenderContext {
Constructor,
Extractor,
QueryExtractor,
Parameter,
ReturnType,
FieldType,
BuilderMethod,
PythonParameter,
NapiParameter,
}
pub trait TypeRenderer {
fn render(&self, ty: &UnifiedType) -> String;
}
pub struct RustRenderer(pub RenderContext);
impl TypeRenderer for RustRenderer {
fn render(&self, ty: &UnifiedType) -> String {
unified_to_rust(ty, self.0)
}
}
pub struct PythonRenderer;
impl TypeRenderer for PythonRenderer {
fn render(&self, ty: &UnifiedType) -> String {
unified_to_python_type(ty)
}
}
pub struct NapiRenderer;
impl TypeRenderer for NapiRenderer {
fn render(&self, ty: &UnifiedType) -> String {
unified_to_napi(ty)
}
}
pub struct TypeScriptRenderer;
impl TypeRenderer for TypeScriptRenderer {
fn render(&self, ty: &UnifiedType) -> String {
unified_to_typescript(ty)
}
}
pub fn unified_to_rust(unified_type: &UnifiedType, context: RenderContext) -> String {
let base_type_str = match &unified_type.base_type {
BaseType::String => {
if matches!(context, RenderContext::Constructor) && !unified_type.is_optional {
"impl Into<String>".to_string()
} else {
"String".to_string()
}
}
BaseType::Int32 => "i32".to_string(),
BaseType::Int64 => "i64".to_string(),
BaseType::Bool => "bool".to_string(),
BaseType::Float64 => "f64".to_string(),
BaseType::Float32 => "f32".to_string(),
BaseType::Bytes => "Vec<u8>".to_string(),
BaseType::Unit => "()".to_string(),
BaseType::Message(name) => extract_simple_type_name(name),
BaseType::Enum(name) => {
if matches!(
context,
RenderContext::Extractor | RenderContext::NapiParameter
) {
"i32".to_string()
} else {
convert_protobuf_enum_to_rust_type(&format!("TYPE_ENUM:{}", name))
}
}
BaseType::OneOf(name) => extract_simple_type_name(name),
BaseType::Map(key_type, value_type) => {
if matches!(context, RenderContext::Constructor) && !unified_type.is_optional {
"impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>".to_string()
} else {
let key_str = unified_to_rust(key_type, context);
let value_str = unified_to_rust(value_type, context);
format!("HashMap<{}, {}>", key_str, value_str)
}
}
};
let mut result = base_type_str;
if unified_type.is_repeated && !matches!(context, RenderContext::BuilderMethod) {
result = format!("Vec<{}>", result);
}
if should_wrap_in_option(context, unified_type) {
result = format!("Option<{}>", result);
}
result
}
fn should_wrap_in_option(ctx: RenderContext, ty: &UnifiedType) -> bool {
let is_optional_non_builder = ty.is_optional && !matches!(ctx, RenderContext::BuilderMethod);
let is_ffi_collection = matches!(
ctx,
RenderContext::PythonParameter | RenderContext::NapiParameter
) && (matches!(ty.base_type, BaseType::Map(_, _)) || ty.is_repeated);
is_optional_non_builder || is_ffi_collection
}
pub fn field_assignment(
unified_type: &UnifiedType,
field_ident: &proc_macro2::Ident,
ctx: &RenderContext,
) -> TokenStream {
if matches!(ctx, RenderContext::BuilderMethod) {
return flexible_optional_field_assignment(unified_type, field_ident);
}
match &unified_type.base_type {
BaseType::String if !unified_type.is_optional => quote! { #field_ident.into() },
BaseType::Enum(_) => {
if unified_type.is_repeated {
quote! { #field_ident.into_iter().map(|v| v as i32).collect() }
} else {
quote! { #field_ident as i32 }
}
}
BaseType::Map(_, _) => quote! {
#field_ident.into_iter().map(|(k, v)| (k.into(), v.into())).collect()
},
_ => quote! { #field_ident },
}
}
pub fn unified_to_python_type(unified_type: &UnifiedType) -> String {
let base_type_str = match &unified_type.base_type {
BaseType::String => "str".to_string(),
BaseType::Int32 | BaseType::Int64 => "int".to_string(),
BaseType::Bool => "bool".to_string(),
BaseType::Float64 | BaseType::Float32 => "float".to_string(),
BaseType::Bytes => "bytes".to_string(),
BaseType::Unit => "None".to_string(),
BaseType::Message(name) => extract_simple_type_name(name),
BaseType::Enum(name) => extract_simple_type_name(name),
BaseType::OneOf(name) => extract_simple_type_name(name),
BaseType::Map(key_type, value_type) => {
let key_str = unified_to_python_type(key_type);
let value_str = unified_to_python_type(value_type);
format!("Dict[{}, {}]", key_str, value_str)
}
};
let mut result = base_type_str;
if unified_type.is_repeated {
result = format!("List[{}]", result);
}
if unified_type.is_optional {
result = format!("Optional[{}]", result);
}
result
}
pub fn unified_to_napi(unified_type: &UnifiedType) -> String {
let base_type_str = match &unified_type.base_type {
BaseType::String => "String".to_string(),
BaseType::Int32 => "i32".to_string(),
BaseType::Int64 => "i64".to_string(),
BaseType::Bool => "bool".to_string(),
BaseType::Float64 => "f64".to_string(),
BaseType::Float32 => "f32".to_string(),
BaseType::Bytes => "Vec<u8>".to_string(),
BaseType::Unit => "()".to_string(),
BaseType::Message(name) => extract_simple_type_name(name),
BaseType::Enum(name) => convert_protobuf_enum_to_rust_type(&format!("TYPE_ENUM:{}", name)),
BaseType::OneOf(name) => extract_simple_type_name(name),
BaseType::Map(key_type, value_type) => {
let key_str = unified_to_napi(key_type);
let value_str = unified_to_napi(value_type);
format!("HashMap<{}, {}>", key_str, value_str)
}
};
let mut result = base_type_str;
if unified_type.is_repeated {
result = format!("Vec<{}>", result);
}
if unified_type.is_optional
|| matches!(unified_type.base_type, BaseType::Map(_, _))
|| unified_type.is_repeated
{
result = format!("Option<{}>", result);
}
result
}
pub fn unified_to_typescript(unified_type: &UnifiedType) -> String {
let base_type_str = match &unified_type.base_type {
BaseType::String => "string".to_string(),
BaseType::Int32 | BaseType::Int64 => "number".to_string(),
BaseType::Bool => "boolean".to_string(),
BaseType::Float64 | BaseType::Float32 => "number".to_string(),
BaseType::Bytes => "Uint8Array".to_string(),
BaseType::Unit => "void".to_string(),
BaseType::Message(name) => extract_simple_type_name(name),
BaseType::Enum(_) => "number".to_string(),
BaseType::OneOf(name) => extract_simple_type_name(name),
BaseType::Map(key_type, value_type) => {
let key_str = unified_to_typescript(key_type);
let value_str = unified_to_typescript(value_type);
format!("Record<{}, {}>", key_str, value_str)
}
};
let mut result = base_type_str;
if unified_type.is_repeated {
result = format!("{}[]", result);
}
if unified_type.is_optional {
result = format!("{} | undefined", result);
}
result
}
fn flexible_optional_field_assignment(
unified_type: &UnifiedType,
field_ident: &proc_macro2::Ident,
) -> TokenStream {
if unified_type.is_optional {
match &unified_type.base_type {
BaseType::Enum(_) => quote! { #field_ident.into().map(|e| e as i32) },
_ => quote! { #field_ident.into() },
}
} else {
match &unified_type.base_type {
BaseType::String => quote! { #field_ident.into() },
BaseType::Int32
| BaseType::Int64
| BaseType::Bool
| BaseType::Float64
| BaseType::Float32 => {
quote! { #field_ident.into() }
}
BaseType::Enum(_) => {
quote! { #field_ident as i32 }
}
_ => quote! { #field_ident },
}
}
}
fn convert_protobuf_enum_to_rust_type(proto_type: &str) -> String {
if let Some(enum_name) = proto_type.strip_prefix("TYPE_ENUM:") {
let enum_name = enum_name.trim_start_matches('.');
if let Some(last_dot) = enum_name.rfind('.') {
let parent_part = &enum_name[..last_dot];
let enum_simple_name = &enum_name[last_dot + 1..];
let parent_parts: Vec<&str> = parent_part.split('.').collect();
if let Some(last_part) = parent_parts.last() {
if last_part.chars().next().is_some_and(|c| c.is_uppercase())
&& *last_part != "V1"
&& !last_part
.chars()
.all(|c| c.is_lowercase() || c.is_numeric())
{
let snake_case_module = last_part.to_case(Case::Snake);
let enum_rust_name = if enum_simple_name.contains('_')
&& enum_simple_name
.chars()
.all(|c| c.is_uppercase() || c == '_')
{
enum_simple_name.to_case(Case::Pascal)
} else {
enum_simple_name.to_string()
};
format!("{}::{}", snake_case_module, enum_rust_name)
} else {
convert_enum_name_to_rust(enum_simple_name)
}
} else {
convert_enum_name_to_rust(enum_simple_name)
}
} else {
convert_enum_name_to_rust(enum_name)
}
} else {
"i32".to_string() }
}
fn convert_enum_name_to_rust(enum_name: &str) -> String {
if enum_name.contains('_') && enum_name.chars().all(|c| c.is_uppercase() || c == '_') {
enum_name.to_case(Case::Pascal)
} else {
enum_name.to_string()
}
}
#[derive(Debug, Clone)]
pub struct UnifiedType {
pub base_type: BaseType,
pub is_optional: bool,
pub is_repeated: bool,
}
#[derive(Debug, Clone)]
pub enum BaseType {
String,
Int32,
Int64,
Bool,
Float64,
Float32,
Bytes,
Message(String),
Enum(String),
OneOf(String),
Map(Box<UnifiedType>, Box<UnifiedType>),
Unit,
}
impl UnifiedType {
pub fn type_ident(&self) -> syn::Ident {
let name = match &self.base_type {
BaseType::Message(n) | BaseType::Enum(n) | BaseType::OneOf(n) => {
n.split('.').next_back().unwrap_or(n)
}
BaseType::String => "String",
BaseType::Int32 => "i32",
BaseType::Int64 => "i64",
BaseType::Bool => "bool",
BaseType::Float64 => "f64",
BaseType::Float32 => "f32",
BaseType::Bytes => "Bytes",
BaseType::Unit => "()",
BaseType::Map(_, _) => "HashMap",
};
quote::format_ident!("{}", name)
}
pub fn string() -> Self {
Self {
base_type: BaseType::String,
is_optional: false,
is_repeated: false,
}
}
pub fn optional(mut self) -> Self {
self.is_optional = true;
self
}
pub fn repeated(mut self) -> Self {
self.is_repeated = true;
self
}
pub fn map(key: UnifiedType, value: UnifiedType) -> Self {
Self {
base_type: BaseType::Map(Box::new(key), Box::new(value)),
is_optional: false,
is_repeated: false,
}
}
}