spacetimedb-bindings-macro 2.4.0

Easy support for interacting between SpacetimeDB and Rust.
Documentation
use heck::ToSnakeCase;
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::ext::IdentExt;
use syn::parse::Parser;
use syn::{FnArg, ItemFn, LitStr};

use crate::reducer::generate_explicit_names_impl;
use crate::sym;
use crate::util::{check_duplicate_msg, match_meta};

pub(crate) struct ViewArgs {
    name: Option<LitStr>,
    accessor: Ident,
    #[allow(unused)]
    public: bool,
}

impl ViewArgs {
    /// Parse `#[view(accessor = ..., public)]` where both `name` and `public` are required.
    pub(crate) fn parse(input: TokenStream, func_ident: &Ident) -> syn::Result<Self> {
        let mut name = None;
        let mut accessor = None;
        let mut public = None;
        syn::meta::parser(|meta| {
            match_meta!(match meta {
                sym::name => {
                    check_duplicate_msg(&name, &meta, "`name` already specified")?;
                    name = Some(meta.value()?.parse::<LitStr>()?);
                }
                sym::public => {
                    check_duplicate_msg(&public, &meta, "`public` already specified")?;
                    public = Some(());
                }
                sym::accessor => {
                    check_duplicate_msg(&accessor, &meta, "`accessor` already specified")?;
                    accessor = Some(meta.value()?.parse()?);
                }
            });
            Ok(())
        })
        .parse2(input)?;
        let accessor = accessor.ok_or_else(|| {
            let view = func_ident.to_string().to_snake_case();
            syn::Error::new(
                Span::call_site(),
                format_args!("must specify view accessor, e.g. `#[spacetimedb::view(accessor = {view})]"),
            )
        })?;
        let () = public
            .ok_or_else(|| syn::Error::new(Span::call_site(), "views must be `public`, e.g. `#[view(public)]`"))?;
        Ok(Self {
            name,
            public: true,
            accessor,
        })
    }
}

/// If `ty` is `impl Query<T>`, returns `Some(T)`. Otherwise `None`.
fn extract_impl_query_inner(ty: &syn::Type) -> Option<&syn::Type> {
    if let syn::Type::ImplTrait(impl_trait) = ty {
        for bound in &impl_trait.bounds {
            if let syn::TypeParamBound::Trait(trait_bound) = bound
                && let Some(seg) = trait_bound.path.segments.last()
                && seg.ident == "Query"
                && let syn::PathArguments::AngleBracketed(args) = &seg.arguments
                && let Some(syn::GenericArgument::Type(inner)) = args.args.first()
            {
                return Some(inner);
            }
        }
    }
    None
}

