db-derive-impl 0.1.8

Derive proc-macro for db-derive
Documentation
use {
    darling::{
        ast::{Data, Fields},
        util::Ignored,
        FromDeriveInput,
    },
    proc_macro2::TokenStream,
    std::{
        fmt::{self, Write},
        string::ToString,
    },
    syn::{DeriveInput, GenericArgument, Generics, Ident, PathArguments, Type},
};

#[derive(Debug, darling::FromDeriveInput)]
#[darling(attributes(table), supports(struct_named))]
pub struct Table {
    ident: Ident,
    data: Data<Ignored, TableField>,
    generics: Generics,

    #[darling(default)]
    rename: Option<String>,

    #[darling(default)]
    exists: bool,

    #[darling(default)]
    schema: bool,
}

#[derive(Clone, Debug, darling::FromField)]
#[darling(attributes(table))]
pub struct TableField {
    ident: Option<Ident>,
    ty: Type,

    #[darling(default)]
    rename: Option<String>,

    #[darling(default)]
    primary: bool,

    #[darling(default)]
    unique: bool,

    #[darling(default)]
    ignore: bool,

    #[darling(default)]
    length: Option<usize>,

    #[darling(default)]
    kind: Option<String>,
}

impl Table {
    pub fn derive(input: DeriveInput) -> TokenStream {
        let table = Table::from_derive_input(&input).unwrap();

        let column_types = if table.schema {
            Some(
                table
                    .data
                    .clone()
                    .take_struct()
                    .map(|table| {
                        table
                            .into_iter()
                            .map(|field| {
                                Self::type_to_column_type(&field.ty, None).map(|parsed| Column {
                                    name: field.rename.clone().unwrap_or_else(|| {
                                        field.ident.clone().map(|i| i.to_string()).unwrap()
                                    }),
                                    typ: parsed.typ,
                                    length: field.length,
                                    null: parsed.null,
                                    primary: field.primary,
                                    unique: field.unique,
                                })
                            })
                            .collect::<Option<Vec<_>>>()
                    })
                    .unwrap()
                    .unwrap_or_else(|| {
                        panic!(
                            "Table {} does not have a valid column type",
                            table.ident.to_string()
                        )
                    }),
            )
        } else {
            None
        };

        let fields = if let Data::Struct(fields) = &table.data {
            fields
        } else {
            panic!("derive table only works on named structs");
        };

        let columns = fields.len();

        let ident = &input.ident;

        let (impl_generics, type_generics, where_clause) = table.generics.split_for_impl();

        let rename = table.rename.unwrap_or_else(|| ident.to_string());

        let (postgresql_schema, postgresql_table) = if cfg!(feature = "postgresql") {
            let row = Self::row(&fields, None);

            let schema = if table.schema {
                let schema = Self::write_schema(
                    false,
                    &rename,
                    table.exists,
                    column_types.as_ref().unwrap(),
                )
                .unwrap();

                Some(quote::quote! {
                    fn schema_postgres() -> &'static str {
                        #schema
                    }
                })
            } else {
                None
            };

            (
                schema,
                Some(quote::quote! {
                    fn from_row_postgres(row: ::db_derive::internal::PostgreRow) -> std::result::Result<Self, ::db_derive::Error>
                    where
                        Self: Sized,
                    {
                        std::result::Result::Ok(#ident {
                            #row
                        })
                    }
                }),
            )
        } else {
            (None, None)
        };

        let (sqlite_schema, sqlite_table) = if cfg!(feature = "postgresql") {
            let row = Self::row(&fields, Some(quote::quote! { ? }));

            let schema = if table.schema {
                let schema =
                    Self::write_schema(true, &rename, table.exists, column_types.as_ref().unwrap())
                        .unwrap();

                Some(quote::quote! {
                    fn schema_sqlite() -> &'static str {
                        #schema
                    }
                })
            } else {
                None
            };

            (
                schema,
                Some(quote::quote! {
                    fn from_row_sqlite<'r>(row: &'r ::db_derive::internal::SQLiteRow<'r>) -> ::std::result::Result<Self, ::db_derive::Error>
                    where
                        Self: Sized,
                    {
                        ::std::result::Result::Ok(#ident {
                            #row
                        })
                    }
                }),
            )
        } else {
            (None, None)
        };

        quote::quote! {
            impl #impl_generics ::db_derive::prelude::Schema for #ident #type_generics #where_clause {
                #postgresql_schema

                #sqlite_schema
            }

            impl #impl_generics ::db_derive::prelude::Table for #ident #type_generics #where_clause {
                #postgresql_table

                #sqlite_table

                fn columns() -> usize {
                    #columns
                }
            }
        }
    }

    fn row(fields: &Fields<TableField>, mark: Option<TokenStream>) -> TokenStream {
        let names = fields.iter().map(|f| {
            let ident = f
                .ident
                .clone()
                .expect("no ident, derive must be a named struct");

            f.rename.clone().unwrap_or_else(|| ident.to_string())
        });

        let ids = fields.iter().map(|f| {
            f.ident
                .clone()
                .expect("no ident, derive must be a named struct")
        });

        quote::quote! {
            #( #ids: row.get(#names)#mark, )*
        }
    }

    fn write_schema(
        sqlite: bool,
        name: &str,
        exists: bool,
        columns: &[Column],
    ) -> Result<String, fmt::Error> {
        let mut buff = String::with_capacity(columns.len() * 20);

        writeln!(
            &mut buff,
            "CREATE TABLE{} {} (",
            if exists { " IF NOT EXISTS" } else { "" },
            name
        )?;

        for (i, column) in columns.iter().enumerate() {
            column.as_string(&mut buff, sqlite, i == columns.len())?;
        }

        write!(&mut buff, ");")?;

        Ok(buff)
    }

    fn type_to_column_type(typ: &Type, null: Option<bool>) -> Option<ParsedColumnType> {
        if let Type::Path(type_path) = typ {
            type_path.path.segments.first().and_then(|segment| {
                if segment.ident == "Option" {
                    // Only `Option<...>`
                    if let PathArguments::AngleBracketed(bracketed) = &segment.arguments {
                        bracketed.args.first().and_then(|arg| {
                            // Only `Option<String>` not `Option<std::string::String>`
                            // TODO: Figure out a way to allow that ^
                            if let GenericArgument::Type(typ) = arg {
                                Self::type_to_column_type(typ, Some(true))
                            } else {
                                None
                            }
                        })
                    } else {
                        None
                    }
                } else if segment.ident == "bool" {
                    Some(ParsedColumnType::new(ColumnType::Bool, null))
                } else if segment.ident == "i8" {
                    Some(ParsedColumnType::new(ColumnType::Char, null))
                } else if segment.ident == "i16" {
                    Some(ParsedColumnType::new(ColumnType::SmallInt, null))
                } else if segment.ident == "i32" {
                    Some(ParsedColumnType::new(ColumnType::Int, null))
                } else if segment.ident == "i64" {
                    Some(ParsedColumnType::new(ColumnType::BigInt, null))
                } else if segment.ident == "f32" {
                    Some(ParsedColumnType::new(ColumnType::Real, null))
                } else if segment.ident == "f64" {
                    Some(ParsedColumnType::new(ColumnType::Double, null))
                } else if segment.ident == "String" {
                    Some(ParsedColumnType::new(ColumnType::Text, null))
                } else if segment.ident == "NaiveDate" {
                    Some(ParsedColumnType::new(ColumnType::Date, null))
                } else if segment.ident == "NaiveTime" {
                    Some(ParsedColumnType::new(ColumnType::Time, null))
                } else if segment.ident == "DateTime" {
                    // Only `DateTime<...>`
                    if let PathArguments::AngleBracketed(bracketed) = &segment.arguments {
                        bracketed.args.first().and_then(|arg| {
                            // Only `DateTime<Utc>` not `DateTime<chrono::Utc>`
                            // TODO: Figure out a way to allow that ^
                            if let GenericArgument::Type(typ) = arg {
                                if let Type::Path(type_path) = typ {
                                    type_path.path.segments.first().and_then(|segment| {
                                        if segment.ident == "Utc" {
                                            Some(ParsedColumnType::new(ColumnType::Utc, null))
                                        } else {
                                            None
                                        }
                                    })
                                } else {
                                    None
                                }
                            } else {
                                None
                            }
                        })
                    } else {
                        None
                    }
                } else {
                    // Default to a string
                    // TODO: Handle custom (db-derive) types
                    Some(ParsedColumnType::new(ColumnType::Text, null))
                }
            })
        } else {
            None
        }
    }
}

