use std::io;
use std::path::Path;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use anyhow::Result;
use parking_lot::Mutex;
use surrealism_types::err::PrefixErr;
use tokio::io::AsyncWrite;
use wasmtime::component::ResourceTable;
use wasmtime_wasi::cli::{IsTerminal, StdoutStream};
use wasmtime_wasi::sockets::SocketAddrUse;
use wasmtime_wasi::{DirPerms, FilePerms, WasiCtx, WasiCtxBuilder};
use crate::net_allow::ResolvedNetAllow;
pub type StdioCallback = Arc<Mutex<Arc<dyn Fn(&str) + Send + Sync>>>;
pub fn new_stdout_callback() -> StdioCallback {
Arc::new(Mutex::new(Arc::new(|output| print!("{}", output))))
}
pub fn new_stderr_callback() -> StdioCallback {
Arc::new(Mutex::new(Arc::new(|output| eprint!("{}", output))))
}
struct CallbackStdoutStream {
callback: StdioCallback,
}
impl IsTerminal for CallbackStdoutStream {
fn is_terminal(&self) -> bool {
false
}
}
impl StdoutStream for CallbackStdoutStream {
fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
Box::new(CallbackWriter {
callback: Arc::clone(&self.callback),
buffer: Vec::new(),
})
}
}
struct CallbackWriter {
callback: StdioCallback,
buffer: Vec<u8>,
}
impl CallbackWriter {
fn snapshot_callback(&self) -> Arc<dyn Fn(&str) + Send + Sync> {
Arc::clone(&self.callback.lock())
}
fn emit_lines(&mut self) {
let cb = self.snapshot_callback();
while let Some(pos) = self.buffer.iter().position(|&b| b == b'\n') {
let line = String::from_utf8_lossy(&self.buffer[..pos]);
cb(&line);
self.buffer.drain(..=pos);
}
}
fn flush_remaining(&mut self) {
if !self.buffer.is_empty() {
let cb = self.snapshot_callback();
let remaining = String::from_utf8_lossy(&self.buffer);
cb(&remaining);
self.buffer.clear();
}
}
}
impl AsyncWrite for CallbackWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.buffer.extend_from_slice(buf);
this.emit_lines();
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut().flush_remaining();
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut().flush_remaining();
Poll::Ready(Ok(()))
}
}
pub fn build(
fs_root: Option<&Path>,
allow_net: Arc<Vec<ResolvedNetAllow>>,
stdout_cb: StdioCallback,
stderr_cb: StdioCallback,
) -> Result<(WasiCtx, ResourceTable)> {
let mut builder = WasiCtxBuilder::new();
builder.stdout(CallbackStdoutStream {
callback: stdout_cb,
});
builder.stderr(CallbackStdoutStream {
callback: stderr_cb,
});
if allow_net.is_empty() {
builder.allow_tcp(false);
builder.allow_udp(false);
builder.allow_ip_name_lookup(false);
} else {
builder.allow_ip_name_lookup(false);
let filters = allow_net;
builder.socket_addr_check(move |addr, reason| {
let is_outbound = matches!(
reason,
SocketAddrUse::TcpConnect
| SocketAddrUse::UdpConnect
| SocketAddrUse::UdpOutgoingDatagram
);
let allowed = is_outbound && filters.iter().any(|f| f.matches_socket_addr(&addr));
Box::pin(async move { allowed })
});
}
if let Some(root) = fs_root {
builder
.preopened_dir(root, "/", DirPerms::READ, FilePerms::READ)
.prefix_err(|| "Failed to preopen filesystem directory")?;
}
Ok((builder.build(), ResourceTable::new()))
}