mod table_derive;
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{
parse_macro_input, spanned::Spanned, Data, DataStruct, DeriveInput, Field, Fields,
GenericArgument, LitInt, LitStr, PathArguments, Type, TypePath,
};
enum FieldSource {
Name(String),
Index(usize),
}
#[proc_macro_derive(Table, attributes(hyperdb))]
pub fn table_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match table_derive::expand(&input) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
#[proc_macro]
pub fn query_as(input: TokenStream) -> TokenStream {
match expand_query_as(&input.into()) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn expand_query_as(input: &TokenStream2) -> syn::Result<TokenStream2> {
use syn::{parse::Parser, punctuated::Punctuated, Expr, Token};
let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
let args = parser.parse2(input.clone())?;
let mut iter = args.iter();
let ty_expr = iter.next().ok_or_else(|| {
syn::Error::new_spanned(
input,
"query_as! expects at least two arguments: query_as!(Type, \"SQL\")",
)
})?;
let ty: Type = syn::parse2(quote!(#ty_expr))?;
let sql_expr = iter.next().ok_or_else(|| {
syn::Error::new_spanned(
ty_expr,
"query_as! expects a SQL string literal as the second argument",
)
})?;
let rest: Vec<&Expr> = iter.collect();
#[cfg(feature = "compile-time")]
{
let struct_name = last_type_ident(&ty).map(ToString::to_string);
let sql_lit: Option<LitStr> = syn::parse2(quote!(#sql_expr)).ok();
if let (Some(struct_name), Some(sql_lit)) = (struct_name, sql_lit) {
let sql_str = sql_lit.value();
if let Err(e) = hyperdb_compile_check::validate_query_as(&struct_name, &sql_str) {
let msg = e.to_diagnostic();
return Ok(quote! {
::std::compile_error!(#msg)
});
}
}
}
Ok(quote! {
::hyperdb_api::QueryAs::<#ty>::new(#sql_expr, &[#(&#rest),*])
})
}
#[cfg(feature = "compile-time")]
fn last_type_ident(ty: &Type) -> Option<&syn::Ident> {
let Type::Path(syn::TypePath { path, qself: None }) = ty else {
return None;
};
path.segments.last().map(|s| &s.ident)
}
#[proc_macro]
pub fn query_scalar(input: TokenStream) -> TokenStream {
match expand_query_scalar(&input.into()) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn expand_query_scalar(input: &TokenStream2) -> syn::Result<TokenStream2> {
use syn::{parse::Parser, punctuated::Punctuated, Expr, Token};
let parser = Punctuated::<Expr, Token![,]>::parse_terminated;
let args = parser.parse2(input.clone())?;
let mut iter = args.iter();
let ty_expr = iter.next().ok_or_else(|| {
syn::Error::new_spanned(
input,
"query_scalar! expects at least two arguments: query_scalar!(Type, \"SQL\")",
)
})?;
let ty: Type = syn::parse2(quote!(#ty_expr))?;
let sql_expr = iter.next().ok_or_else(|| {
syn::Error::new_spanned(
ty_expr,
"query_scalar! expects a SQL string literal as the second argument",
)
})?;
let rest: Vec<&Expr> = iter.collect();
#[cfg(feature = "compile-time")]
{
let sql_lit: Option<LitStr> = syn::parse2(quote!(#sql_expr)).ok();
if let Some(sql_lit) = sql_lit {
let sql_str = sql_lit.value();
match hyperdb_compile_check::validate_scalar_sql(&sql_str) {
Ok(()) => {}
Err(e) => {
let msg = e.to_diagnostic();
return Ok(quote! { ::std::compile_error!(#msg) });
}
}
}
}
Ok(quote! {
::hyperdb_api::QueryScalar::<#ty>::new(#sql_expr, &[#(&#rest),*])
})
}
#[proc_macro_derive(FromRow, attributes(hyperdb))]
pub fn from_row_derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match expand(&input) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn expand(input: &DeriveInput) -> syn::Result<TokenStream2> {
let name = &input.ident;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
let fields = match &input.data {
Data::Struct(DataStruct {
fields: Fields::Named(named),
..
}) => &named.named,
Data::Struct(_) => {
return Err(syn::Error::new_spanned(
&input.ident,
"FromRow can only be derived on structs with named fields",
));
}
Data::Enum(_) => {
return Err(syn::Error::new_spanned(
&input.ident,
"FromRow cannot be derived on enums",
));
}
Data::Union(_) => {
return Err(syn::Error::new_spanned(
&input.ident,
"FromRow cannot be derived on unions",
));
}
};
let assignments = fields
.iter()
.map(field_assignment)
.collect::<syn::Result<Vec<_>>>()?;
Ok(quote! {
#[automatically_derived]
impl #impl_generics ::hyperdb_api::FromRow for #name #ty_generics #where_clause {
fn from_row(
row: ::hyperdb_api::RowAccessor<'_>,
) -> ::hyperdb_api::Result<Self> {
Ok(Self {
#(#assignments),*
})
}
}
})
}
fn field_assignment(field: &Field) -> syn::Result<TokenStream2> {
let ident = field
.ident
.as_ref()
.ok_or_else(|| syn::Error::new_spanned(field, "tuple-struct fields are not supported"))?;
let source = field_source_for(field, ident)?;
let is_opt = is_option_type(&field.ty);
let getter = match (source, is_opt) {
(FieldSource::Name(name), true) => {
let lit = LitStr::new(&name, ident.span());
quote!(row.get_opt(#lit)?)
}
(FieldSource::Name(name), false) => {
let lit = LitStr::new(&name, ident.span());
quote!(row.get(#lit)?)
}
(FieldSource::Index(idx), true) => quote!(row.position_opt(#idx)?),
(FieldSource::Index(idx), false) => quote!(row.position(#idx)?),
};
Ok(quote! { #ident: #getter })
}
fn field_source_for(field: &Field, default: &syn::Ident) -> syn::Result<FieldSource> {
let mut rename: Option<(String, proc_macro2::Span)> = None;
let mut index: Option<(usize, proc_macro2::Span)> = None;
for attr in &field.attrs {
if !attr.path().is_ident("hyperdb") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("rename") {
let s: LitStr = meta.value()?.parse()?;
rename = Some((s.value(), meta.path.span()));
Ok(())
} else if meta.path.is_ident("index") {
let n: LitInt = meta.value()?.parse()?;
let parsed: usize = n.base10_parse()?;
index = Some((parsed, meta.path.span()));
Ok(())
} else if meta.path.is_ident("primary_key") {
Ok(())
} else {
Err(meta.error(format!(
"unrecognized hyperdb attribute `{}`; supported attributes: rename, index",
meta.path
.get_ident()
.map_or_else(|| "?".to_string(), ToString::to_string)
)))
}
})?;
}
match (rename, index) {
(Some(_), Some((_, idx_span))) => Err(syn::Error::new(
idx_span,
"`#[hyperdb(rename = ...)]` and `#[hyperdb(index = N)]` are mutually exclusive",
)),
(Some((name, _)), None) => Ok(FieldSource::Name(name)),
(None, Some((idx, _))) => Ok(FieldSource::Index(idx)),
(None, None) => Ok(FieldSource::Name(default.to_string())),
}
}
fn is_option_type(ty: &Type) -> bool {
let Type::Path(TypePath { path, qself: None }) = ty else {
return false;
};
let Some(last) = path.segments.last() else {
return false;
};
if last.ident != "Option" {
return false;
}
matches!(
last.arguments,
PathArguments::AngleBracketed(ref args)
if matches!(args.args.first(), Some(GenericArgument::Type(_)))
)
}