astro_run/
stream.rs

1use super::{Log, RunResult};
2use parking_lot::Mutex;
3use std::{sync::Arc, task::Waker};
4use tokio_stream::Stream;
5
6struct SharedState {
7  logs: Vec<Log>,
8  result: Option<RunResult>,
9  waker: Option<Waker>,
10}
11
12pub struct StreamReceiver {
13  current_index: Mutex<usize>,
14  state: Arc<Mutex<SharedState>>,
15}
16
17impl StreamReceiver {
18  fn new(state: Arc<Mutex<SharedState>>) -> Self {
19    Self {
20      current_index: Mutex::new(0),
21      state,
22    }
23  }
24
25  pub fn result(&self) -> Option<RunResult> {
26    self.state.lock().result.clone()
27  }
28}
29
30impl Stream for StreamReceiver {
31  type Item = Log;
32
33  fn poll_next(
34    self: std::pin::Pin<&mut Self>,
35    cx: &mut std::task::Context<'_>,
36  ) -> std::task::Poll<Option<Self::Item>> {
37    let mut state = self.state.lock();
38    state.waker = Some(cx.waker().clone());
39
40    let logs = state.logs.clone();
41    let total = logs.len();
42    let current_index = *self.current_index.lock();
43
44    if current_index < total {
45      let log = logs[current_index].clone();
46      *self.current_index.lock() += 1;
47
48      cx.waker().wake_by_ref();
49
50      return std::task::Poll::Ready(Some(log));
51    }
52
53    if state.result.is_some() {
54      return std::task::Poll::Ready(None);
55    }
56
57    std::task::Poll::Pending
58  }
59}
60
61#[derive(Clone)]
62pub struct StreamSender {
63  state: Arc<Mutex<SharedState>>,
64}
65
66impl StreamSender {
67  fn new(state: Arc<Mutex<SharedState>>) -> Self {
68    Self { state }
69  }
70
71  pub fn log(&self, message: impl Into<String>) {
72    let mut state = self.state.lock();
73    state.logs.push(Log::log(message.into()));
74
75    if let Some(waker) = state.waker.take() {
76      waker.wake();
77    }
78  }
79
80  pub fn error(&self, message: impl Into<String>) {
81    let mut state = self.state.lock();
82    state.logs.push(Log::error(message.into()));
83
84    if let Some(waker) = state.waker.take() {
85      waker.wake();
86    }
87  }
88
89  pub fn succeeded(&self) {
90    self.end(RunResult::Succeeded)
91  }
92
93  pub fn cancelled(&self) {
94    self.end(RunResult::Cancelled)
95  }
96
97  pub fn failed(&self, exit_code: i32) {
98    self.end(RunResult::Failed { exit_code })
99  }
100
101  pub fn timeout(&self) {
102    // TODO: use a different exit code
103    self.end(RunResult::Failed { exit_code: 123 })
104  }
105
106  pub fn end(&self, result: RunResult) {
107    if self.is_ended() {
108      log::trace!("StreamSender: already ended");
109      return;
110    }
111
112    let mut state = self.state.lock();
113    state.result = Some(result);
114
115    if let Some(waker) = state.waker.take() {
116      waker.wake();
117    }
118  }
119
120  pub fn is_ended(&self) -> bool {
121    self.state.lock().result.is_some()
122  }
123}
124
125pub fn stream() -> (StreamSender, StreamReceiver) {
126  let state = Arc::new(Mutex::new(SharedState {
127    logs: Vec::new(),
128    waker: None,
129    result: None,
130  }));
131
132  let sender = StreamSender::new(state.clone());
133  let receiver = StreamReceiver::new(state);
134
135  (sender, receiver)
136}
137
138#[cfg(test)]
139mod tests {
140  use super::*;
141  use tokio_stream::StreamExt;
142
143  #[tokio::test]
144  async fn test_stream() {
145    let (sender, mut receiver) = stream();
146
147    sender.log("test");
148    sender.error("error");
149    sender.succeeded();
150
151    let mut logs = Vec::new();
152    while let Some(log) = receiver.next().await {
153      logs.push(log);
154    }
155
156    assert_eq!(logs, vec![Log::log("test"), Log::error("error"),]);
157    assert_eq!(receiver.result().unwrap(), RunResult::Succeeded);
158  }
159
160  #[tokio::test]
161  async fn test_stream_twice() {
162    let (sender, receiver) = stream();
163
164    sender.succeeded();
165    sender.cancelled();
166    assert_eq!(receiver.result().unwrap(), RunResult::Succeeded);
167  }
168}