sqlw_macro 0.1.0

Procedural macros for sqlw
Documentation
pub mod parser {
    use syn::parse::{Parse, ParseStream, Result};
    use syn::{Attribute, Ident, LitStr, Token, Type, braced};

    pub struct SchemaField {
        pub attrs: Vec<Attribute>,
        pub name: Ident,
        pub ty: Option<Type>,
        pub value: LitStr,
    }

    impl Parse for SchemaField {
        fn parse(input: ParseStream) -> Result<Self> {
            let attrs = input.call(Attribute::parse_outer)?;
            let name: Ident = input.parse()?;
            let ty = if input.peek(Token![:]) {
                input.parse::<Token![:]>()?;
                Some(input.parse::<Type>()?)
            } else {
                None
            };
            let value: LitStr = input.parse()?;
            Ok(SchemaField {
                attrs,
                name,
                ty,
                value,
            })
        }
    }

    pub struct SchemaInput {
        pub attrs: Vec<Attribute>,
        pub struct_name: Ident,
        pub table_name: LitStr,
        pub fields: Vec<SchemaField>,
    }

    impl Parse for SchemaInput {
        fn parse(input: ParseStream) -> Result<Self> {
            let attrs = input.call(Attribute::parse_outer)?;
            let struct_name: Ident = input.parse()?;
            let table_name: LitStr = input.parse()?;

            let content;
            braced!(content in input);

            let fields = content.parse_terminated(SchemaField::parse, Token![,])?;

            Ok(SchemaInput {
                attrs,
                struct_name,
                table_name,
                fields: fields.into_iter().collect(),
            })
        }
    }
}

pub mod codegen {
    use super::parser::SchemaInput;
    use proc_macro2::TokenStream;
    use quote::quote;

    pub struct SchemaGenerator {
        input: SchemaInput,
    }

    impl SchemaGenerator {
        pub fn new(input: SchemaInput) -> Self {
            Self { input }
        }

        /// Returns `true` when the type is `Option<…>` (regardless of leading
        /// path segments like `std::option::Option`).
        fn is_option_type(ty: &syn::Type) -> bool {
            if let syn::Type::Path(type_path) = ty {
                if let Some(segment) = type_path.path.segments.last() {
                    return segment.ident == "Option"
                        && matches!(segment.arguments, syn::PathArguments::AngleBracketed(_));
                }
            }
            false
        }

        pub fn generate(self) -> TokenStream {
            let attrs = &self.input.attrs;
            let struct_name = &self.input.struct_name;
            let table_name = &self.input.table_name;

            let typed_fields: Vec<_> = self
                .input
                .fields
                .iter()
                .filter_map(|field| {
                    let ty = field.ty.as_ref()?;
                    Some((field, ty))
                })
                .collect();

            // Generate named struct fields from typed schema entries.
            // The field identifier is derived from the column name string
            // (e.g. `"id"` becomes `pub id: i64`).
            let struct_fields: Vec<_> = typed_fields
                .iter()
                .map(|(field, ty)| {
                    let attrs = &field.attrs;
                    let field_ident = syn::Ident::new(&field.value.value(), field.value.span());
                    quote! {
                        #(#attrs)*
                        pub #field_ident: #ty,
                    }
                })
                .collect();

            // Generate FromRow impl — uses try_get_typed for Option<T> fields
            // and get_typed for everything else.
            let from_row_assignments: Vec<_> = typed_fields
                .iter()
                .map(|(field, ty)| {
                    let field_ident = syn::Ident::new(&field.value.value(), field.value.span());
                    let col = &field.value.value();
                    if Self::is_option_type(ty) {
                        quote! {
                            #field_ident: row.try_get_typed(#col)?,
                        }
                    } else {
                        quote! {
                            #field_ident: row.get_typed(#col)?,
                        }
                    }
                })
                .collect();

            // Generate const definitions (same as before).
            let field_defs = self.input.fields.iter().map(|field| {
                let attrs = &field.attrs;
                let name = &field.name;
                let ty = &field.ty;
                let value = &field.value;
                match ty {
                    Some(ty) => quote! {
                        #(#attrs)*
                        pub const #name: sqlw::Def<Self, sqlw::Typed<#ty>> = sqlw::Def::new(#value);
                    },
                    None => quote! {
                        #(#attrs)*
                        pub const #name: sqlw::Def<Self> = sqlw::Def::new(#value);
                    },
                }
            });

            quote! {
                #(#attrs)*
                #[derive(Debug, Default)]
                pub struct #struct_name {
                    #(#struct_fields)*
                }

                impl sqlw::Schema for #struct_name {}

                impl sqlw::FromRow for #struct_name {
                    fn from_row<R: sqlw::RowLike>(row: &R) -> Result<Self, sqlw::RowError> {
                        Ok(Self {
                            #(#from_row_assignments)*
                        })
                    }
                }

                impl #struct_name {
                    pub const TABLE: sqlw::Def<Self> = sqlw::Def::new(#table_name);
                    #(#field_defs)*
                }
            }
        }
    }
}