use proc_macro::TokenStream;
use quote::quote;
use std::collections::HashSet;
use proc_macro2::Ident;
use syn::punctuated::Punctuated;
use syn::{parse_macro_input, Error, Expr, FnArg, GenericArgument, ItemTrait, Lit, LitStr, Meta, MetaNameValue, Pat, PathArguments, ReturnType, Token, TraitItem, TraitItemFn, Type};
fn parse_feign_client_args(args: &Punctuated<Meta, Token![,]>) -> Result<(String, Option<String>, Option<String>), Error> {
let mut service_id = None;
let mut base_path = None;
let mut url = None;
for meta in args {
match meta {
Meta::NameValue(meta) => {
let value = parse_string_value(&meta.value)?;
if meta.path.is_ident("service_id") {
service_id = Some(value);
} else if meta.path.is_ident("base_path") {
base_path = Some(value);
} else if meta.path.is_ident("url") {
url = Some(value);
}
}
_ => return Err(Error::new_spanned(
meta,
"Expected format: #[feign_client(service_id = \"...\", base_path = \"...\", url = \"...\")]",
)),
}
}
service_id.ok_or_else(|| {
Error::new(proc_macro2::Span::call_site(), "Missing required 'service_id'")
}).map(|sid| (sid, base_path, url))
}
fn parse_string_value(expr: &Expr) -> Result<String, Error> {
match expr {
Expr::Lit(expr_lit) => {
if let Lit::Str(lit) = &expr_lit.lit {
Ok(lit.value())
} else {
Err(Error::new_spanned(expr, "Value must be a string literal"))
}
}
_ => Err(Error::new_spanned(expr, "Value must be a string literal")),
}
}
pub fn feign_client_impl(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args with Punctuated::<Meta, Token![,]>::parse_terminated);
let trait_def = parse_macro_input!(input as ItemTrait);
let (service_id, base_path, url) = match parse_feign_client_args(&args) {
Ok(result) => result,
Err(e) => return e.to_compile_error().into(),
};
let trait_name = &trait_def.ident;
let impl_struct_name = Ident::new(
&format!("{}Impl", trait_name),
trait_name.span(),
);
let methods: Vec<&TraitItemFn> = trait_def
.items
.iter()
.filter_map(|item| {
if let TraitItem::Fn(method) = item {
Some(method)
} else {
None
}
})
.collect();
let method_impls = generate_method_impls(&methods, &service_id, &base_path, &url);
let expanded = quote! {
#trait_def
pub struct #impl_struct_name {
lb_client: conreg_client::lb::LoadBalanceClient,
}
impl #impl_struct_name {
pub fn new(lb_client: conreg_client::lb::LoadBalanceClient) -> Self {
Self { lb_client }
}
pub fn with_timeout(timeout: std::time::Duration) -> Self {
Self {
lb_client: conreg_client::lb::LoadBalanceClient::new_with_connect_timeout(timeout),
}
}
}
impl Default for #impl_struct_name {
fn default() -> Self {
Self::new(conreg_client::lb::LoadBalanceClient::new())
}
}
impl #trait_name for #impl_struct_name {
#method_impls
}
};
TokenStream::from(expanded)
}
fn generate_method_impls(
methods: &[&TraitItemFn],
service_id: &str,
base_path: &Option<String>,
url: &Option<String>,
) -> proc_macro2::TokenStream {
methods
.iter()
.filter_map(|method| {
extract_http_method_and_path(method).map(|(http_method, path)| {
generate_single_method_impl(
method,
&http_method,
&path,
service_id,
base_path,
url,
)
})
})
.collect::<proc_macro2::TokenStream>()
}
fn extract_http_method_and_path(method: &TraitItemFn) -> Option<(String, String)> {
for attr in &method.attrs {
if let Some(ident) = attr.path().get_ident()
&& let http_method @ ("get" | "post" | "put" | "delete" | "patch") = ident.to_string().as_str() {
let args_result = attr.parse_args_with(
Punctuated::<Meta, Token![,]>::parse_terminated
);
match args_result {
Ok(args) if args.is_empty() => {
return Some((http_method.to_uppercase(), String::new()));
}
Err(_) if attr.meta.require_list().is_err() => {
return Some((http_method.to_uppercase(), String::new()));
}
Ok(_) => {}
Err(_) => {}
}
if let Ok(lit_str) = attr.parse_args::<LitStr>() {
return Some((http_method.to_uppercase(), lit_str.value()));
}
return args_result.ok().map(|args| {
let mut path_value: Option<String> = None;
for meta in args {
if let Meta::NameValue(name_value) = meta
&& name_value.path.is_ident("path")
&& let Expr::Lit(expr_lit) = &name_value.value
&& let Lit::Str(lit_str) = &expr_lit.lit {
path_value = Some(lit_str.value());
break;
}
}
(http_method.to_uppercase(), path_value.unwrap_or_default())
});
}
}
None
}
fn generate_single_method_impl(
method: &TraitItemFn,
http_method: &str,
path: &str,
service_id: &str,
base_path: &Option<String>,
url: &Option<String>,
) -> proc_macro2::TokenStream {
let method_name = &method.sig.ident;
let asyncness = &method.sig.asyncness;
let output = &method.sig.output;
let full_path = if let Some(base) = base_path {
format!("{}{}", base, path)
} else {
path.to_string()
};
let param_analysis = analyze_parameters(method, &full_path);
let url_building = generate_url_building(¶m_analysis, &full_path, service_id, url);
let request_building = generate_request_building(http_method, ¶m_analysis);
let parse_response = generate_response_parsing(output);
let params: Vec<_> = method
.sig
.inputs
.iter()
.filter_map(|input| {
if let FnArg::Typed(pat_type) = input {
Some(quote! { #pat_type })
} else {
None
}
})
.collect();
quote! {
#asyncness fn #method_name(&self, #(#params),*) #output {
use reqwest::Method;
let url = #url_building;
let response = #request_building;
#parse_response
Ok(result)
}
}
}
#[derive(Debug, Default)]
struct ParamAnalysis {
path_params: Vec<(Ident, Type)>,
query_params: Vec<(Ident, Type)>,
form_params: Vec<(Ident, Type)>,
body_param: Option<(Ident, Type)>,
json_param: Option<(Ident, Type)>,
header_params: Vec<(String, String)>,
has_multipart_form: bool,
}
fn extract_params_from_template(template: &str) -> Vec<String> {
let mut params = Vec::new();
let mut chars = template.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '{' {
let mut param_name = String::new();
while let Some(&next_ch) = chars.peek() {
if next_ch == '}' {
chars.next(); break;
} else {
param_name.push(chars.next().unwrap());
}
}
if !param_name.is_empty() {
params.push(param_name);
}
}
}
params
}
fn parse_header_template(header: &str) -> Option<(String, Option<String>)> {
header.find('=').map(|eq_pos| {
let header_name = header[..eq_pos].trim().to_string();
let value_part = header[eq_pos + 1..].trim();
if value_part.starts_with('{') && value_part.ends_with('}') {
let param_name = value_part[1..value_part.len() - 1].to_string();
(header_name, Some(param_name))
} else {
(header_name, Some(value_part.to_string()))
}
})
}
fn analyze_parameters(method: &TraitItemFn, path: &str) -> ParamAnalysis {
let mut analysis = ParamAnalysis::default();
let mut query_template_params = Vec::new();
let mut form_template_params = Vec::new();
let mut header_templates = Vec::new();
let mut body_template = None;
let mut json_template = None;
for attr in &method.attrs {
if let Some(ident) = attr.path().get_ident()
&& matches!(ident.to_string().as_str(), "get" | "post" | "put" | "delete" | "patch")
&& let Ok(args) = attr.parse_args_with(
Punctuated::<Meta, Token![,]>::parse_terminated
) {
for meta in args {
match meta {
Meta::NameValue(ref name_value) => {
extract_param_from_meta(
name_value,
&mut query_template_params,
&mut form_template_params,
&mut header_templates,
&mut body_template,
&mut json_template,
);
}
Meta::List(ref meta_list) => {
if meta_list.path.is_ident("headers")
&& let Ok(nested) = meta_list.parse_args_with(
Punctuated::<LitStr, Token![,]>::parse_terminated
) {
for lit_str in nested {
header_templates.push(lit_str.value());
}
}
}
_ => {}
}
}
}
}
let header_param_names: HashSet<_> = header_templates
.iter()
.filter_map(|h| parse_header_template(h).and_then(|(_, p)| p))
.collect();
for input in &method.sig.inputs {
if let FnArg::Typed(pat_type) = input
&& let Pat::Ident(pat_ident) = &*pat_type.pat {
let param_name = pat_ident.ident.to_string();
let param_type = &pat_type.ty;
if header_param_names.contains(¶m_name) {
continue;
}
let param_tuple = (pat_ident.ident.clone(), *param_type.clone());
if is_multipart_form_type(param_type) {
analysis.has_multipart_form = true;
analysis.form_params.push(param_tuple);
} else if path.contains(&format!("{{{}}}", param_name)) {
analysis.path_params.push(param_tuple);
} else if query_template_params.contains(¶m_name) {
analysis.query_params.push(param_tuple);
} else if let Some(ref body_template) = body_template && body_template.contains(¶m_name) {
analysis.body_param = Some(param_tuple);
} else if let Some(ref json_template) = json_template && json_template.contains(¶m_name) {
analysis.json_param = Some(param_tuple);
} else if form_template_params.contains(¶m_name) {
analysis.form_params.push(param_tuple);
}
}
}
for header_template in &header_templates {
if let Some((header_name, param_name)) = parse_header_template(header_template)
&& let Some(param) = param_name {
analysis.header_params.push((header_name, param));
}
}
analysis
}
fn extract_param_from_meta(
name_value: &MetaNameValue,
query_params: &mut Vec<String>,
form_params: &mut Vec<String>,
header_templates: &mut Vec<String>,
body_template: &mut Option<String>,
json_template: &mut Option<String>,
) {
if let Expr::Lit(expr_lit) = &name_value.value
&& let Lit::Str(lit_str) = &expr_lit.lit {
let template = lit_str.value();
if name_value.path.is_ident("query") {
*query_params = extract_params_from_template(&template);
} else if name_value.path.is_ident("form") {
*form_params = extract_params_from_template(&template);
} else if name_value.path.is_ident("body") {
*body_template = Some(template);
} else if name_value.path.is_ident("json") {
*json_template = Some(template);
} else if name_value.path.is_ident("headers") {
header_templates.push(template);
}
}
}
fn is_multipart_form_type(ty: &Type) -> bool {
match ty {
Type::Path(type_path) => {
let path_str = type_path.path.segments.iter()
.map(|seg| seg.ident.to_string())
.collect::<Vec<_>>()
.join("::");
path_str == "reqwest :: multipart :: Form" ||
path_str == "multipart :: Form"
}
_ => false,
}
}
fn is_option_type(ty: &Type) -> bool {
matches!(ty, Type::Path(type_path) if
type_path.path.segments.last().is_some_and(|seg| seg.ident == "Option"))
}
fn generate_url_building(
analysis: &ParamAnalysis,
full_path: &str,
service_id: &str,
url: &Option<String>,
) -> proc_macro2::TokenStream {
let mut url_expr = if let Some(base_url) = url {
quote! { format!("{}{}", #base_url, #full_path) }
} else {
quote! { format!("lb://{}/{}", #service_id, #full_path) }
};
for (param_name, _) in &analysis.path_params {
url_expr = quote! {
#url_expr.replace(&format!("{{{}}}", stringify!(#param_name)), &#param_name.to_string())
};
}
if !analysis.query_params.is_empty() {
let query_parts: Vec<_> = analysis
.query_params
.iter()
.map(|(param_name, param_type)| {
if is_option_type(param_type) {
quote! {
if let Some(val) = #param_name {
query_parts.push(format!("{}={}", stringify!(#param_name), val));
}
}
} else {
quote! {
query_parts.push(format!("{}={}", stringify!(#param_name), #param_name));
}
}
})
.collect();
quote! {
{
let mut url = #url_expr;
let mut query_parts = Vec::new();
#(#query_parts)*
if !query_parts.is_empty() {
url = format!("{}?{}", url, query_parts.join("&"));
}
url
}
}
} else {
url_expr
}
}
fn generate_request_building(
http_method: &str,
analysis: &ParamAnalysis,
) -> proc_macro2::TokenStream {
let method_quote = match http_method {
"GET" => quote! { self.lb_client.get(&url).await },
"POST" => quote! { self.lb_client.post(&url).await },
"PUT" => quote! { self.lb_client.put(&url).await },
"DELETE" => quote! { self.lb_client.delete(&url).await },
"PATCH" => quote! { self.lb_client.patch(&url).await },
_ => quote! { self.lb_client.request(Method::GET, &url).await },
};
let header_quote = if !analysis.header_params.is_empty() {
let header_setters: Vec<_> = analysis
.header_params
.iter()
.map(|(header_name, param_name)| {
if let Ok(ident) = syn::parse_str::<Ident>(param_name) {
quote! {
.header(#header_name, #ident.to_string())
}
} else {
quote! {
.header(#header_name, #param_name)
}
}
})
.collect();
quote! {
#(#header_setters)*
}
} else {
quote! {}
};
let body_quote = if let Some((body_param_name, _)) = &analysis.body_param {
quote! {
.body(#body_param_name)
}
} else {
quote! {}
};
let json_quote = if let Some((json_param_name, _)) = &analysis.json_param {
quote! {
.json(&#json_param_name)
}
} else {
quote! {}
};
let form_quote = if let Some((form_param_name, _)) = &analysis.form_params.first() {
quote! {
.multipart(#form_param_name)
}
} else {
quote! {}
};
quote! {
#method_quote?
#header_quote
#body_quote
#json_quote
#form_quote
.send()
.await
.map_err(|e| {
crate::FeignError::RequestError(e.to_string())
})?
}
}
fn generate_response_parsing(output: &ReturnType) -> proc_macro2::TokenStream {
let return_type = match output {
ReturnType::Type(_, ty) => {
if let Type::Path(type_path) = ty.as_ref() {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Result" {
if let PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(GenericArgument::Type(inner_ty)) = args.args.first() {
inner_ty
} else {
return generate_json_parsing();
}
} else {
return generate_json_parsing();
}
} else {
return generate_json_parsing();
}
} else {
return generate_json_parsing();
}
} else {
return generate_json_parsing();
}
}
ReturnType::Default => {
return quote! { let result = (); };
}
};
match return_type {
Type::Tuple(tuple) if tuple.elems.is_empty() => {
quote! { let result = (); }
}
Type::Path(type_path) => {
if let Some(segment) = type_path.path.segments.last() {
match segment.ident.to_string().as_str() {
"String" => {
quote! {
let result = response.text().await.map_err(|e| {
crate::FeignError::DeserializationError(e.to_string())
})?;
}
}
"Bytes" => {
quote! {
let result = response.bytes().await.map_err(|e| {
crate::FeignError::DeserializationError(e.to_string())
})?;
}
}
"StatusCode" => {
quote! { let result = response.status(); }
}
_ => {
generate_json_parsing()
}
}
} else {
generate_json_parsing()
}
}
_ => generate_json_parsing(),
}
}
fn generate_json_parsing() -> proc_macro2::TokenStream {
quote! {
let result = response.json().await.map_err(|e| {
crate::FeignError::DeserializationError(e.to_string())
})?;
}
}