1mod table_derive;
49
50use proc_macro::TokenStream;
51use proc_macro2::TokenStream as TokenStream2;
52use quote::quote;
53use syn::{
54 parse_macro_input, spanned::Spanned, Data, DataStruct, DeriveInput, Field, Fields,
55 GenericArgument, LitInt, LitStr, PathArguments, Type, TypePath,
56};
57
58enum FieldSource {
62 Name(String),
63 Index(usize),
64}
65
66#[proc_macro_derive(Table, attributes(hyperdb))]
84pub fn table_derive(input: TokenStream) -> TokenStream {
85 let input = parse_macro_input!(input as DeriveInput);
86 match table_derive::expand(&input) {
87 Ok(ts) => ts.into(),
88 Err(e) => e.to_compile_error().into(),
89 }
90}
91
92#[proc_macro]
117pub fn query_as(input: TokenStream) -> TokenStream {
118 match expand_query_as(&input.into()) {
119 Ok(ts) => ts.into(),
120 Err(e) => e.to_compile_error().into(),
121 }
122}
123
124fn expand_query_as(input: &TokenStream2) -> syn::Result<TokenStream2> {
125 use syn::{parse::Parser, punctuated::Punctuated, Expr, Token};
126
127 let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
129 let args = parser.parse2(input.clone())?;
130 let mut iter = args.iter();
131
132 let ty_expr = iter.next().ok_or_else(|| {
133 syn::Error::new_spanned(
134 input,
135 "query_as! expects at least two arguments: query_as!(Type, \"SQL\")",
136 )
137 })?;
138
139 let ty: Type = syn::parse2(quote!(#ty_expr))?;
141
142 let sql_expr = iter.next().ok_or_else(|| {
143 syn::Error::new_spanned(
144 ty_expr,
145 "query_as! expects a SQL string literal as the second argument",
146 )
147 })?;
148
149 let rest: Vec<&Expr> = iter.collect();
151
152 #[cfg(feature = "compile-time")]
157 {
158 let struct_name = last_type_ident(&ty).map(ToString::to_string);
159 let sql_lit: Option<LitStr> = syn::parse2(quote!(#sql_expr)).ok();
160 if let (Some(struct_name), Some(sql_lit)) = (struct_name, sql_lit) {
161 let sql_str = sql_lit.value();
162 if let Err(e) = hyperdb_compile_check::validate_query_as(&struct_name, &sql_str) {
163 let msg = e.to_diagnostic();
164 return Ok(quote! {
165 ::std::compile_error!(#msg)
166 });
167 }
168 }
169 }
170
171 Ok(quote! {
172 ::hyperdb_api::QueryAs::<#ty>::new(#sql_expr, &[#(&#rest),*])
173 })
174}
175
176#[cfg(feature = "compile-time")]
179fn last_type_ident(ty: &Type) -> Option<&syn::Ident> {
180 let Type::Path(syn::TypePath { path, qself: None }) = ty else {
181 return None;
182 };
183 path.segments.last().map(|s| &s.ident)
184}
185
186#[proc_macro]
197pub fn query_scalar(input: TokenStream) -> TokenStream {
198 match expand_query_scalar(&input.into()) {
199 Ok(ts) => ts.into(),
200 Err(e) => e.to_compile_error().into(),
201 }
202}
203
204fn expand_query_scalar(input: &TokenStream2) -> syn::Result<TokenStream2> {
205 use syn::{parse::Parser, punctuated::Punctuated, Expr, Token};
206
207 let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
208 let args = parser.parse2(input.clone())?;
209 let mut iter = args.iter();
210
211 let ty_expr = iter.next().ok_or_else(|| {
212 syn::Error::new_spanned(
213 input,
214 "query_scalar! expects at least two arguments: query_scalar!(Type, \"SQL\")",
215 )
216 })?;
217
218 let ty: Type = syn::parse2(quote!(#ty_expr))?;
219
220 let sql_expr = iter.next().ok_or_else(|| {
221 syn::Error::new_spanned(
222 ty_expr,
223 "query_scalar! expects a SQL string literal as the second argument",
224 )
225 })?;
226
227 let rest: Vec<&Expr> = iter.collect();
228
229 #[cfg(feature = "compile-time")]
231 {
232 let sql_lit: Option<LitStr> = syn::parse2(quote!(#sql_expr)).ok();
233 if let Some(sql_lit) = sql_lit {
234 let sql_str = sql_lit.value();
235 match hyperdb_compile_check::validate_scalar_sql(&sql_str) {
239 Ok(()) => {}
240 Err(e) => {
241 let msg = e.to_diagnostic();
242 return Ok(quote! { ::std::compile_error!(#msg) });
243 }
244 }
245 }
246 }
247
248 Ok(quote! {
249 ::hyperdb_api::QueryScalar::<#ty>::new(#sql_expr, &[#(&#rest),*])
250 })
251}
252
253#[proc_macro_derive(FromRow, attributes(hyperdb))]
257pub fn from_row_derive(input: TokenStream) -> TokenStream {
258 let input = parse_macro_input!(input as DeriveInput);
259 match expand(&input) {
260 Ok(ts) => ts.into(),
261 Err(e) => e.to_compile_error().into(),
262 }
263}
264
265fn expand(input: &DeriveInput) -> syn::Result<TokenStream2> {
266 let name = &input.ident;
267 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
268
269 let fields = match &input.data {
270 Data::Struct(DataStruct {
271 fields: Fields::Named(named),
272 ..
273 }) => &named.named,
274 Data::Struct(_) => {
275 return Err(syn::Error::new_spanned(
276 &input.ident,
277 "FromRow can only be derived on structs with named fields",
278 ));
279 }
280 Data::Enum(_) => {
281 return Err(syn::Error::new_spanned(
282 &input.ident,
283 "FromRow cannot be derived on enums",
284 ));
285 }
286 Data::Union(_) => {
287 return Err(syn::Error::new_spanned(
288 &input.ident,
289 "FromRow cannot be derived on unions",
290 ));
291 }
292 };
293
294 let assignments = fields
295 .iter()
296 .map(field_assignment)
297 .collect::<syn::Result<Vec<_>>>()?;
298
299 Ok(quote! {
300 #[automatically_derived]
301 impl #impl_generics ::hyperdb_api::FromRow for #name #ty_generics #where_clause {
302 fn from_row(
303 row: ::hyperdb_api::RowAccessor<'_>,
304 ) -> ::hyperdb_api::Result<Self> {
305 Ok(Self {
306 #(#assignments),*
307 })
308 }
309 }
310 })
311}
312
313fn field_assignment(field: &Field) -> syn::Result<TokenStream2> {
316 let ident = field
317 .ident
318 .as_ref()
319 .ok_or_else(|| syn::Error::new_spanned(field, "tuple-struct fields are not supported"))?;
320 let source = field_source_for(field, ident)?;
321 let is_opt = is_option_type(&field.ty);
322
323 let getter = match (source, is_opt) {
324 (FieldSource::Name(name), true) => {
325 let lit = LitStr::new(&name, ident.span());
326 quote!(row.get_opt(#lit)?)
327 }
328 (FieldSource::Name(name), false) => {
329 let lit = LitStr::new(&name, ident.span());
330 quote!(row.get(#lit)?)
331 }
332 (FieldSource::Index(idx), true) => quote!(row.position_opt(#idx)?),
333 (FieldSource::Index(idx), false) => quote!(row.position(#idx)?),
334 };
335
336 Ok(quote! { #ident: #getter })
337}
338
339fn field_source_for(field: &Field, default: &syn::Ident) -> syn::Result<FieldSource> {
343 let mut rename: Option<(String, proc_macro2::Span)> = None;
344 let mut index: Option<(usize, proc_macro2::Span)> = None;
345
346 for attr in &field.attrs {
347 if !attr.path().is_ident("hyperdb") {
348 continue;
349 }
350 attr.parse_nested_meta(|meta| {
351 if meta.path.is_ident("rename") {
352 let s: LitStr = meta.value()?.parse()?;
353 rename = Some((s.value(), meta.path.span()));
354 Ok(())
355 } else if meta.path.is_ident("index") {
356 let n: LitInt = meta.value()?.parse()?;
357 let parsed: usize = n.base10_parse()?;
358 index = Some((parsed, meta.path.span()));
359 Ok(())
360 } else if meta.path.is_ident("primary_key") {
361 Ok(())
363 } else {
364 Err(meta.error(format!(
365 "unrecognized hyperdb attribute `{}`; supported attributes: rename, index",
366 meta.path
367 .get_ident()
368 .map_or_else(|| "?".to_string(), ToString::to_string)
369 )))
370 }
371 })?;
372 }
373
374 match (rename, index) {
375 (Some(_), Some((_, idx_span))) => Err(syn::Error::new(
376 idx_span,
377 "`#[hyperdb(rename = ...)]` and `#[hyperdb(index = N)]` are mutually exclusive",
378 )),
379 (Some((name, _)), None) => Ok(FieldSource::Name(name)),
380 (None, Some((idx, _))) => Ok(FieldSource::Index(idx)),
381 (None, None) => Ok(FieldSource::Name(default.to_string())),
382 }
383}
384
385fn is_option_type(ty: &Type) -> bool {
387 let Type::Path(TypePath { path, qself: None }) = ty else {
388 return false;
389 };
390 let Some(last) = path.segments.last() else {
391 return false;
392 };
393 if last.ident != "Option" {
394 return false;
395 }
396 matches!(
397 last.arguments,
398 PathArguments::AngleBracketed(ref args)
399 if matches!(args.args.first(), Some(GenericArgument::Type(_)))
400 )
401}