use std::collections::HashMap;
use proc_macro2::TokenStream;
use quote::quote;
use syn::{parse2, FnArg, ItemFn, Pat, Result, ReturnType};
pub fn api_handler_impl(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
let func: ItemFn = parse2(item.clone())?;
let func_name = &func.sig.ident;
let attr_str = attr.to_string();
let path = extract_attr_value(&attr_str, "path").unwrap_or("/".to_string());
let method = extract_attr_value(&attr_str, "method").unwrap_or("GET".to_string());
let description = extract_attr_value(&attr_str, "description")
.unwrap_or_else(|| format!("{} endpoint", func_name));
let params = extract_parameters(&func);
let response_type = extract_response_type(&func);
let custom_responses = extract_responses(&attr_str);
let schema_fn_name =
syn::Ident::new(&format!("{}_openapi_schema", func_name), func_name.span());
let schema = build_openapi_schema(
&path,
&method,
&description,
¶ms,
&response_type,
&custom_responses,
);
let expanded = quote! {
#func
fn #schema_fn_name() -> String {
#schema.to_string()
}
};
Ok(expanded)
}
struct ParamInfo {
name: String,
type_name: String,
}
fn extract_parameters(func: &ItemFn) -> Vec<ParamInfo> {
let mut params = Vec::new();
for input in &func.sig.inputs {
if let FnArg::Typed(pat_type) = input {
let param_name = if let Pat::Ident(pat_ident) = &*pat_type.pat {
pat_ident.ident.to_string()
} else {
continue;
};
let type_str = quote!(#pat_type.ty).to_string();
params.push(ParamInfo {
name: param_name,
type_name: type_str,
});
}
}
params
}
fn extract_response_type(func: &ItemFn) -> String {
match &func.sig.output {
ReturnType::Default => "()".to_string(),
ReturnType::Type(_, ty) => quote!(#ty).to_string(),
}
}
fn extract_responses(attr_str: &str) -> HashMap<String, String> {
let mut responses = HashMap::new();
if let Some(start) = attr_str.find("responses") {
if let Some(brace_start) = attr_str[start..].find('{') {
let after_brace = &attr_str[start + brace_start + 1..];
if let Some(brace_end) = after_brace.find('}') {
let responses_content = &after_brace[..brace_end];
for entry in responses_content.split(',') {
let parts: Vec<&str> = entry.split(':').collect();
if parts.len() == 2 {
let code = parts[0].trim().to_string();
let type_name = parts[1].trim().to_string();
responses.insert(code, type_name);
}
}
}
}
}
responses
}
fn build_openapi_schema(
path: &str,
method: &str,
description: &str,
params: &[ParamInfo],
response_type: &str,
custom_responses: &HashMap<String, String>,
) -> String {
let method_lower = method.to_lowercase();
let mut parameters_json = String::new();
let mut request_body_json = String::new();
let uses_query_params = method_lower == "get" || method_lower == "delete";
let path_params: Vec<&str> = path
.split('/')
.filter(|s| s.starts_with('{') && s.ends_with('}'))
.collect();
let mut param_entries = Vec::new();
for path_param in path_params {
let param_name = path_param.trim_matches(|c| c == '{' || c == '}');
param_entries.push(format!(
r#"{{
"name": "{}",
"in": "path",
"required": true,
"schema": {{
"type": "integer"
}}
}}"#,
param_name
));
}
if !params.is_empty() {
if uses_query_params {
for param in params {
let is_query_struct = param.type_name.contains("Query")
|| param.type_name.contains("Params")
|| (!param.type_name.contains("Option")
&& !param.type_name.starts_with("i")
&& !param.type_name.starts_with("u")
&& !param.type_name.starts_with("String"));
if is_query_struct {
param_entries.push(format!(
r##"{{
"name": "{}",
"in": "query",
"required": false,
"schema": {{
"$ref": "#/components/schemas/{}"
}}
}}"##,
param.name, param.type_name
));
} else {
let is_optional = param.type_name.contains("Option");
param_entries.push(format!(
r#"{{
"name": "{}",
"in": "query",
"required": {},
"schema": {{
"type": "string"
}}
}}"#,
param.name, !is_optional
));
}
}
} else {
let param = ¶ms[0]; request_body_json = format!(
r##",
"requestBody": {{
"required": true,
"content": {{
"application/json": {{
"schema": {{
"$ref": "#/components/schemas/{type_name}"
}}
}}
}}
}}"##,
type_name = param.type_name
);
}
}
if !param_entries.is_empty() {
parameters_json = format!(
r#",
"parameters": [
{}
]"#,
param_entries.join(",\n ")
);
}
let responses_json = if !custom_responses.is_empty() {
let mut response_entries = Vec::new();
for (code, type_name) in custom_responses {
response_entries.push(format!(
r##""{code}": {{
"description": "Response",
"content": {{
"application/json": {{
"schema": {{
"$ref": "#/components/schemas/{type_name}"
}}
}}
}}
}}"##,
code = code,
type_name = type_name
));
}
response_entries.join(",\n ")
} else {
format!(
r##""200": {{
"description": "Successful response",
"content": {{
"application/json": {{
"schema": {{
"$ref": "#/components/schemas/{type_name}"
}}
}}
}}
}}"##,
type_name = response_type
)
};
let mut schema_types = Vec::new();
for param in params {
if !schema_types.contains(¶m.type_name) {
schema_types.push(param.type_name.clone());
}
}
if !schema_types.contains(&response_type.to_string()) {
schema_types.push(response_type.to_string());
}
for type_name in custom_responses.values() {
if !schema_types.contains(type_name) {
schema_types.push(type_name.clone());
}
}
let components_json = if !schema_types.is_empty() {
let schema_defs: Vec<String> = schema_types
.iter()
.map(|type_name| {
let hint_fields = if type_name.contains("Query") {
r#", "page": {"type": "integer"}, "limit": {"type": "integer"}"#
} else {
""
};
format!(
r#""{type_name}": {{
"type": "object",
"properties": {{
"placeholder": {{"type": "string"}}{hint_fields}
}}
}}"#,
type_name = type_name,
hint_fields = hint_fields
)
})
.collect();
format!(
r#",
"components": {{
"schemas": {{
{}
}}
}}"#,
schema_defs.join(",\n ")
)
} else {
String::new()
};
format!(
r#"{{
"openapi": "3.1.0",
"info": {{
"title": "API",
"version": "1.0.0"
}},
"paths": {{
"{}": {{
"{}": {{
"description": "{}"{}{},
"responses": {{
{}
}}
}}
}}
}}{}
}}"#,
path,
method_lower,
description,
parameters_json,
request_body_json,
responses_json,
components_json
)
}
fn extract_attr_value(attr_str: &str, key: &str) -> Option<String> {
let key_pattern = format!("{} =", key);
if let Some(start_pos) = attr_str.find(&key_pattern) {
let after_key = &attr_str[start_pos + key_pattern.len()..];
let trimmed = after_key.trim_start();
if let Some(quote_char) = trimmed.chars().next() {
if quote_char == '"' || quote_char == '\'' {
if let Some(end_pos) = trimmed[1..].find(quote_char) {
return Some(trimmed[1..=end_pos].to_string());
}
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_attr_value() {
assert_eq!(
extract_attr_value(r#"path = "/users", method = "POST""#, "path"),
Some("/users".to_string())
);
assert_eq!(
extract_attr_value(r#"path = "/users", method = "POST""#, "method"),
Some("POST".to_string())
);
assert_eq!(
extract_attr_value(r#"description = "Create user""#, "description"),
Some("Create user".to_string())
);
}
}