db_proc_macro 0.1.0

Procedural macros for db_core project
Documentation
use proc_macro::TokenStream;
use syn::{parse_macro_input, DeriveInput, Data, Fields, Type};
use quote::quote;
use crate::create_crate_ident;

pub fn derive_to_sql_values(input: TokenStream) -> TokenStream {
    let input = parse_macro_input!(input as DeriveInput);
    let struct_name = input.ident;

    let mut field_conversions = Vec::new();

    if let Data::Struct(data) = input.data {
        if let Fields::Named(fields) = data.fields {
            for field in fields.named.iter() {
                let ident = &field.ident;
                let name_str = ident.as_ref().unwrap().to_string();
                let ty = &field.ty;

                let conversion = if type_is_option(ty) {
                    let inner_ty = extract_option_inner(ty).unwrap();
                    generate_conversion(&name_str, ident, Some(inner_ty))
                } else {
                    generate_conversion(&name_str, ident, Some(ty))
                };

                field_conversions.push(conversion);
            }
        }
    }

    let expanded = quote! {
        impl #struct_name {
            pub fn to_sql_values(&self) -> std::collections::HashMap<String, SqlValue> {
                let mut map = std::collections::HashMap::new();
                #(#field_conversions)*
                map
            }
        }
    };
    TokenStream::from(expanded)
}

// --------------------- 辅助函数 ---------------------

fn type_is_option(ty: &Type) -> bool {
    match ty {
        Type::Path(tp) => tp.path.segments.iter().any(|s| s.ident == "Option"),
        _ => false,
    }
}

fn extract_option_inner(ty: &Type) -> Option<&Type> {
    if let Type::Path(tp) = ty {
        if let Some(seg) = tp.path.segments.iter().find(|s| s.ident == "Option") {
            if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
                if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
                    return Some(inner);
                }
            }
        }
    }
    None
}


fn generate_conversion(name_str: &str, ident: &Option<syn::Ident>, ty: Option<&Type>) -> proc_macro2::TokenStream {
    let ident = ident.as_ref().unwrap();
    // 这里 db_common 是 使用些过程宏, 定义普通宏的 crate; 而非调用者的crate, 重点 db_common 必须存在 lib  pub use SqlValue;
    let crate_ident = create_crate_ident("db_cores");
    let root = if crate_ident == "crate" {
        quote!(crate)
    } else {
        quote!(::#crate_ident)
    };
    if let Some(ty) = ty {
        let type_name = quote! {#ty}.to_string();
        match type_name.as_str() {
            // "i64" | "i32" => quote! { map.insert(#name_str.to_string(), SqlValue::Num(self.#ident as i64)); },
            "i16" | "i8" | "u64" | "u32" | "u16" => quote! { map.insert(#name_str.to_string(), #root::SqlValue::Num(self.#ident as i64)); },
            "f64" => quote! { map.insert(#name_str.to_string(), #root::SqlValue::Float(self.#ident)); },
            "bool" => quote! { map.insert(#name_str.to_string(), #root::SqlValue::Bool(self.#ident)); },
            "String" => quote! { map.insert(#name_str.to_string(), #root::SqlValue::Str(self.#ident.clone())); },
            "chrono::DateTime < chrono::Utc >" | "DateTime < Utc >"=> quote! { map.insert(#name_str.to_string(), #root::SqlValue::UtcTime(self.#ident.clone())); },
            "chrono::NaiveTime" | "NaiveTime"=> quote! { map.insert(#name_str.to_string(), #root::SqlValue::Time(self.#ident)); },
            "chrono::NaiveDate" | "NaiveDate" => quote! { map.insert(#name_str.to_string(), #root::SqlValue::Date(self.#ident)); },
            "chrono::NaiveDateTime" | "NaiveDateTime" => quote! { map.insert(#name_str.to_string(), #root::SqlValue::DateTime(self.#ident)); },
            "Vec < u8 >" => quote! { map.insert(#name_str.to_string(), #root::SqlValue::Buff(self.#ident.clone())); },
            "JsonValue" | "Value" => quote! { map.insert(#name_str.to_string(), #root::SqlValue::Json(self.#ident.clone())); },
            // 嵌套类型一般对应 数据库的json 字段
            _ => quote! { map.insert(#name_str.to_string(), #root::SqlValue::Json(self.#ident.clone())); }
        }
    } else {
        quote! { map.insert(#name_str.to_string(), #root::SqlValue::Null); }
    }
}