1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
use super::{Log, RunResult};
use parking_lot::Mutex;
use std::{sync::Arc, task::Waker};
use tokio_stream::Stream;

struct SharedState {
  logs: Vec<Log>,
  result: Option<RunResult>,
  waker: Option<Waker>,
}

pub struct StreamReceiver {
  current_index: Mutex<usize>,
  state: Arc<Mutex<SharedState>>,
}

impl StreamReceiver {
  fn new(state: Arc<Mutex<SharedState>>) -> Self {
    Self {
      current_index: Mutex::new(0),
      state,
    }
  }

  pub fn result(&self) -> Option<RunResult> {
    self.state.lock().result.clone()
  }
}

impl Stream for StreamReceiver {
  type Item = Log;

  fn poll_next(
    self: std::pin::Pin<&mut Self>,
    cx: &mut std::task::Context<'_>,
  ) -> std::task::Poll<Option<Self::Item>> {
    let mut state = self.state.lock();
    state.waker = Some(cx.waker().clone());

    let logs = state.logs.clone();
    let total = logs.len();
    let current_index = self.current_index.lock().clone();

    if current_index < total {
      let log = logs[current_index].clone();
      *self.current_index.lock() += 1;

      cx.waker().wake_by_ref();

      return std::task::Poll::Ready(Some(log));
    }

    if state.result.is_some() {
      return std::task::Poll::Ready(None);
    }

    std::task::Poll::Pending
  }
}

#[derive(Clone)]
pub struct StreamSender {
  state: Arc<Mutex<SharedState>>,
}

impl StreamSender {
  fn new(state: Arc<Mutex<SharedState>>) -> Self {
    Self { state }
  }

  pub fn log(&self, message: impl Into<String>) {
    let mut state = self.state.lock();
    state.logs.push(Log::log(message.into()));

    if let Some(waker) = state.waker.take() {
      waker.wake();
    }
  }

  pub fn error(&self, message: impl Into<String>) {
    let mut state = self.state.lock();
    state.logs.push(Log::error(message.into()));

    if let Some(waker) = state.waker.take() {
      waker.wake();
    }
  }

  pub fn end(&self, result: RunResult) {
    let mut state = self.state.lock();
    state.result = Some(result);

    if let Some(waker) = state.waker.take() {
      waker.wake();
    }
  }

  pub fn is_ended(&self) -> bool {
    self.state.lock().result.is_some()
  }
}

pub fn stream() -> (StreamSender, StreamReceiver) {
  let state = Arc::new(Mutex::new(SharedState {
    logs: Vec::new(),
    waker: None,
    result: None,
  }));

  let sender = StreamSender::new(state.clone());
  let receiver = StreamReceiver::new(state);

  (sender, receiver)
}

#[cfg(test)]
mod tests {
  use super::*;
  use tokio_stream::StreamExt;

  #[tokio::test]
  async fn test_stream() {
    let (sender, mut receiver) = stream();

    sender.log("test");
    sender.error("error");
    sender.end(RunResult::Succeeded);

    let mut logs = Vec::new();
    while let Some(log) = receiver.next().await {
      logs.push(log);
    }

    assert_eq!(logs, vec![Log::log("test"), Log::error("error"),]);
    assert_eq!(receiver.result().unwrap(), RunResult::Succeeded);
  }
}