use super::EncodingState;
use anyhow::Result;
use std::collections::HashMap;
use wasm_encoder::*;
use wit_parser::{
Enum, Flags, Function, Handle, InterfaceId, Param, Record, Resolve, Result_, Tuple, Type,
TypeDefKind, TypeId, TypeOwner, Variant,
};
#[derive(Clone)]
struct ParamSignatures<'a>(&'a [Param]);
impl PartialEq for ParamSignatures<'_> {
fn eq(&self, other: &Self) -> bool {
self.0.len() == other.0.len()
&& self
.0
.iter()
.zip(other.0)
.all(|(a, b)| a.name == b.name && a.ty == b.ty)
}
}
impl Eq for ParamSignatures<'_> {}
impl std::hash::Hash for ParamSignatures<'_> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
for p in self.0 {
p.name.hash(state);
p.ty.hash(state);
}
}
}
#[derive(Hash, PartialEq, Eq, Clone)]
pub struct FunctionKey<'a> {
async_: bool,
params: ParamSignatures<'a>,
result: &'a Option<Type>,
}
#[derive(Default, Clone)]
pub struct TypeEncodingMaps<'a> {
pub id_to_index: HashMap<TypeId, u32>,
pub def_to_index: HashMap<&'a TypeDefKind, u32>,
pub func_type_map: HashMap<FunctionKey<'a>, u32>,
pub unit_future: Option<u32>,
pub unit_stream: Option<u32>,
}
impl<'a> TypeEncodingMaps<'a> {
fn lookup(&self, resolve: &'a Resolve, id: TypeId) -> Option<u32> {
if let Some(index) = self.id_to_index.get(&id) {
return Some(*index);
}
let ty = &resolve.types[id];
if ty.name.is_none() {
if let Some(index) = self.def_to_index.get(&ty.kind) {
return Some(*index);
}
}
None
}
fn insert(&mut self, resolve: &'a Resolve, id: TypeId, index: u32) {
self.id_to_index.insert(id, index);
let ty = &resolve.types[id];
if ty.name.is_none() {
self.def_to_index.insert(&ty.kind, index);
}
}
}
pub trait ValtypeEncoder<'a> {
fn defined_type(&mut self) -> (u32, ComponentDefinedTypeEncoder<'_>);
fn define_function_type(&mut self) -> (u32, ComponentFuncTypeEncoder<'_>);
fn export_type(&mut self, index: u32, name: &'a str) -> Option<u32>;
fn export_resource(&mut self, name: &'a str) -> u32;
fn type_encoding_maps(&mut self) -> &mut TypeEncodingMaps<'a>;
fn import_type(&mut self, interface: InterfaceId, id: TypeId) -> u32;
fn interface(&self) -> Option<InterfaceId>;
fn encode_func_type(&mut self, resolve: &'a Resolve, func: &'a Function) -> Result<u32> {
let key = FunctionKey {
async_: func.kind.is_async(),
params: ParamSignatures(&func.params),
result: &func.result,
};
if let Some(index) = self.type_encoding_maps().func_type_map.get(&key) {
return Ok(*index);
}
let params: Vec<_> = self.encode_params(resolve, &func.params)?;
let result = func
.result
.map(|ty| self.encode_valtype(resolve, &ty))
.transpose()?;
let (index, mut f) = self.define_function_type();
f.async_(func.kind.is_async()).params(params).result(result);
let prev = self.type_encoding_maps().func_type_map.insert(key, index);
assert!(prev.is_none());
Ok(index)
}
fn encode_params(
&mut self,
resolve: &'a Resolve,
params: &'a [Param],
) -> Result<Vec<(&'a str, ComponentValType)>> {
params
.iter()
.map(|p| Ok((p.name.as_str(), self.encode_valtype(resolve, &p.ty)?)))
.collect::<Result<_>>()
}
fn encode_valtype(&mut self, resolve: &'a Resolve, ty: &Type) -> Result<ComponentValType> {
Ok(match *ty {
Type::Bool => ComponentValType::Primitive(PrimitiveValType::Bool),
Type::U8 => ComponentValType::Primitive(PrimitiveValType::U8),
Type::U16 => ComponentValType::Primitive(PrimitiveValType::U16),
Type::U32 => ComponentValType::Primitive(PrimitiveValType::U32),
Type::U64 => ComponentValType::Primitive(PrimitiveValType::U64),
Type::S8 => ComponentValType::Primitive(PrimitiveValType::S8),
Type::S16 => ComponentValType::Primitive(PrimitiveValType::S16),
Type::S32 => ComponentValType::Primitive(PrimitiveValType::S32),
Type::S64 => ComponentValType::Primitive(PrimitiveValType::S64),
Type::F32 => ComponentValType::Primitive(PrimitiveValType::F32),
Type::F64 => ComponentValType::Primitive(PrimitiveValType::F64),
Type::Char => ComponentValType::Primitive(PrimitiveValType::Char),
Type::String => ComponentValType::Primitive(PrimitiveValType::String),
Type::ErrorContext => ComponentValType::Primitive(PrimitiveValType::ErrorContext),
Type::Id(id) => {
if let Some(index) = self.type_encoding_maps().lookup(resolve, id) {
return Ok(ComponentValType::Type(index));
}
let ty = &resolve.types[id];
log::trace!("encode type name={:?} {:?}", ty.name, &ty.kind);
if let Some(index) = self.maybe_import_type(resolve, id) {
self.type_encoding_maps().insert(resolve, id, index);
return Ok(ComponentValType::Type(index));
}
let mut encoded = match &ty.kind {
TypeDefKind::Record(r) => self.encode_record(resolve, r)?,
TypeDefKind::Tuple(t) => self.encode_tuple(resolve, t)?,
TypeDefKind::Flags(r) => self.encode_flags(r)?,
TypeDefKind::Variant(v) => self.encode_variant(resolve, v)?,
TypeDefKind::Option(t) => self.encode_option(resolve, t)?,
TypeDefKind::Result(r) => self.encode_result(resolve, r)?,
TypeDefKind::Enum(e) => self.encode_enum(e)?,
TypeDefKind::List(ty) => {
let ty = self.encode_valtype(resolve, ty)?;
let (index, encoder) = self.defined_type();
encoder.list(ty);
ComponentValType::Type(index)
}
TypeDefKind::Map(key_ty, value_ty) => {
let key = self.encode_valtype(resolve, key_ty)?;
let value = self.encode_valtype(resolve, value_ty)?;
let (index, encoder) = self.defined_type();
encoder.map(key, value);
ComponentValType::Type(index)
}
TypeDefKind::FixedLengthList(ty, elements) => {
let ty = self.encode_valtype(resolve, ty)?;
let (index, encoder) = self.defined_type();
encoder.fixed_length_list(ty, *elements);
ComponentValType::Type(index)
}
TypeDefKind::Type(ty) => self.encode_valtype(resolve, ty)?,
TypeDefKind::Future(ty) => self.encode_future(resolve, ty)?,
TypeDefKind::Stream(ty) => self.encode_stream(resolve, ty)?,
TypeDefKind::Unknown => unreachable!(),
TypeDefKind::Resource => {
let name = ty.name.as_ref().expect("resources must be named");
let index = self.export_resource(name);
self.type_encoding_maps().id_to_index.insert(id, index);
return Ok(ComponentValType::Type(index));
}
TypeDefKind::Handle(Handle::Own(id)) => {
let ty = match self.encode_valtype(resolve, &Type::Id(*id))? {
ComponentValType::Type(index) => index,
_ => panic!("must be an indexed type"),
};
let (index, encoder) = self.defined_type();
encoder.own(ty);
ComponentValType::Type(index)
}
TypeDefKind::Handle(Handle::Borrow(id)) => {
let ty = match self.encode_valtype(resolve, &Type::Id(*id))? {
ComponentValType::Type(index) => index,
_ => panic!("must be an indexed type"),
};
let (index, encoder) = self.defined_type();
encoder.borrow(ty);
ComponentValType::Type(index)
}
};
if let Some(name) = &ty.name {
let index = match encoded {
ComponentValType::Type(index) => index,
ComponentValType::Primitive(ty) => {
let (index, encoder) = self.defined_type();
encoder.primitive(ty);
index
}
};
let index = self.export_type(index, name).unwrap_or(index);
encoded = ComponentValType::Type(index);
}
if let ComponentValType::Type(index) = encoded {
self.type_encoding_maps().insert(resolve, id, index);
}
encoded
}
})
}
fn maybe_import_type(&mut self, resolve: &Resolve, id: TypeId) -> Option<u32> {
let ty = &resolve.types[id];
let owner = match ty.owner {
TypeOwner::Interface(i) => i,
_ => return None,
};
if Some(owner) == self.interface() {
return None;
}
Some(self.import_type(owner, id))
}
fn encode_optional_valtype(
&mut self,
resolve: &'a Resolve,
ty: Option<&Type>,
) -> Result<Option<ComponentValType>> {
match ty {
Some(ty) => self.encode_valtype(resolve, ty).map(Some),
None => Ok(None),
}
}
fn encode_record(&mut self, resolve: &'a Resolve, record: &Record) -> Result<ComponentValType> {
let fields = record
.fields
.iter()
.map(|f| Ok((f.name.as_str(), self.encode_valtype(resolve, &f.ty)?)))
.collect::<Result<Vec<_>>>()?;
let (index, encoder) = self.defined_type();
encoder.record(fields);
Ok(ComponentValType::Type(index))
}
fn encode_tuple(&mut self, resolve: &'a Resolve, tuple: &Tuple) -> Result<ComponentValType> {
let tys = tuple
.types
.iter()
.map(|ty| self.encode_valtype(resolve, ty))
.collect::<Result<Vec<_>>>()?;
let (index, encoder) = self.defined_type();
encoder.tuple(tys);
Ok(ComponentValType::Type(index))
}
fn encode_flags(&mut self, flags: &Flags) -> Result<ComponentValType> {
let (index, encoder) = self.defined_type();
encoder.flags(flags.flags.iter().map(|f| f.name.as_str()));
Ok(ComponentValType::Type(index))
}
fn encode_variant(
&mut self,
resolve: &'a Resolve,
variant: &Variant,
) -> Result<ComponentValType> {
let cases = variant
.cases
.iter()
.map(|c| {
Ok((
c.name.as_str(),
self.encode_optional_valtype(resolve, c.ty.as_ref())?,
))
})
.collect::<Result<Vec<_>>>()?;
let (index, encoder) = self.defined_type();
encoder.variant(cases);
Ok(ComponentValType::Type(index))
}
fn encode_option(&mut self, resolve: &'a Resolve, payload: &Type) -> Result<ComponentValType> {
let ty = self.encode_valtype(resolve, payload)?;
let (index, encoder) = self.defined_type();
encoder.option(ty);
Ok(ComponentValType::Type(index))
}
fn encode_result(
&mut self,
resolve: &'a Resolve,
result: &Result_,
) -> Result<ComponentValType> {
let ok = self.encode_optional_valtype(resolve, result.ok.as_ref())?;
let error = self.encode_optional_valtype(resolve, result.err.as_ref())?;
let (index, encoder) = self.defined_type();
encoder.result(ok, error);
Ok(ComponentValType::Type(index))
}
fn encode_enum(&mut self, enum_: &Enum) -> Result<ComponentValType> {
let (index, encoder) = self.defined_type();
encoder.enum_type(enum_.cases.iter().map(|c| c.name.as_str()));
Ok(ComponentValType::Type(index))
}
fn encode_future(
&mut self,
resolve: &'a Resolve,
payload: &Option<Type>,
) -> Result<ComponentValType> {
let ty = self.encode_optional_valtype(resolve, payload.as_ref())?;
let (index, encoder) = self.defined_type();
encoder.future(ty);
Ok(ComponentValType::Type(index))
}
fn encode_stream(
&mut self,
resolve: &'a Resolve,
payload: &Option<Type>,
) -> Result<ComponentValType> {
let ty = self.encode_optional_valtype(resolve, payload.as_ref())?;
let (index, encoder) = self.defined_type();
encoder.stream(ty);
Ok(ComponentValType::Type(index))
}
fn encode_unit_future(&mut self) -> u32 {
if let Some(index) = self.type_encoding_maps().unit_future {
return index;
}
let (index, encoder) = self.defined_type();
encoder.future(None);
self.type_encoding_maps().unit_future = Some(index);
index
}
fn encode_unit_stream(&mut self) -> u32 {
if let Some(index) = self.type_encoding_maps().unit_stream {
return index;
}
let (index, encoder) = self.defined_type();
encoder.stream(None);
self.type_encoding_maps().unit_stream = Some(index);
index
}
}
pub struct RootTypeEncoder<'state, 'a> {
pub state: &'state mut EncodingState<'a>,
pub interface: Option<InterfaceId>,
pub import_types: bool,
}
impl<'a> ValtypeEncoder<'a> for RootTypeEncoder<'_, 'a> {
fn defined_type(&mut self) -> (u32, ComponentDefinedTypeEncoder<'_>) {
self.state.component.type_defined(None)
}
fn define_function_type(&mut self) -> (u32, ComponentFuncTypeEncoder<'_>) {
self.state.component.type_function(None)
}
fn interface(&self) -> Option<InterfaceId> {
self.interface
}
fn export_type(&mut self, idx: u32, name: &'a str) -> Option<u32> {
if self.interface.is_none() {
Some(if self.import_types {
self.state
.component
.import(name, ComponentTypeRef::Type(TypeBounds::Eq(idx)))
} else {
self.state
.component
.export(name, ComponentExportKind::Type, idx, None)
})
} else {
assert!(!self.import_types);
None
}
}
fn export_resource(&mut self, name: &'a str) -> u32 {
assert!(self.interface.is_none());
assert!(self.import_types);
self.state
.component
.import(name, ComponentTypeRef::Type(TypeBounds::SubResource))
}
fn import_type(&mut self, interface: InterfaceId, id: TypeId) -> u32 {
if !self.import_types {
if let Some(cur) = self.interface {
let set = &self.state.info.exports_used[&cur];
if set.contains(&interface) {
return self.state.alias_exported_type(interface, id);
}
}
}
self.state.alias_imported_type(interface, id)
}
fn type_encoding_maps(&mut self) -> &mut TypeEncodingMaps<'a> {
if self.import_types {
&mut self.state.import_type_encoding_maps
} else {
&mut self.state.export_type_encoding_maps
}
}
}
pub struct InstanceTypeEncoder<'state, 'a> {
pub state: &'state mut EncodingState<'a>,
pub interface: InterfaceId,
pub type_encoding_maps: TypeEncodingMaps<'a>,
pub ty: InstanceType,
}
impl<'a> ValtypeEncoder<'a> for InstanceTypeEncoder<'_, 'a> {
fn defined_type(&mut self) -> (u32, ComponentDefinedTypeEncoder<'_>) {
(self.ty.type_count(), self.ty.ty().defined_type())
}
fn define_function_type(&mut self) -> (u32, ComponentFuncTypeEncoder<'_>) {
(self.ty.type_count(), self.ty.ty().function())
}
fn export_type(&mut self, idx: u32, name: &str) -> Option<u32> {
let ret = self.ty.type_count();
self.ty
.export(name, ComponentTypeRef::Type(TypeBounds::Eq(idx)));
Some(ret)
}
fn export_resource(&mut self, name: &str) -> u32 {
let ret = self.ty.type_count();
self.ty
.export(name, ComponentTypeRef::Type(TypeBounds::SubResource));
ret
}
fn type_encoding_maps(&mut self) -> &mut TypeEncodingMaps<'a> {
&mut self.type_encoding_maps
}
fn interface(&self) -> Option<InterfaceId> {
Some(self.interface)
}
fn import_type(&mut self, interface: InterfaceId, id: TypeId) -> u32 {
self.ty.alias(Alias::Outer {
count: 1,
index: self.state.alias_imported_type(interface, id),
kind: ComponentOuterAliasKind::Type,
});
self.ty.type_count() - 1
}
}