Skip to main content

oxide_sql_derive/
lib.rs

1//! Derive macros for type-safe SQL table definitions.
2//!
3//! This crate provides the `#[derive(Table)]` macro for defining database tables
4//! with compile-time checked column names.
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{format_ident, quote};
9use syn::{parse_macro_input, Attribute, Data, DeriveInput, Expr, Fields, Ident, Lit, Meta, Type};
10
11/// Derives the `Table` trait for a struct, generating type-safe column accessors.
12///
13/// # Attributes
14///
15/// - `#[table(name = "table_name")]` - Specifies the SQL table name (optional,
16///   defaults to snake_case of struct name)
17///
18/// # Field Attributes
19///
20/// - `#[column(primary_key)]` - Marks the field as primary key
21/// - `#[column(name = "column_name")]` - Specifies the SQL column name (optional,
22///   defaults to field name)
23/// - `#[column(nullable)]` - Marks the column as nullable
24///
25/// # Generated Items
26///
27/// For a struct `User`, this macro generates:
28///
29/// - `UserTable` - A type implementing `Table` trait with table metadata
30/// - `UserColumns` - A module containing column types (`Id`, `Name`, etc.)
31/// - Column accessor methods on `UserTable`
32#[proc_macro_derive(Table, attributes(table, column))]
33pub fn derive_table(input: TokenStream) -> TokenStream {
34    let input = parse_macro_input!(input as DeriveInput);
35    derive_table_impl(input)
36        .unwrap_or_else(|e| e.to_compile_error())
37        .into()
38}
39
40fn derive_table_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
41    let struct_name = &input.ident;
42    let table_name = get_table_name(&input.attrs, struct_name)?;
43
44    let fields = match &input.data {
45        Data::Struct(data) => match &data.fields {
46            Fields::Named(fields) => &fields.named,
47            _ => {
48                return Err(syn::Error::new_spanned(
49                    &input,
50                    "Table derive only supports structs with named fields",
51                ));
52            }
53        },
54        _ => {
55            return Err(syn::Error::new_spanned(
56                &input,
57                "Table derive only supports structs",
58            ));
59        }
60    };
61
62    // Collect field information
63    let mut column_infos: Vec<ColumnInfo> = Vec::new();
64    for field in fields {
65        let field_name = field.ident.as_ref().unwrap();
66        let field_type = &field.ty;
67        let column_attrs = parse_column_attrs(&field.attrs)?;
68
69        column_infos.push(ColumnInfo {
70            field_name: field_name.clone(),
71            field_type: field_type.clone(),
72            column_name: column_attrs.name.unwrap_or_else(|| field_name.to_string()),
73            is_primary_key: column_attrs.primary_key,
74            is_nullable: column_attrs.nullable,
75        });
76    }
77
78    // Generate column type names (PascalCase)
79    let column_type_names: Vec<Ident> = column_infos
80        .iter()
81        .map(|c| format_ident!("{}", to_pascal_case(&c.field_name.to_string())))
82        .collect();
83
84    // Generate the table struct name
85    let table_struct_name = format_ident!("{}Table", struct_name);
86    let columns_mod_name = format_ident!("{}Columns", struct_name);
87
88    // Generate column structs
89    let column_structs: Vec<TokenStream2> = column_infos
90        .iter()
91        .zip(column_type_names.iter())
92        .map(|(info, type_name)| {
93            let column_name = &info.column_name;
94            let field_type = &info.field_type;
95            let is_nullable = info.is_nullable;
96            let is_primary_key = info.is_primary_key;
97
98            quote! {
99                /// Column type for compile-time checked queries.
100                #[derive(Debug, Clone, Copy)]
101                pub struct #type_name;
102
103                impl ::oxide_sql_core::schema::Column for #type_name {
104                    type Table = super::#table_struct_name;
105                    type Type = #field_type;
106
107                    const NAME: &'static str = #column_name;
108                    const NULLABLE: bool = #is_nullable;
109                    const PRIMARY_KEY: bool = #is_primary_key;
110                }
111
112                impl ::oxide_sql_core::schema::TypedColumn<#field_type> for #type_name {}
113            }
114        })
115        .collect();
116
117    // Generate column accessor methods
118    let column_accessors: Vec<TokenStream2> = column_infos
119        .iter()
120        .zip(column_type_names.iter())
121        .map(|(info, type_name)| {
122            let method_name = &info.field_name;
123            quote! {
124                /// Returns the column type for type-safe queries.
125                #[inline]
126                pub const fn #method_name() -> #columns_mod_name::#type_name {
127                    #columns_mod_name::#type_name
128                }
129            }
130        })
131        .collect();
132
133    // Generate list of all column names
134    let all_column_names: Vec<&str> = column_infos
135        .iter()
136        .map(|c| c.column_name.as_str())
137        .collect();
138
139    // Find primary key column
140    let primary_key_column = column_infos
141        .iter()
142        .find(|c| c.is_primary_key)
143        .map(|c| &c.column_name);
144
145    let primary_key_impl = if let Some(pk) = primary_key_column {
146        quote! {
147            const PRIMARY_KEY: Option<&'static str> = Some(#pk);
148        }
149    } else {
150        quote! {
151            const PRIMARY_KEY: Option<&'static str> = None;
152        }
153    };
154
155    let expanded = quote! {
156        /// Column types for `#struct_name` table.
157        #[allow(non_snake_case)]
158        pub mod #columns_mod_name {
159            #(#column_structs)*
160        }
161
162        /// Table metadata for `#struct_name`.
163        #[derive(Debug, Clone, Copy)]
164        pub struct #table_struct_name;
165
166        impl ::oxide_sql_core::schema::Table for #table_struct_name {
167            type Row = #struct_name;
168
169            const NAME: &'static str = #table_name;
170            const COLUMNS: &'static [&'static str] = &[#(#all_column_names),*];
171            #primary_key_impl
172        }
173
174        impl #table_struct_name {
175            /// Returns the table name.
176            #[inline]
177            pub const fn table_name() -> &'static str {
178                #table_name
179            }
180
181            #(#column_accessors)*
182        }
183
184        impl #struct_name {
185            /// Returns the table metadata type.
186            pub fn table() -> #table_struct_name {
187                #table_struct_name
188            }
189
190            #(#column_accessors)*
191        }
192    };
193
194    Ok(expanded)
195}
196
197struct ColumnInfo {
198    field_name: Ident,
199    field_type: Type,
200    column_name: String,
201    is_primary_key: bool,
202    is_nullable: bool,
203}
204
205struct ColumnAttrs {
206    name: Option<String>,
207    primary_key: bool,
208    nullable: bool,
209}
210
211fn get_table_name(attrs: &[Attribute], struct_name: &Ident) -> syn::Result<String> {
212    for attr in attrs {
213        if attr.path().is_ident("table") {
214            let mut table_name = None;
215            attr.parse_nested_meta(|meta| {
216                if meta.path.is_ident("name") {
217                    let value: Expr = meta.value()?.parse()?;
218                    if let Expr::Lit(lit) = value {
219                        if let Lit::Str(s) = lit.lit {
220                            table_name = Some(s.value());
221                        }
222                    }
223                }
224                Ok(())
225            })?;
226            if let Some(name) = table_name {
227                return Ok(name);
228            }
229        }
230    }
231    // Default to snake_case of struct name
232    Ok(to_snake_case(&struct_name.to_string()))
233}
234
235fn parse_column_attrs(attrs: &[Attribute]) -> syn::Result<ColumnAttrs> {
236    let mut result = ColumnAttrs {
237        name: None,
238        primary_key: false,
239        nullable: false,
240    };
241
242    for attr in attrs {
243        if attr.path().is_ident("column") {
244            // Handle empty attribute like #[column]
245            if matches!(attr.meta, Meta::Path(_)) {
246                continue;
247            }
248
249            attr.parse_nested_meta(|meta| {
250                if meta.path.is_ident("primary_key") {
251                    result.primary_key = true;
252                } else if meta.path.is_ident("nullable") {
253                    result.nullable = true;
254                } else if meta.path.is_ident("name") {
255                    let value: Expr = meta.value()?.parse()?;
256                    if let Expr::Lit(lit) = value {
257                        if let Lit::Str(s) = lit.lit {
258                            result.name = Some(s.value());
259                        }
260                    }
261                }
262                Ok(())
263            })?;
264        }
265    }
266
267    Ok(result)
268}
269
270fn to_snake_case(s: &str) -> String {
271    let mut result = String::new();
272    for (i, c) in s.chars().enumerate() {
273        if c.is_uppercase() {
274            if i > 0 {
275                result.push('_');
276            }
277            result.push(c.to_ascii_lowercase());
278        } else {
279            result.push(c);
280        }
281    }
282    result
283}
284
285fn to_pascal_case(s: &str) -> String {
286    let mut result = String::new();
287    let mut capitalize_next = true;
288    for c in s.chars() {
289        if c == '_' {
290            capitalize_next = true;
291        } else if capitalize_next {
292            result.push(c.to_ascii_uppercase());
293            capitalize_next = false;
294        } else {
295            result.push(c);
296        }
297    }
298    result
299}