cartel-gen 0.1.2

Procedural macros for cartel
Documentation
use quote::{format_ident, quote};
use syn::spanned::Spanned;
use syn::{ItemFn, ReturnType, Type};

use crate::backend::{Backend, Compiled, Dialect, Placeholder, TableSpec};
use crate::emit::{DecodeKind, DispatchKind};
use crate::util::SqlIdent;

const UNSUPPORTED_OPS: &[&str] = &[
    "ilike",
    "not_ilike",
    "pg_contains",
    "pg_overlaps",
    "regex_match",
    "regex_imatch",
    "not_regex_match",
    "not_regex_imatch",
    "is_descendant_of",
    "is_ancestor_of",
    "fts_match",
    "in_",
    "not_in",
];

pub(super) struct SqliteBackend;

impl SqliteBackend {
    fn no_param_cast(_: &syn::Type) -> Option<&'static str> {
        None
    }
}

impl Backend for SqliteBackend {
    fn dialect() -> Dialect {
        Dialect {
            rt_crate: syn::parse_quote! { ::cartel_sqlite },
            placeholder: Placeholder::Question,
            unsupported_ops: UNSUPPORTED_OPS,
            param_cast: Self::no_param_cast,
        }
    }

    fn emit_table(spec: TableSpec<'_>, dialect: &Dialect) -> syn::Result<proc_macro2::TokenStream> {
        Self::emit_table(spec, dialect)
    }
}

impl SqliteBackend {
    pub(super) fn query_free(f: ItemFn, no_probe: bool) -> syn::Result<proc_macro2::TokenStream> {
        if f.sig.asyncness.is_some() {
            return Err(syn::Error::new(
                f.sig.span(),
                "#[query] functions must not be async",
            ));
        }
        let compiled = Compiled::build::<Self>(
            &f.sig.generics,
            f.sig.inputs.iter().map(Compiled::fn_arg),
            &f.block,
            &f.sig.output,
            f.sig.span(),
        )?;
        Self::emit_query(&f, &compiled, no_probe)
    }