#[derive(Debug)]
struct Column {
    name: String,
    typ: ColumnType,
    length: Option<usize>,
    null: bool,
    primary: bool,
    unique: bool,
}

impl Column {
    fn as_string(&self, buff: &mut impl fmt::Write, sqlite: bool, last: bool) -> fmt::Result {
        write!(buff, "{} ", self.name)?;

        match self.typ {
            ColumnType::Bool => {
                write!(buff, "BOOLEAN")?;
            }
            ColumnType::Char => {
                if sqlite {
                    write!(buff, "TEXT")?;
                } else {
                    write!(buff, "CHAR")?;
                }
            }
            ColumnType::SmallInt => {
                if sqlite {
                    write!(buff, "INTEGER")?;
                } else {
                    write!(buff, "SMALLINT")?;
                }
            }
            ColumnType::Int => {
                write!(buff, "INTEGER")?;
            }
            ColumnType::BigInt => {
                if sqlite {
                    write!(buff, "INTEGER")?;
                } else {
                    write!(buff, "BIGINT")?;
                }
            }
            ColumnType::Real => {
                write!(buff, "REAL")?;
            }
            ColumnType::Double => {
                if sqlite {
                    write!(buff, "REAL")?;
                } else {
                    write!(buff, "DOUBLE PRECISION")?;
                }
            }
            ColumnType::Text => {
                if sqlite {
                    write!(buff, "TEXT")?;
                } else if let Some(len) = self.length {
                    write!(buff, "VARHCAR({})", len)?;
                } else {
                    write!(buff, "VARHCAR")?;
                }
            }
            ColumnType::Date => {
                if sqlite {
                    write!(buff, "TEXT")?;
                } else {
                    write!(buff, "DATE")?;
                }
            }
            ColumnType::Time => {
                if sqlite {
                    write!(buff, "TEXT")?;
                } else {
                    write!(buff, "TIME")?;
                }
            }
            ColumnType::Utc => {
                if sqlite {
                    write!(buff, "TEXT")?;
                } else {
                    write!(buff, "TIMESTAMP WITH TIME ZONE")?;
                }
            }
        }

        if !self.null {
            write!(buff, " NOT NULL")?;
        }

        if self.primary {
            write!(buff, " PRIMARY")?;
        }

        if self.unique && !self.primary {
            write!(buff, " UNIQUE")?;
        }

        if !last {
            write!(buff, ",")?;
        }

        writeln!(buff)?;

        Ok(())
    }
}

#[derive(Clone, Copy, Debug)]
enum ColumnType {
    Bool,     // bool
    Char,     // i8
    SmallInt, // i16
    Int,      // i32
    BigInt,   // i64
    Real,     // f32
    Double,   // f64
    Text,     // String
    Date,     // NaiveDate
    Time,     // NaiveTime
    Utc,      // DateTime<Utc>,
}

#[derive(Debug)]
struct ParsedColumnType {
    typ: ColumnType,
    null: bool,
}

impl ParsedColumnType {
    fn new(typ: ColumnType, null: Option<bool>) -> Self {
        ParsedColumnType {
            typ,
            null: null.unwrap_or_else(|| false),
        }
    }
}