stdout_channel/
lib.rs

1pub mod rate_limiter;
2
3pub use rate_limiter::RateLimiter;
4
5use deadqueue::unlimited::Queue;
6use std::{
7    fmt,
8    fmt::Display,
9    io::{Error as IoError, Write},
10    ops::Deref,
11    sync::Arc,
12};
13use thiserror::Error;
14use tokio::{
15    io::{stderr, stdout, AsyncWriteExt},
16    sync::Mutex,
17    task::{spawn, JoinError, JoinHandle},
18};
19
20#[derive(Error, Debug)]
21pub enum StdoutChannelError {
22    #[error("task join error")]
23    JoinError(#[from] JoinError),
24    #[error("io error")]
25    IoError(#[from] IoError),
26}
27
28enum StdoutMessage<T> {
29    Mesg(T),
30    Close,
31}
32
33type StdoutQueue<T> = Queue<StdoutMessage<T>>;
34type StdoutTask = JoinHandle<Result<(), StdoutChannelError>>;
35
36#[derive(Clone)]
37pub struct StdoutChannel<T> {
38    stdout_queue: Arc<StdoutQueue<T>>,
39    stderr_queue: Arc<StdoutQueue<T>>,
40    stdout_task: Arc<Mutex<Option<StdoutTask>>>,
41    stderr_task: Arc<Mutex<Option<StdoutTask>>>,
42}
43
44impl<T> Default for StdoutChannel<T>
45where
46    T: Display + Send + 'static,
47{
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl<T> fmt::Debug for StdoutChannel<T> {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        write!(f, "StdoutChannel")
56    }
57}
58
59impl<T> StdoutChannel<T>
60where
61    T: Display + Send + 'static,
62{
63    #[must_use]
64    pub fn new() -> Self {
65        let stdout_queue = Queue::new().into();
66        let stderr_queue = Queue::new().into();
67        let stdout_task = Mutex::new(Some(spawn({
68            let queue = Arc::clone(&stdout_queue);
69            async move { Self::process_stdout(&queue).await }
70        })))
71        .into();
72        let stderr_task = Mutex::new(Some(spawn({
73            let queue = Arc::clone(&stderr_queue);
74            async move { Self::process_stderr(&queue).await }
75        })))
76        .into();
77        Self {
78            stdout_queue,
79            stderr_queue,
80            stdout_task,
81            stderr_task,
82        }
83    }
84
85    #[must_use]
86    pub fn with_mock_stdout(mock_stdout: MockStdout<T>, mock_stderr: MockStdout<T>) -> Self {
87        let stdout_queue = Queue::new().into();
88        let stderr_queue = Queue::new().into();
89        let stdout_task = Mutex::new(Some(spawn({
90            let queue = Arc::clone(&stdout_queue);
91            async move { Self::process_mock(&queue, &mock_stdout).await }
92        })))
93        .into();
94        let stderr_task = Mutex::new(Some(spawn({
95            let queue = Arc::clone(&stderr_queue);
96            async move { Self::process_mock(&queue, &mock_stderr).await }
97        })))
98        .into();
99        Self {
100            stdout_queue,
101            stderr_queue,
102            stdout_task,
103            stderr_task,
104        }
105    }
106
107    pub fn send(&self, item: impl Into<T>) {
108        self.stdout_queue.push(StdoutMessage::Mesg(item.into()));
109    }
110
111    pub fn send_err(&self, item: impl Into<T>) {
112        self.stderr_queue.push(StdoutMessage::Mesg(item.into()));
113    }
114
115    /// Close the `StdoutChannel`
116    /// # Errors
117    ///
118    /// Will error if there have been any errors or panics in the stdout and
119    /// stderr tasks
120    pub async fn close(&self) -> Result<(), StdoutChannelError> {
121        self.stdout_queue.push(StdoutMessage::Close);
122        self.stderr_queue.push(StdoutMessage::Close);
123        if let Some(stdout_task) = self.stdout_task.lock().await.take() {
124            stdout_task.await??;
125        }
126        if let Some(stderr_task) = self.stderr_task.lock().await.take() {
127            stderr_task.await??;
128        }
129        Ok(())
130    }
131
132    async fn process_stdout(queue: &StdoutQueue<T>) -> Result<(), StdoutChannelError> {
133        let mut buf = Buffer::new();
134        while let StdoutMessage::Mesg(line) = queue.pop().await {
135            stdout().write_all(buf.write_line(line)?).await?;
136        }
137        Ok(())
138    }
139
140    async fn process_stderr(queue: &StdoutQueue<T>) -> Result<(), StdoutChannelError> {
141        let mut buf = Buffer::new();
142        while let StdoutMessage::Mesg(line) = queue.pop().await {
143            stderr().write_all(buf.write_line(line)?).await?;
144        }
145        Ok(())
146    }
147
148    async fn process_mock(
149        queue: &StdoutQueue<T>,
150        mock_stdout: &MockStdout<T>,
151    ) -> Result<(), StdoutChannelError> {
152        while let StdoutMessage::Mesg(line) = queue.pop().await {
153            mock_stdout.lock().await.push(line);
154        }
155        Ok(())
156    }
157}
158
159const MAX_BUFFER_CAPACITY: usize = 4096;
160
161struct Buffer(Vec<u8>);
162
163impl Buffer {
164    pub fn new() -> Self {
165        Self(Vec::new())
166    }
167
168    pub fn write_line<T: Display>(&mut self, line: T) -> Result<&[u8], StdoutChannelError> {
169        self.0.clear();
170        if self.0.capacity() > MAX_BUFFER_CAPACITY {
171            self.0.shrink_to(MAX_BUFFER_CAPACITY);
172        }
173        writeln!(self.0, "{line}")?;
174        Ok(&self.0)
175    }
176}
177
178#[derive(Clone)]
179pub struct MockStdout<T>(Arc<Mutex<Vec<T>>>);
180
181impl<T> Default for MockStdout<T> {
182    fn default() -> Self {
183        Self::new()
184    }
185}
186
187impl<T> Deref for MockStdout<T> {
188    type Target = Mutex<Vec<T>>;
189    fn deref(&self) -> &Self::Target {
190        &self.0
191    }
192}
193
194impl<T> MockStdout<T> {
195    #[must_use]
196    pub fn new() -> Self {
197        Self(Mutex::new(Vec::new()).into())
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use stack_string::StackString;
204
205    use super::{MockStdout, StdoutChannel, StdoutChannelError};
206
207    #[tokio::test]
208    async fn test_default_mockstdout() -> Result<(), StdoutChannelError> {
209        let mock = MockStdout::default();
210        mock.lock().await.push(StackString::from("HEY"));
211        assert_eq!(mock.lock().await.len(), 1);
212        assert_eq!(mock.lock().await[0].as_str(), "HEY");
213        Ok(())
214    }
215
216    #[tokio::test]
217    async fn test_default() -> Result<(), StdoutChannelError> {
218        let chan = StdoutChannel::<StackString>::default();
219
220        chan.send("stdout: Hey There");
221        chan.send("What's happening");
222        chan.send_err("stderr: How it goes");
223
224        chan.close().await?;
225        Ok(())
226    }
227
228    #[tokio::test]
229    async fn test_stdout_task() -> Result<(), StdoutChannelError> {
230        let chan = StdoutChannel::<StackString>::default();
231
232        chan.send("stdout: Hey There");
233        chan.send("What's happening");
234        chan.send_err("stderr: How it goes");
235
236        chan.close().await?;
237        Ok(())
238    }
239
240    #[tokio::test]
241    async fn test_mock_stdout() -> Result<(), StdoutChannelError> {
242        let stdout = MockStdout::<StackString>::new();
243        let stderr = MockStdout::new();
244
245        let chan = StdoutChannel::with_mock_stdout(stdout.clone(), stderr.clone());
246
247        chan.send("stdout: Hey There");
248        chan.send("What's happening");
249        chan.send_err("stderr: How it goes");
250        chan.close().await?;
251
252        assert_eq!(stdout.lock().await.len(), 2);
253        assert_eq!(stdout.lock().await[0], "stdout: Hey There");
254        assert_eq!(stdout.lock().await[1], "What's happening");
255        assert_eq!(stderr.lock().await.len(), 1);
256        assert_eq!(stderr.lock().await[0], "stderr: How it goes");
257
258        Ok(())
259    }
260}