use std::collections::HashSet;
use heck::ToKebabCase;
use proc_macro2::{Span, TokenStream};
use proc_macro_error::abort;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use syn::{spanned::Spanned, FnArg, Ident, ItemTrait, LitStr, ReturnType, TraitItem, TraitItemFn};
use super::wit_interface;
use crate::util::{AttributeParameters, TokensSetItem};
pub fn generate(trait_definition: ItemTrait, parameters: AttributeParameters) -> TokenStream {
WitImportGenerator::new(&trait_definition, parameters).generate()
}
pub struct WitImportGenerator<'input> {
parameters: AttributeParameters,
trait_name: &'input Ident,
namespace: LitStr,
functions: Vec<FunctionInformation<'input>>,
}
pub(crate) struct FunctionInformation<'input> {
pub(crate) function: &'input TraitItemFn,
parameter_definitions: TokenStream,
parameter_bindings: TokenStream,
return_type: TokenStream,
interface: TokenStream,
instance_constraint: TokenStream,
}
impl<'input> WitImportGenerator<'input> {
fn new(trait_definition: &'input ItemTrait, parameters: AttributeParameters) -> Self {
let trait_name = &trait_definition.ident;
let namespace = parameters.namespace(trait_name);
let functions = trait_definition
.items
.iter()
.map(FunctionInformation::from)
.collect::<Vec<_>>();
WitImportGenerator {
trait_name,
parameters,
namespace,
functions,
}
}
fn generate(self) -> TokenStream {
let function_slots = self.function_slots();
let slot_initializations = self.slot_initializations();
let imported_functions = self.imported_functions();
let (instance_trait_alias_name, instance_trait_alias) = self.instance_trait_alias();
let trait_name = self.trait_name;
let wit_interface_implementation = wit_interface::generate(
self.parameters.package_name(),
self.parameters.interface_name(trait_name),
&self.functions,
);
quote! {
#[allow(clippy::type_complexity)]
pub struct #trait_name<Instance>
where
Instance: #instance_trait_alias_name,
<Instance::Runtime as linera_witty::Runtime>::Memory:
linera_witty::RuntimeMemory<Instance>,
{
instance: Instance,
#( #function_slots ),*
}
impl<Instance> #trait_name<Instance>
where
Instance: #instance_trait_alias_name,
<Instance::Runtime as linera_witty::Runtime>::Memory:
linera_witty::RuntimeMemory<Instance>,
{
pub fn new(instance: Instance) -> Self {
#trait_name {
instance,
#( #slot_initializations ),*
}
}
#( #imported_functions )*
}
impl<Instance> linera_witty::wit_generation::WitInterface for #trait_name<Instance>
where
Instance: #instance_trait_alias_name,
<Instance::Runtime as linera_witty::Runtime>::Memory:
linera_witty::RuntimeMemory<Instance>,
{
#wit_interface_implementation
}
#instance_trait_alias
}
}
fn function_slots(&self) -> impl Iterator<Item = TokenStream> + '_ {
self.functions.iter().map(|function| {
let function_name = function.name();
let instance_constraint = &function.instance_constraint;
quote_spanned! { function.span() =>
#function_name: Option<<Instance as #instance_constraint>::Function>
}
})
}
fn slot_initializations(&self) -> impl Iterator<Item = TokenStream> + '_ {
self.functions.iter().map(|function| {
let function_name = function.name();
quote_spanned! { function.span() =>
#function_name: None
}
})
}
fn imported_functions(&self) -> impl Iterator<Item = TokenStream> + '_ {
self.functions.iter().map(|function| {
let namespace = &self.namespace;
let function_name = function.name();
let function_wit_name = function_name.to_string().to_kebab_case();
let instance = &function.instance_constraint;
let parameters = &function.parameter_definitions;
let parameter_bindings = &function.parameter_bindings;
let return_type = &function.return_type;
let interface = &function.interface;
quote_spanned! { function.span() =>
pub fn #function_name(
&mut self,
#parameters
) -> Result<#return_type, linera_witty::RuntimeError> {
let function = match &self.#function_name {
Some(function) => function,
None => {
self.#function_name = Some(<Instance as #instance>::load_function(
&mut self.instance,
&format!("{}#{}", #namespace, #function_wit_name),
)?);
self.#function_name
.as_ref()
.expect("Function loaded into slot, but the slot remains empty")
}
};
let flat_parameters = #interface::lower_parameters(
linera_witty::hlist![#parameter_bindings],
&mut self.instance.memory()?,
)?;
let flat_results = self.instance.call(function, flat_parameters)?;
#[allow(clippy::let_unit_value)]
let result = #interface::lift_results(flat_results, &self.instance.memory()?)?;
Ok(result)
}
}
})
}
fn instance_trait_alias(&self) -> (Ident, TokenStream) {
let name = format_ident!("InstanceFor{}", self.trait_name);
let constraints = self.instance_constraints();
let definition = quote! {
pub trait #name : #constraints
where
<<Self as linera_witty::Instance>::Runtime as linera_witty::Runtime>::Memory:
linera_witty::RuntimeMemory<Self>,
{}
impl<AnyInstance> #name for AnyInstance
where
AnyInstance: #constraints,
<AnyInstance::Runtime as linera_witty::Runtime>::Memory:
linera_witty::RuntimeMemory<AnyInstance>,
{}
};
(name, definition)
}
fn instance_constraints(&self) -> TokenStream {
let constraint_set: HashSet<_> = self
.functions
.iter()
.map(|function| TokensSetItem::from(&function.instance_constraint))
.collect();
constraint_set.into_iter().fold(
quote! { linera_witty::InstanceWithMemory },
|list, item| quote! { #list + #item },
)
}
}
impl<'input> FunctionInformation<'input> {
pub fn new(function: &'input TraitItemFn) -> Self {
let (parameter_definitions, parameter_bindings, parameter_types) =
Self::parse_parameters(function.sig.inputs.iter());
let return_type = match &function.sig.output {
ReturnType::Default => quote_spanned! { function.sig.output.span() => () },
ReturnType::Type(_, return_type) => return_type.to_token_stream(),
};
let interface = quote_spanned! { function.sig.span() =>
<(linera_witty::HList![#parameter_types], #return_type)
as linera_witty::ImportedFunctionInterface>
};
let instance_constraint = quote_spanned! { function.sig.span() =>
linera_witty::InstanceWithFunction<
#interface::GuestParameters,
#interface::GuestResults,
>
};
FunctionInformation {
function,
parameter_definitions,
parameter_bindings,
return_type,
interface,
instance_constraint,
}
}
fn parse_parameters(
function_inputs: impl Iterator<Item = &'input FnArg>,
) -> (TokenStream, TokenStream, TokenStream) {
let parameters = function_inputs.map(|input| match input {
FnArg::Typed(parameter) => parameter,
FnArg::Receiver(receiver) => abort!(
receiver.self_token,
"Imported interfaces can not have `self` parameters"
),
});
let mut parameter_definitions = quote! {};
let mut parameter_bindings = quote! {};
let mut parameter_types = quote! {};
for parameter in parameters {
let parameter_binding = ¶meter.pat;
let parameter_type = ¶meter.ty;
parameter_definitions.extend(quote! { #parameter, });
parameter_bindings.extend(quote! { #parameter_binding, });
parameter_types.extend(quote! { #parameter_type, });
}
(parameter_definitions, parameter_bindings, parameter_types)
}
pub fn name(&self) -> &Ident {
&self.function.sig.ident
}
pub fn span(&self) -> Span {
self.function.span()
}
}
impl<'input> From<&'input TraitItem> for FunctionInformation<'input> {
fn from(item: &'input TraitItem) -> Self {
match item {
TraitItem::Fn(function) => FunctionInformation::new(function),
TraitItem::Const(const_item) => abort!(
const_item.ident,
"Const items are not supported in imported traits"
),
TraitItem::Type(type_item) => abort!(
type_item.ident,
"Type items are not supported in imported traits"
),
TraitItem::Macro(macro_item) => abort!(
macro_item.mac.path,
"Macro items are not supported in imported traits"
),
_ => abort!(item, "Only function items are supported in imported traits"),
}
}
}