1use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{format_ident, quote};
9use syn::{parse_macro_input, Attribute, Data, DeriveInput, Expr, Fields, Ident, Lit, Meta, Type};
10
11#[proc_macro_derive(Table, attributes(table, column))]
33pub fn derive_table(input: TokenStream) -> TokenStream {
34 let input = parse_macro_input!(input as DeriveInput);
35 derive_table_impl(input)
36 .unwrap_or_else(|e| e.to_compile_error())
37 .into()
38}
39
40fn derive_table_impl(input: DeriveInput) -> syn::Result<TokenStream2> {
41 let struct_name = &input.ident;
42 let table_name = get_table_name(&input.attrs, struct_name)?;
43
44 let fields = match &input.data {
45 Data::Struct(data) => match &data.fields {
46 Fields::Named(fields) => &fields.named,
47 _ => {
48 return Err(syn::Error::new_spanned(
49 &input,
50 "Table derive only supports structs with named fields",
51 ));
52 }
53 },
54 _ => {
55 return Err(syn::Error::new_spanned(
56 &input,
57 "Table derive only supports structs",
58 ));
59 }
60 };
61
62 let mut column_infos: Vec<ColumnInfo> = Vec::new();
64 for field in fields {
65 let field_name = field.ident.as_ref().unwrap();
66 let field_type = &field.ty;
67 let column_attrs = parse_column_attrs(&field.attrs)?;
68
69 column_infos.push(ColumnInfo {
70 field_name: field_name.clone(),
71 field_type: field_type.clone(),
72 column_name: column_attrs.name.unwrap_or_else(|| field_name.to_string()),
73 is_primary_key: column_attrs.primary_key,
74 is_nullable: column_attrs.nullable,
75 });
76 }
77
78 let column_type_names: Vec<Ident> = column_infos
80 .iter()
81 .map(|c| format_ident!("{}", to_pascal_case(&c.field_name.to_string())))
82 .collect();
83
84 let table_struct_name = format_ident!("{}Table", struct_name);
86 let columns_mod_name = format_ident!("{}Columns", struct_name);
87
88 let column_structs: Vec<TokenStream2> = column_infos
90 .iter()
91 .zip(column_type_names.iter())
92 .map(|(info, type_name)| {
93 let column_name = &info.column_name;
94 let field_type = &info.field_type;
95 let is_nullable = info.is_nullable;
96 let is_primary_key = info.is_primary_key;
97
98 quote! {
99 #[derive(Debug, Clone, Copy)]
101 pub struct #type_name;
102
103 impl ::oxide_sql_core::schema::Column for #type_name {
104 type Table = super::#table_struct_name;
105 type Type = #field_type;
106
107 const NAME: &'static str = #column_name;
108 const NULLABLE: bool = #is_nullable;
109 const PRIMARY_KEY: bool = #is_primary_key;
110 }
111
112 impl ::oxide_sql_core::schema::TypedColumn<#field_type> for #type_name {}
113 }
114 })
115 .collect();
116
117 let column_accessors: Vec<TokenStream2> = column_infos
119 .iter()
120 .zip(column_type_names.iter())
121 .map(|(info, type_name)| {
122 let method_name = &info.field_name;
123 quote! {
124 #[inline]
126 pub const fn #method_name() -> #columns_mod_name::#type_name {
127 #columns_mod_name::#type_name
128 }
129 }
130 })
131 .collect();
132
133 let all_column_names: Vec<&str> = column_infos
135 .iter()
136 .map(|c| c.column_name.as_str())
137 .collect();
138
139 let primary_key_column = column_infos
141 .iter()
142 .find(|c| c.is_primary_key)
143 .map(|c| &c.column_name);
144
145 let primary_key_impl = if let Some(pk) = primary_key_column {
146 quote! {
147 const PRIMARY_KEY: Option<&'static str> = Some(#pk);
148 }
149 } else {
150 quote! {
151 const PRIMARY_KEY: Option<&'static str> = None;
152 }
153 };
154
155 let expanded = quote! {
156 #[allow(non_snake_case)]
158 pub mod #columns_mod_name {
159 #(#column_structs)*
160 }
161
162 #[derive(Debug, Clone, Copy)]
164 pub struct #table_struct_name;
165
166 impl ::oxide_sql_core::schema::Table for #table_struct_name {
167 type Row = #struct_name;
168
169 const NAME: &'static str = #table_name;
170 const COLUMNS: &'static [&'static str] = &[#(#all_column_names),*];
171 #primary_key_impl
172 }
173
174 impl #table_struct_name {
175 #[inline]
177 pub const fn table_name() -> &'static str {
178 #table_name
179 }
180
181 #(#column_accessors)*
182 }
183
184 impl #struct_name {
185 pub fn table() -> #table_struct_name {
187 #table_struct_name
188 }
189
190 #(#column_accessors)*
191 }
192 };
193
194 Ok(expanded)
195}
196
197struct ColumnInfo {
198 field_name: Ident,
199 field_type: Type,
200 column_name: String,
201 is_primary_key: bool,
202 is_nullable: bool,
203}
204
205struct ColumnAttrs {
206 name: Option<String>,
207 primary_key: bool,
208 nullable: bool,
209}
210
211fn get_table_name(attrs: &[Attribute], struct_name: &Ident) -> syn::Result<String> {
212 for attr in attrs {
213 if attr.path().is_ident("table") {
214 let mut table_name = None;
215 attr.parse_nested_meta(|meta| {
216 if meta.path.is_ident("name") {
217 let value: Expr = meta.value()?.parse()?;
218 if let Expr::Lit(lit) = value {
219 if let Lit::Str(s) = lit.lit {
220 table_name = Some(s.value());
221 }
222 }
223 }
224 Ok(())
225 })?;
226 if let Some(name) = table_name {
227 return Ok(name);
228 }
229 }
230 }
231 Ok(to_snake_case(&struct_name.to_string()))
233}
234
235fn parse_column_attrs(attrs: &[Attribute]) -> syn::Result<ColumnAttrs> {
236 let mut result = ColumnAttrs {
237 name: None,
238 primary_key: false,
239 nullable: false,
240 };
241
242 for attr in attrs {
243 if attr.path().is_ident("column") {
244 if matches!(attr.meta, Meta::Path(_)) {
246 continue;
247 }
248
249 attr.parse_nested_meta(|meta| {
250 if meta.path.is_ident("primary_key") {
251 result.primary_key = true;
252 } else if meta.path.is_ident("nullable") {
253 result.nullable = true;
254 } else if meta.path.is_ident("name") {
255 let value: Expr = meta.value()?.parse()?;
256 if let Expr::Lit(lit) = value {
257 if let Lit::Str(s) = lit.lit {
258 result.name = Some(s.value());
259 }
260 }
261 }
262 Ok(())
263 })?;
264 }
265 }
266
267 Ok(result)
268}
269
270fn to_snake_case(s: &str) -> String {
271 let mut result = String::new();
272 for (i, c) in s.chars().enumerate() {
273 if c.is_uppercase() {
274 if i > 0 {
275 result.push('_');
276 }
277 result.push(c.to_ascii_lowercase());
278 } else {
279 result.push(c);
280 }
281 }
282 result
283}
284
285fn to_pascal_case(s: &str) -> String {
286 let mut result = String::new();
287 let mut capitalize_next = true;
288 for c in s.chars() {
289 if c == '_' {
290 capitalize_next = true;
291 } else if capitalize_next {
292 result.push(c.to_ascii_uppercase());
293 capitalize_next = false;
294 } else {
295 result.push(c);
296 }
297 }
298 result
299}