1pub mod rate_limiter;
2
3pub use rate_limiter::RateLimiter;
4
5use deadqueue::unlimited::Queue;
6use std::{
7 fmt,
8 fmt::Display,
9 io::{Error as IoError, Write},
10 ops::Deref,
11 sync::Arc,
12};
13use thiserror::Error;
14use tokio::{
15 io::{stderr, stdout, AsyncWriteExt},
16 sync::Mutex,
17 task::{spawn, JoinError, JoinHandle},
18};
19
20#[derive(Error, Debug)]
21pub enum StdoutChannelError {
22 #[error("task join error")]
23 JoinError(#[from] JoinError),
24 #[error("io error")]
25 IoError(#[from] IoError),
26}
27
28enum StdoutMessage<T> {
29 Mesg(T),
30 Close,
31}
32
33type StdoutQueue<T> = Queue<StdoutMessage<T>>;
34type StdoutTask = JoinHandle<Result<(), StdoutChannelError>>;
35
36#[derive(Clone)]
37pub struct StdoutChannel<T> {
38 stdout_queue: Arc<StdoutQueue<T>>,
39 stderr_queue: Arc<StdoutQueue<T>>,
40 stdout_task: Arc<Mutex<Option<StdoutTask>>>,
41 stderr_task: Arc<Mutex<Option<StdoutTask>>>,
42}
43
44impl<T> Default for StdoutChannel<T>
45where
46 T: Display + Send + 'static,
47{
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl<T> fmt::Debug for StdoutChannel<T> {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 write!(f, "StdoutChannel")
56 }
57}
58
59impl<T> StdoutChannel<T>
60where
61 T: Display + Send + 'static,
62{
63 #[must_use]
64 pub fn new() -> Self {
65 let stdout_queue = Queue::new().into();
66 let stderr_queue = Queue::new().into();
67 let stdout_task = Mutex::new(Some(spawn({
68 let queue = Arc::clone(&stdout_queue);
69 async move { Self::process_stdout(&queue).await }
70 })))
71 .into();
72 let stderr_task = Mutex::new(Some(spawn({
73 let queue = Arc::clone(&stderr_queue);
74 async move { Self::process_stderr(&queue).await }
75 })))
76 .into();
77 Self {
78 stdout_queue,
79 stderr_queue,
80 stdout_task,
81 stderr_task,
82 }
83 }
84
85 #[must_use]
86 pub fn with_mock_stdout(mock_stdout: MockStdout<T>, mock_stderr: MockStdout<T>) -> Self {
87 let stdout_queue = Queue::new().into();
88 let stderr_queue = Queue::new().into();
89 let stdout_task = Mutex::new(Some(spawn({
90 let queue = Arc::clone(&stdout_queue);
91 async move { Self::process_mock(&queue, &mock_stdout).await }
92 })))
93 .into();
94 let stderr_task = Mutex::new(Some(spawn({
95 let queue = Arc::clone(&stderr_queue);
96 async move { Self::process_mock(&queue, &mock_stderr).await }
97 })))
98 .into();
99 Self {
100 stdout_queue,
101 stderr_queue,
102 stdout_task,
103 stderr_task,
104 }
105 }
106
107 pub fn send(&self, item: impl Into<T>) {
108 self.stdout_queue.push(StdoutMessage::Mesg(item.into()));
109 }
110
111 pub fn send_err(&self, item: impl Into<T>) {
112 self.stderr_queue.push(StdoutMessage::Mesg(item.into()));
113 }
114
115 pub async fn close(&self) -> Result<(), StdoutChannelError> {
121 self.stdout_queue.push(StdoutMessage::Close);
122 self.stderr_queue.push(StdoutMessage::Close);
123 if let Some(stdout_task) = self.stdout_task.lock().await.take() {
124 stdout_task.await??;
125 }
126 if let Some(stderr_task) = self.stderr_task.lock().await.take() {
127 stderr_task.await??;
128 }
129 Ok(())
130 }
131
132 async fn process_stdout(queue: &StdoutQueue<T>) -> Result<(), StdoutChannelError> {
133 let mut buf = Buffer::new();
134 while let StdoutMessage::Mesg(line) = queue.pop().await {
135 stdout().write_all(buf.write_line(line)?).await?;
136 }
137 Ok(())
138 }
139
140 async fn process_stderr(queue: &StdoutQueue<T>) -> Result<(), StdoutChannelError> {
141 let mut buf = Buffer::new();
142 while let StdoutMessage::Mesg(line) = queue.pop().await {
143 stderr().write_all(buf.write_line(line)?).await?;
144 }
145 Ok(())
146 }
147
148 async fn process_mock(
149 queue: &StdoutQueue<T>,
150 mock_stdout: &MockStdout<T>,
151 ) -> Result<(), StdoutChannelError> {
152 while let StdoutMessage::Mesg(line) = queue.pop().await {
153 mock_stdout.lock().await.push(line);
154 }
155 Ok(())
156 }
157}
158
159const MAX_BUFFER_CAPACITY: usize = 4096;
160
161struct Buffer(Vec<u8>);
162
163impl Buffer {
164 pub fn new() -> Self {
165 Self(Vec::new())
166 }
167
168 pub fn write_line<T: Display>(&mut self, line: T) -> Result<&[u8], StdoutChannelError> {
169 self.0.clear();
170 if self.0.capacity() > MAX_BUFFER_CAPACITY {
171 self.0.shrink_to(MAX_BUFFER_CAPACITY);
172 }
173 writeln!(self.0, "{line}")?;
174 Ok(&self.0)
175 }
176}
177
178#[derive(Clone)]
179pub struct MockStdout<T>(Arc<Mutex<Vec<T>>>);
180
181impl<T> Default for MockStdout<T> {
182 fn default() -> Self {
183 Self::new()
184 }
185}
186
187impl<T> Deref for MockStdout<T> {
188 type Target = Mutex<Vec<T>>;
189 fn deref(&self) -> &Self::Target {
190 &self.0
191 }
192}
193
194impl<T> MockStdout<T> {
195 #[must_use]
196 pub fn new() -> Self {
197 Self(Mutex::new(Vec::new()).into())
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use stack_string::StackString;
204
205 use super::{MockStdout, StdoutChannel, StdoutChannelError};
206
207 #[tokio::test]
208 async fn test_default_mockstdout() -> Result<(), StdoutChannelError> {
209 let mock = MockStdout::default();
210 mock.lock().await.push(StackString::from("HEY"));
211 assert_eq!(mock.lock().await.len(), 1);
212 assert_eq!(mock.lock().await[0].as_str(), "HEY");
213 Ok(())
214 }
215
216 #[tokio::test]
217 async fn test_default() -> Result<(), StdoutChannelError> {
218 let chan = StdoutChannel::<StackString>::default();
219
220 chan.send("stdout: Hey There");
221 chan.send("What's happening");
222 chan.send_err("stderr: How it goes");
223
224 chan.close().await?;
225 Ok(())
226 }
227
228 #[tokio::test]
229 async fn test_stdout_task() -> Result<(), StdoutChannelError> {
230 let chan = StdoutChannel::<StackString>::default();
231
232 chan.send("stdout: Hey There");
233 chan.send("What's happening");
234 chan.send_err("stderr: How it goes");
235
236 chan.close().await?;
237 Ok(())
238 }
239
240 #[tokio::test]
241 async fn test_mock_stdout() -> Result<(), StdoutChannelError> {
242 let stdout = MockStdout::<StackString>::new();
243 let stderr = MockStdout::new();
244
245 let chan = StdoutChannel::with_mock_stdout(stdout.clone(), stderr.clone());
246
247 chan.send("stdout: Hey There");
248 chan.send("What's happening");
249 chan.send_err("stderr: How it goes");
250 chan.close().await?;
251
252 assert_eq!(stdout.lock().await.len(), 2);
253 assert_eq!(stdout.lock().await[0], "stdout: Hey There");
254 assert_eq!(stdout.lock().await[1], "What's happening");
255 assert_eq!(stderr.lock().await.len(), 1);
256 assert_eq!(stderr.lock().await[0], "stderr: How it goes");
257
258 Ok(())
259 }
260}