1use carboncopy::{BoxFuture, Sink};
2use std::fmt;
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::io::{stdout, AsyncWriteExt, Stdout};
6use tokio::sync::mpsc::{unbounded_channel, UnboundedSender as DropTx};
7use tokio::sync::watch::{channel as watch_channel, Receiver as WatchRx};
8use tokio::sync::Mutex;
9use tokio::time::sleep;
10
11pub struct BufSink<T: AsyncWriteExt + Unpin + Send + 'static> {
17 rt: tokio::runtime::Handle,
18 interior: Arc<Mutex<Interior<T>>>,
19 drop_chan_tx: DropTx<EmptySignal>,
20 last_flush_err_chan_rx: WatchRx<Option<Arc<std::io::Error>>>,
21}
22
23impl<T: AsyncWriteExt + Unpin + Send + 'static> Sink for BufSink<T> {
24 fn sink_blocking(&self, entry: String) -> std::io::Result<()> {
25 self.rt.block_on(self.sink(entry))
26 }
27
28 fn sink(&self, entry: String) -> BoxFuture<std::io::Result<()>> {
29 Box::pin(async move {
30 let mut inner = self.interior.lock().await;
31 if let Some((buf, _)) = inner.buf.as_mut() {
32 let _ = buf.write(entry.as_bytes()).await; Ok(())
34 } else {
35 inner.output_writer.write(entry.as_bytes()).await?;
37 Ok(())
38 }
39 })
40 }
41}
42
43impl<T: AsyncWriteExt + Unpin + Send + 'static> Drop for BufSink<T> {
44 fn drop(&mut self) {
45 let _ = self.drop_chan_tx.send(EmptySignal);
46 }
47}
48
49impl<T: AsyncWriteExt + Unpin + Send + 'static> BufSink<T> {
50 pub fn new(opts: SinkOptions<T>) -> Self {
56 let interior = Arc::new(Mutex::new(Interior {
57 backlogged: false,
58 buf: if opts.buffer.is_none() {
59 None
60 } else {
61 let cap = opts.buffer.as_ref().unwrap();
62 Some((Vec::with_capacity(cap.0), cap.0))
63 },
64 output_writer: opts.output_writer,
65 }));
66
67 let (drop_tx, mut drop_rx) = unbounded_channel();
68 let (err_tx, err_rx) = watch_channel(None);
69
70 let rt = opts.tokio_runtime.clone();
71 let interior_clone = interior.clone();
72 let timeout_ms = opts.flush_timeout_ms;
73 rt.spawn(async move {
74 if interior_clone.lock().await.buf.is_some() {
75 loop {
76 let overflow = async {
77 loop {
78 {
79 let interior_check = interior_clone.lock().await;
80 if interior_check.buf.as_ref().unwrap().0.len()
81 >= interior_check.buf.as_ref().unwrap().1
82 {
83 return;
84 }
85 }
86 if timeout_ms > 1 {
87 sleep(Duration::from_millis(1)).await;
88 }
89 }
90 };
91 let timeout = async move {
92 sleep(Duration::from_millis(timeout_ms)).await;
93 };
94 tokio::select! {
95 _ = overflow => {
96 if let Err(io_err) = interior_clone.lock().await.flush().await {
97 let _ = err_tx.send(Some(Arc::new(io_err)));
100 } else {
101 let _ = err_tx.send(None);
102 };
103 }
104 _ = timeout => {
105 if let Err(io_err) = interior_clone.lock().await.flush().await {
106 let _ = err_tx.send(Some(Arc::new(io_err)));
107 } else {
108 let _ = err_tx.send(None);
109 };
110 }
111 _ = drop_rx.recv() => {
112 return; }
114 }
115 }
116 } else {
117 return; }
119 });
120
121 Self {
122 rt: rt,
123 interior: interior,
124 drop_chan_tx: drop_tx,
125 last_flush_err_chan_rx: err_rx,
126 }
127 }
128
129 pub async fn flush(&self) -> std::io::Result<usize> {
131 self.interior.lock().await.flush().await
132 }
133
134 pub async fn backlogged(&self) -> bool {
136 self.interior.lock().await.backlogged()
137 }
138
139 pub fn last_flush_err(&self) -> Option<Arc<std::io::Error>> {
145 self.last_flush_err_chan_rx.borrow().clone()
146 }
147}
148
149struct Interior<T: AsyncWriteExt + Unpin + Send + 'static> {
150 backlogged: bool,
151 buf: Option<(Vec<u8>, usize)>,
152 output_writer: T,
153}
154
155impl<T: AsyncWriteExt + Unpin + Send + 'static> Interior<T> {
156 async fn flush(&mut self) -> Result<usize, std::io::Error> {
157 if self.buf.is_none() {
158 Ok(0)
159 } else {
160 let vec_len = self.buf.as_ref().unwrap().0.len();
161 if vec_len > 0 {
162 let mut written: usize = 0;
163 while vec_len > 0 {
164 let res = self
165 .output_writer
166 .write(self.buf.as_ref().unwrap().0.as_slice())
167 .await;
168
169 if let Ok(delta) = res {
172 if delta == 0 {
173 return res;
174 }
175 if delta == vec_len {
176 self.buf.as_mut().unwrap().0 =
177 Vec::with_capacity(self.buf.as_ref().unwrap().1);
178 self.backlogged = false;
179 } else {
180 self.buf.as_mut().unwrap().0.drain(0..delta);
181 self.backlogged = true;
182 }
183 written += delta;
184 } else {
185 self.backlogged = true;
186 return res;
187 }
188 }
189 Ok(written)
190 } else {
191 Ok(0)
192 }
193 }
194 }
195
196 fn backlogged(&self) -> bool {
197 self.backlogged
198 }
199}
200
201pub struct SinkOptions<T: AsyncWriteExt + Unpin + Send + 'static> {
203 pub buffer: Option<BufferOverflowThreshold>,
204 pub flush_timeout_ms: u64,
205 pub tokio_runtime: tokio::runtime::Handle,
206 pub output_writer: T,
207}
208
209impl Default for SinkOptions<Stdout> {
210 fn default() -> Self {
214 Self {
215 buffer: Some(BufferOverflowThreshold::new(64 * 1024).unwrap()),
217 flush_timeout_ms: 100,
218 tokio_runtime: if let Ok(handle) = tokio::runtime::Handle::try_current() {
219 handle
220 } else {
221 panic!("SinkOptions::default() called outside of a tokio runtime")
222 },
223 output_writer: stdout(),
224 }
225 }
226}
227
228#[derive(Debug, PartialEq, Eq, Copy, Clone, Ord, PartialOrd)]
231pub struct BufferOverflowThreshold(usize);
232
233impl BufferOverflowThreshold {
234 pub fn new(cap: usize) -> Result<Self, ThresholdError> {
236 const KB: usize = 1024;
237 const GB: usize = 1024 * 1024 * 1024;
238 if cap >= 1 * KB && cap <= 1 * GB {
239 Ok(Self(cap))
240 } else if cap < 1 * KB {
241 Err(ThresholdError::LessThan1KB)
242 } else {
243 Err(ThresholdError::MoreThan1GB)
244 }
245 }
246}
247
248#[derive(Debug, PartialEq, Eq, Copy, Clone, Ord, PartialOrd)]
249pub enum ThresholdError {
250 LessThan1KB,
251 MoreThan1GB,
252}
253
254impl fmt::Display for ThresholdError {
255 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
256 match self {
257 Self::LessThan1KB => {
258 write!(
259 f,
260 "buffer overflow threshold can't be less than 1024 bytes (1KB)"
261 )
262 }
263 Self::MoreThan1GB => {
264 write!(
265 f,
266 "buffer overflow threshold can't be greater than 1024 * 1024 * 1024 bytes (1GB)",
267 )
268 }
269 }
270 }
271}
272
273struct EmptySignal;
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn overflow_threshold() {
281 assert_eq!(
282 BufferOverflowThreshold::new(1000).err().unwrap(),
283 ThresholdError::LessThan1KB
284 );
285 assert_eq!(
286 BufferOverflowThreshold::new(1024 * 1024 * 1024 + 1)
287 .err()
288 .unwrap(),
289 ThresholdError::MoreThan1GB
290 );
291 }
292
293 #[test]
294 fn default_options_dont_panic() {
295 let rt = tokio::runtime::Runtime::new().unwrap();
296 rt.block_on(async {
297 assert_eq!(100, SinkOptions::default().flush_timeout_ms); });
299 }
300
301 #[test]
302 fn no_buffer() {
303 let rt = tokio::runtime::Runtime::new().unwrap();
305 let opts = SinkOptions {
306 buffer: None,
307 flush_timeout_ms: 30,
308 tokio_runtime: rt.handle().clone(),
309 output_writer: Vec::new(),
310 };
311 let mem_sink = Arc::new(BufSink::new(opts));
312 for i in 0..5 {
315 assert!(rt
316 .block_on(async {
317 mem_sink
318 .clone()
319 .sink(String::from(format!("hello world {}\n", i)))
320 .await
321 })
322 .is_ok());
323 }
324
325 let ref_output =
326 "hello world 0\nhello world 1\nhello world 2\nhello world 3\nhello world 4\n";
327
328 let output =
329 rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
330
331 assert_eq!(ref_output, std::str::from_utf8(output.as_ref()).unwrap());
332 }
333
334 #[test]
335 fn timeout_flush() {
336 let rt = tokio::runtime::Runtime::new().unwrap();
338 let opts = SinkOptions {
339 buffer: Some(BufferOverflowThreshold::new(64 * 1024).unwrap()),
340 flush_timeout_ms: 30,
341 tokio_runtime: rt.handle().clone(),
342 output_writer: Vec::new(),
343 };
344 let mem_sink = Arc::new(BufSink::new(opts));
345 for i in 0..5 {
348 assert!(rt
349 .block_on(async {
350 mem_sink
351 .clone()
352 .sink(String::from(format!("hello world {}\n", i)))
353 .await
354 })
355 .is_ok());
356 }
357
358 let ref_output =
359 "hello world 0\nhello world 1\nhello world 2\nhello world 3\nhello world 4\n";
360
361 let output_before_flush_timeout =
362 rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
363
364 assert_ne!(
365 ref_output,
366 std::str::from_utf8(output_before_flush_timeout.as_ref()).unwrap()
367 );
368
369 rt.block_on(async {
371 sleep(Duration::from_millis(40)).await;
372 });
373
374 let output_after_flush_timeout =
375 rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
376
377 assert_eq!(
378 ref_output,
379 std::str::from_utf8(output_after_flush_timeout.as_ref()).unwrap()
380 );
381 }
382
383 #[test]
384 fn overflow_flush() {
385 let rt = tokio::runtime::Runtime::new().unwrap();
387 let opts = SinkOptions {
388 buffer: Some(BufferOverflowThreshold::new(1 * 1024).unwrap()),
389 flush_timeout_ms: 30,
390 tokio_runtime: rt.handle().clone(),
391 output_writer: Vec::new(),
392 };
393 let mem_sink = Arc::new(BufSink::new(opts));
394 for _ in 0..1024 {
397 assert!(rt
398 .block_on(async { mem_sink.clone().sink(String::from("X")).await })
399 .is_ok());
400 }
401
402 let mut ref_output: String = vec!['X'; 1024].into_iter().collect();
403
404 let output_before_buf_overflow =
405 rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
406
407 assert_ne!(
408 ref_output,
409 std::str::from_utf8(output_before_buf_overflow.as_ref()).unwrap()
410 );
411
412 assert!(rt
414 .block_on(async { mem_sink.clone().sink(String::from("X")).await })
415 .is_ok());
416 rt.block_on(async {
418 sleep(Duration::from_millis(1 + 9)).await;
419 });
420 ref_output.push('X');
422
423 let output_after_buf_overflow =
424 rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
425
426 assert_eq!(
427 ref_output,
428 std::str::from_utf8(output_after_buf_overflow.as_ref()).unwrap()
429 );
430 }
431
432 #[test]
433 fn flush_err() {
434 use core::task::{Context, Poll};
436 use std::io::{Error, ErrorKind};
437 use std::pin::Pin;
438 use tokio::io::AsyncWrite;
439
440 struct ProblematicWriter;
441 impl AsyncWrite for ProblematicWriter {
442 fn poll_write(
443 self: Pin<&mut Self>,
444 _: &mut Context<'_>,
445 _: &[u8],
446 ) -> Poll<Result<usize, Error>> {
447 Poll::Ready(Err(Error::new(ErrorKind::Other, "kaboom!")))
448 }
449 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Error>> {
450 Poll::Ready(Err(Error::new(ErrorKind::Other, "kaboom!")))
451 }
452 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Error>> {
453 Poll::Ready(Err(Error::new(ErrorKind::Other, "kaboom!")))
454 }
455 }
456
457 let rt = tokio::runtime::Runtime::new().unwrap();
458 let opts = SinkOptions {
459 buffer: Some(BufferOverflowThreshold::new(1 * 1024).unwrap()),
460 flush_timeout_ms: 20,
461 tokio_runtime: rt.handle().clone(),
462 output_writer: ProblematicWriter,
463 };
464 let mem_sink = Arc::new(BufSink::new(opts));
465 assert!(rt
468 .block_on(async { mem_sink.clone().sink(String::from("hello world\n")).await })
469 .is_ok());
470
471 assert!(mem_sink.last_flush_err().is_none());
472
473 rt.block_on(async {
475 sleep(Duration::from_millis(20 + 5)).await;
476 });
477
478 assert!(mem_sink.last_flush_err().is_some());
479 assert_eq!(ErrorKind::Other, mem_sink.last_flush_err().unwrap().kind());
480 assert_eq!("kaboom!", format!("{}", mem_sink.last_flush_err().unwrap()));
481 }
482}