use crate::component::concurrent::futures_and_streams::{self, TransmitOrigin};
use crate::component::concurrent::{TableId, TransmitHandle};
use crate::component::func::{LiftContext, LowerContext, bad_type_info, desc};
use crate::component::matching::InstanceType;
use crate::component::types::{self, FutureType, StreamType};
use crate::component::{
ComponentInstanceId, ComponentType, FutureReader, Lift, Lower, StreamReader,
};
use crate::store::StoreOpaque;
use crate::{AsContextMut, Result, bail, error::Context};
use std::any::TypeId;
use std::mem::MaybeUninit;
use wasmtime_environ::component::{
CanonicalAbiInfo, InterfaceType, TypeFutureTableIndex, TypeStreamTableIndex,
};
#[derive(Debug, Clone, PartialEq)]
pub struct FutureAny {
id: TableId<TransmitHandle>,
ty: PayloadType<FutureType>,
}
impl FutureAny {
fn lower_to_index<T>(&self, cx: &mut LowerContext<'_, T>, ty: InterfaceType) -> Result<u32> {
let future_ty = match ty {
InterfaceType::Future(payload) => payload,
_ => bad_type_info(),
};
let payload = cx.types[cx.types[future_ty].ty].payload.as_ref();
self.ty.typecheck_guest(
&cx.instance_type(),
payload,
FutureType::equivalent_payload_guest,
)?;
futures_and_streams::lower_future_to_index(self.id, cx, ty)
}
pub fn try_into_future_reader<T>(self) -> Result<FutureReader<T>>
where
T: ComponentType + 'static,
{
self.ty
.typecheck_host::<T>(FutureType::equivalent_payload_host::<T>)?;
Ok(FutureReader::new_(self.id))
}
pub fn try_from_future_reader<T>(
mut store: impl AsContextMut,
reader: FutureReader<T>,
) -> Result<Self>
where
T: ComponentType + 'static,
{
let store = store.as_context_mut();
let ty = match store.0.transmit_origin(reader.id())? {
TransmitOrigin::Host => PayloadType::new_host::<T>(),
TransmitOrigin::GuestFuture(id, ty) => PayloadType::new_guest_future(store.0, id, ty),
TransmitOrigin::GuestStream(..) => bail!("not a future"),
};
Ok(FutureAny {
id: reader.id(),
ty,
})
}
fn lift_from_index(cx: &mut LiftContext<'_>, ty: InterfaceType, index: u32) -> Result<Self> {
let id = futures_and_streams::lift_index_to_future(cx, ty, index)?;
let InterfaceType::Future(ty) = ty else {
unreachable!()
};
let ty = cx.types[ty].ty;
Ok(FutureAny {
id,
ty: PayloadType::Guest(FutureType::from(ty, &cx.instance_type())),
})
}
pub fn close(&mut self, mut store: impl AsContextMut) {
futures_and_streams::future_close(store.as_context_mut().0, &mut self.id)
}
}
unsafe impl ComponentType for FutureAny {
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::SCALAR4;
type Lower = <u32 as ComponentType>::Lower;
fn typecheck(ty: &InterfaceType, _types: &InstanceType<'_>) -> Result<()> {
match ty {
InterfaceType::Future(_) => Ok(()),
other => bail!("expected `future`, found `{}`", desc(other)),
}
}
}
unsafe impl Lower for FutureAny {
fn linear_lower_to_flat<T>(
&self,
cx: &mut LowerContext<'_, T>,
ty: InterfaceType,
dst: &mut MaybeUninit<Self::Lower>,
) -> Result<()> {
self.lower_to_index(cx, ty)?
.linear_lower_to_flat(cx, InterfaceType::U32, dst)
}
fn linear_lower_to_memory<T>(
&self,
cx: &mut LowerContext<'_, T>,
ty: InterfaceType,
offset: usize,
) -> Result<()> {
self.lower_to_index(cx, ty)?
.linear_lower_to_memory(cx, InterfaceType::U32, offset)
}
}
unsafe impl Lift for FutureAny {
fn linear_lift_from_flat(
cx: &mut LiftContext<'_>,
ty: InterfaceType,
src: &Self::Lower,
) -> Result<Self> {
let index = u32::linear_lift_from_flat(cx, InterfaceType::U32, src)?;
Self::lift_from_index(cx, ty, index)
}
fn linear_lift_from_memory(
cx: &mut LiftContext<'_>,
ty: InterfaceType,
bytes: &[u8],
) -> Result<Self> {
let index = u32::linear_lift_from_memory(cx, InterfaceType::U32, bytes)?;
Self::lift_from_index(cx, ty, index)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct StreamAny {
id: TableId<TransmitHandle>,
ty: PayloadType<StreamType>,
}
impl StreamAny {
fn lower_to_index<T>(&self, cx: &mut LowerContext<'_, T>, ty: InterfaceType) -> Result<u32> {
let stream_ty = match ty {
InterfaceType::Stream(payload) => payload,
_ => bad_type_info(),
};
let payload = cx.types[cx.types[stream_ty].ty].payload.as_ref();
self.ty.typecheck_guest(
&cx.instance_type(),
payload,
StreamType::equivalent_payload_guest,
)?;
futures_and_streams::lower_stream_to_index(self.id, cx, ty)
}
pub fn try_into_stream_reader<T>(self) -> Result<StreamReader<T>>
where
T: ComponentType + 'static,
{
self.ty
.typecheck_host::<T>(StreamType::equivalent_payload_host::<T>)?;
Ok(StreamReader::new_(self.id))
}
pub fn try_from_stream_reader<T>(
mut store: impl AsContextMut,
reader: StreamReader<T>,
) -> Result<Self>
where
T: ComponentType + 'static,
{
let store = store.as_context_mut();
let ty = match store.0.transmit_origin(reader.id())? {
TransmitOrigin::Host => PayloadType::new_host::<T>(),
TransmitOrigin::GuestStream(id, ty) => PayloadType::new_guest_stream(store.0, id, ty),
TransmitOrigin::GuestFuture(..) => bail!("not a stream"),
};
Ok(StreamAny {
id: reader.id(),
ty,
})
}
fn lift_from_index(cx: &mut LiftContext<'_>, ty: InterfaceType, index: u32) -> Result<Self> {
let id = futures_and_streams::lift_index_to_stream(cx, ty, index)?;
let InterfaceType::Stream(ty) = ty else {
unreachable!()
};
let ty = cx.types[ty].ty;
Ok(StreamAny {
id,
ty: PayloadType::Guest(StreamType::from(ty, &cx.instance_type())),
})
}
pub fn close(&mut self, mut store: impl AsContextMut) {
futures_and_streams::future_close(store.as_context_mut().0, &mut self.id)
}
}
unsafe impl ComponentType for StreamAny {
const ABI: CanonicalAbiInfo = CanonicalAbiInfo::SCALAR4;
type Lower = <u32 as ComponentType>::Lower;
fn typecheck(ty: &InterfaceType, _types: &InstanceType<'_>) -> Result<()> {
match ty {
InterfaceType::Stream(_) => Ok(()),
other => bail!("expected `stream`, found `{}`", desc(other)),
}
}
}
unsafe impl Lower for StreamAny {
fn linear_lower_to_flat<T>(
&self,
cx: &mut LowerContext<'_, T>,
ty: InterfaceType,
dst: &mut MaybeUninit<Self::Lower>,
) -> Result<()> {
self.lower_to_index(cx, ty)?
.linear_lower_to_flat(cx, InterfaceType::U32, dst)
}
fn linear_lower_to_memory<T>(
&self,
cx: &mut LowerContext<'_, T>,
ty: InterfaceType,
offset: usize,
) -> Result<()> {
self.lower_to_index(cx, ty)?
.linear_lower_to_memory(cx, InterfaceType::U32, offset)
}
}
unsafe impl Lift for StreamAny {
fn linear_lift_from_flat(
cx: &mut LiftContext<'_>,
ty: InterfaceType,
src: &Self::Lower,
) -> Result<Self> {
let index = u32::linear_lift_from_flat(cx, InterfaceType::U32, src)?;
Self::lift_from_index(cx, ty, index)
}
fn linear_lift_from_memory(
cx: &mut LiftContext<'_>,
ty: InterfaceType,
bytes: &[u8],
) -> Result<Self> {
let index = u32::linear_lift_from_memory(cx, InterfaceType::U32, bytes)?;
Self::lift_from_index(cx, ty, index)
}
}
#[derive(Debug, Clone)]
enum PayloadType<T> {
Guest(T),
Host {
id: TypeId,
typecheck: fn(Option<&InterfaceType>, &InstanceType<'_>) -> Result<()>,
},
}
impl<T: PartialEq> PartialEq for PayloadType<T> {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(PayloadType::Guest(a), PayloadType::Guest(b)) => a == b,
(PayloadType::Guest(_), _) => false,
(PayloadType::Host { id: a_id, .. }, PayloadType::Host { id: b_id, .. }) => {
a_id == b_id
}
(PayloadType::Host { .. }, _) => false,
}
}
}
impl PayloadType<FutureType> {
fn new_guest_future(
store: &StoreOpaque,
id: ComponentInstanceId,
ty: TypeFutureTableIndex,
) -> Self {
let types = InstanceType::new(&store.component_instance(id));
let ty = types.types[ty].ty;
PayloadType::Guest(FutureType::from(ty, &types))
}
}
impl PayloadType<StreamType> {
fn new_guest_stream(
store: &StoreOpaque,
id: ComponentInstanceId,
ty: TypeStreamTableIndex,
) -> Self {
let types = InstanceType::new(&store.component_instance(id));
let ty = types.types[ty].ty;
PayloadType::Guest(StreamType::from(ty, &types))
}
}
impl<T> PayloadType<T> {
fn new_host<P>() -> Self
where
P: ComponentType + 'static,
{
PayloadType::Host {
typecheck: types::typecheck_payload::<P>,
id: TypeId::of::<P>(),
}
}
fn typecheck_guest(
&self,
types: &InstanceType<'_>,
payload: Option<&InterfaceType>,
equivalent: fn(&T, &InstanceType<'_>, Option<&InterfaceType>) -> bool,
) -> Result<()> {
match self {
Self::Guest(ty) => {
if equivalent(ty, types, payload) {
Ok(())
} else {
bail!("future payload types differ")
}
}
Self::Host { typecheck, .. } => {
typecheck(payload, types).context("future payload types differ")
}
}
}
fn typecheck_host<P>(&self, equivalent: fn(&T) -> Result<()>) -> Result<()>
where
P: ComponentType + 'static,
{
match self {
Self::Guest(ty) => equivalent(ty),
Self::Host { id, .. } => {
if *id == TypeId::of::<P>() {
Ok(())
} else {
bail!("future payload types differ")
}
}
}
}
}