use std::{
borrow::Cow,
future::Future,
sync::Arc,
task::{Context, Poll},
};
use bytes::Bytes;
use futures::task::noop_waker_ref;
use parking_lot::Mutex;
use smallvec::SmallVec;
use tokio::io::AsyncWrite;
use wasmtime_wasi::{
cli::{IsTerminal, StdoutStream},
p2::{OutputStream, Pollable, StreamError, StreamResult},
};
use crate::host::{LogContext, LogLevel, OutputSink};
pub struct TraceOutput {
level: LogLevel,
context: LogContext<'static>,
sink_store: LogSinkStore,
}
impl TraceOutput {
pub const fn new(
level: LogLevel,
context: LogContext<'static>,
sink_store: LogSinkStore,
) -> Self {
Self {
level,
context,
sink_store,
}
}
}
impl StdoutStream for TraceOutput {
fn async_stream(&self) -> Box<dyn AsyncWrite + Send + Sync> {
Box::new(tokio::io::sink())
}
fn p2_stream(&self) -> Box<dyn OutputStream> {
Box::new(TraceOutputStream {
level: self.level,
context: self.context,
sink_store: Arc::clone(&self.sink_store),
buffer: SmallVec::new(),
in_flight: None,
last_error: None,
})
}
}
impl IsTerminal for TraceOutput {
fn is_terminal(&self) -> bool {
false
}
}
pub struct TraceOutputStream {
level: LogLevel,
context: LogContext<'static>,
sink_store: LogSinkStore,
buffer: SmallVec<[u8; MAX_BUFFER + MAX_UTF8_BYTES]>,
in_flight: Option<wasmtime_wasi::runtime::AbortOnDropJoinHandle<wasmtime::Result<()>>>,
last_error: Option<wasmtime::Error>,
}
const MIN_BUFFER: usize = 64;
const MAX_BUFFER: usize = 1024;
const MAX_UTF8_BYTES: usize = 4;
impl TraceOutputStream {
fn record(&mut self, s: &str) -> StreamResult<()> {
if self.in_flight.is_some() {
return Err(StreamError::Trap(wasmtime::Error::msg(
"write not permitted while emit pending",
)));
}
if s.is_empty() {
return Ok(());
}
let Some(sink) = self.sink_store.lock().clone() else {
return Ok(());
};
let level = self.level;
let context = self.context;
let message = s.to_string();
let mut future = Box::pin(async move {
sink.on_log(level, context, &message)
.await
.map_err(wasmtime::Error::from_boxed)
});
let waker = noop_waker_ref();
let mut cx = Context::from_waker(waker);
match future.as_mut().poll(&mut cx) {
Poll::Ready(result) => result.map_err(StreamError::LastOperationFailed),
Poll::Pending => {
self.in_flight = Some(wasmtime_wasi::runtime::spawn(future));
Ok(())
}
}
}
}
#[async_trait::async_trait]
impl Pollable for TraceOutputStream {
async fn ready(&mut self) {
if let Some(task) = self.in_flight.take()
&& let Err(error) = task.await
{
self.last_error = Some(error);
}
}
}
fn decode_utf8(buf: &[u8]) -> (Cow<'_, str>, SmallVec<[u8; MAX_UTF8_BYTES]>) {
match std::str::from_utf8(buf) {
Ok(s) => (s.into(), SmallVec::new_const()),
Err(error) => {
if buf.len() - error.valid_up_to() > MAX_UTF8_BYTES {
(String::from_utf8_lossy(buf), SmallVec::new_const())
} else {
let (valid, rest) = buf.split_at(error.valid_up_to());
(
if valid.is_empty() {
Cow::Borrowed("")
} else {
unsafe { std::str::from_utf8_unchecked(valid) }.into()
},
SmallVec::from_slice(rest),
)
}
}
}
}
impl OutputStream for TraceOutputStream {
fn write(&mut self, bytes: Bytes) -> StreamResult<()> {
if let Some(error) = self.last_error.take() {
return Err(StreamError::LastOperationFailed(error));
}
if self.in_flight.is_some() {
return Err(StreamError::Trap(wasmtime::Error::msg(
"write not permitted while emit pending",
)));
}
if bytes.len() + self.buffer.len() < MIN_BUFFER {
self.buffer.extend(bytes);
return Ok(());
}
let buf: &[u8] = if self.buffer.is_empty() {
&bytes
} else {
self.buffer.extend(bytes);
&self.buffer
};
let (s, remainder) = decode_utf8(buf);
let message = if s.is_empty() {
None
} else {
Some(s.into_owned())
};
if let Some(message) = message {
self.record(&message)?;
}
self.buffer.clear();
if !remainder.is_empty() {
self.buffer.extend(remainder);
}
Ok(())
}
fn flush(&mut self) -> StreamResult<()> {
if let Some(error) = self.last_error.take() {
return Err(StreamError::LastOperationFailed(error));
}
if self.in_flight.is_some() {
return Ok(());
}
if !self.buffer.is_empty() {
let (s, remainder) = decode_utf8(&self.buffer);
let message = if s.is_empty() {
None
} else {
Some(s.into_owned())
};
if let Some(message) = message {
self.record(&message)?;
}
self.buffer.clear();
if !remainder.is_empty() {
self.buffer.extend(remainder);
}
}
Ok(())
}
fn check_write(&mut self) -> StreamResult<usize> {
if let Some(error) = self.last_error.take() {
return Err(StreamError::LastOperationFailed(error));
}
if self.in_flight.is_some() {
return Ok(0);
}
let local_capacity = MAX_BUFFER.saturating_sub(self.buffer.len());
Ok(local_capacity)
}
}
pub type LogSinkStore = Arc<Mutex<Option<Arc<dyn OutputSink>>>>;
#[must_use]
pub fn new_log_sink_store() -> LogSinkStore {
Arc::new(Mutex::new(None))
}
pub fn set_log_sink(store: &LogSinkStore, sink: Option<Arc<dyn OutputSink>>) {
let mut guard = store.lock();
*guard = sink;
}
#[cfg(test)]
mod tests {
use super::*;
fn new_stream() -> TraceOutputStream {
TraceOutputStream {
level: LogLevel::Info,
context: LogContext::Other("test"),
sink_store: new_log_sink_store(),
buffer: SmallVec::new(),
in_flight: None,
last_error: None,
}
}
#[test]
fn small_write_buffers() {
let mut s = new_stream();
s.write(Bytes::from_static(b"hi")).unwrap();
assert_eq!(s.buffer.as_slice(), b"hi");
}
#[test]
fn large_write_flushes_buffer() {
let mut s = new_stream();
let data = Bytes::from(vec![b'a'; MIN_BUFFER + 1]);
s.write(data).unwrap();
assert!(s.buffer.is_empty());
}
#[test]
fn partial_utf8_retained() {
let mut s = new_stream();
let mut data = vec![b'a'; MIN_BUFFER];
data.push(0xC3);
s.write(Bytes::from(data)).unwrap();
assert_eq!(s.buffer.as_slice(), &[0xC3]);
}
#[test]
fn partial_utf8_completed_on_next_write() {
let mut s = new_stream();
let mut data = vec![b'a'; MIN_BUFFER];
data.push(0xC3); s.write(Bytes::from(data)).unwrap();
assert_eq!(s.buffer.as_slice(), &[0xC3]);
let mut data2 = vec![0xBC]; data2.extend(vec![b'b'; MIN_BUFFER]);
s.write(Bytes::from(data2)).unwrap();
assert!(s.buffer.is_empty());
}
#[test]
fn flush_emits_buffered() {
let mut s = new_stream();
s.write(Bytes::from_static(b"hi")).unwrap();
assert!(!s.buffer.is_empty());
s.flush().unwrap();
assert!(s.buffer.is_empty());
}
#[test]
fn flush_noop_when_empty() {
let mut s = new_stream();
s.flush().unwrap();
assert!(s.buffer.is_empty());
}
#[test]
fn check_write_capacity() {
let mut s = new_stream();
assert_eq!(s.check_write().unwrap(), MAX_BUFFER);
s.buffer.extend(vec![b'x'; MAX_BUFFER]);
assert_eq!(s.check_write().unwrap(), 0);
}
#[test]
fn invalid_utf8_uses_lossy() {
let mut s = new_stream();
let data = vec![0xFF; MAX_UTF8_BYTES + MIN_BUFFER + 1];
s.write(Bytes::from(data)).unwrap();
assert!(s.buffer.is_empty());
}
#[test]
fn flush_with_partial_utf8_retains_tail() {
let mut s = new_stream();
s.buffer.extend_from_slice(b"hello");
s.buffer.push(0xE2); s.flush().unwrap();
assert_eq!(s.buffer.as_slice(), &[0xE2]);
}
}