use proc_macro2::{Ident, Literal, Span, TokenStream};
use syn::parse::ParseStream;
use syn::{ext::IdentExt, spanned::Spanned};
use crate::db_lifetime;
use crate::options::{AllowedOptions, Options};
pub(crate) struct SalsaStruct<'s, A: SalsaStructAllowedOptions> {
struct_item: &'s syn::ItemStruct,
args: &'s Options<A>,
fields: Vec<SalsaField<'s>>,
}
pub(crate) trait SalsaStructAllowedOptions: AllowedOptions {
const KIND: &'static str;
const ALLOW_MAYBE_UPDATE: bool;
const ALLOW_TRACKED: bool;
const HAS_LIFETIME: bool;
const ELIDABLE_LIFETIME: bool;
const ALLOW_DEFAULT: bool;
}
pub(crate) struct SalsaField<'s> {
pub(crate) field: &'s syn::Field,
pub(crate) has_tracked_attr: bool,
pub(crate) has_default_attr: bool,
pub(crate) returns: syn::Ident,
pub(crate) has_no_eq_attr: bool,
pub(crate) maybe_update_attr: Option<(syn::Path, syn::Expr)>,
get_name: syn::Ident,
set_name: syn::Ident,
unknown_attrs: Vec<&'s syn::Attribute>,
}
const BANNED_FIELD_NAMES: &[&str] = &["from", "new"];
const ALLOWED_RETURN_MODES: &[&str] = &["copy", "clone", "ref", "deref", "as_ref", "as_deref"];
#[allow(clippy::type_complexity)]
pub(crate) const FIELD_OPTION_ATTRIBUTES: &[(
&str,
fn(&syn::Attribute, &mut SalsaField) -> syn::Result<()>,
)] = &[
("tracked", |_, ef| {
ef.has_tracked_attr = true;
Ok(())
}),
("default", |_, ef| {
ef.has_default_attr = true;
Ok(())
}),
("returns", |attr, ef| {
ef.returns = attr.parse_args_with(syn::Ident::parse_any)?;
Ok(())
}),
("no_eq", |_, ef| {
ef.has_no_eq_attr = true;
Ok(())
}),
("get", |attr, ef| {
ef.get_name = attr.parse_args()?;
Ok(())
}),
("set", |attr, ef| {
ef.set_name = attr.parse_args()?;
Ok(())
}),
("maybe_update", |attr, ef| {
ef.maybe_update_attr = Some(attr.parse_args_with(|parser: ParseStream| {
let expr = parser.parse::<syn::Expr>()?;
Ok((attr.path().clone(), expr))
})?);
Ok(())
}),
];
impl<'s, A> SalsaStruct<'s, A>
where
A: SalsaStructAllowedOptions,
{
pub fn new(struct_item: &'s syn::ItemStruct, args: &'s Options<A>) -> syn::Result<Self> {
let syn::Fields::Named(n) = &struct_item.fields else {
return Err(syn::Error::new_spanned(
&struct_item.ident,
"must have named fields for a struct",
));
};
let fields = n
.named
.iter()
.map(SalsaField::new)
.collect::<syn::Result<_>>()?;
let this = Self {
struct_item,
args,
fields,
};
this.maybe_disallow_maybe_update_fields()?;
this.maybe_disallow_tracked_fields()?;
this.maybe_disallow_default_fields()?;
this.check_generics()?;
Ok(this)
}
pub(crate) fn constructor_name(&self) -> syn::Ident {
match self.args.constructor_name.clone() {
Some(name) => name,
None => Ident::new("new", self.struct_item.ident.span()),
}
}
pub(crate) fn id(&self) -> syn::Path {
match &self.args.id {
Some(id) => id.clone(),
None => parse_quote!(salsa::Id),
}
}
pub(crate) fn revisions(&self) -> impl Iterator<Item = &syn::Expr> + '_ {
self.args.revisions.iter()
}
fn maybe_disallow_maybe_update_fields(&self) -> syn::Result<()> {
if A::ALLOW_MAYBE_UPDATE {
return Ok(());
}
for ef in &self.fields {
if ef.maybe_update_attr.is_some() {
return Err(syn::Error::new_spanned(
ef.field,
format!(
"`#[maybe_update]` cannot be used with `#[salsa::{}]`",
A::KIND
),
));
}
}
Ok(())
}
fn maybe_disallow_tracked_fields(&self) -> syn::Result<()> {
if A::ALLOW_TRACKED {
return Ok(());
}
for ef in &self.fields {
if ef.has_tracked_attr {
return Err(syn::Error::new_spanned(
ef.field,
format!("`#[tracked]` cannot be used with `#[salsa::{}]`", A::KIND),
));
}
}
Ok(())
}
fn maybe_disallow_default_fields(&self) -> syn::Result<()> {
if A::ALLOW_DEFAULT {
return Ok(());
}
for ef in &self.fields {
if ef.has_default_attr {
return Err(syn::Error::new_spanned(
ef.field,
format!("`#[default]` cannot be used with `#[salsa::{}]`", A::KIND),
));
}
}
Ok(())
}
fn check_generics(&self) -> syn::Result<()> {
if A::HAS_LIFETIME {
if !A::ELIDABLE_LIFETIME {
db_lifetime::require_db_lifetime(&self.struct_item.generics)
} else {
Ok(())
}
} else {
db_lifetime::require_no_generics(&self.struct_item.generics)
}
}
pub(crate) fn field_ids(&self) -> Vec<&syn::Ident> {
self.fields
.iter()
.map(|f| f.field.ident.as_ref().unwrap())
.collect()
}
pub(crate) fn tracked_ids(&self) -> Vec<&syn::Ident> {
self.tracked_fields_iter()
.map(|(_, f)| f.field.ident.as_ref().unwrap())
.collect()
}
pub(crate) fn field_indices(&self) -> Vec<Literal> {
(0..self.fields.len())
.map(Literal::usize_unsuffixed)
.collect()
}
pub(crate) fn tracked_field_indices(&self) -> Vec<Literal> {
self.tracked_fields_iter()
.map(|(index, _)| Literal::usize_unsuffixed(index))
.collect()
}
pub(crate) fn untracked_field_indices(&self) -> Vec<Literal> {
self.untracked_fields_iter()
.map(|(index, _)| Literal::usize_unsuffixed(index))
.collect()
}
pub(crate) fn num_fields(&self) -> Literal {
Literal::usize_unsuffixed(self.fields.len())
}
pub(crate) fn num_tracked_fields(&self) -> Literal {
Literal::usize_unsuffixed(self.tracked_fields_iter().count())
}
pub(crate) fn required_fields(&self) -> Vec<TokenStream> {
self.fields
.iter()
.filter_map(|f| {
if f.has_default_attr {
None
} else {
let ident = f.field.ident.as_ref().unwrap();
let ty = &f.field.ty;
Some(quote!(#ident #ty))
}
})
.collect()
}
pub(crate) fn field_vis(&self) -> Vec<&syn::Visibility> {
self.fields.iter().map(|f| &f.field.vis).collect()
}
pub(crate) fn tracked_vis(&self) -> Vec<&syn::Visibility> {
self.tracked_fields_iter()
.map(|(_, f)| &f.field.vis)
.collect()
}
pub(crate) fn untracked_vis(&self) -> Vec<&syn::Visibility> {
self.untracked_fields_iter()
.map(|(_, f)| &f.field.vis)
.collect()
}
pub(crate) fn field_getter_ids(&self) -> Vec<&syn::Ident> {
self.fields.iter().map(|f| &f.get_name).collect()
}
pub(crate) fn tracked_getter_ids(&self) -> Vec<&syn::Ident> {
self.tracked_fields_iter()
.map(|(_, f)| &f.get_name)
.collect()
}
pub(crate) fn untracked_getter_ids(&self) -> Vec<&syn::Ident> {
self.untracked_fields_iter()
.map(|(_, f)| &f.get_name)
.collect()
}
pub(crate) fn field_setter_ids(&self) -> Vec<&syn::Ident> {
self.fields.iter().map(|f| &f.set_name).collect()
}
pub(crate) fn field_durability_ids(&self) -> Vec<syn::Ident> {
self.fields
.iter()
.map(|f| quote::format_ident!("{}_durability", f.field.ident.as_ref().unwrap()))
.collect()
}
pub(crate) fn field_tys(&self) -> Vec<&syn::Type> {
self.fields.iter().map(|f| &f.field.ty).collect()
}
pub(crate) fn tracked_tys(&self) -> Vec<&syn::Type> {
self.tracked_fields_iter()
.map(|(_, f)| &f.field.ty)
.collect()
}
pub(crate) fn untracked_tys(&self) -> Vec<&syn::Type> {
self.untracked_fields_iter()
.map(|(_, f)| &f.field.ty)
.collect()
}
pub(crate) fn field_indexed_tys(&self) -> Vec<syn::Ident> {
self.fields
.iter()
.enumerate()
.map(|(i, _)| quote::format_ident!("T{i}"))
.collect()
}
pub(crate) fn field_attrs(&self) -> Vec<&[&syn::Attribute]> {
self.fields.iter().map(|f| &*f.unknown_attrs).collect()
}
pub(crate) fn tracked_field_attrs(&self) -> Vec<&[&syn::Attribute]> {
self.tracked_fields_iter()
.map(|f| &*f.1.unknown_attrs)
.collect()
}
pub(crate) fn untracked_field_attrs(&self) -> Vec<&[&syn::Attribute]> {
self.untracked_fields_iter()
.map(|f| &*f.1.unknown_attrs)
.collect()
}
pub(crate) fn field_options(&self) -> Vec<TokenStream> {
self.fields.iter().map(SalsaField::options).collect()
}
pub(crate) fn tracked_options(&self) -> Vec<TokenStream> {
self.tracked_fields_iter()
.map(|(_, f)| f.options())
.collect()
}
pub(crate) fn untracked_options(&self) -> Vec<TokenStream> {
self.untracked_fields_iter()
.map(|(_, f)| f.options())
.collect()
}
pub fn generate_debug_impl(&self) -> bool {
self.args.debug.is_some()
}
pub fn generate_lifetime(&self) -> bool {
self.args.no_lifetime.is_none()
}
pub fn tracked_fields_iter(&self) -> impl Iterator<Item = (usize, &SalsaField<'s>)> {
self.fields
.iter()
.enumerate()
.filter(|(_, f)| f.has_tracked_attr)
}
pub fn untracked_fields_iter(&self) -> impl Iterator<Item = (usize, &SalsaField<'s>)> {
self.fields
.iter()
.enumerate()
.filter(|(_, f)| !f.has_tracked_attr)
}
pub(crate) fn serialize_fn(&self) -> impl Iterator<Item = syn::Path> + '_ {
self.args
.persist
.clone()
.map(|persist| {
persist
.serialize_fn
.unwrap_or(parse_quote! { serde::Serialize::serialize })
})
.into_iter()
}
pub(crate) fn deserialize_fn(&self) -> impl Iterator<Item = syn::Path> + '_ {
self.args
.persist
.clone()
.map(|persist| {
persist
.deserialize_fn
.unwrap_or(parse_quote! { serde::Deserialize::deserialize })
})
.into_iter()
}
}
impl<'s> SalsaField<'s> {
fn new(field: &'s syn::Field) -> syn::Result<Self> {
let field_name = field.ident.as_ref().unwrap();
let field_name_str = field_name.to_string();
if BANNED_FIELD_NAMES.iter().any(|n| *n == field_name_str) {
return Err(syn::Error::new(
field_name.span(),
format!("the field name `{field_name_str}` is disallowed in salsa structs",),
));
}
let get_name = Ident::new(&field_name_str, field_name.span());
let set_name = Ident::new(&format!("set_{field_name_str}",), field_name.span());
let returns = Ident::new("clone", field.span());
let mut result = SalsaField {
field,
has_tracked_attr: false,
returns,
has_default_attr: false,
has_no_eq_attr: false,
maybe_update_attr: None,
get_name,
set_name,
unknown_attrs: Default::default(),
};
for attr in &field.attrs {
let mut handled = false;
for (fa, func) in FIELD_OPTION_ATTRIBUTES {
if attr.path().is_ident(fa) {
func(attr, &mut result)?;
handled = true;
break;
}
}
if !handled {
result.unknown_attrs.push(attr);
}
}
if !ALLOWED_RETURN_MODES
.iter()
.any(|mode| mode == &result.returns.to_string())
{
return Err(syn::Error::new(
result.returns.span(),
format!("Invalid return mode. Allowed modes are: {ALLOWED_RETURN_MODES:?}"),
));
}
Ok(result)
}
fn options(&self) -> TokenStream {
let returns = &self.returns;
let default_ident = if self.has_default_attr {
syn::Ident::new("default", Span::call_site())
} else {
syn::Ident::new("required", Span::call_site())
};
quote!((#returns, #default_ident))
}
}