use {
darling::{
ast::{Data, Fields},
util::Ignored,
FromDeriveInput,
},
proc_macro2::TokenStream,
std::{
fmt::{self, Write},
string::ToString,
},
syn::{DeriveInput, GenericArgument, Generics, Ident, PathArguments, Type},
};
#[derive(Debug, darling::FromDeriveInput)]
#[darling(attributes(table), supports(struct_named))]
pub struct Table {
ident: Ident,
data: Data<Ignored, TableField>,
generics: Generics,
#[darling(default)]
rename: Option<String>,
#[darling(default)]
exists: bool,
#[darling(default)]
schema: bool,
}
#[derive(Clone, Debug, darling::FromField)]
#[darling(attributes(table))]
pub struct TableField {
ident: Option<Ident>,
ty: Type,
#[darling(default)]
rename: Option<String>,
#[darling(default)]
primary: bool,
#[darling(default)]
unique: bool,
#[darling(default)]
ignore: bool,
#[darling(default)]
length: Option<usize>,
#[darling(default)]
kind: Option<String>,
}
impl Table {
pub fn derive(input: DeriveInput) -> TokenStream {
let table = Table::from_derive_input(&input).unwrap();
let column_types = if table.schema {
Some(
table
.data
.clone()
.take_struct()
.map(|table| {
table
.into_iter()
.map(|field| {
Self::type_to_column_type(&field.ty, None).map(|parsed| Column {
name: field.rename.clone().unwrap_or_else(|| {
field.ident.clone().map(|i| i.to_string()).unwrap()
}),
typ: parsed.typ,
length: field.length,
null: parsed.null,
primary: field.primary,
unique: field.unique,
})
})
.collect::<Option<Vec<_>>>()
})
.unwrap()
.unwrap_or_else(|| {
panic!(
"Table {} does not have a valid column type",
table.ident.to_string()
)
}),
)
} else {
None
};
let fields = if let Data::Struct(fields) = &table.data {
fields
} else {
panic!("derive table only works on named structs");
};
let columns = fields.len();
let ident = &input.ident;
let (impl_generics, type_generics, where_clause) = table.generics.split_for_impl();
let rename = table.rename.unwrap_or_else(|| ident.to_string());
let (postgresql_schema, postgresql_table) = if cfg!(feature = "postgresql") {
let row = Self::row(&fields, None);
let schema = if table.schema {
let schema = Self::write_schema(
false,
&rename,
table.exists,
column_types.as_ref().unwrap(),
)
.unwrap();
Some(quote::quote! {
fn schema_postgres() -> &'static str {
#schema
}
})
} else {
None
};
(
schema,
Some(quote::quote! {
fn from_row_postgres(row: ::db_derive::internal::PostgreRow) -> std::result::Result<Self, ::db_derive::Error>
where
Self: Sized,
{
std::result::Result::Ok(#ident {
#row
})
}
}),
)
} else {
(None, None)
};
let (sqlite_schema, sqlite_table) = if cfg!(feature = "postgresql") {
let row = Self::row(&fields, Some(quote::quote! { ? }));
let schema = if table.schema {
let schema =
Self::write_schema(true, &rename, table.exists, column_types.as_ref().unwrap())
.unwrap();
Some(quote::quote! {
fn schema_sqlite() -> &'static str {
#schema
}
})
} else {
None
};
(
schema,
Some(quote::quote! {
fn from_row_sqlite<'r>(row: &'r ::db_derive::internal::SQLiteRow<'r>) -> ::std::result::Result<Self, ::db_derive::Error>
where
Self: Sized,
{
::std::result::Result::Ok(#ident {
#row
})
}
}),
)
} else {
(None, None)
};
quote::quote! {
impl #impl_generics ::db_derive::prelude::Schema for #ident #type_generics #where_clause {
#postgresql_schema
#sqlite_schema
}
impl #impl_generics ::db_derive::prelude::Table for #ident #type_generics #where_clause {
#postgresql_table
#sqlite_table
fn columns() -> usize {
#columns
}
}
}
}
fn row(fields: &Fields<TableField>, mark: Option<TokenStream>) -> TokenStream {
let names = fields.iter().map(|f| {
let ident = f
.ident
.clone()
.expect("no ident, derive must be a named struct");
f.rename.clone().unwrap_or_else(|| ident.to_string())
});
let ids = fields.iter().map(|f| {
f.ident
.clone()
.expect("no ident, derive must be a named struct")
});
quote::quote! {
#( #ids: row.get(#names)#mark, )*
}
}
fn write_schema(
sqlite: bool,
name: &str,
exists: bool,
columns: &[Column],
) -> Result<String, fmt::Error> {
let mut buff = String::with_capacity(columns.len() * 20);
writeln!(
&mut buff,
"CREATE TABLE{} {} (",
if exists { " IF NOT EXISTS" } else { "" },
name
)?;
for (i, column) in columns.iter().enumerate() {
column.as_string(&mut buff, sqlite, i == columns.len())?;
}
write!(&mut buff, ");")?;
Ok(buff)
}
fn type_to_column_type(typ: &Type, null: Option<bool>) -> Option<ParsedColumnType> {
if let Type::Path(type_path) = typ {
type_path.path.segments.first().and_then(|segment| {
if segment.ident == "Option" {
if let PathArguments::AngleBracketed(bracketed) = &segment.arguments {
bracketed.args.first().and_then(|arg| {
if let GenericArgument::Type(typ) = arg {
Self::type_to_column_type(typ, Some(true))
} else {
None
}
})
} else {
None
}
} else if segment.ident == "bool" {
Some(ParsedColumnType::new(ColumnType::Bool, null))
} else if segment.ident == "i8" {
Some(ParsedColumnType::new(ColumnType::Char, null))
} else if segment.ident == "i16" {
Some(ParsedColumnType::new(ColumnType::SmallInt, null))
} else if segment.ident == "i32" {
Some(ParsedColumnType::new(ColumnType::Int, null))
} else if segment.ident == "i64" {
Some(ParsedColumnType::new(ColumnType::BigInt, null))
} else if segment.ident == "f32" {
Some(ParsedColumnType::new(ColumnType::Real, null))
} else if segment.ident == "f64" {
Some(ParsedColumnType::new(ColumnType::Double, null))
} else if segment.ident == "String" {
Some(ParsedColumnType::new(ColumnType::Text, null))
} else if segment.ident == "NaiveDate" {
Some(ParsedColumnType::new(ColumnType::Date, null))
} else if segment.ident == "NaiveTime" {
Some(ParsedColumnType::new(ColumnType::Time, null))
} else if segment.ident == "DateTime" {
if let PathArguments::AngleBracketed(bracketed) = &segment.arguments {
bracketed.args.first().and_then(|arg| {
if let GenericArgument::Type(typ) = arg {
if let Type::Path(type_path) = typ {
type_path.path.segments.first().and_then(|segment| {
if segment.ident == "Utc" {
Some(ParsedColumnType::new(ColumnType::Utc, null))
} else {
None
}
})
} else {
None
}
} else {
None
}
})
} else {
None
}
} else {
Some(ParsedColumnType::new(ColumnType::Text, null))
}
})
} else {
None
}
}
}
#[derive(Debug)]
struct Column {
name: String,
typ: ColumnType,
length: Option<usize>,
null: bool,
primary: bool,
unique: bool,
}
impl Column {
fn as_string(&self, buff: &mut impl fmt::Write, sqlite: bool, last: bool) -> fmt::Result {
write!(buff, "{} ", self.name)?;
match self.typ {
ColumnType::Bool => {
write!(buff, "BOOLEAN")?;
}
ColumnType::Char => {
if sqlite {
write!(buff, "TEXT")?;
} else {
write!(buff, "CHAR")?;
}
}
ColumnType::SmallInt => {
if sqlite {
write!(buff, "INTEGER")?;
} else {
write!(buff, "SMALLINT")?;
}
}
ColumnType::Int => {
write!(buff, "INTEGER")?;
}
ColumnType::BigInt => {
if sqlite {
write!(buff, "INTEGER")?;
} else {
write!(buff, "BIGINT")?;
}
}
ColumnType::Real => {
write!(buff, "REAL")?;
}
ColumnType::Double => {
if sqlite {
write!(buff, "REAL")?;
} else {
write!(buff, "DOUBLE PRECISION")?;
}
}
ColumnType::Text => {
if sqlite {
write!(buff, "TEXT")?;
} else if let Some(len) = self.length {
write!(buff, "VARHCAR({})", len)?;
} else {
write!(buff, "VARHCAR")?;
}
}
ColumnType::Date => {
if sqlite {
write!(buff, "TEXT")?;
} else {
write!(buff, "DATE")?;
}
}
ColumnType::Time => {
if sqlite {
write!(buff, "TEXT")?;
} else {
write!(buff, "TIME")?;
}
}
ColumnType::Utc => {
if sqlite {
write!(buff, "TEXT")?;
} else {
write!(buff, "TIMESTAMP WITH TIME ZONE")?;
}
}
}
if !self.null {
write!(buff, " NOT NULL")?;
}
if self.primary {
write!(buff, " PRIMARY")?;
}
if self.unique && !self.primary {
write!(buff, " UNIQUE")?;
}
if !last {
write!(buff, ",")?;
}
writeln!(buff)?;
Ok(())
}
}
#[derive(Clone, Copy, Debug)]
enum ColumnType {
Bool, Char, SmallInt, Int, BigInt, Real, Double, Text, Date, Time, Utc, }
#[derive(Debug)]
struct ParsedColumnType {
typ: ColumnType,
null: bool,
}
impl ParsedColumnType {
fn new(typ: ColumnType, null: Option<bool>) -> Self {
ParsedColumnType {
typ,
null: null.unwrap_or_else(|| false),
}
}
}