sqlw_macro 0.1.0

Procedural macros for sqlw
Documentation
pub mod parser {
    use syn::{DeriveInput, parse::Parse};

    pub struct FromRowInput {
        pub input: DeriveInput,
    }

    impl Parse for FromRowInput {
        fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
            Ok(FromRowInput {
                input: input.parse()?,
            })
        }
    }
}

pub mod codegen {
    use super::parser::FromRowInput;
    use proc_macro2::TokenStream;
    use quote::quote;
    use syn::{Data, Field, Fields, Meta};

    pub struct FromRowGenerator {
        input: FromRowInput,
    }

    impl FromRowGenerator {
        pub fn new(input: FromRowInput) -> Self {
            Self { input }
        }

        pub fn generate(self) -> TokenStream {
            let input = self.input.input;
            let name = &input.ident;

            let fields = match &input.data {
                Data::Struct(data) => match &data.fields {
                    Fields::Named(fields) => &fields.named,
                    _ => panic!("FromRow can only be derived for structs with named fields"),
                },
                _ => panic!("FromRow can only be derived for structs"),
            };

            let field_assignments = fields.iter().map(|field| {
                let field_name = field.ident.as_ref().unwrap();
                let is_optional = Self::is_optional_field(field);
                let column_expr = Self::extract_column_expr(field);

                let col_name = if let Some(expr) = column_expr {
                    expr
                } else {
                    let name = field_name.to_string();
                    quote! { #name }
                };

                if is_optional {
                    quote! {
                        #field_name: row.try_get_typed(#col_name)?
                    }
                } else {
                    quote! {
                        #field_name: row.get_typed(#col_name)?
                    }
                }
            });

            quote! {
                impl sqlw::FromRow for #name {
                    fn from_row<R: sqlw::RowLike>(row: &R) -> Result<Self, sqlw::RowError> {
                        Ok(Self {
                            #(#field_assignments,)*
                        })
                    }
                }
            }
        }

