1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
mod parse;

use darling::FromDeriveInput;
use proc_macro::{self, TokenStream};
use quote::quote;
use syn::parse_macro_input;

use parse::ModelInput;

/// Automatically implement `Model` for your struct.
///
/// ## Attributes
///  * `table` - for structs:
///      - `table_name: String`: optional. Overwrites the table name
///  * `column` - for struct fields:
///      - `dtype: String`: optional. Overwrites the postgres datatype
///      - `unique: bool`: optional, default: `false`. Enables the `unqiue` constraint.
///      - `auto: bool`: optional, default: `false`. This autogenerated the
///        values.
///      - `column_name: String`: optional. Overwrites the column name.
#[proc_macro_derive(Model, attributes(table, column))]
pub fn derive(input: TokenStream) -> TokenStream {
    let opts = ModelInput::from_derive_input(&parse_macro_input!(input)).unwrap();

    let ident = opts.ident();

    let table_name = opts.table_name();

    // Retrieve the struct's fields
    let fields = opts.fields().collect::<Vec<_>>();

    if opts.n_fields() == 0 {
        panic!("struct must have at least one field to become a `Model`")
    }

    // Get the fields' idents
    let field_idents = fields
        .clone()
        .into_iter()
        .map(|f| f.clone().ident())
        .collect::<Vec<_>>();
    let column_names = fields.iter().map(|f| f.column_name()).collect::<Vec<_>>();

    let insert_fields = opts.insert_fields().collect::<Vec<_>>();
    let insert_columns_counter = (1..=insert_fields.len())
        .map(|i| format!("${i}"))
        .collect::<Vec<_>>()
        .join(", ");

    let insert_columns = insert_fields
        .clone()
        .into_iter()
        .map(|f| f.column_name())
        .collect::<Vec<_>>()
        .join(", ");

    let insert_columns_idents = insert_fields
        .clone()
        .into_iter()
        .map(|f| f.ident())
        .collect::<Vec<_>>();
    let insert_column_dtypes = insert_fields
        .clone()
        .into_iter()
        .map(|f| f.insert_arg_type())
        .collect::<Vec<_>>();

    let table_creation_sql = opts.table_creation_sql();

    let impl_column_consts = opts.impl_column_consts();

    // Generate the needed impl code
    let output = quote!(
        #impl_column_consts

        impl<'a> TryFrom<&'a pg_worm::Row> for #ident {
            type Error = pg_worm::pg::Error;

            fn try_from(value: &'a pg_worm::Row) -> Result<#ident, Self::Error> {
                // Parse each column into the corresponding field
                Ok(#ident {
                    #(#field_idents: value
                        .try_get(#column_names)
                        .expect(
                            format!(
                                "couldn't parse {} from postgres value",
                                stringify!(#field_idents)
                            ).as_str()
                        )
                    ),*
                })
            }
        }

        #[pg_worm::async_trait]
        impl Model<#ident> for #ident {
            #[inline]
            fn _table_creation_sql() -> &'static str {
                #table_creation_sql
            }

            async fn select(filter: pg_worm::Filter) -> Vec<#ident> {
                // Retrieve client. Panic if not connected
                let client = pg_worm::_get_client()
                    .expect("not connected to db");

                // Convert args to correct datatype
                let args: Vec<&(dyn pg_worm::pg::types::ToSql + Sync)> = filter
                    ._args()
                    .into_iter()
                    .map(|i| &**i as _)
                    .collect();

                // Make the query
                let rows = client
                    .query(
                        format!(
                            "SELECT * FROM {} {}",
                            #table_name,
                            if filter._stmt().is_empty() { "".to_string() }
                            else { format!("WHERE {}", filter._stmt()) }
                        ).as_str(),
                        args.as_slice()
                    ).await.unwrap();

                // Parse each result to the rust type
                rows
                    .iter()
                    .map(|r|
                        #ident::try_from(r).expect("couldn't parse data")
                    ).collect()
            }

            async fn select_one(filter: pg_worm::Filter) -> Option<#ident> {
                // Retrieve client. Panic if not connected
                let client = pg_worm::_get_client()
                    .expect("not connected to db");

                // Convert args to correct datatype
                let args: Vec<&(dyn pg_worm::pg::types::ToSql + Sync)> = filter
                    ._args()
                    .into_iter()
                    .map(|i| &**i as _)
                    .collect();

                // Make the query
                let rows = client
                    .query(
                        // Fill in table name and filter
                        format!(
                            "SELECT * FROM {} {} LIMIT 1",
                            #table_name,
                            if filter._stmt().is_empty() { "".to_string() }
                            else { format!("WHERE {}", filter._stmt()) }
                        ).as_str(),
                        // Pass filter arguments
                        args.as_slice()
                    ).await.unwrap();

                // If no entities could be fetched, return None
                if rows.len() != 1 {
                    return None;
                }

                // Else parse and return the first entity fetched
                Some(#ident::try_from(&rows[0]).unwrap())
            }

            async fn delete(filter: pg_worm::Filter) -> u64 {
                // Retrieve client. Panic if not connected
                let client = pg_worm::_get_client()
                    .expect("not connected to db");

                // Convert args to correct datatype
                let args: Vec<&(dyn pg_worm::pg::types::ToSql + Sync)> = filter
                    ._args()
                    .into_iter()
                    .map(|i| &**i as _)
                    .collect();

                // Make the query
                let rows_affected = client
                    .execute(
                        // Fill in table name and filter
                        format!(
                            "DELETE FROM {} {}",
                            #table_name,
                            if filter._stmt().is_empty() { "".to_string() }
                            else { format!("WHERE {}", filter._stmt()) }
                        ).as_str(),
                        // Pass filter arguments
                        args.as_slice()
                    ).await.unwrap();

                // Return the number of rows affected
                rows_affected
            }
        }

        impl #ident {
            /// Insert a new entity into the database.
            ///
            /// For columns which are autogenerated (like in the example below, `id`),
            /// no value has to be specified.
            ///
            /// # Example
            ///
            /// ```ignore
            /// use pg_worm::Model;
            ///
            /// #[derive(Model)]
            /// struct Book {
            ///     #[column(dtype = "BIGSERIAL")]
            ///     id: i64,
            ///     #[column(dtype = "TEXT")]
            ///     title: String
            /// }
            ///
            /// async fn some_func() -> Result<(), pg_worm::Error> {
            ///     Book::insert("Foo".to_string()).await?;
            /// }
            /// ```
            pub async fn insert(
                #(#insert_columns_idents: #insert_column_dtypes),*
            ) -> Result<(), pg_worm::Error> {
                // Prepare sql statement
                let stmt = format!(
                    "INSERT INTO {} ({}) VALUES ({})",
                    #table_name,
                    #insert_columns,
                    #insert_columns_counter
                );

                // Retrieve the client
                let client = pg_worm::_get_client()?;

                // Execute the query
                client.execute(
                    stmt.as_str(),
                    &[
                        #(&#insert_columns_idents), *
                    ]
                ).await?;

                // Everything's fine
                Ok(())
            }
        }
    );

    output.into()
}