rusqlite_model_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::parse_macro_input;
4use syn::DeriveInput;
5
6#[proc_macro_derive(Model, attributes(sql_type))]
7pub fn derive_model(input: TokenStream) -> TokenStream {
8    let ast = parse_macro_input!(input as DeriveInput);
9
10    let name = &ast.ident;
11    let fields = match &ast.data {
12        syn::Data::Struct(data) => match &data.fields {
13            syn::Fields::Named(fields) => fields.named.iter(),
14            _ => panic!("expected named fields"),
15        },
16        _ => panic!("expected struct"),
17    }
18    .collect();
19
20    let modelimpl = impl_model(&name, &fields);
21    let tryfromrow = impl_tryfrom(&fields);
22
23    let mut ts = TokenStream::new();
24    ts.extend(modelimpl);
25    ts.extend(tryfromrow);
26    ts
27}
28
29fn impl_model(name: &syn::Ident, fields: &Vec<&syn::Field>) -> TokenStream {
30    let param_holders = vec![quote!(?); fields.len()];
31    let mut table_name = format!("{}s", name);
32    table_name.make_ascii_lowercase();
33
34    let field_names: Vec<_> = fields
35        .iter()
36        .map(|f| f.ident.as_ref().expect("expected named field"))
37        .collect();
38
39    let sql_types = fields.iter().map(|f| {
40        for attr in &f.attrs {
41            if let Some(attrident) = attr.path.get_ident() {
42                if attrident.to_string() == "sql_type" {
43                    return attr.parse_args().expect("failed to read sql_type");
44                }
45            }
46        }
47
48        match &f.ty {
49            syn::Type::Path(tp) => match tp.path.get_ident() {
50                Some(ident) => match ident.to_string().as_ref() {
51                    "bool" => quote!(BOOL NOT NULL),
52                    "String" => quote!(TEXT NOT NULL),
53                    _ => panic!("Don't know how to map to SQL type {}", ident.to_string()),
54                },
55                None => panic!("Unsupported type path"),
56            },
57            _ => panic!("Don't know how to map to SQL type {:?}", f.ty),
58        }
59    });
60
61    let gen = quote! {
62        impl Model<'_> for #name {
63
64            fn create_table(conn: &rusqlite::Connection) -> rusqlite::Result<usize> {
65                conn.execute(
66                    std::stringify!(
67                        CREATE TABLE #table_name (
68                            #(#field_names #sql_types),*
69                        )
70                    ),
71                    rusqlite::NO_PARAMS,
72                )
73            }
74
75            fn drop_table(conn: &rusqlite::Connection) -> rusqlite::Result<usize> {
76            conn.execute(std::stringify!(
77                    DROP TABLE IF EXISTS #table_name
78                ),
79                rusqlite::NO_PARAMS,
80            )
81            }
82
83            fn insert(self, conn: &rusqlite::Connection) -> rusqlite::Result<usize> {
84                let mut stmt = conn
85                    .prepare(std::stringify!(
86                        INSERT INTO #table_name (#(#field_names),*) VALUES (#(#param_holders),*)
87                    ))
88                    .unwrap();
89
90                println!("{:?}", stmt);
91                stmt.execute(self.into_params())
92            }
93
94            fn into_params(self) -> std::vec::IntoIter<Box<dyn rusqlite::ToSql>> {
95                let ret: Vec<Box<dyn rusqlite::ToSql>> = vec![
96                    #(Box::new(self.#field_names),)*
97                ];
98                ret.into_iter()
99            }
100
101        }
102    };
103
104    gen.into()
105}
106
107fn impl_tryfrom(fields: &Vec<&syn::Field>) -> TokenStream {
108    let field_names = fields
109        .iter()
110        .map(|f| f.ident.as_ref().expect("expected named field"));
111
112    let gen = quote! {
113        impl<'a> std::convert::TryFrom<&'a rusqlite::Row<'a>> for Transaction {
114            type Error = rusqlite::Error;
115
116            fn try_from(row: &rusqlite::Row<'_>) -> rusqlite::Result<Self> {
117                Ok(Self {
118                    #(#field_names: row.get(row.column_index(std::stringify!(#field_names))?)?,)*
119                })
120            }
121        }
122    };
123
124    gen.into()
125}