easy_sqlx_macro/
lib.rs

1use condition::create_conditions;
2use delete::{create_delete, create_delete_builder, create_delete_by_id};
3use field::create_field_wrapper;
4use heck::ToSnakeCase;
5use insert::{create_insert, create_insert_builder};
6use order::create_order_func;
7use proc_macro2::Span;
8use quote::quote;
9use select::{create_select_builder, create_select_by_id};
10use syn::{parse_macro_input, DeriveInput};
11
12mod attrs;
13mod condition;
14mod delete;
15mod field;
16mod insert;
17mod order;
18mod select;
19mod update;
20
21use attrs::{column::parse_column_attrs, table::parse_table_attrs};
22use update::{create_update, create_update_builder, create_update_by_id};
23
24/// 使用示例
25/// 定义表结构
26/// ```rust,ignore
27///     #[derive(Table)]
28///     #[table(
29///         indexes [
30///             (name = "123", columns("a", "b"))
31///         ]
32///     )]
33///     #[index(columns("ooi"))]
34///     struct Table1 {
35///         // #[col(column = "key", ignore, col_type = "abc", )]
36///         #[col(column = "key", comment = "123")]
37///         #[col(pk, autoincr, len = 100)]
38///         pub id: String,
39///         #[col(comment = "姓名", len = 20)]
40///         pub name: Option<String>,
41///         #[col(ignore)]
42///         pub t_o: chrono::NaiveTime,
43///         pub blob: Vec<u8>,
44///     }
45/// ```
46/// 同步表结构
47/// 参数 connection 为数据库连接
48/// ```rust,ignore
49///  sync_tables(connection, vec![Table1::table()], None).await?;
50/// ```
51#[proc_macro_derive(Table, attributes(table, index, col))]
52pub fn derive_table(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
53    let input: DeriveInput = parse_macro_input!(input);
54    // ident 当前枚举名称
55    let DeriveInput {
56        attrs, ident, data, ..
57    } = input;
58
59    // 列名称函数
60    let mut col_name_methods: Vec<proc_macro2::TokenStream> = Vec::new();
61    // let mut col_names: Vec<String> = Vec::new();
62    let mut cols = Vec::new();
63
64    // 列属性函数
65    let mut col_wrapper_methods: Vec<proc_macro2::TokenStream> = Vec::new();
66    // 条件属性函数
67    let mut col_conditions: Vec<proc_macro2::TokenStream> = Vec::new();
68    // 排序生成
69    let mut col_order_methods: Vec<proc_macro2::TokenStream> = Vec::new();
70
71    let mut struct_fields: Vec<syn::Field> = vec![];
72
73    let default_table_name = ident.clone().to_string().to_snake_case();
74    // 解析表属性及索引
75    let mut table = parse_table_attrs(&attrs, default_table_name)
76        .map_err(|err| panic!("{}", err))
77        .unwrap();
78
79    if let syn::Data::Struct(syn::DataStruct {
80        struct_token: _,
81        fields,
82        semi_token: _,
83    }) = data
84    {
85        for field in fields {
86            match parse_column_attrs(&field) {
87                Ok((col, rust_type, syn_type, is_vec)) => {
88                    if let Some(column) = col {
89                        if let Some(rust_type) = rust_type {
90                            if let Some(syn_type) = syn_type {
91                                let field_name = &column.name;
92
93                                if column.pk && column.nullable {
94                                    panic!("pk field must not nullable, consider remove type Option<> or remove pk props of field: {field_name}");
95                                }
96
97                                // 生成列方法名称
98                                let fn_name = syn::Ident::new(
99                                    format!("col_{}", &field_name).as_str(),
100                                    Span::call_site(),
101                                );
102
103                                // let col_name = column.get_query_column_name();
104                                // col_names.push(col_name.clone());
105
106                                // 添加列方法
107                                col_name_methods.push(quote! {
108                                    /// #col_name 列名称
109                                    pub fn #fn_name() -> easy_sqlx_core::sql::schema::column::Column {
110                                        #column
111                                    }
112                                });
113
114                                // 生成列函数
115                                if !table.query_only {
116                                    // 非 query_only 的表生成列 wrapper 函数
117                                    let wrappers = create_field_wrapper(
118                                        &column, &field, syn_type, &rust_type, is_vec,
119                                    );
120                                    col_wrapper_methods.extend(wrappers);
121                                }
122
123                                // 生成条件属性函数
124                                let conds = create_conditions(
125                                    &column, &field, syn_type, &rust_type, is_vec,
126                                );
127                                col_conditions.extend(conds);
128
129                                // 生成排序函数
130                                col_order_methods.push(create_order_func(&column));
131
132                                // 储存字段
133                                struct_fields.push(field);
134
135                                // let self_dot_name = syn::Ident::new(
136                                //     format!("$self.{}", &column.name).as_str(),
137                                //     Span::call_site(),
138                                // );
139
140                                // 添加列
141                                cols.push(column);
142                            }
143                        }
144                    }
145                }
146                Err(err) => {
147                    panic!("{}", err);
148                }
149            }
150        }
151    }
152
153    // if let Err(err) = check_col_in_table_attrs(&attrs) {
154    //     panic!("{}", err);
155    // }
156
157    table.columns = cols.clone();
158
159    if let Err(err) = table.check_indexes_columns() {
160        // 有错误
161        panic!("{}", err);
162    }
163
164    let table_name = table.name_with_schema();
165
166    let (insert, build_insert) = if table.query_only {
167        (quote! {}, quote! {})
168    } else {
169        (create_insert(&table), create_insert_builder())
170    };
171
172    let (update, build_update, update_by_id) = if table.query_only {
173        (quote! {}, quote! {}, quote! {})
174    } else {
175        (
176            create_update(&table, &ident),
177            create_update_builder(),
178            create_update_by_id(&table, &ident, &struct_fields),
179        )
180    };
181
182    let (delete, delete_by_id, build_delete) = if table.query_only {
183        (quote! {}, quote! {}, quote! {})
184    } else {
185        (
186            create_delete(&table, &ident),
187            create_delete_by_id(&table, &ident, &struct_fields),
188            create_delete_builder(),
189        )
190    };
191
192    let build_select = create_select_builder();
193    let select_by_id = create_select_by_id(&table, &ident, &struct_fields);
194
195    // 实现 comment 方法
196    let output = quote! {
197        impl #ident {
198            /// 获取数据库表名称
199            pub fn table_name() -> &'static str {
200                #table_name
201            }
202
203            /// 获取表结构定义
204            pub fn table() -> easy_sqlx_core::sql::schema::table::TableSchema {
205                #table
206            }
207
208            /// 列名称函数
209            #(#col_name_methods) *
210            // /// 获取所有列名称
211            // pub fn all_cols() -> Vec<&'static str> {
212            //     [#(#col_names), *].to_vec()
213            // }
214
215            #insert
216            #build_insert
217
218            #update
219            #build_update
220            #update_by_id
221
222            #delete
223            #delete_by_id
224            #build_delete
225
226            #(#col_order_methods) *
227
228            #build_select
229            #select_by_id
230
231            #(#col_wrapper_methods) *
232
233            #(#col_conditions) *
234            fn columns() -> Vec<easy_sqlx_core::sql::schema::column::Column> {
235                [#(#cols), *].to_vec()
236            }
237        }
238    };
239    output.into()
240}