1use chrono::{DateTime, Utc};
3use concepts::{
4 ExecutionId,
5 prefixed_ulid::RunId,
6 storage::{LogEntry, LogInfoAppendRow, LogStreamType},
7};
8use std::{
9 pin::Pin,
10 sync::{
11 Arc,
12 atomic::{AtomicBool, Ordering},
13 },
14 task::{Context, Poll},
15};
16use tokio::{io::AsyncWrite, sync::mpsc};
17use tracing::{debug, instrument};
18use wasmtime_wasi::p2::{StreamError, StreamResult};
19
20#[derive(Clone)]
21pub enum StdOutputConfigWithSender {
22 Stdout,
23 Stderr,
24 Db {
25 sender: mpsc::Sender<LogInfoAppendRow>,
26 forwarding_from: LogStreamType,
27 },
28}
29impl StdOutputConfigWithSender {
30 #[must_use]
31 pub fn new(
32 config: Option<StdOutputConfig>,
33 log_forwarder_sender: &mpsc::Sender<LogInfoAppendRow>,
34 forwarding_from: LogStreamType,
35 ) -> Option<Self> {
36 config.map(|config| match config {
37 StdOutputConfig::Stdout => StdOutputConfigWithSender::Stdout,
38 StdOutputConfig::Stderr => StdOutputConfigWithSender::Stderr,
39 StdOutputConfig::Db => StdOutputConfigWithSender::Db {
40 sender: log_forwarder_sender.clone(),
41 forwarding_from,
42 },
43 })
44 }
45
46 #[must_use]
47 pub fn build(&self, execution_id: &ExecutionId, run_id: RunId) -> StdOutput {
48 match self {
49 StdOutputConfigWithSender::Stdout => StdOutput::Stdout,
50 StdOutputConfigWithSender::Stderr => StdOutput::Stderr,
51 StdOutputConfigWithSender::Db {
52 sender,
53 forwarding_from,
54 } => StdOutput::Db(DbOutput {
55 execution_id: execution_id.clone(),
56 run_id,
57 sender: sender.clone(),
58 forwarding_from: *forwarding_from,
59 }),
60 }
61 }
62}
63
64#[derive(Clone, Copy, Debug)]
65pub enum StdOutputConfig {
66 Stdout,
67 Stderr,
68 Db,
69}
70
71#[derive(Clone)]
72pub enum StdOutput {
73 Stdout,
74 Stderr,
75 Db(DbOutput),
76}
77
78#[derive(Clone)]
79pub struct DbOutput {
80 pub sender: mpsc::Sender<LogInfoAppendRow>,
81 pub execution_id: ExecutionId,
82 pub run_id: RunId,
83 pub forwarding_from: LogStreamType,
84}
85impl DbOutput {
86 #[instrument(skip_all, fields(execution_id = %self.execution_id, run_id = %self.run_id))]
87 fn write(&mut self, buf: &[u8]) {
88 let res = self.sender.try_send(LogInfoAppendRow {
89 execution_id: self.execution_id.clone(),
90 run_id: self.run_id,
91 log_entry: LogEntry::Stream {
92 created_at: Utc::now(),
93 payload: Vec::from(buf),
94 stream_type: self.forwarding_from,
95 },
96 });
97 if res.is_err() {
98 debug!("Dropping stream message");
99 }
100 }
101}
102
103#[derive(Clone)]
104pub struct OutputEvent {
105 pub buf: Vec<u8>,
106 pub created_at: DateTime<Utc>,
107}
108
109#[derive(Clone, Copy, PartialEq, Eq)]
110enum OutOrErr {
111 Stdout,
112 Stderr,
113}
114impl OutOrErr {
115 fn write_all(&self, buf: &[u8]) -> Result<(), std::io::Error> {
116 use std::io::Write;
117
118 match self {
119 OutOrErr::Stdout => std::io::stdout().write_all(buf),
120 OutOrErr::Stderr => std::io::stderr().write_all(buf),
121 }
122 }
123}
124
125#[derive(Clone)]
126pub(crate) struct LogStream {
127 output: StdOutput,
128 state: Arc<LogStreamState>,
129}
130
131struct LogStreamState {
132 prefix: String,
133 needs_prefix_on_next_write: AtomicBool,
134}
135
136impl LogStream {
137 pub(crate) fn new(prefix: String, output: StdOutput) -> LogStream {
138 LogStream {
139 output,
140 state: Arc::new(LogStreamState {
141 prefix,
142 needs_prefix_on_next_write: AtomicBool::new(true),
143 }),
144 }
145 }
146}
147
148impl wasmtime_wasi::cli::StdoutStream for LogStream {
149 fn p2_stream(&self) -> Box<dyn wasmtime_wasi::p2::OutputStream> {
150 Box::new(self.clone())
151 }
152 fn async_stream(&self) -> Box<dyn tokio::io::AsyncWrite + Send + Sync> {
153 Box::new(self.clone())
154 }
155}
156
157impl wasmtime_wasi::cli::IsTerminal for LogStream {
158 fn is_terminal(&self) -> bool {
159 match &self.output {
160 StdOutput::Stdout => std::io::stdout().is_terminal(),
161 StdOutput::Stderr => std::io::stderr().is_terminal(),
162 StdOutput::Db { .. } => false,
163 }
164 }
165}
166
167impl wasmtime_wasi::p2::OutputStream for LogStream {
168 fn write(&mut self, bytes: bytes::Bytes) -> StreamResult<()> {
169 self.write_all(&bytes)
170 .map_err(|e| StreamError::LastOperationFailed(e.into()))?;
171 Ok(())
172 }
173
174 fn flush(&mut self) -> StreamResult<()> {
175 Ok(())
176 }
177
178 fn check_write(&mut self) -> StreamResult<usize> {
179 Ok(1024 * 1024)
180 }
181}
182
183impl LogStream {
184 fn write_all(&mut self, mut bytes: &[u8]) -> std::io::Result<()> {
185 let our_or_err = match &mut self.output {
186 StdOutput::Db(db_output) => {
187 db_output.write(bytes);
188 return Ok(());
189 }
190 StdOutput::Stdout => OutOrErr::Stdout,
191 StdOutput::Stderr => OutOrErr::Stderr,
192 };
193 while !bytes.is_empty() {
195 if self
196 .state
197 .needs_prefix_on_next_write
198 .load(Ordering::Relaxed)
199 {
200 our_or_err.write_all(self.state.prefix.as_bytes())?;
201 self.state
202 .needs_prefix_on_next_write
203 .store(false, Ordering::Relaxed);
204 }
205 if let Some(i) = bytes.iter().position(|b| *b == b'\n') {
206 let (a, b) = bytes.split_at(i + 1);
207 bytes = b;
208 our_or_err.write_all(a)?;
209 self.state
210 .needs_prefix_on_next_write
211 .store(true, Ordering::Relaxed);
212 } else {
213 our_or_err.write_all(bytes)?;
214 break;
215 }
216 }
217 Ok(())
218 }
219}
220
221impl AsyncWrite for LogStream {
222 fn poll_write(
223 mut self: Pin<&mut Self>,
224 _cx: &mut Context<'_>,
225 buf: &[u8],
226 ) -> Poll<std::io::Result<usize>> {
227 Poll::Ready(self.write_all(buf).map(|()| buf.len()))
228 }
229 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
230 Poll::Ready(Ok(()))
231 }
232 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
233 Poll::Ready(Ok(()))
234 }
235}
236
237#[async_trait::async_trait]
238impl wasmtime_wasi::p2::Pollable for LogStream {
239 async fn ready(&mut self) {}
240}