use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input, spanned::Spanned};
pub fn expand_model(attr: TokenStream, item: TokenStream) -> TokenStream {
let input_clone = item.clone();
let input = parse_macro_input!(item as DeriveInput);
match expand_model_impl(attr.into(), input, input_clone.into()) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into(),
}
}
fn expand_model_impl(
_attr: TokenStream2,
input: DeriveInput,
_original_tokens: TokenStream2,
) -> syn::Result<TokenStream2> {
let struct_name = &input.ident;
let vis = &input.vis;
let table_name = get_table_name(&input)?;
let fields = match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(fields) => &fields.named,
_ => {
return Err(syn::Error::new(
input.span(),
"Only named fields are supported",
));
}
},
_ => return Err(syn::Error::new(input.span(), "Only structs are supported")),
};
let field_tokens: Vec<TokenStream2> = fields
.iter()
.map(|field| {
let field_name = field.ident.as_ref().unwrap();
let field_type = &field.ty;
let type_str = quote!(#field_type).to_string();
let name = field_name.to_string();
let column_name = to_snake_case(&name);
quote! {
{
let rust_type = forge::forge_core::schema::RustType::from_type_string(#type_str);
let mut field = forge::forge_core::schema::FieldDef::new(#name, rust_type);
field.column_name = #column_name.to_string();
field
}
}
})
.collect();
let field_defs: Vec<TokenStream2> = fields
.iter()
.map(|field| {
let field_name = &field.ident;
let field_type = &field.ty;
let field_vis = &field.vis;
quote! { #field_vis #field_name: #field_type }
})
.collect();
let other_attrs: Vec<&syn::Attribute> = input
.attrs
.iter()
.filter(|attr| {
let path = attr.path();
!path.is_ident("derive") && path.segments.first().is_none_or(|s| s.ident != "forge")
})
.collect();
let expanded = quote! {
#(#other_attrs)*
#vis struct #struct_name {
#(#field_defs),*
}
impl forge::forge_core::schema::ModelMeta for #struct_name {
const TABLE_NAME: &'static str = #table_name;
fn table_def() -> forge::forge_core::schema::TableDef {
let mut table = forge::forge_core::schema::TableDef::new(#table_name, stringify!(#struct_name));
table.fields = vec![
#(#field_tokens),*
];
table
}
fn primary_key_field() -> &'static str {
"id"
}
}
};
Ok(expanded)
}
fn get_table_name(input: &DeriveInput) -> syn::Result<String> {
for attr in &input.attrs {
if attr.path().is_ident("table") {
let meta = attr.meta.clone();
if let Meta::List(list) = meta {
let tokens: TokenStream2 = list.tokens;
let tokens_str = tokens.to_string();
if tokens_str.starts_with("name")
&& let Some(value) = extract_string_value(&tokens_str)
{
return Ok(value);
}
}
}
}
let name = to_snake_case(&input.ident.to_string());
Ok(pluralize(&name))
}
fn extract_string_value(s: &str) -> Option<String> {
let parts: Vec<&str> = s.splitn(2, '=').collect();
if parts.len() == 2 {
let value = parts[1].trim();
if let Some(stripped) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
return Some(stripped.to_string());
}
}
None
}
fn to_snake_case(s: &str) -> String {
let mut result = String::new();
for (i, c) in s.chars().enumerate() {
if c.is_uppercase() {
if i > 0 {
result.push('_');
}
result.push(c.to_lowercase().next().unwrap());
} else {
result.push(c);
}
}
result
}
fn pluralize(s: &str) -> String {
if s.ends_with('s')
|| s.ends_with("sh")
|| s.ends_with("ch")
|| s.ends_with('x')
|| s.ends_with('z')
{
format!("{}es", s)
} else if let Some(stem) = s.strip_suffix('y') {
if !s.ends_with("ay") && !s.ends_with("ey") && !s.ends_with("oy") && !s.ends_with("uy") {
format!("{}ies", stem)
} else {
format!("{}s", s)
}
} else {
format!("{}s", s)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
mod tests {
use super::*;
#[test]
fn snake_case_simple() {
assert_eq!(to_snake_case("User"), "user");
assert_eq!(to_snake_case("UserProfile"), "user_profile");
assert_eq!(to_snake_case("HTTPRequest"), "h_t_t_p_request");
}
#[test]
fn snake_case_already_lowercase() {
assert_eq!(to_snake_case("user"), "user");
assert_eq!(to_snake_case("item"), "item");
}
#[test]
fn pluralize_regular_nouns() {
assert_eq!(pluralize("user"), "users");
assert_eq!(pluralize("item"), "items");
assert_eq!(pluralize("product"), "products");
assert_eq!(pluralize("order"), "orders");
assert_eq!(pluralize("account"), "accounts");
}
#[test]
fn pluralize_sibilant_endings() {
assert_eq!(pluralize("address"), "addresses");
assert_eq!(pluralize("crash"), "crashes");
assert_eq!(pluralize("match"), "matches");
assert_eq!(pluralize("box"), "boxes");
assert_eq!(pluralize("quiz"), "quizes");
}
#[test]
fn pluralize_consonant_y() {
assert_eq!(pluralize("category"), "categories");
assert_eq!(pluralize("company"), "companies");
assert_eq!(pluralize("policy"), "policies");
assert_eq!(pluralize("entry"), "entries");
}
#[test]
fn pluralize_vowel_y() {
assert_eq!(pluralize("key"), "keys");
assert_eq!(pluralize("day"), "days");
assert_eq!(pluralize("boy"), "boys");
assert_eq!(pluralize("buy"), "buys");
}
#[test]
fn extract_string_value_valid() {
assert_eq!(
extract_string_value(r#"name = "custom_table""#),
Some("custom_table".to_string())
);
}
#[test]
fn extract_string_value_no_quotes() {
assert_eq!(extract_string_value("name = bare_value"), None);
}
#[test]
fn extract_string_value_no_equals() {
assert_eq!(extract_string_value(r#""just a string""#), None);
}
#[test]
fn table_name_from_struct_name() {
let cases = vec![
("User", "users"),
("UserProfile", "user_profiles"),
("Category", "categories"),
("Address", "addresses"),
("TodoItem", "todo_items"),
("OrderStatus", "order_statuses"),
];
for (struct_name, expected_table) in cases {
let snake = to_snake_case(struct_name);
let table = pluralize(&snake);
assert_eq!(table, expected_table, "Failed for struct {struct_name}");
}
}
}