use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{
Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, LitStr, PathArguments, Type,
TypePath,
};
pub(crate) fn expand(input: &DeriveInput) -> syn::Result<TokenStream2> {
let struct_name = &input.ident;
let fields = match &input.data {
Data::Struct(DataStruct {
fields: Fields::Named(named),
..
}) => &named.named,
Data::Struct(_) => {
return Err(syn::Error::new_spanned(
&input.ident,
"Table can only be derived on structs with named fields",
));
}
Data::Enum(_) => {
return Err(syn::Error::new_spanned(
&input.ident,
"Table cannot be derived on enums",
));
}
Data::Union(_) => {
return Err(syn::Error::new_spanned(
&input.ident,
"Table cannot be derived on unions",
));
}
};
let struct_opts = parse_struct_opts(input)?;
let table_name = struct_opts
.table_name
.unwrap_or_else(|| to_snake_case(&struct_name.to_string()));
let col_defs = fields
.iter()
.map(|f| column_def(f, &table_name))
.collect::<syn::Result<Vec<_>>>()?;
let create_sql = format!(
"CREATE TABLE IF NOT EXISTS {} ({})",
table_name,
col_defs.join(", ")
);
let name_lit = LitStr::new(&table_name, struct_name.span());
let create_sql_lit = LitStr::new(&create_sql, struct_name.span());
#[cfg(feature = "compile-time")]
if struct_opts.register {
let field_names: Vec<String> = fields
.iter()
.filter_map(|f| {
let ident = f.ident.as_ref()?;
let col = column_name_for(f, ident).ok()?;
if col.is_empty() {
None
} else {
Some(col)
}
})
.collect();
hyperdb_compile_check::registry::register(
struct_name.to_string(),
table_name.clone(),
create_sql.clone(),
field_names,
);
}
Ok(quote! {
#[automatically_derived]
impl ::hyperdb_api::Table for #struct_name {
const NAME: &'static str = #name_lit;
const CREATE_SQL: &'static str = #create_sql_lit;
}
})
}
struct StructOpts {
table_name: Option<String>,
#[allow(dead_code, reason = "only used when compile-time feature is enabled")]
register: bool,
}
fn parse_struct_opts(input: &DeriveInput) -> syn::Result<StructOpts> {
let mut table_name = None;
let mut register = false;
for attr in &input.attrs {
if !attr.path().is_ident("hyperdb") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("table") {
let s: LitStr = meta.value()?.parse()?;
table_name = Some(s.value());
Ok(())
} else if meta.path.is_ident("register") {
register = true;
Ok(())
} else if meta.path.is_ident("primary_key")
|| meta.path.is_ident("rename")
|| meta.path.is_ident("index")
{
Ok(())
} else {
Err(meta.error(format!(
"unrecognized hyperdb attribute `{}`; \
supported struct attributes: table, register",
meta.path
.get_ident()
.map_or_else(|| "?".to_string(), ToString::to_string)
)))
}
})?;
}
Ok(StructOpts {
table_name,
register,
})
}
struct FieldOpts {
rename: Option<String>,
index: Option<usize>,
#[allow(dead_code, reason = "reserved for v2 schema enforcement")]
primary_key: bool,
}
fn parse_field_opts(field: &Field) -> syn::Result<FieldOpts> {
let mut rename = None;
let mut index = None;
let mut primary_key = false;
for attr in &field.attrs {
if !attr.path().is_ident("hyperdb") {
continue;
}
attr.parse_nested_meta(|meta| {
if meta.path.is_ident("rename") {
let s: syn::LitStr = meta.value()?.parse()?;
rename = Some(s.value());
Ok(())
} else if meta.path.is_ident("index") {
let n: syn::LitInt = meta.value()?.parse()?;
index = Some(n.base10_parse::<usize>()?);
Ok(())
} else if meta.path.is_ident("primary_key") {
primary_key = true;
Ok(())
} else {
Err(meta.error(format!(
"unrecognized hyperdb attribute `{}`; \
supported field attributes: rename, index, primary_key",
meta.path
.get_ident()
.map_or_else(|| "?".to_string(), ToString::to_string)
)))
}
})?;
}
Ok(FieldOpts {
rename,
index,
primary_key,
})
}
#[cfg(feature = "compile-time")]
fn column_name_for(field: &Field, default: &syn::Ident) -> syn::Result<String> {
let opts = parse_field_opts(field)?;
if opts.index.is_some() {
return Ok(String::new()); }
Ok(opts.rename.unwrap_or_else(|| default.to_string()))
}
fn column_def(field: &Field, _table_name: &str) -> syn::Result<String> {
let ident = field
.ident
.as_ref()
.ok_or_else(|| syn::Error::new_spanned(field, "tuple-struct fields are not supported"))?;
let opts = parse_field_opts(field)?;
let col_name = opts.rename.unwrap_or_else(|| ident.to_string());
let (inner_ty, nullable) = unwrap_option(&field.ty);
let sql_type = rust_type_to_sql(field, inner_ty)?;
let nullability = if nullable || opts.index.is_some() {
""
} else {
" NOT NULL"
};
Ok(format!("{col_name} {sql_type}{nullability}"))
}
fn rust_type_to_sql<'a>(field: &Field, ty: &'a Type) -> syn::Result<&'a str> {
let type_name = last_path_ident(ty).map(ToString::to_string);
match type_name.as_deref() {
Some("i16") => Ok("SMALLINT"),
Some("i32") => Ok("INTEGER"),
Some("i64") => Ok("BIGINT"),
Some("f32") => Ok("REAL"),
Some("f64") => Ok("DOUBLE PRECISION"),
Some("bool") => Ok("BOOLEAN"),
Some("String") => Ok("TEXT"),
Some("Vec") => {
if is_vec_u8(ty) {
Ok("BYTES")
} else {
Err(syn::Error::new_spanned(
field,
format!(
"unsupported field type `{}` for derive(Table): \
only Vec<u8> is supported (maps to BYTES); \
other Vec<T> types have no Hyper SQL equivalent. \
Use a manual `impl Table` for this field.",
quote::quote!(#ty)
),
))
}
}
Some("NaiveDate") => Ok("DATE"),
Some("NaiveDateTime") => Ok("TIMESTAMP"),
Some("NaiveTime") => Ok("TIME"),
Some("DateTime") => Ok("TIMESTAMPTZ"),
Some("Numeric") => Ok("NUMERIC"),
_ => Err(syn::Error::new_spanned(
field,
format!(
"unsupported field type `{}` for derive(Table); \
supported: i16, i32, i64, f32, f64, bool, String, Vec<u8>, \
NaiveDate, NaiveDateTime, NaiveTime, DateTime<Utc>, Numeric. \
Use a manual `impl Table` for custom types.",
quote::quote!(#ty)
),
)),
}
}
fn unwrap_option(ty: &Type) -> (&Type, bool) {
let Type::Path(TypePath { path, qself: None }) = ty else {
return (ty, false);
};
let Some(last) = path.segments.last() else {
return (ty, false);
};
if last.ident != "Option" {
return (ty, false);
}
let PathArguments::AngleBracketed(ref args) = last.arguments else {
return (ty, false);
};
if let Some(GenericArgument::Type(inner)) = args.args.first() {
(inner, true)
} else {
(ty, false)
}
}
fn is_vec_u8(ty: &Type) -> bool {
let Type::Path(TypePath { path, qself: None }) = ty else {
return false;
};
let Some(last) = path.segments.last() else {
return false;
};
if last.ident != "Vec" {
return false;
}
let PathArguments::AngleBracketed(ref args) = last.arguments else {
return false;
};
matches!(
args.args.first(),
Some(GenericArgument::Type(Type::Path(TypePath { path, qself: None })))
if path.is_ident("u8")
)
}
fn last_path_ident(ty: &Type) -> Option<&syn::Ident> {
let Type::Path(TypePath { path, qself: None }) = ty else {
return None;
};
path.segments.last().map(|s| &s.ident)
}
fn to_snake_case(s: &str) -> String {
let mut out = String::with_capacity(s.len() + 4);
for (i, ch) in s.chars().enumerate() {
if ch.is_uppercase() && i > 0 {
out.push('_');
}
out.extend(ch.to_lowercase());
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn snake_case_conversion() {
assert_eq!(to_snake_case("User"), "user");
assert_eq!(to_snake_case("UserOrder"), "user_order");
assert_eq!(to_snake_case("HTTPResponse"), "h_t_t_p_response");
assert_eq!(to_snake_case("already_snake"), "already_snake");
}
}