pub(crate) fn view_impl(args: ViewArgs, original_function: &ItemFn) -> syn::Result<TokenStream> {
    let vis = &original_function.vis;
    let func_name = &original_function.sig.ident;
    let view_ident = args.accessor;
    let view_name = view_ident.unraw().to_string();

    for param in &original_function.sig.generics.params {
        let err = |msg| syn::Error::new_spanned(param, msg);
        match param {
            syn::GenericParam::Lifetime(_) => {}
            syn::GenericParam::Type(_) => return Err(err("type parameters are not allowed on views")),
            syn::GenericParam::Const(_) => return Err(err("const parameters are not allowed on views")),
        }
    }

    // Extract parameters
    let typed_args = original_function
        .sig
        .inputs
        .iter()
        .map(|arg| match arg {
            FnArg::Typed(arg) => Ok(arg),
            FnArg::Receiver(_) => Err(syn::Error::new_spanned(
                arg,
                "The `self` parameter is not allowed in views",
            )),
        })
        .collect::<syn::Result<Vec<_>>>()?;

    // Extract parameter names
    let opt_arg_names = typed_args.iter().map(|arg| {
        if let syn::Pat::Ident(i) = &*arg.pat {
            let name = i.ident.to_string();
            quote!(Some(#name))
        } else {
            quote!(None)
        }
    });

    let arg_tys = typed_args.iter().map(|arg| arg.ty.as_ref()).collect::<Vec<_>>();

    // Extract the context type and the rest of the parameter types
    let [ctx_ty, arg_tys @ ..] = &arg_tys[..] else {
        return Err(syn::Error::new_spanned(
            &original_function.sig,
            "Views must always have a context parameter: `&ViewContext` or `&AnonymousViewContext`",
        ));
    };

    // TODO: Re-enable parameterized views once we can pass args from sql
    if !arg_tys.is_empty() {
        return Err(syn::Error::new_spanned(
            &original_function.sig,
            "Views do not take parameters other than `&ViewContext` or `&AnonymousViewContext`",
        ));
    }

    // Extract the context type
    let ctx_ty = match ctx_ty {
        syn::Type::Reference(ctx_ty) => ctx_ty.elem.as_ref(),
        _ => {
            return Err(syn::Error::new_spanned(
                ctx_ty,
                "The first parameter of a view must be a context parameter: `&ViewContext` or `&AnonymousViewContext`; passed by reference",
            ));
        }
    };

    // Views must return a result
    let ret_ty = match &original_function.sig.output {
        syn::ReturnType::Type(_, t) => t.as_ref(),
        syn::ReturnType::Default => {
            return Err(syn::Error::new_spanned(
                &original_function.sig,
                "views must return `Vec<T>` or `Option<T>` where `T` is a `SpacetimeType`",
            ));
        }
    };

    let register_describer_symbol = format!("__preinit__20_register_describer_{}", view_name);

    let lt_params = &original_function.sig.generics;
    let lt_where_clause = &lt_params.where_clause;

    let generated_describe_function = quote! {
        #[unsafe(export_name = #register_describer_symbol)]
        pub extern "C" fn __register_describer() {
            spacetimedb::rt::ViewRegistrar::<#ctx_ty>::register::<_, #func_name, _, _>(#func_name)
        }
    };

    let explicit_name = args.name.as_ref();
    let generate_explicit_names = generate_explicit_names_impl(&view_name, func_name, explicit_name);

    let original_attrs = &original_function.attrs;
    let original_body = &original_function.block;

    // Detect `impl Query<T>` return type and extract `T`.
    let impl_query_inner = extract_impl_query_inner(ret_ty);

    // When the return type is `impl Query<T>`:
    //   - Rewrite the function to return `RawQuery<T>`
    //   - Wrap the body: `RawQuery::new(Query::into_sql({ body }))`
    //   - Use `RawQuery<T>` for SpacetimeType/ViewReturn assertions
    // When the return type is `RawQuery<T>` (concrete query struct):
    //   - Wrap with `.into()` so builder types auto-convert
    // Otherwise (Vec<T>, Option<T>):
    //   - Emit unchanged to preserve type inference
    let (emitted_fn, effective_ret_ty) = if let Some(inner_ty) = impl_query_inner {
        let original_sig = &original_function.sig;
        // Build a new signature with the return type replaced
        let mut new_sig = original_sig.clone();
        new_sig.output = syn::parse_quote!(-> spacetimedb::RawQuery<#inner_ty>);
        let effective_ty: syn::Type = syn::parse_quote!(spacetimedb::RawQuery<#inner_ty>);
        (
            quote! {
                #(#original_attrs)*
                #vis
                #new_sig {
                    spacetimedb::RawQuery::new(
                        Query::into_sql(#original_body)
                    )
                }
            },
            effective_ty,
        )
    } else {
        let original_sig = &original_function.sig;
        let returns_raw_query =
            matches!(ret_ty, syn::Type::Path(p) if p.path.segments.last().is_some_and(|s| s.ident == "RawQuery"));
        let emitted_body = if returns_raw_query {
            quote! { { ::core::convert::Into::into(#original_body) } }
        } else {
            quote! { #original_body }
        };
        (
            quote! {
                #(#original_attrs)*
                #vis
                #original_sig
                    #emitted_body
            },
            ret_ty.clone(),
        )
    };

    let eff_ret_ty = &effective_ret_ty;

    Ok(quote! {
        #emitted_fn

        const _: () = { #generated_describe_function };

        #[allow(non_camel_case_types)]
        #vis struct #func_name { _never: ::core::convert::Infallible }

        const _: () = {
            fn _assert_args #lt_params () #lt_where_clause {
                let _ = <#ctx_ty as spacetimedb::rt::ViewContextArg>::_ITEM;
                let _ = <#eff_ret_ty as spacetimedb::rt::ViewReturn>::_ITEM;
            }
        };

        const _: () = {
            fn _assert_args #lt_params () #lt_where_clause {
                #(let _ = <#arg_tys as spacetimedb::rt::ViewArg>::_ITEM;)*
            }
        };

        impl #func_name {
            fn invoke(__ctx: #ctx_ty, __args: &[u8]) -> Vec<u8> {
                spacetimedb::rt::ViewDispatcher::<#ctx_ty>::invoke::<_, _, _>(#func_name, __ctx, __args)
            }
        }

        #[automatically_derived]
        impl spacetimedb::rt::FnInfo for #func_name {
            /// The type of this function
            type Invoke = <spacetimedb::rt::ViewKind<#ctx_ty> as spacetimedb::rt::ViewKindTrait>::InvokeFn;

            /// The function kind, which will cause scheduled tables to reject views.
            type FnKind = spacetimedb::rt::FnKindView;

            /// The name of this function
            const NAME: &'static str = #view_name;

            /// The parameter names of this function
            const ARG_NAMES: &'static [Option<&'static str>] = &[#(#opt_arg_names),*];

            /// The pointer for invoking this function
            const INVOKE: Self::Invoke = #func_name::invoke;

            /// The return type of this function
            fn return_type(
                ts: &mut impl spacetimedb::sats::typespace::TypespaceBuilder
            ) -> Option<spacetimedb::sats::AlgebraicType> {
                Some(<#eff_ret_ty as spacetimedb::SpacetimeType>::make_type(ts))
            }
        }

        #generate_explicit_names
    })
}