db_proc_macro 0.1.1

Procedural macros for db_core project
Documentation
use proc_macro::TokenStream; // proc_macro 提供的 TokenStream 类型,是编译器传给宏的语法片段
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Fields}; // syn 用于解析 Rust 语法树 // quote 用于生成 Rust 代码
use crate::create_crate_ident;
/// 真实宏逻辑实现
/// 输入: TokenStream(Rust 编译器传入的枚举语法片段)
/// 输出: TokenStream(生成的 impl 代码)
/// 作用: 自动为枚举生成 FromStr、Serialize/Deserialize、sqlx Decode/Type 等实现
pub fn derive_to_sqlx_enum(input: TokenStream) -> TokenStream {
    // 将 TokenStream 解析成 AST(抽象语法树),方便操作
    // DeriveInput 是 syn 提供的类型,代表一个可派生宏的 Rust item(struct、enum、union)
    let input = parse_macro_input!(input as DeriveInput);

    // 获取枚举名字,例如 enum Status {...} 的名字就是 "Status"
    let name = &input.ident;
    let crate_ident = create_crate_ident("db_cores");
    let serde_path = if crate_ident == "crate" {
        quote!(crate::serde)
    } else {
        quote!(::#crate_ident::serde)
    };
    let sqlx_path = if crate_ident == "crate" {
        quote!(crate::sqlx) // 如果是db_common本身调 crate::sqlx,则直接使用 crate::sqlx,不能写成::crate::sqlx
    } else {
        quote!(::#crate_ident::sqlx) // 如果是其它组件调用则是 ::db_common::sqlx,则使用 ::db_common::sqlx
    };

    // 检查枚举类型,并收集所有 unit variant
    // unit variant = 没有字段的枚举,如 Pending、Approved、Rejected
    let variants = if let syn::Data::Enum(data_enum) = &input.data {
        // 遍历枚举的每个 variant
        data_enum
            .variants
            .iter()
            .map(|v| {
                let ident = &v.ident; // variant 名字,例如 Pending
                match &v.fields {
                    Fields::Unit => {
                        // 只支持无字段的枚举
                        let value = ident.to_string(); // 将 variant 名字转换为字符串,用于 FromStr/Serialize
                        (ident.clone(), value) // 返回元组 (variant 标识符, 对应字符串)
                    }
                    _ => panic!("SqlEnum only supports unit variants"), // 如果有字段,宏直接报错
                }
            })
            .collect::<Vec<_>>() // 收集到 Vec,避免迭代器被移动
    } else {
        panic!("SqlEnum can only be derived on enums"); // 如果不是枚举,宏报错
    };

    // 构建 FromStr match 分支
    // 用于把字符串解析成枚举,例如 "Pending" => Ok(Status::Pending)
    let from_str_match: Vec<_> = variants
        .iter()
        .map(|(ident, value)| {
            quote! {
                x if x.eq_ignore_ascii_case(#value) => Ok(#name::#ident), //忽略 大小
            }
        })
        .collect();
    
    let lowercases: Vec<_> = variants
        .iter()
        .map(|(ident, value)| {
            quote! {
               #name::#ident => #value.to_ascii_lowercase(), //忽略 大小
            }
        })
        .collect();
    
    
    // 为 Deserialize 单独构建 match 分支,使用正确的变量名
    let deserialize_match: Vec<_> = variants
        .iter()
        .map(|(ident, value)| {
            quote! {
                x if x.eq_ignore_ascii_case(#value) => Ok(#name::#ident),
            }
    }).collect();



    // 使用 quote! 生成最终代码
    // quote! 可以把 Rust 代码写成模板,变量用 #var 插入
    let expanded = quote! {
        // 实现 std::str::FromStr
        impl std::str::FromStr for #name {
            type Err = String;
            fn from_str(s: &str) -> Result<Self, Self::Err> {
                match s {
                    #(#from_str_match)* // 插入所有 match 分支
                    _ => Err(format!("Invalid value for {}: {}", stringify!(#name), s)),
                }
            }
        }
        
        impl #name {
            pub fn to_lower(&self) -> String {
                match self {
                    #(#lowercases)* // 插入所有 match 分支
                }
            }
        }
        
        // 实现 Deserialize  当sqlx 解析son内容时会调用
        // impl<'de> ::#crate_ident::serde::Deserialize<'de> for #name {
        //     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
        //     where
        //         D: ::#crate_ident::serde::Deserializer<'de>,
        //     {
        //         // let s = String::deserialize(deserializer)?;
        //         let s = <String as ::#crate_ident::serde::Deserialize>::deserialize(deserializer)?;
        //         match s.as_str() {
        //             #(#deserialize_match)*
        //             _ => Err(::#crate_ident::serde::de::Error::custom(
        //                 format!(concat!("Invalid value for ", stringify!($name), ": {}"), s))
        //             ),
        //         }
        //     }
        // }
        

        impl<'de> #serde_path::Deserialize<'de> for #name {
            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
            where
                D: #serde_path::Deserializer<'de>,
            {
                let s = <String as #serde_path::Deserialize>::deserialize(deserializer)?;
                match s.as_str() {
                    #(#deserialize_match)*
                    _ => Err(<D::Error as #serde_path::de::Error>::custom(
                        format!(
                            concat!("Invalid value for ", stringify!(#name), ": {}"),
                            s
                        )
                    )),
                }
            }
        }


        
        // SQLite
        #[cfg(feature = "sqlite")]
        impl<'r> #sqlx_path::Decode<'r, #sqlx_path::Sqlite> for #name {
            fn decode(value: #sqlx_path::sqlite::SqliteValueRef<'r>) -> Result<Self,#sqlx_path::error::BoxDynError> {
                // 先把数据库字段解析成 String
                let value_str = <String as #sqlx_path::Decode<#sqlx_path::Sqlite>>::decode(value)?;
                // 再用 FromStr 转成枚举
                #name::from_str(&value_str).map_err(|e| e.into())
            }
        }

        // MySQL
        #[cfg(feature = "mysql")]
        impl<'r>#sqlx_path::Decode<'r,::sqlx::MySql> for #name {
            fn decode(value: #sqlx_path::mysql::MySqlValueRef<'r>) -> Result<Self,#sqlx_path::error::BoxDynError> {
                let value_str = <String as #sqlx_path::Decode<#sqlx_path::MySql>>::decode(value)?;
                #name::from_str(&value_str).map_err(|e| e.into())
            }
        }

        // PostgreSQL
        #[cfg(feature = "postgres")]
        impl<'r>#sqlx_path::Decode<'r,#sqlx_path::Postgres> for #name {
            fn decode(value: #sqlx_path::postgres::PgValueRef<'r>) -> Result<Self,#sqlx_path::error::BoxDynError> {
                let value_str = <String as #sqlx_path::Decode<#sqlx_path::Postgres>>::decode(value)?;
                #name::from_str(&value_str).map_err(|e| e.into())
            }
        }

        // ==================::sqlx::Type impls ==================
        #[cfg(feature = "sqlite")]
        impl #sqlx_path::Type<#sqlx_path::Sqlite> for #name {
            fn type_info() ->#sqlx_path::sqlite::SqliteTypeInfo {
                <String as #sqlx_path::Type<#sqlx_path::Sqlite>>::type_info() // 用 String 类型信息
            }
        }

        #[cfg(feature = "mysql")]
        impl #sqlx_path::Type<#sqlx_path::MySql> for #name {
            fn type_info() ->#sqlx_path::mysql::MySqlTypeInfo {
                <String as #sqlx_path::Type<#sqlx_path::MySql>>::type_info()
            }
        }

        #[cfg(feature = "postgres")]
        impl #sqlx_path::Type<#sqlx_path::Postgres> for #name {
            fn type_info() ->#sqlx_path::postgres::PgTypeInfo {
                <String as #sqlx_path::Type<#sqlx_path::Postgres>>::type_info()
            }
        }
    };

    // 将 quote! 生成的 TokenStream 返回给编译器
    TokenStream::from(expanded)
}