        fn extract_column_expr(field: &Field) -> Option<proc_macro2::TokenStream> {
            for attr in &field.attrs {
                if attr.path().is_ident("field") {
                    match &attr.meta {
                        // #[field(Table::FIELD)] — schema constant, needs .desc()
                        Meta::List(list) => {
                            if let Ok(expr) = list.parse_args::<syn::Expr>() {
                                return Some(quote! { #expr.desc() });
                            }
                        }
                        // #[field = "column_name"] — raw string, use as-is
                        Meta::NameValue(name_value) => {
                            if let syn::Expr::Lit(expr_lit) = &name_value.value {
                                if let syn::Lit::Str(lit_str) = &expr_lit.lit {
                                    let column_name = lit_str.value();
                                    return Some(quote! { #column_name });
                                }
                            }
                        }
                        _ => {}
                    }
                }
            }
            None
        }

        /// Checks if a field is marked as optional via `#[optional]`.
        fn is_optional_field(field: &Field) -> bool {
            field
                .attrs
                .iter()
                .any(|attr| attr.path().is_ident("optional"))
        }
    }
}

pub mod try_from_value_ref {
    use proc_macro2::TokenStream;
    use quote::quote;
    use syn::{Data, DeriveInput, Fields, Lit, Meta, Variant};

    pub fn generate(input: DeriveInput) -> TokenStream {
        let name = &input.ident;

        match &input.data {
            Data::Struct(data) => generate_struct(name, data),
            Data::Enum(data) => generate_enum(name, data),
            _ => panic!("TryFromValueRef can only be derived for structs and enums"),
        }
    }

    fn generate_struct(name: &syn::Ident, data: &syn::DataStruct) -> TokenStream {
        let inner_type = match &data.fields {
            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
                &fields.unnamed.first().unwrap().ty
            }
            _ => {
                panic!("TryFromValueRef for structs requires a tuple struct with exactly one field")
            }
        };

        quote! {
            impl<'a> ::core::convert::TryFrom<sqlw::ValueRef<'a>> for #name {
                type Error = sqlw::RowError;

                fn try_from(value: sqlw::ValueRef<'a>) -> Result<Self, Self::Error> {
                    <#inner_type as ::core::convert::TryFrom<sqlw::ValueRef<'a>>>::try_from(value)
                        .map(#name)
                }
            }
        }
    }

    fn generate_enum(name: &syn::Ident, data: &syn::DataEnum) -> TokenStream {
        let arms: Vec<TokenStream> = data
            .variants
            .iter()
            .map(|variant| {
                let variant_name = &variant.ident;
                let value_str = extract_value_attr(variant)
                    .unwrap_or_else(|| variant_name.to_string().to_lowercase());
                quote! { #value_str => Ok(#name::#variant_name), }
            })
            .collect();

        quote! {
            impl<'a> ::core::convert::TryFrom<sqlw::ValueRef<'a>> for #name {
                type Error = sqlw::RowError;

                fn try_from(value: sqlw::ValueRef<'a>) -> Result<Self, Self::Error> {
                    let s = <String as ::core::convert::TryFrom<sqlw::ValueRef<'a>>>::try_from(value)?;
                    match s.as_str() {
                        #(#arms)*
                        other => Err(sqlw::RowError::Any(format!("unknown variant: {}", other))),
                    }
                }
            }
        }
    }

    fn extract_value_attr(variant: &Variant) -> Option<String> {
        use syn::Expr;
        for attr in &variant.attrs {
            if attr.path().is_ident("value") {
                if let Meta::NameValue(nv) = &attr.meta {
                    if let Expr::Lit(expr_lit) = &nv.value {
                        if let Lit::Str(s) = &expr_lit.lit {
                            return Some(s.value());
                        }
                    }
                }
            }
        }
        None
    }
}

pub mod into_value {
    use proc_macro2::TokenStream;
    use quote::quote;
    use syn::{Data, DeriveInput, Fields, Lit, Meta, Variant};

    pub fn generate(input: DeriveInput) -> TokenStream {
        let name = &input.ident;

        match &input.data {
            Data::Struct(data) => generate_struct(name, data),
            Data::Enum(data) => generate_enum(name, data),
            _ => panic!("IntoValue can only be derived for structs and enums"),
        }
    }

    fn generate_struct(name: &syn::Ident, data: &syn::DataStruct) -> TokenStream {
        let _ = match &data.fields {
            Fields::Unnamed(fields) if fields.unnamed.len() == 1 => {
                &fields.unnamed.first().unwrap().ty
            }
            _ => panic!("IntoValue for structs requires a tuple struct with exactly one field"),
        };

        quote! {
            impl ::core::convert::From<#name> for sqlw::Value {
                fn from(v: #name) -> Self {
                    v.0.into()
                }
            }
            impl ::core::convert::From<&#name> for sqlw::Value {
                fn from(v: &#name) -> Self {
                    v.0.clone().into()
                }
            }
        }
    }

    fn generate_enum(name: &syn::Ident, data: &syn::DataEnum) -> TokenStream {
        let arms: Vec<TokenStream> = data
            .variants
            .iter()
            .map(|variant| {
                let variant_name = &variant.ident;
                let value_str = extract_value_attr(variant)
                    .unwrap_or_else(|| variant_name.to_string().to_lowercase());
                quote! {
                    #name::#variant_name => #value_str,
                }
            })
            .collect();

        quote! {
            impl ::core::convert::From<#name> for sqlw::Value {
                fn from(v: #name) -> Self {
                    sqlw::Value::Text(match v {
                        #(#arms)*
                    }.to_string())
                }
            }
            impl ::core::convert::From<&#name> for sqlw::Value {
                fn from(v: &#name) -> Self {
                    sqlw::Value::Text(match v {
                        #(#arms)*
                    }.to_string())
                }
            }
        }
    }

    fn extract_value_attr(variant: &Variant) -> Option<String> {
        use syn::Expr;
        for attr in &variant.attrs {
            if attr.path().is_ident("value") {
                if let Meta::NameValue(nv) = &attr.meta {
                    if let Expr::Lit(expr_lit) = &nv.value {
                        if let Lit::Str(s) = &expr_lit.lit {
                            return Some(s.value());
                        }
                    }
                }
            }
        }
        None
    }
}