use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use wasmtime::Store;
use wasmtime::component::Component;
use super::async_bindings::Plugin as AsyncPlugin;
use super::async_bindings::PluginPre as AsyncPluginPre;
use super::async_bindings::orts::plugin::types as wit;
use super::async_host_state::{AsyncHostState, GuestResponse};
use super::async_runtime::AsyncRuntime;
use super::convert::r#async as convert;
use super::engine::WasmEngine;
use crate::plugin::controller::PluginController;
use crate::plugin::tick_input::TickInput;
use crate::plugin::{Command, PluginError};
pub struct AsyncPluginPreBuilt {
engine: Arc<WasmEngine>,
runtime: Arc<AsyncRuntime>,
pre: AsyncPluginPre<AsyncHostState>,
component: Component,
}
impl AsyncPluginPreBuilt {
pub fn new(
engine: &Arc<WasmEngine>,
runtime: &Arc<AsyncRuntime>,
component: &Component,
) -> Result<Self, PluginError> {
let mut linker = wasmtime::component::Linker::new(engine.inner());
wasmtime_wasi::p2::add_to_linker_async(&mut linker)
.map_err(|e| PluginError::Init(format!("WASI add_to_linker_async failed: {e}")))?;
AsyncPlugin::add_to_linker::<AsyncHostState, AsyncHostState>(&mut linker, |state| state)
.map_err(|e| PluginError::Init(format!("async add_to_linker failed: {e}")))?;
let instance_pre = linker
.instantiate_pre(component)
.map_err(|e| PluginError::Init(format!("async instantiate_pre failed: {e}")))?;
let pre = AsyncPluginPre::new(instance_pre)
.map_err(|e| PluginError::Init(format!("AsyncPluginPre::new failed: {e}")))?;
Ok(Self {
engine: Arc::clone(engine),
runtime: Arc::clone(runtime),
pre,
component: component.clone(),
})
}
pub fn engine(&self) -> &Arc<WasmEngine> {
&self.engine
}
pub fn runtime(&self) -> &Arc<AsyncRuntime> {
&self.runtime
}
pub(super) fn pre(&self) -> &AsyncPluginPre<AsyncHostState> {
&self.pre
}
pub(super) fn component(&self) -> &Component {
&self.component
}
}
pub struct AsyncWasmController {
runtime: Arc<AsyncRuntime>,
input_tx: mpsc::Sender<Option<wit::TickInput>>,
output_rx: mpsc::Receiver<GuestResponse>,
sample_period_s: f64,
name: String,
}
impl AsyncWasmController {
pub fn new(
built: &AsyncPluginPreBuilt,
label: impl Into<String>,
config: &str,
) -> Result<Self, PluginError> {
let label = label.into();
let config = config.to_string();
let (input_tx, input_rx) = mpsc::channel::<Option<wit::TickInput>>(1);
let (output_tx, output_rx) = mpsc::channel::<GuestResponse>(1);
let (meta_tx, meta_rx) = oneshot::channel::<Result<f64, String>>();
let engine = Arc::clone(built.engine());
let runtime = Arc::clone(built.runtime());
let component = built.component().clone();
let pre = built.pre().clone();
let label_for_task = label.clone();
runtime.handle().spawn(async move {
let host_state = AsyncHostState {
label: label_for_task,
field: tobari::magnetic::TiltedDipole::earth(),
wasi: wasmtime_wasi::WasiCtxBuilder::new().build(),
table: wasmtime_wasi::ResourceTable::new(),
input_rx,
output_tx: output_tx.clone(),
pending_cmd: None,
is_first_wait: true,
};
let mut store = Store::new(engine.inner(), host_state);
let plugin = match pre.instantiate_async(&mut store).await {
Ok(p) => p,
Err(e) => {
let _ = meta_tx.send(Err(format!("instantiate_async: {e}")));
return;
}
};
let _ = &component;
let metadata = match plugin.call_metadata(&mut store, &config).await {
Ok(Ok(md)) => md,
Ok(Err(guest_err)) => {
let _ = meta_tx.send(Err(format!("metadata: {guest_err}")));
return;
}
Err(trap) => {
let _ = meta_tx.send(Err(format!("metadata call: {trap}")));
return;
}
};
if !metadata.sample_period_s.is_finite() || metadata.sample_period_s <= 0.0 {
let _ = meta_tx.send(Err(format!(
"guest returned invalid sample_period: {}",
metadata.sample_period_s
)));
return;
}
let _ = meta_tx.send(Ok(metadata.sample_period_s));
let run_result = plugin.call_run(&mut store, &config).await;
let done = match run_result {
Ok(Ok(())) => Ok(()),
Ok(Err(guest_err)) => Err(guest_err),
Err(trap) => Err(format!("trap: {trap}")),
};
let _ = output_tx.send(GuestResponse::Done(done)).await;
});
let sample_period_s = runtime.handle().block_on(async move {
meta_rx
.await
.map_err(|_| PluginError::Init("async task dropped before metadata".to_string()))?
.map_err(|e| PluginError::Init(format!("metadata: {e}")))
})?;
Ok(Self {
runtime,
input_tx,
output_rx,
sample_period_s,
name: format!("wasm-async:{label}"),
})
}
}
impl PluginController for AsyncWasmController {
fn name(&self) -> &str {
&self.name
}
fn sample_period(&self) -> f64 {
self.sample_period_s
}
fn update(&mut self, obs: &TickInput<'_>) -> Result<Option<Command>, PluginError> {
let wit_obs = convert::tick_input_to_wit(obs);
let input_tx = self.input_tx.clone();
let output_rx = &mut self.output_rx;
self.runtime.handle().block_on(async move {
input_tx
.send(Some(wit_obs))
.await
.map_err(|_| PluginError::Runtime("async task dropped".to_string()))?;
match output_rx
.recv()
.await
.ok_or_else(|| PluginError::Runtime("async task channel closed".to_string()))?
{
GuestResponse::Command(Some(wit_cmd)) => {
convert::command_from_wit(wit_cmd).map(Some)
}
GuestResponse::Command(None) => Ok(None),
GuestResponse::Done(Ok(())) => Err(PluginError::Runtime(
"guest run() returned early".to_string(),
)),
GuestResponse::Done(Err(e)) => {
Err(PluginError::Runtime(format!("guest error: {e}")))
}
}
})
}
}
impl Drop for AsyncWasmController {
fn drop(&mut self) {
let input_tx = self.input_tx.clone();
let _ = self
.runtime
.handle()
.block_on(async move { input_tx.send(None).await });
}
}