use crate::app::extract_app_meta;
use crate::context::partition_context_params;
use crate::server_attrs::{has_server_hidden, has_server_skip, validate_server_attrs};
use heck::{ToSnakeCase, ToUpperCamelCase};
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use server_less_parse::{
MethodInfo, ParamInfo, ReturnInfo, extract_methods, get_impl_name, unwrap_option_type,
unwrap_result_ok_type, unwrap_vec_type,
};
use syn::{ItemImpl, Token, parse::Parse};
#[derive(Default)]
pub(crate) struct ThriftArgs {
namespace: Option<String>,
schema: Option<String>,
}
impl Parse for ThriftArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut args = ThriftArgs::default();
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
input.parse::<Token![=]>()?;
match ident.to_string().as_str() {
"namespace" => {
let lit: syn::LitStr = input.parse()?;
args.namespace = Some(lit.value());
}
"schema" => {
let lit: syn::LitStr = input.parse()?;
args.schema = Some(lit.value());
}
other => {
const VALID: &[&str] = &["namespace", "schema"];
let suggestion = crate::did_you_mean(other, VALID)
.map(|s| format!(" — did you mean `{s}`?"))
.unwrap_or_default();
return Err(syn::Error::new(
ident.span(),
format!(
"unknown argument `{other}`{suggestion}. Valid arguments: namespace, schema"
),
));
}
}
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
}
Ok(args)
}
}
pub(crate) fn expand_thrift(args: ThriftArgs, mut impl_block: ItemImpl) -> syn::Result<TokenStream2> {
crate::reject_generic_impl(&impl_block)?;
let app_meta = extract_app_meta(&mut impl_block.attrs);
let struct_name = get_impl_name(&impl_block)?;
let (impl_generics, _ty_generics, where_clause) = impl_block.generics.split_for_impl();
let self_ty = &impl_block.self_ty;
let struct_name_str = struct_name.to_string();
let all_methods = extract_methods(&impl_block)?;
for m in &all_methods {
validate_server_attrs(m)?;
}
let methods: Vec<_> = all_methods
.into_iter()
.filter(|m| !has_server_skip(m) && !has_server_hidden(m))
.collect();
let namespace = args
.namespace
.or_else(|| app_meta.name.map(|n| n.to_snake_case()))
.unwrap_or_else(|| struct_name_str.to_snake_case());
let service_methods: Vec<String> = methods
.iter()
.enumerate()
.map(|(i, m)| generate_thrift_method(m, i + 1))
.collect();
let structs: Vec<String> = methods.iter().flat_map(generate_thrift_structs).collect();
let thrift_schema = format!(
r#"namespace rs {namespace}
service {service_name} {{
{methods}
}}
{structs}
"#,
namespace = namespace,
service_name = struct_name_str,
methods = service_methods.join("\n"),
structs = structs.join("\n")
);
let validation_method = if let Some(schema_path) = &args.schema {
quote! {
pub fn validate_schema() -> Result<(), ::server_less::SchemaValidationError> {
let expected = include_str!(#schema_path);
let generated = Self::thrift_schema();
fn normalize(s: &str) -> Vec<String> {
s.lines()
.map(|l| l.trim().to_string())
.filter(|l| !l.is_empty() && !l.starts_with('#') && !l.starts_with("//"))
.collect()
}
let expected_lines = normalize(expected);
let generated_lines = normalize(generated);
let mut error = ::server_less::SchemaValidationError::new("Thrift");
for line in &expected_lines {
if !generated_lines.contains(line) {
error.add_missing(line.clone());
}
}
for line in &generated_lines {
if !expected_lines.contains(line) {
error.add_extra(line.clone());
}
}
if error.has_differences() {
Err(error)
} else {
Ok(())
}
}
pub fn assert_schema_matches() {
if let Err(err) = Self::validate_schema() {
panic!("{}", err);
}
}
}
} else {
quote! {}
};
let maybe_impl = if crate::is_protocol_impl_emitter(&impl_block, "thrift") {
quote! { #impl_block }
} else {
quote! {}
};
Ok(quote! {
#maybe_impl
impl #impl_generics #self_ty #where_clause {
pub fn thrift_schema() -> &'static str {
#thrift_schema
}
pub fn write_thrift(path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
std::fs::write(path, Self::thrift_schema())
}
#validation_method
}
})
}
fn generate_thrift_method(method: &MethodInfo, index: usize) -> String {
let method_name = method.name_str().to_snake_case();
let args_name = format!("{}Args", method.name_str().to_upper_camel_case());
let result_type = get_thrift_return_type(&method.return_info);
let doc = method
.docs
.as_ref()
.map(|d| format!(" // {}\n", d))
.unwrap_or_default();
format!(
"{} {} {}({} args) = {};",
doc, result_type, method_name, args_name, index
)
}
fn get_thrift_return_type(ret: &ReturnInfo) -> String {
if ret.is_unit {
"void".to_string()
} else {
rust_type_to_thrift(&ret.ty)
}
}
fn generate_thrift_structs(method: &MethodInfo) -> Vec<String> {
let method_upper = method.name_str().to_upper_camel_case();
let args_name = format!("{}Args", method_upper);
let (_, schema_params) = partition_context_params(&method.params).unwrap_or((None, method.params.iter().collect()));
let arg_fields: Vec<String> = schema_params
.iter()
.enumerate()
.map(|(i, p)| generate_thrift_field(p, i + 1))
.collect();
let args_struct = format!("struct {} {{\n{}\n}}", args_name, arg_fields.join("\n"));
vec![args_struct]
}
fn generate_thrift_field(param: &ParamInfo, index: usize) -> String {
let name = param.name_str().to_snake_case();
let ty = if let Some(inner) = unwrap_option_type(¶m.ty) {
inner.clone()
} else {
param.ty.clone()
};
let thrift_type = rust_type_to_thrift(&Some(ty));
let optional = if param.is_optional { "optional " } else { "" };
format!(" {}: {}{} {};", index, optional, thrift_type, name)
}
fn rust_type_to_thrift(ty: &Option<syn::Type>) -> String {
let Some(ty) = ty else {
return "void".to_string();
};
rust_type_to_thrift_ty(ty)
}
fn rust_type_to_thrift_ty(ty: &syn::Type) -> String {
if let Some(ok) = unwrap_result_ok_type(ty) {
return rust_type_to_thrift_ty(ok);
}
if let Some(inner) = unwrap_option_type(ty) {
return rust_type_to_thrift_ty(inner);
}
if let Some(inner) = unwrap_vec_type(ty) {
if let syn::Type::Path(tp) = inner
&& tp.path.segments.last().map(|s| s.ident == "u8").unwrap_or(false)
{
return "binary".to_string();
}
return format!("list<{}>", rust_type_to_thrift_ty(inner));
}
if let syn::Type::Slice(ts) = ty
&& let syn::Type::Path(tp) = &*ts.elem
&& tp.path.segments.last().map(|s| s.ident == "u8").unwrap_or(false)
{
return "binary".to_string();
}
let ident = if let syn::Type::Path(tp) = ty {
tp.path.segments.last().map(|s| s.ident.to_string())
} else {
None
};
match ident.as_deref() {
Some("HashMap") | Some("BTreeMap") => "map<string, string>".to_string(), Some("HashSet") | Some("BTreeSet") => "set<string>".to_string(), Some("String") | Some("str") => "string".to_string(),
Some("bool") => "bool".to_string(),
Some("i8") => "byte".to_string(),
Some("i16") => "i16".to_string(),
Some("i32") => "i32".to_string(),
Some("i64") => "i64".to_string(),
Some("f64") => "double".to_string(),
_ => "binary".to_string(), }
}