use super::TraitDefinition;
use crate::{
generator::{
self,
},
traits::GenerateCode,
EnforcedErrors,
};
use derive_more::From;
use proc_macro2::{
Span,
TokenStream as TokenStream2,
};
use quote::{
format_ident,
quote,
quote_spanned,
};
use syn::{
parse_quote,
spanned::Spanned,
};
impl<'a> TraitDefinition<'a> {
pub fn generate_trait_registry_impl(&self) -> TokenStream2 {
TraitRegistry::from(*self).generate_code()
}
pub fn trait_info_ident(&self) -> syn::Ident {
self.append_trait_suffix("TraitInfo")
}
}
#[derive(From)]
struct TraitRegistry<'a> {
trait_def: TraitDefinition<'a>,
}
impl GenerateCode for TraitRegistry<'_> {
fn generate_code(&self) -> TokenStream2 {
let registry_impl = self.generate_registry_impl();
let trait_info = self.generate_trait_info_object();
quote! {
#registry_impl
#trait_info
}
}
}
impl TraitRegistry<'_> {
fn span(&self) -> Span {
self.trait_def.span()
}
fn trait_ident(&self) -> &syn::Ident {
self.trait_def.trait_def.item().ident()
}
fn generate_registry_impl(&self) -> TokenStream2 {
let span = self.span();
let name = self.trait_ident();
let trait_info_ident = self.trait_def.trait_info_ident();
let messages = self.generate_registry_messages();
quote_spanned!(span=>
impl<E> #name for ::ink::reflect::TraitDefinitionRegistry<E>
where
E: ::ink::env::Environment,
{
#[allow(non_camel_case_types)]
type __ink_TraitInfo = #trait_info_ident<E>;
#messages
}
)
}
fn generate_registry_messages(&self) -> TokenStream2 {
let messages = self.trait_def.trait_def.item().iter_items().filter_map(
|(item, selector)| {
item.filter_map_message()
.map(|message| self.generate_registry_for_message(&message, selector))
},
);
quote! {
#( #messages )*
}
}
fn generate_inout_guards_for_message(message: &ir::InkTraitMessage) -> TokenStream2 {
let message_span = message.span();
let message_inputs = message.inputs().map(|input| {
let input_span = input.span();
let input_type = &*input.ty;
quote_spanned!(input_span=>
::ink::codegen::utils::consume_type::<
::ink::codegen::DispatchInput<#input_type>
>();
)
});
let message_output = message.output().map(|output_type| {
let output_span = output_type.span();
quote_spanned!(output_span=>
::ink::codegen::utils::consume_type::<
::ink::codegen::DispatchOutput<#output_type>
>();
)
});
quote_spanned!(message_span=>
#( #message_inputs )*
#message_output
)
}
fn generate_registry_for_message(
&self,
message: &ir::InkTraitMessage,
selector: ir::Selector,
) -> TokenStream2 {
let span = message.span();
let ident = message.ident();
let attrs = message.attrs();
let cfg_attrs = message.get_cfg_attrs(span);
let output_ident = generator::output_ident(message.ident());
let output_type = message
.output()
.cloned()
.unwrap_or_else(|| parse_quote! { () });
let mut_token = message.receiver().is_ref_mut().then(|| quote! { mut });
let (input_bindings, input_types) =
Self::input_bindings_and_types(message.inputs());
let linker_error_ident = EnforcedErrors::cannot_call_trait_message(
self.trait_ident(),
message.ident(),
selector,
message.mutates(),
);
let inout_guards = Self::generate_inout_guards_for_message(message);
let impl_body = match option_env!("INK_COVERAGE_REPORTING") {
Some("true") => {
quote! {
::core::unreachable!(
"this is an invalid ink! message call which should never be possible."
);
}
}
_ => {
quote! {
extern {
fn #linker_error_ident() -> !;
}
unsafe { #linker_error_ident() }
}
}
};
quote_spanned!(span=>
#( #cfg_attrs )*
type #output_ident = #output_type;
#( #attrs )*
#[cold]
fn #ident(
& #mut_token self
#( , #input_bindings : #input_types )*
) -> Self::#output_ident {
#inout_guards
#impl_body
}
)
}
fn input_bindings_and_types(
inputs: ir::InputsIter,
) -> (Vec<syn::Ident>, Vec<&syn::Type>) {
inputs
.enumerate()
.map(|(n, pat_type)| {
let binding = format_ident!("__ink_binding_{}", n);
let ty = &*pat_type.ty;
(binding, ty)
})
.unzip()
}
fn generate_trait_info_object(&self) -> TokenStream2 {
let span = self.span();
let trait_id = self.generate_trait_id();
let trait_ident = self.trait_ident();
let trait_info_ident = self.trait_def.trait_info_ident();
let trait_call_forwarder = self.trait_def.call_forwarder_ident();
let trait_message_builder = self.trait_def.message_builder_ident();
let trait_message_info = self.generate_info_for_trait_messages();
quote_spanned!(span =>
#[doc(hidden)]
#[allow(non_camel_case_types)]
pub struct #trait_info_ident<E> {
marker: ::core::marker::PhantomData<fn() -> E>,
}
#trait_message_info
impl<E> ::ink::reflect::TraitInfo for #trait_info_ident<E>
where
E: ::ink::env::Environment,
{
const ID: u32 = #trait_id;
const PATH: &'static ::core::primitive::str = ::core::module_path!();
const NAME: &'static ::core::primitive::str = ::core::stringify!(#trait_ident);
}
impl<E> ::ink::codegen::TraitCallForwarder for #trait_info_ident<E>
where
E: ::ink::env::Environment,
{
type Forwarder = #trait_call_forwarder<E>;
}
impl<E> ::ink::codegen::TraitMessageBuilder for #trait_info_ident<E>
where
E: ::ink::env::Environment,
{
type MessageBuilder = #trait_message_builder<E>;
}
)
}
fn generate_trait_id(&self) -> syn::LitInt {
let span = self.span();
let mut id = 0u32;
debug_assert!(
self.trait_def
.trait_def
.item()
.iter_items()
.next()
.is_some(),
"invalid empty ink! trait definition"
);
for (_, selector) in self.trait_def.trait_def.item().iter_items() {
id ^= selector.into_be_u32()
}
syn::LitInt::new(&format!("{id}"), span)
}
fn generate_info_for_trait_messages(&self) -> TokenStream2 {
let span = self.span();
let message_impls = self.trait_def.trait_def.item().iter_items().filter_map(
|(trait_item, selector)| {
trait_item.filter_map_message().map(|message| {
self.generate_info_for_trait_for_message(&message, selector)
})
},
);
quote_spanned!(span=>
#( #message_impls )*
)
}
fn generate_info_for_trait_for_message(
&self,
message: &ir::InkTraitMessage,
selector: ir::Selector,
) -> TokenStream2 {
let span = message.span();
let trait_info_ident = self.trait_def.trait_info_ident();
let local_id = message.local_id();
let selector_bytes = selector.hex_lits();
let is_payable = message.ink_attrs().is_payable();
quote_spanned!(span=>
impl<E> ::ink::reflect::TraitMessageInfo<#local_id> for #trait_info_ident<E> {
const PAYABLE: ::core::primitive::bool = #is_payable;
const SELECTOR: [::core::primitive::u8; 4usize] = [ #( #selector_bytes ),* ];
}
)
}
}