#[cfg(not(feature = "tokio-postgres"))]
mod dy_sqlx;
#[cfg(not(feature = "tokio-postgres"))]
use dy_sqlx::expand;
#[cfg(feature = "tokio-postgres")]
mod dy_tokio_postgres;
#[cfg(feature = "tokio-postgres")]
use dy_tokio_postgres::expand;
mod sql_expand;
use proc_macro::TokenStream;
use syn::{punctuated::Punctuated, parse_macro_input, Token};
use std::{collections::HashMap, sync::RwLock};
use quote::quote;
use dysql::{QueryType, get_dysql_config, STATIC_SQL_FRAGMENT_MAP, get_sql_fragment};
#[derive(Debug)]
struct SqlFragment {
name: String,
value: String,
}
impl syn::parse::Parse for SqlFragment {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let name= input.parse::<syn::LitStr>()?.value();
input.parse::<syn::Token!(,)>()?;
let value= input.parse::<syn::LitStr>()?.value();
Ok(Self { name, value })
}
}
#[allow(dead_code)]
#[derive(Debug)]
struct SqlClosure {
dto: Option<syn::Ident>,
is_dto_ref: bool,
is_dto_ref_mut: bool,
tmp_pg_dto: Option<syn::Ident>,
cot: syn::Ident, is_cot_ref: bool,
is_cot_ref_mut: bool,
sql_name: Option<String>,
ret_type: Option<syn::Path>, dialect: syn::Ident,
body: String,
}
impl syn::parse::Parse for SqlClosure {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut is_dto_ref = false;
let mut is_dto_ref_mut = false;
input.parse::<syn::Token!(|)>()?;
let dto = match input.parse::<syn::Ident>() {
Ok(i) => Some(i),
Err(e) => match input.parse::<syn::Token!(_)>() {
Ok(_) => None,
Err(_) => {
match input.parse::<syn::Token!(&)>() {
Ok(_) => {
is_dto_ref = true;
match input.parse::<syn::Token!(mut)>() {
Ok(_) => {
is_dto_ref = false;
is_dto_ref_mut = true;
Some(input.parse::<syn::Ident>()?)
},
Err(_) => Some(input.parse::<syn::Ident>()?),
}
},
Err(_) => return Err(e),
}
},
},
};
let mut is_cot_ref = false;
let mut is_cot_ref_mut = false;
input.parse::<syn::Token!(,)>()?;
let cot: syn::Ident = match input.parse::<syn::Token!(&)>() {
Ok(_) => {
is_cot_ref = true;
match input.parse::<syn::Token!(mut)>() {
Ok(_) => {
is_cot_ref = false;
is_cot_ref_mut = true;
input.parse()?
},
Err(_) => input.parse()?,
}
},
Err(_) => {
input.parse()?
},
};
let mut sql_name = None;
match input.parse::<syn::Token!(|)>() {
Ok(_) => (),
Err(_) => {
input.parse::<syn::Token!(,)>()?;
sql_name = match input.parse::<syn::LitStr>() {
Ok(s) => Some(s.value()),
Err(_) => None,
};
input.parse::<syn::Token!(|)>()?;
},
}
let dialect: syn::Ident;
let ret_type:Option<syn::Path>;
match input.parse::<syn::Token!(->)>() {
Ok(_) => {
match input.parse::<syn::Path>() {
Ok(p) => {
ret_type = Some(p);
dialect = get_default_dialect(&input.span());
},
Err(_) => match parse_return_tuple(input) {
Ok(tp) => {
match tp.parse::<syn::Path>() {
Ok(p) => {
ret_type = Some(p);
tp.parse::<syn::Token!(,)>()?;
if let Err(_) = tp.parse::<syn::Token!(_)>() {
match tp.parse::<syn::Ident>() {
Ok(i) => dialect = i,
Err(_) => return Err(syn::Error::new(proc_macro2::Span::call_site(), "Need specify the dialect")),
}
} else {
dialect = get_default_dialect(&input.span());
}
},
Err(_) => match tp.parse::<syn::Token!(_)>() {
Ok(_) => {
ret_type = None;
tp.parse::<syn::Token!(,)>()?;
if let Err(_) = tp.parse::<syn::Token!(_)>() {
match tp.parse::<syn::Ident>() {
Ok(i) => dialect = i,
Err(_) => return Err(syn::Error::new(proc_macro2::Span::call_site(), "Need specify the dialect")),
}
} else {
dialect = get_default_dialect(&input.span());
}
},
Err(_) => return Err(syn::Error::new(proc_macro2::Span::call_site(), "Need specify the return type")),
},
}
},
Err(_) => return Err(syn::Error::new(proc_macro2::Span::call_site(), "Need specify the return type and dialect")),
},
}
},
Err(_) => {
ret_type = None;
dialect = get_default_dialect(&input.span());
},
};
let body = parse_body(input)?;
let body: Vec<String> = body.split('\n').into_iter().map(|f| f.trim().to_owned()).collect();
let body = body.join(" ").to_owned();
let tmp_pg_dto = Some(syn::Ident::new("_tmp_pg_dto", proc_macro2::Span::call_site()));
let sc = SqlClosure { dto, is_dto_ref, is_dto_ref_mut, tmp_pg_dto, cot, is_cot_ref, is_cot_ref_mut, sql_name, ret_type, dialect, body };
Ok(sc)
}
}
fn parse_body(input: &syn::parse::ParseBuffer) -> Result<String, syn::Error> {
let body_buf;
syn::braced!(body_buf in input);
let ts = body_buf.cursor().token_stream().into_iter();
let mut sql = String::new();
for it in ts {
match it {
proc_macro2::TokenTree::Group(_) => {
return Err(syn::Error::new(input.span(), "error not support group in sql".to_owned()));
},
proc_macro2::TokenTree::Ident(_) => {
let v: syn::Ident = body_buf.parse()?;
let sql_fragment = get_sql_fragment(&v.to_string());
if let Some(s) = sql_fragment {
sql.push_str(&s);
} else {
return Err(syn::Error::new(input.span(), "error not found sql identity".to_owned()));
}
},
proc_macro2::TokenTree::Punct(v) => {
if v.to_string() == "+" {
body_buf.parse::<Token!(+)>()?;
} else {
return Err(syn::Error::new(input.span(), "error only support '+' expr".to_owned()));
}
},
proc_macro2::TokenTree::Literal(_) => {
let rst: syn::LitStr = body_buf.parse()?;
sql.push_str(&rst.value());
},
};
}
Ok(sql)
}
pub(crate) fn gen_path(s: &str) -> syn::Path {
let seg = syn::PathSegment {
ident: syn::Ident::new(s, proc_macro2::Span::call_site()),
arguments: syn::PathArguments::None,
};
let mut punct: Punctuated<syn::PathSegment, syn::Token![::]> = Punctuated::new();
punct.push_value(seg);
let path = syn::Path{ leading_colon: None, segments: punct };
path
}
fn parse_return_tuple(input: syn::parse::ParseStream) -> syn::Result<syn::parse::ParseBuffer> {
let tuple_buf;
syn::parenthesized!(tuple_buf in input);
Ok(tuple_buf)
}
fn get_default_dialect(span: &proc_macro2::Span) -> syn::Ident {
match get_dysql_config().dialect {
dysql::SqlDialect::postgres => syn::Ident::new(&dysql::SqlDialect::postgres.to_string(), span.clone()),
dysql::SqlDialect::mysql => syn::Ident::new(&dysql::SqlDialect::mysql.to_string(), span.clone()),
dysql::SqlDialect::sqlite => syn::Ident::new(&dysql::SqlDialect::sqlite.to_string(), span.clone()),
}
}
#[proc_macro]
pub fn fetch_all(input: TokenStream) -> TokenStream {
let st = syn::parse_macro_input!(input as SqlClosure);
if st.ret_type.is_none() { panic!("ret_type can't be null.") }
match expand(&st, QueryType::FetchAll) {
Ok(ret) => ret.into(),
Err(e) => e.into_compile_error().into(),
}
}
#[proc_macro]
pub fn fetch_one(input: TokenStream) -> TokenStream {
let st = syn::parse_macro_input!(input as SqlClosure);
if st.ret_type.is_none() { panic!("ret_type can't be null.") }
match expand(&st, QueryType::FetchOne) {
Ok(ret) => ret.into(),
Err(e) => e.into_compile_error().into(),
}
}
#[proc_macro]
pub fn fetch_scalar(input: TokenStream) -> TokenStream {
let st = syn::parse_macro_input!(input as SqlClosure);
if st.ret_type.is_none() { panic!("ret_type can't be null.") }
match expand(&st, QueryType::FetchScalar) {
Ok(ret) => ret.into(),
Err(e) => e.into_compile_error().into(),
}
}
#[proc_macro]
pub fn execute(input: TokenStream) -> TokenStream {
let st = syn::parse_macro_input!(input as SqlClosure);
match expand(&st, QueryType::Execute) {
Ok(ret) => ret.into(),
Err(e) => e.into_compile_error().into(),
}
}
#[proc_macro]
pub fn insert(input: TokenStream) -> TokenStream {
let st = syn::parse_macro_input!(input as SqlClosure);
match expand(&st, QueryType::Insert) {
Ok(ret) => ret.into(),
Err(e) => e.into_compile_error().into(),
}
}
#[proc_macro]
pub fn sql(input: TokenStream) -> TokenStream {
let st = parse_macro_input!(input as SqlFragment);
let cache = STATIC_SQL_FRAGMENT_MAP.get_or_init(|| {
RwLock::new(HashMap::new())
});
cache.write().unwrap().insert(st.name, st.value.to_string());
quote!().into()
}
#[proc_macro]
pub fn page(input: TokenStream) -> TokenStream {
let st = syn::parse_macro_input!(input as SqlClosure);
if st.ret_type.is_none() { panic!("ret_type can't be null.") }
match expand(&st, QueryType::Page) {
Ok(ret) => ret.into(),
Err(e) => e.into_compile_error().into(),
}
}