use heck::ToSnakeCase;
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::ext::IdentExt;
use syn::parse::Parser;
use syn::{FnArg, ItemFn, LitStr};
use crate::reducer::generate_explicit_names_impl;
use crate::sym;
use crate::util::{check_duplicate_msg, match_meta};
pub(crate) struct ViewArgs {
name: Option<LitStr>,
accessor: Ident,
#[allow(unused)]
public: bool,
}
impl ViewArgs {
pub(crate) fn parse(input: TokenStream, func_ident: &Ident) -> syn::Result<Self> {
let mut name = None;
let mut accessor = None;
let mut public = None;
syn::meta::parser(|meta| {
match_meta!(match meta {
sym::name => {
check_duplicate_msg(&name, &meta, "`name` already specified")?;
name = Some(meta.value()?.parse::<LitStr>()?);
}
sym::public => {
check_duplicate_msg(&public, &meta, "`public` already specified")?;
public = Some(());
}
sym::accessor => {
check_duplicate_msg(&accessor, &meta, "`accessor` already specified")?;
accessor = Some(meta.value()?.parse()?);
}
});
Ok(())
})
.parse2(input)?;
let accessor = accessor.ok_or_else(|| {
let view = func_ident.to_string().to_snake_case();
syn::Error::new(
Span::call_site(),
format_args!("must specify view accessor, e.g. `#[spacetimedb::view(accessor = {view})]"),
)
})?;
let () = public
.ok_or_else(|| syn::Error::new(Span::call_site(), "views must be `public`, e.g. `#[view(public)]`"))?;
Ok(Self {
name,
public: true,
accessor,
})
}
}
fn extract_impl_query_inner(ty: &syn::Type) -> Option<&syn::Type> {
if let syn::Type::ImplTrait(impl_trait) = ty {
for bound in &impl_trait.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound
&& let Some(seg) = trait_bound.path.segments.last()
&& seg.ident == "Query"
&& let syn::PathArguments::AngleBracketed(args) = &seg.arguments
&& let Some(syn::GenericArgument::Type(inner)) = args.args.first()
{
return Some(inner);
}
}
}
None
}
pub(crate) fn view_impl(args: ViewArgs, original_function: &ItemFn) -> syn::Result<TokenStream> {
let vis = &original_function.vis;
let func_name = &original_function.sig.ident;
let view_ident = args.accessor;
let view_name = view_ident.unraw().to_string();
for param in &original_function.sig.generics.params {
let err = |msg| syn::Error::new_spanned(param, msg);
match param {
syn::GenericParam::Lifetime(_) => {}
syn::GenericParam::Type(_) => return Err(err("type parameters are not allowed on views")),
syn::GenericParam::Const(_) => return Err(err("const parameters are not allowed on views")),
}
}
let typed_args = original_function
.sig
.inputs
.iter()
.map(|arg| match arg {
FnArg::Typed(arg) => Ok(arg),
FnArg::Receiver(_) => Err(syn::Error::new_spanned(
arg,
"The `self` parameter is not allowed in views",
)),
})
.collect::<syn::Result<Vec<_>>>()?;
let opt_arg_names = typed_args.iter().map(|arg| {
if let syn::Pat::Ident(i) = &*arg.pat {
let name = i.ident.to_string();
quote!(Some(#name))
} else {
quote!(None)
}
});
let arg_tys = typed_args.iter().map(|arg| arg.ty.as_ref()).collect::<Vec<_>>();
let [ctx_ty, arg_tys @ ..] = &arg_tys[..] else {
return Err(syn::Error::new_spanned(
&original_function.sig,
"Views must always have a context parameter: `&ViewContext` or `&AnonymousViewContext`",
));
};
if !arg_tys.is_empty() {
return Err(syn::Error::new_spanned(
&original_function.sig,
"Views do not take parameters other than `&ViewContext` or `&AnonymousViewContext`",
));
}
let ctx_ty = match ctx_ty {
syn::Type::Reference(ctx_ty) => ctx_ty.elem.as_ref(),
_ => {
return Err(syn::Error::new_spanned(
ctx_ty,
"The first parameter of a view must be a context parameter: `&ViewContext` or `&AnonymousViewContext`; passed by reference",
));
}
};
let ret_ty = match &original_function.sig.output {
syn::ReturnType::Type(_, t) => t.as_ref(),
syn::ReturnType::Default => {
return Err(syn::Error::new_spanned(
&original_function.sig,
"views must return `Vec<T>` or `Option<T>` where `T` is a `SpacetimeType`",
));
}
};
let register_describer_symbol = format!("__preinit__20_register_describer_{}", view_name);
let lt_params = &original_function.sig.generics;
let lt_where_clause = <_params.where_clause;
let generated_describe_function = quote! {
#[unsafe(export_name = #register_describer_symbol)]
pub extern "C" fn __register_describer() {
spacetimedb::rt::ViewRegistrar::<#ctx_ty>::register::<_, #func_name, _, _>(#func_name)
}
};
let explicit_name = args.name.as_ref();
let generate_explicit_names = generate_explicit_names_impl(&view_name, func_name, explicit_name);
let original_attrs = &original_function.attrs;
let original_body = &original_function.block;
let impl_query_inner = extract_impl_query_inner(ret_ty);
let (emitted_fn, effective_ret_ty) = if let Some(inner_ty) = impl_query_inner {
let original_sig = &original_function.sig;
let mut new_sig = original_sig.clone();
new_sig.output = syn::parse_quote!(-> spacetimedb::RawQuery<#inner_ty>);
let effective_ty: syn::Type = syn::parse_quote!(spacetimedb::RawQuery<#inner_ty>);
(
quote! {
#(#original_attrs)*
#vis
#new_sig {
spacetimedb::RawQuery::new(
Query::into_sql(#original_body)
)
}
},
effective_ty,
)
} else {
let original_sig = &original_function.sig;
let returns_raw_query =
matches!(ret_ty, syn::Type::Path(p) if p.path.segments.last().is_some_and(|s| s.ident == "RawQuery"));
let emitted_body = if returns_raw_query {
quote! { { ::core::convert::Into::into(#original_body) } }
} else {
quote! { #original_body }
};
(
quote! {
#(#original_attrs)*
#vis
#original_sig
#emitted_body
},
ret_ty.clone(),
)
};
let eff_ret_ty = &effective_ret_ty;
Ok(quote! {
#emitted_fn
const _: () = { #generated_describe_function };
#[allow(non_camel_case_types)]
#vis struct #func_name { _never: ::core::convert::Infallible }
const _: () = {
fn _assert_args #lt_params () #lt_where_clause {
let _ = <#ctx_ty as spacetimedb::rt::ViewContextArg>::_ITEM;
let _ = <#eff_ret_ty as spacetimedb::rt::ViewReturn>::_ITEM;
}
};
const _: () = {
fn _assert_args #lt_params () #lt_where_clause {
#(let _ = <#arg_tys as spacetimedb::rt::ViewArg>::_ITEM;)*
}
};
impl #func_name {
fn invoke(__ctx: #ctx_ty, __args: &[u8]) -> Vec<u8> {
spacetimedb::rt::ViewDispatcher::<#ctx_ty>::invoke::<_, _, _>(#func_name, __ctx, __args)
}
}
#[automatically_derived]
impl spacetimedb::rt::FnInfo for #func_name {
type Invoke = <spacetimedb::rt::ViewKind<#ctx_ty> as spacetimedb::rt::ViewKindTrait>::InvokeFn;
type FnKind = spacetimedb::rt::FnKindView;
const NAME: &'static str = #view_name;
const ARG_NAMES: &'static [Option<&'static str>] = &[#(#opt_arg_names),*];
const INVOKE: Self::Invoke = #func_name::invoke;
fn return_type(
ts: &mut impl spacetimedb::sats::typespace::TypespaceBuilder
) -> Option<spacetimedb::sats::AlgebraicType> {
Some(<#eff_ret_ty as spacetimedb::SpacetimeType>::make_type(ts))
}
}
#generate_explicit_names
})
}