db_code_macro 0.1.5

Generate Db code
Documentation
use std::{fs, io::Write, ops::Add, path};

use quote::{format_ident, quote};

use super::{db_meta::TableMeta, kits::to_snake_name};
use crate::kits::T_VERSION;

pub fn generate_dao(tm: &TableMeta) {
    let de = tm.ident_name.as_ref().expect("the derive_input is None");
    let m_name = format_ident!("{}", &de);
    let dao_name = format_ident!("{}Dao", &de);
    let table_name = to_snake_name(&tm.type_name);
    let test_name = syn::Ident::new(&to_snake_name(&dao_name.to_string()), m_name.span());
    let db_name = to_snake_name(&dao_name.to_string()).add(".db");
    let no_id: Vec<proc_macro2::TokenStream> = {
        tm.col_names[1..]
            .iter()
            .map(|col| {
                let id = format_ident!("T_{}", col.to_uppercase());
                // let id = syn::Ident::new(&v, m_name.span());
                quote! { #dao_name::#id}
            })
            .collect()
    };
    let no_id_len = no_id.len();

    let const_col_names: Vec<proc_macro2::TokenStream> = {
        tm.col_names
            .iter()
            .map(|col| {
                let id = format_ident!("T_{}", col.to_uppercase());
                // let id = syn::Ident::new(&v, m_name.span());
                quote! {pub const #id : &'static str = #col}
            })
            .collect()
    };

    let add_bind = {
        let mut t = Vec::new();
        let id = syn::Ident::new(&tm.col_names[0], m_name.span());
        t.push(quote!(q.bind(&m.#id)));
        tm.col_names[1..].iter().for_each(|col| {
            // let id = syn::Ident::new(col, m_name.span());
            let id = format_ident!("{}", col);
            t.push(quote!(.bind(&m.#id)))
        });
        t
    };

    let update_bind = {
        let mut t = Vec::new();
        let id = syn::Ident::new(&tm.col_names[1], m_name.span());
        t.push(quote!(q.bind(&m.#id)));
        tm.col_names[2..].iter().for_each(|col| {
            // let id = syn::Ident::new(col, m_name.span());
            let id = format_ident!("{}", col);
            t.push(quote!(.bind(&m.#id)));
        });
        // let id = syn::Ident::new(&tm.col_names[0], m_name.span());
        let id = format_ident!("{}", &tm.col_names[0]);
        t.push(quote!(.bind(&m.#id)));
        t
    };
    let update_bind_ol = {
        let mut t = Vec::new();
        // let id = syn::Ident::new(&tm.col_names[1], m_name.span());
        let id = format_ident!("{}", &tm.col_names[1]);
        t.push(quote!(q.bind(&m.#id)));
        tm.col_names[2..].iter().for_each(|col| {
            if col != T_VERSION {
                // let id = syn::Ident::new(col, m_name.span());
                let id = format_ident!("{}", col);
                t.push(quote!(.bind(&m.#id)));
            }
        });
        // let id = syn::Ident::new(&tm.col_names[0], m_name.span());
        let id = format_ident!("{}", &tm.col_names[0]);
        t.push(quote!(.bind(&m.#id)));
        // let id = syn::Ident::new("version", m_name.span());
        let id = format_ident!("{}", T_VERSION);
        t.push(quote!(.bind(m.#id)));
        t
    };

    let gen_code = quote!(
    use std::sync::Arc;

    use sqlx::{Pool, Sqlite, SqlitePool};
    use sqlx::query::Query;
    use sqlx::sqlite::SqliteArguments;

    use db_code::dao::{Dao, KitsDb, Times};

    use crate::#m_name;

    #[derive(Debug)]
    pub struct #dao_name {
        pool: Arc<Pool<Sqlite>>,
    }

    impl #dao_name {
        pub const TT: &'static str = #table_name;
        #(#const_col_names;)*

        pub(super) const fn columns_no_id() -> [&'static str; #no_id_len] {
            [
                #(#no_id,)*
            ]
        }

        pub(super) fn columns() -> String {
            format!("{},{}", #dao_name::T_ID, #dao_name::columns_no_id().join(","))
        }

        fn sql_get() -> String {
            format!("SELECT {} FROM {} WHERE {} = ?", #dao_name::columns(), #dao_name::TT, #dao_name::T_ID)
        }

        fn sql_add() -> String {
            let vs = ["?"; #dao_name::columns_no_id().len() + 1];
            format!("insert into {}({}) values ({})",#dao_name::TT, #dao_name::columns(), vs.join(","))
        }

        fn sql_remove() -> String {
            format!("delete from {} where {} = ?", #dao_name::TT, #dao_name::T_ID)
        }

        fn sql_remove_all() -> String {
            format!("delete from {}", #dao_name::TT)
        }

        fn sql_update() -> String {
            let vs = #dao_name::columns_no_id().join(" = ?, ") + " = ?";
            format!("update {} set {} where {} = ?", #dao_name::TT, vs, #dao_name::T_ID)
        }

        fn sql_update_ol() -> String {
            let mut vs = #dao_name::columns_no_id().join(" = ?, ") + " = ?";
            vs = vs.replace("version = ?","version = version + 1");
            format!("update {} set {} where {} = ? and {} = ?", #dao_name::TT, vs, #dao_name::T_ID, #dao_name::T_VERSION)
        }

        fn sql_list() -> String {
            format!("select {} from {}", #dao_name::columns(), #dao_name::TT)
        }
        pub(super) fn _bind_add<'a>(m: &'a #m_name, q: Query<'a, Sqlite, SqliteArguments<'a>>) -> Query<'a, Sqlite, SqliteArguments<'a>> {
            #(#add_bind)*
        }
        pub(super) fn _bind_update<'a>(m: &'a #m_name, q: Query<'a, Sqlite, SqliteArguments<'a>>) -> Query<'a, Sqlite, SqliteArguments<'a>> {
            #(#update_bind)*
        }
        pub(super) fn _bind_update_ol<'a>(m: &'a #m_name, q: Query<'a, Sqlite, SqliteArguments<'a>>) -> Query<'a, Sqlite, SqliteArguments<'a>> {
            #(#update_bind_ol)*
        }
    }

    impl Dao<#m_name> for #dao_name  {
        fn pool(&self) -> &SqlitePool {
            self.pool.as_ref()
        }

        fn new(pool: Arc<SqlitePool>) -> Self {
            Self  {
                pool
            }
        }

        async fn add(&self, m: &mut #m_name) -> Result<u64, sqlx::Error> {
            if m.id.is_empty() {
                m.id = KitsDb::uuid();
            }
            if m.update_ts < 1 {
                m.update_ts = Times::ts_now();
            }
            let sql = Self::sql_add();
            let re = Self::_bind_add(m, sqlx::query(&sql)).execute(self.pool()).await?;
            Ok(re.rows_affected())
        }

        async fn remove(&self, id: &str) -> Result<u64, sqlx::Error> {
            let sql = Self::sql_remove();
            let re = sqlx::query(&sql).bind(id).execute(self.pool()).await?;
            Ok(re.rows_affected())
        }

        async fn remove_all(&self) -> Result<u64, sqlx::Error> {
            let sql = Self::sql_remove_all();
            let re = sqlx::query(&sql).execute(self.pool()).await?;
            Ok(re.rows_affected())
        }

        async fn update(&self, m: &mut #m_name) -> Result<u64, sqlx::Error> {
            let sql = Self::sql_update();
            m.update_ts = Times::ts_now();
            let re = Self::_bind_update(m, sqlx::query(&sql)).execute(self.pool()).await?;
            Ok(re.rows_affected())
        }

        async fn update_ol(&self, m: &mut #m_name) -> Result<u64, sqlx::Error> {
            let sql = Self::sql_update_ol();
            m.update_ts = Times::ts_now();
            let re = Self::_bind_update_ol(m, sqlx::query(&sql)).execute(self.pool()).await?;
            Ok(re.rows_affected())
        }

        async fn get(&self, id: &str) -> Result<Option<#m_name>, sqlx::Error> {
            let sql = Self::sql_get();
            let condition = sqlx::query_as::<_, #m_name>(&sql).bind(id).fetch_optional(self.pool()).await?;
            Ok(condition)
        }

        async fn list(&self) -> Result<Vec<#m_name>, sqlx::Error> {
            let sql = Self::sql_list();
            let rows = sqlx::query_as(&sql).fetch_all(self.pool()).await?;
            Ok(rows)
        }
    }

    #[cfg(test)]
    mod tests {
        use self::super::#dao_name;
        use crate::#m_name;
        use db_code::dao::{Dao, KitsDb};

        #[tokio::test]
        async fn #test_name() {
            let pool = KitsDb::new_with_name(#db_name, "init/sql_.sql")
                .await
                .expect("");
            let dao_ = #dao_name::new(pool);
            let mut m = #m_name::default();
            {
                //清除数据,可以反复测试
                dao_.remove_all().await.expect("");
                dao_.remove(&m.id).await.expect("");
            }
            {
                let get_m = dao_.get(&m.id).await.expect("");
                assert_eq!(true, get_m.is_none());
            }

            m.update_ts = 1;

            {
                //add
                let re = dao_.add(&mut m).await.expect("");
                assert_eq!(re, 1);
                let get_m = dao_.get(&m.id).await.expect("").expect("");
                assert_eq!(m.id, get_m.id);
                assert_eq!(m.version, get_m.version);
                assert_eq!(m.update_ts, get_m.update_ts);
            }

            {
                //update
                let re = dao_.update(&mut m).await.expect("");
                assert_eq!(1, re);
                let get_m = dao_.get(&m.id).await.expect("").expect("");
                assert_eq!(m.id, get_m.id);
                assert_eq!(m.version, get_m.version);
                assert_eq!(m.update_ts, get_m.update_ts);
            }

            {
                //update ol
                let re = dao_.update_ol(&mut m).await.expect("");
                assert_eq!(1, re);
                let get_m = dao_.get(&m.id).await.expect("").expect("");
                assert_eq!(m.id, get_m.id);
                assert_eq!(m.version + 1, get_m.version);
                assert_eq!(m.update_ts, get_m.update_ts);
            }

            {
                //list
                let ms = dao_.list().await.expect("");
                assert_eq!(1, ms.len());
                let new_m = &ms[0];
                assert_eq!(m.id, new_m.id);
            }

            {
                //remove
                let re = dao_.remove(&m.id).await.expect("");
                assert_eq!(1, re);
                let old = dao_.get(&m.id).await.expect("");
                assert_eq!(true, old.is_none());
            }
        }
    }
        );

    let file_name = get_dap_path(&to_snake_name(&tm.type_name).add("_dao.rs"));
    let file_meta = fs::metadata(file_name.clone());
    let file_op = match file_meta {
        Ok(meta) => {
            if meta.len() < 10 {
                let file = fs::File::create(file_name).expect("fs::File::create(file_name)");
                Some(file)
            } else {
                None
            }
        }
        Err(_) => {
            let file = fs::File::create(file_name).expect("fs::File::create(file_name)");
            Some(file)
        }
    };
    if let Some(mut file) = file_op {
        let file_str = syn::parse_file(&gen_code.to_string()).expect("");
        let format_str = prettyplease::unparse(&file_str);
        let _ = file.write_all(format_str.as_bytes());
    }
}

fn get_dap_path(short_name: &str) -> String {
    const CARGO_MANIFEST_DIR: &str = "CARGO_MANIFEST_DIR";

    let mut cur = "dao".to_owned();
    if let Ok(p) = std::env::var(CARGO_MANIFEST_DIR) {
        let p = path::Path::new(p.as_str()).join("src").join(cur);
        cur = p.to_str().expect("cur = p.to_str().expect").to_owned();
    }

    if fs::metadata(cur.as_str()).is_err() {
        let _ = fs::create_dir(cur.as_str());
    }
    let full = path::Path::new(cur.as_str()).join(short_name);
    full.to_str().expect("full.to_str().").to_owned()
}