dysql_macro/
lib.rs

1#![feature(proc_macro_span)]
2
3//! Dysql 是一个轻量级的编译时生成 SQL 模板的库,它在运行时根据传入的 DTO 自动生成动态的 SQL 并设置数据参数,
4//! 在底层 Dysql 使用 sqlx, tokio-postgres, rbac 等框架执行最终的 SQL。
5
6mod sql_fragment;
7mod sql_expand;
8
9use proc_macro::TokenStream;
10use sql_expand::SqlExpand;
11use sql_fragment::{STATIC_SQL_FRAGMENT_MAP, SqlFragment};
12use syn::{parse_macro_input, Token};
13use std::{collections::HashMap, sync::RwLock, path::PathBuf};
14use quote::quote;
15
16use sql_fragment::get_sql_fragment;
17
18/// 用于解析 dysql 所有过程宏的语句
19#[allow(dead_code)]
20#[derive(Debug)]
21pub(crate) struct DyClosure {
22    executor_info: ExecutorInfo,
23    dto_info: DtoInfo,
24    sql_name: Option<String>,
25    ret_type: Option<syn::Path>, // return type
26    body: String,
27    source_file: PathBuf,
28}
29
30#[derive(Debug)]
31enum RefKind {
32    Immutable,
33    Mutable,
34    None
35}
36
37
38#[derive(Debug)]
39struct DtoInfo {
40    src: Option<syn::Ident>,
41    ref_kind: RefKind,
42}
43
44impl DtoInfo {
45    pub fn new(src: Option<syn::Ident>, ref_kind: RefKind) -> Self {
46        Self {
47            src,
48            ref_kind,
49        }
50    }
51
52    pub fn gen_token(&self) -> proc_macro2::TokenStream {
53        if let Some(_) = self.src {
54            let mut rst = match self.ref_kind {
55                RefKind::Immutable => quote!(&),
56                RefKind::Mutable => quote!(&mut),
57                RefKind::None => quote!(),
58            };
59    
60            let dto = &self.src;
61            rst.extend(quote!(#dto));
62
63            rst.into()
64        } else {
65            quote!()
66        }
67    }
68}
69
70
71#[derive(Debug)]
72struct ExecutorInfo {
73    src: syn::Ident,
74    ref_kind: RefKind,
75    is_deref: bool,
76}
77
78impl ExecutorInfo {
79    pub fn new(src: syn::Ident, ref_kind: RefKind, is_deref: bool) -> Self {
80        Self {
81            src,
82            ref_kind,
83            is_deref,
84        }
85    }
86
87    pub fn gen_token(&self) -> proc_macro2::TokenStream {
88        let mut rst = match self.ref_kind {
89            RefKind::Immutable => quote!(&),
90            RefKind::Mutable => quote!(&mut),
91            RefKind::None => quote!(),
92        };
93
94        if self.is_deref {
95            rst.extend(quote!(*))
96        }
97
98        let executor = &self.src;
99        rst.extend(quote!(#executor));
100
101        quote!((#rst))
102    }
103}
104
105impl syn::parse::Parse for DyClosure {
106    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
107        // 加载 .env 文件中的环境变量,读取自动持久化 sql 文件的参数
108        dotenv::dotenv().ok();
109
110        // 测试是否 | 开始
111        input.parse::<syn::Token!(|)>()?;
112
113        // 解析 executor 的引用(可能为 &mut, &)
114        let executor_ref_kind: RefKind;
115        match input.parse::<syn::Token!(&)>() {
116            Err(_) => executor_ref_kind = RefKind::None,
117            Ok(_) => match input.parse::<syn::Token!(mut)>() {
118                Err(_) => executor_ref_kind = RefKind::Immutable,
119                Ok(_) => executor_ref_kind = RefKind::Mutable,
120            }
121        }
122
123        // 解析 executor (可能为 *executor, executor)
124        let is_executor_deref: bool;
125        let executor: syn::Ident;
126        match input.parse::<syn::Token!(*)>() {
127            Ok(_) => is_executor_deref = true,
128            Err(_) => is_executor_deref = false,
129        }
130        match input.parse::<syn::Ident>() {
131            Err(e) => return Err(e),
132            Ok(ex) => executor = ex,
133        }
134
135        // 测试是否 | 结束, 并解析 ',dto '(dto 可能为 _, &dto, &mut dto)
136        let sql_name: Option<String>;
137        let dto: Option<syn::Ident>;
138        let dto_ref_kind: RefKind;
139        match input.parse::<syn::Token!(|)>() {
140            Ok(_) => {
141                sql_name = None;
142                dto = None;
143                dto_ref_kind = RefKind::None;
144            },
145            Err(_) => match input.parse::<syn::Token!(,)>() {
146                Err(e) =>  return Err(e),
147                Ok(_) => {
148                    // 解析 dto
149                    match input.parse::<syn::Token!(_)>() {
150                        Ok(_) => {
151                            dto = None;
152                            dto_ref_kind = RefKind::None;
153                        },
154                        Err(_) => {
155                            match input.parse::<syn::Token!(&)>() {
156                                Ok(_) => match input.parse::<syn::Token!(mut)>(){
157                                    Ok(_) => dto_ref_kind = RefKind::Mutable,
158                                    Err(_) => dto_ref_kind = RefKind::Immutable,
159                                },
160                                Err(_) =>  dto_ref_kind = RefKind::None,
161                            }
162                            match input.parse::<syn::Ident>() {
163                                Err(e) => return Err(e),
164                                Ok(d) => dto = Some(d),
165                            }
166                        }
167                    }
168
169                    // 测试是否 | 结束,并解析 , 'sql_name |'
170                    match input.parse::<syn::Token!(|)>() {
171                        // | 结束
172                        Ok(_) => sql_name = None,
173                        // 解析是否为接下来是否为 " , sql_name |"
174                        Err(_) => match input.parse::<syn::Token!(,)>() {
175                            Err(e) => return Err(e),
176                            // 解析 sql_name
177                            Ok(_) => {
178                                match input.parse::<syn::Token!(_)>() {
179                                    Ok(_) => { sql_name = None },
180                                    Err(_) => match input.parse::<syn::LitStr>() {
181                                        Ok(s) => sql_name = Some(s.value()),
182                                        Err(_) => return Err(syn::Error::new(proc_macro2::Span::call_site(), "need specify the sql_name")),
183                                    }
184                                }
185                                // | 结束
186                                input.parse::<syn::Token!(|)>()?; 
187                            }
188                        }
189                    }
190                }
191            }
192        }
193
194        // 解析 -> 符号
195        let ret_type:Option<syn::Path>;
196        match input.parse::<syn::Token!(->)>() {
197            // 解析 ret_type
198            Ok(_) => match input.parse::<syn::Path>() {
199                Ok(p) => ret_type = Some(p),
200                Err(_) => 
201                    return Err(syn::Error::new(proc_macro2::Span::call_site(), "Need specify the return type")),
202            }
203            Err(_) => ret_type = None,
204        };
205
206        // 解析 { sql body } 
207        let body = parse_body(input)?;
208        let body: Vec<String> = body.split('\n').into_iter().map(|f| f.trim().to_owned()).collect();
209        let body = body.join(" ").to_owned();
210
211        // 获取当前被解析的文件位置
212        let span: proc_macro::Span = input.span().unwrap();
213        let source_file = span.source_file().path();
214
215        let executor_info = ExecutorInfo::new(executor, executor_ref_kind, is_executor_deref);
216        let dto_info = DtoInfo::new(dto, dto_ref_kind);
217
218        let dsf = DyClosure { executor_info, dto_info, sql_name, ret_type, body, source_file };
219        // eprintln!("{:#?}", dsf);
220
221        Ok(dsf)
222    }
223}
224
225/// 解析 sql body
226fn parse_body(input: &syn::parse::ParseBuffer) -> Result<String, syn::Error> {
227    let body_buf;
228    // 解析大括号
229    syn::braced!(body_buf in input);
230    
231    let ts = body_buf.cursor().token_stream().into_iter();
232    let mut sql = String::new();
233    for it in ts {
234        match it {
235            proc_macro2::TokenTree::Group(_) => {
236                return Err(syn::Error::new(input.span(), "error not support group in sql".to_owned()));
237            },
238            proc_macro2::TokenTree::Ident(_) => {
239                let v: syn::Ident = body_buf.parse()?;
240                let sql_fragment = get_sql_fragment(&v.to_string());
241                
242                if let Some(s) = sql_fragment {
243                    sql.push_str(&s);
244                } else {
245                    return Err(syn::Error::new(input.span(), "error not found sql identity".to_owned()));
246                }
247            },
248            proc_macro2::TokenTree::Punct(v) => {
249                if v.to_string() == "+" {
250                    body_buf.parse::<Token!(+)>()?;
251                } else {
252                    return Err(syn::Error::new(input.span(), "error only support '+' expr".to_owned()));
253                }
254            },
255            proc_macro2::TokenTree::Literal(_) => {
256                let rst: syn::LitStr = body_buf.parse()?;
257                
258                sql.push_str(&rst.value());
259            },
260        };
261    }
262
263    Ok(sql)
264}
265
266// /// 根据 s 生成 syn::Path 对象,用于 dysql 中有返回值的过程宏
267// pub(crate) fn gen_type_path(s: &str) -> syn::Path {
268//     let seg = syn::PathSegment {
269//         ident: syn::Ident::new(s, proc_macro2::Span::call_site()),
270//         arguments: syn::PathArguments::None,
271//     };
272//     let mut punct: Punctuated<syn::PathSegment, syn::Token![::]> = Punctuated::new();
273//     punct.push_value(seg);
274//     let path = syn::Path{ leading_colon: None, segments: punct };
275
276//     path
277// }
278
279/// fetch all datas that filtered by dto
280/// 
281/// # Examples
282///
283/// Basic usage:
284/// 
285/// ```ignore
286/// let mut conn = connect_db().await;
287/// 
288/// let dto = UserDto {id: None, name: None, age: 13};
289/// let rst = fetch_all!(|&conn, dto| -> User {
290///     r#"select * from test_user 
291///     where 1 = 1
292///         {{#name}}and name = :name{{/name}}
293///         {{#age}}and age > :age{{/age}}
294///     order by id"#
295/// }).unwrap();
296/// 
297/// assert_eq!(
298///     vec![
299///         User { id: 2, name: Some("zhanglan".to_owned()), age: Some(21) }, 
300///         User { id: 3, name: Some("zhangsan".to_owned()), age: Some(35) },
301///     ], 
302///     rst
303/// );
304/// ```
305#[proc_macro]
306pub fn fetch_all(input: TokenStream) -> TokenStream {
307    // 将 input 解析成 SqlClosure
308    let st = syn::parse_macro_input!(input as DyClosure);
309
310    // 必须要指定单个 item 的返回值类型
311    if st.ret_type.is_none() { panic!("return type can't be null.") }
312
313    match SqlExpand.fetch_all(&st) {
314        Ok(ret) => ret.into(),
315        Err(e) => e.into_compile_error().into(),
316    }
317}
318
319///
320/// fetch one data that filtered by dto
321/// 
322/// # Examples
323///
324/// Basic usage:
325/// 
326/// ```ignore
327/// let conn = connect_db().await;
328/// 
329/// let dto = UserDto {id: 2, name: None, age: None};
330/// let rst = fetch_one!(|&conn, dto| -> User {
331///     r#"select * from test_user 
332///     where id = :id
333///     order by id"#
334/// }).unwrap();
335/// 
336/// assert_eq!(User { id: 2, name: Some("zhanglan".to_owned()), age: Some(21) }, rst);
337/// ```
338#[proc_macro]
339pub fn fetch_one(input: TokenStream) -> TokenStream {
340    // 将 input 解析成 SqlClosure
341    let st = syn::parse_macro_input!(input as DyClosure);
342
343    // 必须要指定单个 item 的返回值类型
344    if st.ret_type.is_none() { panic!("return type can't be null.") }
345
346    match SqlExpand.fetch_one(&st) {
347        Ok(ret) => ret.into(),
348        Err(e) => e.into_compile_error().into(),
349    }
350}
351
352///
353/// Fetch a scalar value from query
354/// 
355/// # Examples
356///
357/// Basic usage:
358/// 
359/// ```ignore
360/// let conn = connect_db().await;
361/// 
362/// let rst = fetch_scalar!(|&conn| -> i64 {
363///     r#"select count (*) from test_user"#
364/// }).unwrap();
365/// assert_eq!(3, rst);
366/// ```
367#[proc_macro]
368pub fn fetch_scalar(input: TokenStream) -> TokenStream {
369    // 将 input 解析成 SqlClosure
370    let st = syn::parse_macro_input!(input as DyClosure);
371    if st.ret_type.is_none() { panic!("return type can't be null.") }
372
373    match SqlExpand.fetch_scalar(&st) {
374        Ok(ret) => ret.into(),
375        Err(e) => e.into_compile_error().into(),
376    }
377}
378
379///
380/// Execute query
381/// 
382/// # Examples
383///
384/// Basic usage:
385/// 
386/// ```ignore
387/// let mut tran = get_transaction().await.unwrap();
388/// 
389/// let dto = UserDto::new(Some(2), None, None);
390/// let rst = execute!(|&mut *tran, dto| {
391///     r#"delete from test_user where id = :id"#
392/// }).unwrap();
393/// assert_eq!(1, rst);
394/// 
395/// tran.rollback().await?;
396/// ```
397#[proc_macro]
398pub fn execute(input: TokenStream) -> TokenStream {
399    // 将 input 解析成 SqlClosure
400    let st = syn::parse_macro_input!(input as DyClosure);
401
402    match SqlExpand.execute(&st) {
403        Ok(ret) => ret.into(),
404        Err(e) => e.into_compile_error().into(),
405    }
406}
407
408///
409/// Insert data
410/// **Note:** if you use this macro under **postgres** database, you should add "returning id" at the end of sql statement by yourself.
411/// 
412/// # Examples
413///
414/// Basic usage:
415/// 
416/// ```ignore
417/// let mut tran = get_transaction().await.unwrap();
418
419/// let dto = UserDto{ id: Some(4), name: Some("lisi".to_owned()), age: Some(50) };
420/// let last_insert_id = insert!(|&mut *tran, dto| -> (_, mysql) {
421///     r#"insert into test_user (id, name, age) values (4, 'aa', 1)"#  // works for mysql and sqlite
422///     // r#"insert into test_user (id, name, age) values (4, 'aa', 1) returning id"#  // works for postgres
423/// }).unwrap();
424/// assert_eq!(4, last_insert_id);
425/// 
426/// tran.rollback().await?;
427/// ```
428#[proc_macro]
429pub fn insert(input: TokenStream) -> TokenStream {
430    // 将 input 解析成 SqlClosure
431    let st = syn::parse_macro_input!(input as DyClosure);
432    if st.ret_type.is_none() { panic!("return type can't be null.") }
433
434    match SqlExpand.insert(&st) {
435        Ok(ret) => ret.into(),
436        Err(e) => e.into_compile_error().into(),
437    }
438}
439
440///
441/// Define a global sql fragment
442/// 
443/// # Examples
444///
445/// Basic usage:
446/// 
447/// ```ignore
448/// sql!("select_sql", "select * from table1 ")
449/// 
450/// let last_insert_id = fetch_all!(|&conn, &dto| {
451///     select_sql + "where age > 10 "
452/// }).unwrap();
453/// 
454/// tran.rollback().await?;
455/// ```
456#[proc_macro]
457pub fn sql(input: TokenStream) -> TokenStream {
458    let st = parse_macro_input!(input as SqlFragment);
459    let cache = STATIC_SQL_FRAGMENT_MAP.get_or_init(|| {
460        RwLock::new(HashMap::new())
461    });
462
463    cache.write().unwrap().insert(st.name, st.value.to_string());
464
465    quote!().into()
466}
467
468/// page query
469/// 
470/// # Examples
471///
472/// Basic usage:
473/// 
474/// ```ignore
475/// let conn = connect_db().await;
476/// let dto = UserDto::new(None, None, Some(13));
477/// let mut pg_dto = PageDto::new(3, 10, &dto);
478/// 
479/// let rst = page!(|&conn, pg_dto| -> User {
480///     "select * from test_user 
481///     where 1 = 1
482///         {{#data}}
483///             {{#name}}and name = :data.name{{/name}}
484///             {{#age}}and age > :data.age{{/age}}
485///         {{/data}}
486///     order by id"
487/// }).unwrap();
488/// 
489/// assert_eq!(7, rst.total);
490/// ```
491#[proc_macro]
492pub fn page(input: TokenStream) -> TokenStream {
493    // 将 input 解析成 SqlClosure
494    let st = syn::parse_macro_input!(input as DyClosure);
495    if st.ret_type.is_none() { panic!("return type can't be null.") }
496
497    match SqlExpand.page(&st) {
498        Ok(ret) => ret.into(),
499        Err(e) => e.into_compile_error().into(),
500    }
501}