    fn emit_query(
        f: &ItemFn,
        compiled: &Compiled,
        no_probe: bool,
    ) -> syn::Result<proc_macro2::TokenStream> {
        let Compiled {
            param_ids,
            param_tys,
            captures,
            plan,
            ..
        } = compiled;

        let fn_vis = &f.vis;
        let fn_name = &f.sig.ident;
        let arg_decls = param_ids
            .iter()
            .zip(param_tys.iter())
            .map(|(id, ty)| quote! { #id: #ty })
            .collect::<Vec<_>>();
        let bind_idents = captures.iter().map(|c| quote! { #c }).collect::<Vec<_>>();

        let sql_parts = &plan.sql_parts;
        let sql_const = quote! {
            const __SQL: &str = {
                const __PARTS: &[&str] = &[ #( #sql_parts ),* ];
                const __LEN: usize = ::cartel_sqlite::__internal::concat_len(__PARTS);
                const __BYTES: [u8; __LEN] = ::cartel_sqlite::__internal::concat::<__LEN>(__PARTS);
                match ::core::str::from_utf8(&__BYTES) {
                    ::core::result::Result::Ok(s) => s,
                    ::core::result::Result::Err(_) => panic!("cartel_sqlite: SQL build failed"),
                }
            };
        };

        let decode_one = match &plan.decode {
            DecodeKind::Row(t) => quote! { <#t as ::cartel_sqlite::Decode>::decode(__r)? },
            DecodeKind::Unit => quote! { () },
        };

        let body = match &plan.dispatch {
            DispatchKind::NoRows => quote! {
                #sql_const
                let mut __stmt = __conn.prepare_cached(__SQL)?;
                __stmt.execute(::cartel_sqlite::params![ #( #bind_idents ),* ])?;
                ::core::result::Result::Ok(())
            },
            DispatchKind::One(_) => quote! {
                #sql_const
                let mut __stmt = __conn.prepare_cached(__SQL)?;
                let mut __rows = __stmt.query(::cartel_sqlite::params![ #( #bind_idents ),* ])?;
                match __rows.next()? {
                    ::core::option::Option::Some(__r) => ::core::result::Result::Ok(#decode_one),
                    ::core::option::Option::None => ::core::result::Result::Err(::cartel_sqlite::Error::QueryReturnedNoRows),
                }
            },
            DispatchKind::First(_) => quote! {
                #sql_const
                let mut __stmt = __conn.prepare_cached(__SQL)?;
                let mut __rows = __stmt.query(::cartel_sqlite::params![ #( #bind_idents ),* ])?;
                match __rows.next()? {
                    ::core::option::Option::Some(__r) => ::core::result::Result::Ok(::core::option::Option::Some(#decode_one)),
                    ::core::option::Option::None => ::core::result::Result::Ok(::core::option::Option::None),
                }
            },
            DispatchKind::All(_) | DispatchKind::Stream(_) => quote! {
                #sql_const
                let mut __stmt = __conn.prepare_cached(__SQL)?;
                let mut __rows = __stmt.query(::cartel_sqlite::params![ #( #bind_idents ),* ])?;
                let mut __out = ::std::vec::Vec::new();
                while let ::core::option::Option::Some(__r) = __rows.next()? {
                    __out.push(#decode_one);
                }
                ::core::result::Result::Ok(__out)
            },
        };

        let ret = match &plan.dispatch {
            DispatchKind::NoRows => quote! { ::cartel_sqlite::Result<()> },
            DispatchKind::One(t) => quote! { ::cartel_sqlite::Result<#t> },
            DispatchKind::First(t) => {
                quote! { ::cartel_sqlite::Result<::core::option::Option<#t>> }
            }
            DispatchKind::All(t) | DispatchKind::Stream(t) => {
                quote! { ::cartel_sqlite::Result<::std::vec::Vec<#t>> }
            }
        };

        let probe_emit = if no_probe {
            quote! {}
        } else {
            let probe_fn_name = format_ident!("__cartel_sqlite_probe_{}", fn_name);
            let probe_body_block = &f.block;
            let probe_returns_unit = match &f.sig.output {
                ReturnType::Default => true,
                ReturnType::Type(_, t) => {
                    matches!(t.as_ref(), Type::Tuple(tup) if tup.elems.is_empty())
                }
            };
            let (probe_ret_arrow, probe_body) = if probe_returns_unit {
                (quote! {}, quote! { let _ = #probe_body_block; })
            } else {
                let ret_ty = match &f.sig.output {
                    ReturnType::Type(_, t) => t.clone(),
                    ReturnType::Default => unreachable!(),
                };
                (quote! { -> #ret_ty }, quote! { #probe_body_block })
            };
            quote! {
                #[doc(hidden)]
                #[allow(
                    dead_code, unreachable_code, unused_variables,
                    non_snake_case, clippy::let_unit_value, clippy::unused_unit,
                    clippy::no_effect
                )]
                fn #probe_fn_name(#( #arg_decls, )*) #probe_ret_arrow {
                    #probe_body
                }
            }
        };

        Ok(quote! {
            #fn_vis fn #fn_name(
                __conn: &::cartel_sqlite::Connection,
                #( #arg_decls, )*
            ) -> #ret {
                #body
            }

            #probe_emit
        })
    }

    fn emit_table(
        spec: TableSpec<'_>,
        _dialect: &Dialect,
    ) -> syn::Result<proc_macro2::TokenStream> {
        let TableSpec {
            input,
            table_name,
            fields,
            pk_cols,
        } = spec;
        let name = &input.ident;
        let (impl_g, type_g, where_g) = input.generics.split_for_impl();
        let table_name_sql = table_name.quote_if_needed();

        let mut field_decodes = Vec::new();
        let mut n_cols_acc = quote! { 0usize };
        let mut col_names: Vec<String> = Vec::new();
        for f in fields {
            let fname = f.ident.as_ref().expect("named field");
            let fty = &f.ty;
            field_decodes.push(quote! {
                #fname: <#fty as ::cartel_sqlite::Decode>::decode_at(__r, __off + (#n_cols_acc))?,
            });
            n_cols_acc = quote! { #n_cols_acc + <#fty as ::cartel_sqlite::Decode>::N_COLS };
            col_names.push(fname.to_string());
        }

        let col_names_sql: Vec<String> = col_names.iter().map(|c| c.quote_if_needed()).collect();
        let select_cols = col_names_sql.join(",");
        let qualified_select_cols = col_names_sql
            .iter()
            .map(|c| format!("{table_name_sql}.{c}"))
            .collect::<Vec<_>>()
            .join(",");
        let pk_const = pk_cols
            .iter()
            .map(|c| c.quote_if_needed())
            .collect::<Vec<_>>()
            .join(",");
        let n_cols = col_names.len() as u16;

        Ok(quote! {
            impl #impl_g ::cartel_sqlite::Decode for #name #type_g #where_g {
                const N_COLS: usize = #n_cols_acc;
                fn decode_at(__r: &::cartel_sqlite::Row<'_>, __off: usize) -> ::cartel_sqlite::Result<Self> {
                    ::core::result::Result::Ok(Self { #( #field_decodes )* })
                }
            }

            impl #impl_g #name #type_g #where_g {
                #[doc(hidden)]
                pub const __CARTEL_TABLE: &'static str = #table_name_sql;
                #[doc(hidden)]
                pub const __CARTEL_SELECT_COLS: &'static str = #select_cols;
                #[doc(hidden)]
                pub const __CARTEL_SELECT_COLS_QUALIFIED: &'static str = #qualified_select_cols;
                #[doc(hidden)]
                pub const __CARTEL_PK_COL: &'static str = #pk_const;
                #[doc(hidden)]
                pub const __CARTEL_N_COLS: u16 = #n_cols;
            }

            #[allow(unused_variables, clippy::unused_self)]
            impl #impl_g #name #type_g #where_g {
                pub fn filter(_f: impl ::core::ops::FnOnce(#name #type_g) -> bool)
                    -> ::cartel_sqlite::FilterBuilder<#name #type_g>
                {
                    ::core::unreachable!("cartel_sqlite: Table::filter only valid inside #[query] body")
                }

                pub fn join<__T>(
                    _f: impl ::core::ops::FnOnce(#name #type_g, __T) -> bool,
                ) -> ::cartel_sqlite::JoinBuilder<#name #type_g, __T> {
                    ::core::unreachable!("cartel_sqlite: Table::join only valid inside #[query] body")
                }
                pub fn left_join<__T>(
                    _f: impl ::core::ops::FnOnce(#name #type_g, __T) -> bool,
                ) -> ::cartel_sqlite::JoinBuilder<#name #type_g, __T> {
                    ::core::unreachable!("cartel_sqlite: Table::left_join only valid inside #[query] body")
                }
                pub fn right_join<__T>(
                    _f: impl ::core::ops::FnOnce(#name #type_g, __T) -> bool,
                ) -> ::cartel_sqlite::JoinBuilder<#name #type_g, __T> {
                    ::core::unreachable!("cartel_sqlite: Table::right_join only valid inside #[query] body")
                }
                pub fn full_join<__T>(
                    _f: impl ::core::ops::FnOnce(#name #type_g, __T) -> bool,
                ) -> ::cartel_sqlite::JoinBuilder<#name #type_g, __T> {
                    ::core::unreachable!("cartel_sqlite: Table::full_join only valid inside #[query] body")
                }

                pub fn insert(_f: impl ::core::ops::FnOnce(&mut #name #type_g))
                    -> ::cartel_sqlite::InsertBuilder<#name #type_g>
                {
                    ::core::unreachable!("cartel_sqlite: Table::insert only valid inside #[query] body")
                }
                pub fn insert_from<__S: ::cartel_sqlite::SourceRow>(
                    _s: __S,
                    _f: impl ::core::ops::FnOnce(&mut #name #type_g, __S::Row),
                ) -> ::cartel_sqlite::InsertBuilder<#name #type_g> {
                    ::core::unreachable!("cartel_sqlite: Table::insert_from only valid inside #[query] body")
                }
            }
        })
    }
}