use crate::bindings;
use crate::bindings::core;
use crate::instance::InstanceState;
use crate::model::Command;
use crate::object::IdRepr;
use tokio::sync::oneshot;
use wasmtime::component::Resource;
use wasmtime_wasi::async_trait;
use wasmtime_wasi::p2::{DynPollable, IoView, Pollable, subscribe};
#[derive(Debug)]
pub struct DistributionResult {
receivers: Vec<oneshot::Receiver<(Vec<u32>, Vec<f32>)>>,
results: Vec<(Vec<u32>, Vec<f32>)>,
done: bool,
}
#[async_trait]
impl Pollable for DistributionResult {
async fn ready(&mut self) {
if self.done {
return;
}
for rx in &mut self.receivers {
let result = rx.await.unwrap();
self.results.push(result);
}
self.done = true;
}
}
impl bindings::pie::inferlet::output_text::Host for InstanceState {
async fn get_next_token_distribution(
&mut self,
queue: Resource<core::Queue>,
emb_ids: Vec<IdRepr>,
) -> anyhow::Result<Resource<DistributionResult>> {
let inst_id = self.id();
let q = self.table().get(&queue)?;
let mut receivers = Vec::with_capacity(emb_ids.len());
for emb_id in emb_ids {
let (tx, rx) = oneshot::channel();
receivers.push(rx);
Command::SampleTopK {
inst_id,
stream_id: q.stream_id,
emb_id,
k: 1, handle: tx,
}
.dispatch(q.service_id)?;
}
let dist_result = DistributionResult {
receivers,
results: Vec::new(),
done: false,
};
Ok(self.table().push(dist_result)?)
}
}
impl bindings::pie::inferlet::output_text::HostDistributionResult for InstanceState {
async fn pollable(
&mut self,
this: Resource<DistributionResult>,
) -> anyhow::Result<Resource<DynPollable>> {
subscribe(self.table(), this)
}
async fn get(
&mut self,
this: Resource<DistributionResult>,
) -> anyhow::Result<Option<Vec<(Vec<u32>, Vec<f32>)>>> {
let result = self.table().get_mut(&this)?;
if result.done {
Ok(Some(result.results.clone()))
} else {
Ok(None)
}
}
async fn drop(&mut self, this: Resource<DistributionResult>) -> anyhow::Result<()> {
self.table().delete(this)?;
Ok(())
}
}