use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use quote::quote;
use crate::schema::{FieldIdentifier, KeyStrategy, Schema, SchemaKey};
pub struct Generator(Schema);
impl Generator {
pub fn new(schema: Schema) -> Self {
Self(schema)
}
pub fn generate_record_impl(&self, visibility: &syn::Visibility) -> TokenStream {
let name = &self.0.name;
let key_type = syn::Ident::new(&format!("{name}Key"), name.span());
let keys = self.generate_keys();
let mut key_impl = self.generate_key_impl(&key_type, &keys, visibility);
let (index_impl, index_values) = self.generate_index_impls(&key_type, visibility);
key_impl.extend(index_impl);
let trait_impls = self.generate_main_impl(&key_type, &index_values);
key_impl.extend(trait_impls);
TokenStream::from(key_impl)
}
fn generate_keys(&self) -> Vec<SchemaKey> {
match &self.0.key_strategy {
KeyStrategy::Autoincrement => {
vec![SchemaKey {
field_id: FieldIdentifier::Indexed(0),
ty: syn::parse_quote!(u64),
}]
}
KeyStrategy::FieldKeys(keys) => keys.clone(),
KeyStrategy::Derived(types) => {
types
.iter()
.enumerate()
.map(|(i, ty)| SchemaKey {
field_id: FieldIdentifier::Indexed(i),
ty: ty.clone(),
})
.collect()
}
}
}
fn generate_key_impl(
&self,
key_type: &syn::Ident,
keys: &[SchemaKey],
visibility: &syn::Visibility,
) -> proc_macro2::TokenStream {
let other_attrs = &self.0.attrs;
let field_types: Vec<_> = keys.iter().map(|k| &k.ty).collect();
let key_trait = self.generate_key_trait_impl(key_type, keys);
quote! {
#(#other_attrs)*
#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
#visibility struct #key_type(#(pub #field_types),*);
#key_trait
}
}
fn generate_key_trait_impl(
&self,
key_type: &syn::Ident,
keys: &[SchemaKey],
) -> proc_macro2::TokenStream {
let name = &self.0.name;
let (impl_generics, ty_generics, where_clause) = self.0.generics.split_for_impl();
match &self.0.key_strategy {
KeyStrategy::Autoincrement => {
quote! {
impl ::kivis::Incrementable for #key_type {
fn next_id(&self) -> ::core::option::Option<Self> {
self.0.checked_add(1).map(|id| #key_type(id))
}
}
}
}
KeyStrategy::FieldKeys(_) => {
let field_accesses: Vec<proc_macro2::TokenStream> = keys
.iter()
.map(|key| match &key.field_id {
FieldIdentifier::Named(name) => quote! { c.#name.clone() },
FieldIdentifier::Indexed(idx) => {
let index = syn::Index::from(*idx);
quote! { c.#index.clone() }
}
})
.collect();
quote! {
impl #impl_generics ::kivis::DeriveKey for #name #ty_generics #where_clause {
type Key = #key_type;
fn key(c: &<Self::Key as ::kivis::RecordKey>::Record) -> Self::Key {
#key_type(#(#field_accesses),*)
}
}
}
}
KeyStrategy::Derived(_) => {
quote! {}
}
}
}
fn generate_index_impls(
&self,
key_type: &syn::Ident,
visibility: &syn::Visibility,
) -> (proc_macro2::TokenStream, Vec<proc_macro2::TokenStream>) {
let name = &self.0.name;
let mut index_impl = proc_macro2::TokenStream::new();
let mut index_values = Vec::new();
for (i, index) in self.0.indexes.iter().enumerate() {
let index_type_suffix = match &index.field_id {
FieldIdentifier::Named(field_name) => field_name.to_string().to_case(Case::Pascal),
FieldIdentifier::Indexed(idx) => format!("Field{idx}"),
};
let index_name =
syn::Ident::new(&format!("{name}{index_type_suffix}Index"), name.span());
let index_type = &index.ty;
let current_index_impl = quote! {
#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
#visibility struct #index_name(pub #index_type);
impl ::kivis::Index for #index_name {
type Key = #key_type;
type Record = #name;
const INDEX: u8 = #i as u8;
}
};
index_impl.extend(current_index_impl);
let field_access = match &index.field_id {
FieldIdentifier::Named(field_name) => {
quote! { &self.#field_name }
}
FieldIdentifier::Indexed(idx) => {
let index = syn::Index::from(*idx);
quote! { &self.#index }
}
};
index_values.push(field_access);
}
(index_impl, index_values)
}
fn generate_main_impl(
&self,
key_type: &syn::Ident,
index_values: &[proc_macro2::TokenStream],
) -> proc_macro2::TokenStream {
let name = &self.0.name;
let (impl_generics, ty_generics, where_clause) = self.0.generics.split_for_impl();
let Ok(index_count) = u8::try_from(index_values.len()) else {
return quote! {
compile_error!("Too many indexes: maximum of 256 indexes allowed per record");
};
};
let indices = 0..index_count;
quote! {
impl #impl_generics ::kivis::RecordKey for #key_type #ty_generics #where_clause {
type Record = #name;
}
impl #impl_generics ::kivis::DatabaseEntry for #name #ty_generics #where_clause {
type Key = #key_type;
const INDEX_COUNT_HINT: u8 = #index_count as u8;
fn index_key<KU: ::kivis::Unifier>(
&self,
buffer: &mut KU::D,
discriminator: u8,
serializer: &KU,
) -> core::result::Result<(), kivis::BufferOverflowOr<KU::SerError>> {
match discriminator {
#(
#indices => {
serializer.serialize_ref(buffer, #index_values)?;
}
)*
_ => {}
}
Ok(())
}
}
}
}
}