1use super::{Log, RunResult};
2use parking_lot::Mutex;
3use std::{sync::Arc, task::Waker};
4use tokio_stream::Stream;
5
6struct SharedState {
7 logs: Vec<Log>,
8 result: Option<RunResult>,
9 waker: Option<Waker>,
10}
11
12pub struct StreamReceiver {
13 current_index: Mutex<usize>,
14 state: Arc<Mutex<SharedState>>,
15}
16
17impl StreamReceiver {
18 fn new(state: Arc<Mutex<SharedState>>) -> Self {
19 Self {
20 current_index: Mutex::new(0),
21 state,
22 }
23 }
24
25 pub fn result(&self) -> Option<RunResult> {
26 self.state.lock().result.clone()
27 }
28}
29
30impl Stream for StreamReceiver {
31 type Item = Log;
32
33 fn poll_next(
34 self: std::pin::Pin<&mut Self>,
35 cx: &mut std::task::Context<'_>,
36 ) -> std::task::Poll<Option<Self::Item>> {
37 let mut state = self.state.lock();
38 state.waker = Some(cx.waker().clone());
39
40 let logs = state.logs.clone();
41 let total = logs.len();
42 let current_index = *self.current_index.lock();
43
44 if current_index < total {
45 let log = logs[current_index].clone();
46 *self.current_index.lock() += 1;
47
48 cx.waker().wake_by_ref();
49
50 return std::task::Poll::Ready(Some(log));
51 }
52
53 if state.result.is_some() {
54 return std::task::Poll::Ready(None);
55 }
56
57 std::task::Poll::Pending
58 }
59}
60
61#[derive(Clone)]
62pub struct StreamSender {
63 state: Arc<Mutex<SharedState>>,
64}
65
66impl StreamSender {
67 fn new(state: Arc<Mutex<SharedState>>) -> Self {
68 Self { state }
69 }
70
71 pub fn log(&self, message: impl Into<String>) {
72 let mut state = self.state.lock();
73 state.logs.push(Log::log(message.into()));
74
75 if let Some(waker) = state.waker.take() {
76 waker.wake();
77 }
78 }
79
80 pub fn error(&self, message: impl Into<String>) {
81 let mut state = self.state.lock();
82 state.logs.push(Log::error(message.into()));
83
84 if let Some(waker) = state.waker.take() {
85 waker.wake();
86 }
87 }
88
89 pub fn succeeded(&self) {
90 self.end(RunResult::Succeeded)
91 }
92
93 pub fn cancelled(&self) {
94 self.end(RunResult::Cancelled)
95 }
96
97 pub fn failed(&self, exit_code: i32) {
98 self.end(RunResult::Failed { exit_code })
99 }
100
101 pub fn timeout(&self) {
102 self.end(RunResult::Failed { exit_code: 123 })
104 }
105
106 pub fn end(&self, result: RunResult) {
107 if self.is_ended() {
108 log::trace!("StreamSender: already ended");
109 return;
110 }
111
112 let mut state = self.state.lock();
113 state.result = Some(result);
114
115 if let Some(waker) = state.waker.take() {
116 waker.wake();
117 }
118 }
119
120 pub fn is_ended(&self) -> bool {
121 self.state.lock().result.is_some()
122 }
123}
124
125pub fn stream() -> (StreamSender, StreamReceiver) {
126 let state = Arc::new(Mutex::new(SharedState {
127 logs: Vec::new(),
128 waker: None,
129 result: None,
130 }));
131
132 let sender = StreamSender::new(state.clone());
133 let receiver = StreamReceiver::new(state);
134
135 (sender, receiver)
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use tokio_stream::StreamExt;
142
143 #[tokio::test]
144 async fn test_stream() {
145 let (sender, mut receiver) = stream();
146
147 sender.log("test");
148 sender.error("error");
149 sender.succeeded();
150
151 let mut logs = Vec::new();
152 while let Some(log) = receiver.next().await {
153 logs.push(log);
154 }
155
156 assert_eq!(logs, vec![Log::log("test"), Log::error("error"),]);
157 assert_eq!(receiver.result().unwrap(), RunResult::Succeeded);
158 }
159
160 #[tokio::test]
161 async fn test_stream_twice() {
162 let (sender, receiver) = stream();
163
164 sender.succeeded();
165 sender.cancelled();
166 assert_eq!(receiver.result().unwrap(), RunResult::Succeeded);
167 }
168}