Skip to main content

obeli_sk_wasm_workers/
std_output_stream.rs

1// Based on https://github.com/bytecodealliance/wasmtime/blob/v36.0.1/src/commands/serve.rs#L874
2use chrono::{DateTime, Utc};
3use concepts::{
4    ExecutionId,
5    prefixed_ulid::RunId,
6    storage::{LogEntry, LogInfoAppendRow, LogStreamType},
7};
8use std::{
9    pin::Pin,
10    sync::{
11        Arc,
12        atomic::{AtomicBool, Ordering},
13    },
14    task::{Context, Poll},
15};
16use tokio::{io::AsyncWrite, sync::mpsc};
17use tracing::{debug, instrument};
18use wasmtime_wasi::p2::{StreamError, StreamResult};
19
20#[derive(Clone)]
21pub enum StdOutputConfigWithSender {
22    Stdout,
23    Stderr,
24    Db {
25        sender: mpsc::Sender<LogInfoAppendRow>,
26        forwarding_from: LogStreamType,
27    },
28}
29impl StdOutputConfigWithSender {
30    #[must_use]
31    pub fn new(
32        config: Option<StdOutputConfig>,
33        log_forwarder_sender: &mpsc::Sender<LogInfoAppendRow>,
34        forwarding_from: LogStreamType,
35    ) -> Option<Self> {
36        config.map(|config| match config {
37            StdOutputConfig::Stdout => StdOutputConfigWithSender::Stdout,
38            StdOutputConfig::Stderr => StdOutputConfigWithSender::Stderr,
39            StdOutputConfig::Db => StdOutputConfigWithSender::Db {
40                sender: log_forwarder_sender.clone(),
41                forwarding_from,
42            },
43        })
44    }
45
46    #[must_use]
47    pub fn build(&self, execution_id: &ExecutionId, run_id: RunId) -> StdOutput {
48        match self {
49            StdOutputConfigWithSender::Stdout => StdOutput::Stdout,
50            StdOutputConfigWithSender::Stderr => StdOutput::Stderr,
51            StdOutputConfigWithSender::Db {
52                sender,
53                forwarding_from,
54            } => StdOutput::Db(DbOutput {
55                execution_id: execution_id.clone(),
56                run_id,
57                sender: sender.clone(),
58                forwarding_from: *forwarding_from,
59            }),
60        }
61    }
62}
63
64#[derive(Clone, Copy, Debug)]
65pub enum StdOutputConfig {
66    Stdout,
67    Stderr,
68    Db,
69}
70
71#[derive(Clone)]
72pub enum StdOutput {
73    Stdout,
74    Stderr,
75    Db(DbOutput),
76}
77
78#[derive(Clone)]
79pub struct DbOutput {
80    pub sender: mpsc::Sender<LogInfoAppendRow>,
81    pub execution_id: ExecutionId,
82    pub run_id: RunId,
83    pub forwarding_from: LogStreamType,
84}
85impl DbOutput {
86    #[instrument(skip_all, fields(execution_id = %self.execution_id, run_id = %self.run_id))]
87    fn write(&mut self, buf: &[u8]) {
88        let res = self.sender.try_send(LogInfoAppendRow {
89            execution_id: self.execution_id.clone(),
90            run_id: self.run_id,
91            log_entry: LogEntry::Stream {
92                created_at: Utc::now(),
93                payload: Vec::from(buf),
94                stream_type: self.forwarding_from,
95            },
96        });
97        if res.is_err() {
98            debug!("Dropping stream message");
99        }
100    }
101}
102
103#[derive(Clone)]
104pub struct OutputEvent {
105    pub buf: Vec<u8>,
106    pub created_at: DateTime<Utc>,
107}
108
109#[derive(Clone, Copy, PartialEq, Eq)]
110enum OutOrErr {
111    Stdout,
112    Stderr,
113}
114impl OutOrErr {
115    fn write_all(&self, buf: &[u8]) -> Result<(), std::io::Error> {
116        use std::io::Write;
117
118        match self {
119            OutOrErr::Stdout => std::io::stdout().write_all(buf),
120            OutOrErr::Stderr => std::io::stderr().write_all(buf),
121        }
122    }
123}
124
125#[derive(Clone)]
126pub(crate) struct LogStream {
127    output: StdOutput,
128    state: Arc<LogStreamState>,
129}
130
131struct LogStreamState {
132    prefix: String,
133    needs_prefix_on_next_write: AtomicBool,
134}
135
136impl LogStream {
137    pub(crate) fn new(prefix: String, output: StdOutput) -> LogStream {
138        LogStream {
139            output,
140            state: Arc::new(LogStreamState {
141                prefix,
142                needs_prefix_on_next_write: AtomicBool::new(true),
143            }),
144        }
145    }
146}
147
148impl wasmtime_wasi::cli::StdoutStream for LogStream {
149    fn p2_stream(&self) -> Box<dyn wasmtime_wasi::p2::OutputStream> {
150        Box::new(self.clone())
151    }
152    fn async_stream(&self) -> Box<dyn tokio::io::AsyncWrite + Send + Sync> {
153        Box::new(self.clone())
154    }
155}
156
157impl wasmtime_wasi::cli::IsTerminal for LogStream {
158    fn is_terminal(&self) -> bool {
159        match &self.output {
160            StdOutput::Stdout => std::io::stdout().is_terminal(),
161            StdOutput::Stderr => std::io::stderr().is_terminal(),
162            StdOutput::Db { .. } => false,
163        }
164    }
165}
166
167impl wasmtime_wasi::p2::OutputStream for LogStream {
168    fn write(&mut self, bytes: bytes::Bytes) -> StreamResult<()> {
169        self.write_all(&bytes)
170            .map_err(|e| StreamError::LastOperationFailed(e.into()))?;
171        Ok(())
172    }
173
174    fn flush(&mut self) -> StreamResult<()> {
175        Ok(())
176    }
177
178    fn check_write(&mut self) -> StreamResult<usize> {
179        Ok(1024 * 1024)
180    }
181}
182
183impl LogStream {
184    fn write_all(&mut self, mut bytes: &[u8]) -> std::io::Result<()> {
185        let our_or_err = match &mut self.output {
186            StdOutput::Db(db_output) => {
187                db_output.write(bytes);
188                return Ok(());
189            }
190            StdOutput::Stdout => OutOrErr::Stdout,
191            StdOutput::Stderr => OutOrErr::Stderr,
192        };
193        // write with prefix
194        while !bytes.is_empty() {
195            if self
196                .state
197                .needs_prefix_on_next_write
198                .load(Ordering::Relaxed)
199            {
200                our_or_err.write_all(self.state.prefix.as_bytes())?;
201                self.state
202                    .needs_prefix_on_next_write
203                    .store(false, Ordering::Relaxed);
204            }
205            if let Some(i) = bytes.iter().position(|b| *b == b'\n') {
206                let (a, b) = bytes.split_at(i + 1);
207                bytes = b;
208                our_or_err.write_all(a)?;
209                self.state
210                    .needs_prefix_on_next_write
211                    .store(true, Ordering::Relaxed);
212            } else {
213                our_or_err.write_all(bytes)?;
214                break;
215            }
216        }
217        Ok(())
218    }
219}
220
221impl AsyncWrite for LogStream {
222    fn poll_write(
223        mut self: Pin<&mut Self>,
224        _cx: &mut Context<'_>,
225        buf: &[u8],
226    ) -> Poll<std::io::Result<usize>> {
227        Poll::Ready(self.write_all(buf).map(|()| buf.len()))
228    }
229    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
230        Poll::Ready(Ok(()))
231    }
232    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
233        Poll::Ready(Ok(()))
234    }
235}
236
237#[async_trait::async_trait]
238impl wasmtime_wasi::p2::Pollable for LogStream {
239    async fn ready(&mut self) {}
240}