carboncopy_tokio/
lib.rs

1use carboncopy::{BoxFuture, Sink};
2use std::fmt;
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::io::{stdout, AsyncWriteExt, Stdout};
6use tokio::sync::mpsc::{unbounded_channel, UnboundedSender as DropTx};
7use tokio::sync::watch::{channel as watch_channel, Receiver as WatchRx};
8use tokio::sync::Mutex;
9use tokio::time::sleep;
10
11/// A sink with memory buffer and periodic flusher built with Tokio facilities so it is
12/// suitable for binaries relying on the Tokio executor. It can also be used by async-blind
13/// library clients since the Sink trait offers blocking API.
14///
15/// Its mutable interior is guarded by `Arc<Mutex<...>>`.
16pub struct BufSink<T: AsyncWriteExt + Unpin + Send + 'static> {
17    rt: tokio::runtime::Handle,
18    interior: Arc<Mutex<Interior<T>>>,
19    drop_chan_tx: DropTx<EmptySignal>,
20    last_flush_err_chan_rx: WatchRx<Option<Arc<std::io::Error>>>,
21}
22
23impl<T: AsyncWriteExt + Unpin + Send + 'static> Sink for BufSink<T> {
24    fn sink_blocking(&self, entry: String) -> std::io::Result<()> {
25        self.rt.block_on(self.sink(entry))
26    }
27
28    fn sink(&self, entry: String) -> BoxFuture<std::io::Result<()>> {
29        Box::pin(async move {
30            let mut inner = self.interior.lock().await;
31            if let Some((buf, _)) = inner.buf.as_mut() {
32                let _ = buf.write(entry.as_bytes()).await; // infallible, writing to memory
33                Ok(())
34            } else {
35                // no buffer, write directly to stdout
36                inner.output_writer.write(entry.as_bytes()).await?;
37                Ok(())
38            }
39        })
40    }
41}
42
43impl<T: AsyncWriteExt + Unpin + Send + 'static> Drop for BufSink<T> {
44    fn drop(&mut self) {
45        let _ = self.drop_chan_tx.send(EmptySignal);
46    }
47}
48
49impl<T: AsyncWriteExt + Unpin + Send + 'static> BufSink<T> {
50    /// At the same time as the instantiation, a flusher task is spawned in the background
51    /// whose job is to flush after the buffer overflows or a set timeout has elapsed
52    /// (whichever happens first).
53    ///
54    /// The flusher task will terminate when the instance is dropped.
55    pub fn new(opts: SinkOptions<T>) -> Self {
56        let interior = Arc::new(Mutex::new(Interior {
57            backlogged: false,
58            buf: if opts.buffer.is_none() {
59                None
60            } else {
61                let cap = opts.buffer.as_ref().unwrap();
62                Some((Vec::with_capacity(cap.0), cap.0))
63            },
64            output_writer: opts.output_writer,
65        }));
66
67        let (drop_tx, mut drop_rx) = unbounded_channel();
68        let (err_tx, err_rx) = watch_channel(None);
69
70        let rt = opts.tokio_runtime.clone();
71        let interior_clone = interior.clone();
72        let timeout_ms = opts.flush_timeout_ms;
73        rt.spawn(async move {
74            if interior_clone.lock().await.buf.is_some() {
75                loop {
76                    let overflow = async {
77                        loop {
78                            {
79                                let interior_check = interior_clone.lock().await;
80                                if interior_check.buf.as_ref().unwrap().0.len()
81                                    >= interior_check.buf.as_ref().unwrap().1
82                                {
83                                    return;
84                                }
85                            }
86                            if timeout_ms > 1 {
87                                sleep(Duration::from_millis(1)).await;
88                            }
89                        }
90                    };
91                    let timeout = async move {
92                        sleep(Duration::from_millis(timeout_ms)).await;
93                    };
94                    tokio::select! {
95                        _ = overflow => {
96                            if let Err(io_err) = interior_clone.lock().await.flush().await {
97                                // error will be returned by send() if the sink has been dropped,
98                                // at which point, the error no longer matters
99                                let _ = err_tx.send(Some(Arc::new(io_err)));
100                            } else {
101                                let _ = err_tx.send(None);
102                            };
103                        }
104                        _ = timeout => {
105                            if let Err(io_err) = interior_clone.lock().await.flush().await {
106                                let _ = err_tx.send(Some(Arc::new(io_err)));
107                            } else {
108                                let _ = err_tx.send(None);
109                            };
110                        }
111                        _ = drop_rx.recv() => {
112                            return; // sink instance dropped, terminate loop/task
113                        }
114                    }
115                }
116            } else {
117                return; // no need for buffer checks if there is no buffer
118            }
119        });
120
121        Self {
122            rt: rt,
123            interior: interior,
124            drop_chan_tx: drop_tx,
125            last_flush_err_chan_rx: err_rx,
126        }
127    }
128
129    /// Attempts to manually flush the underlying buffer to Stdout.
130    pub async fn flush(&self) -> std::io::Result<usize> {
131        self.interior.lock().await.flush().await
132    }
133
134    /// Checks if buffer flushing is being backlogged (not necessarily by errors).
135    pub async fn backlogged(&self) -> bool {
136        self.interior.lock().await.backlogged()
137    }
138
139    /// Checks if the flusher has just encountered an error. Only use this function to check
140    /// for long running errors. A temporary error could already be cleared by retries by the
141    /// time you call this function.
142    ///
143    /// A bufferless sink will always return None.
144    pub fn last_flush_err(&self) -> Option<Arc<std::io::Error>> {
145        self.last_flush_err_chan_rx.borrow().clone()
146    }
147}
148
149struct Interior<T: AsyncWriteExt + Unpin + Send + 'static> {
150    backlogged: bool,
151    buf: Option<(Vec<u8>, usize)>,
152    output_writer: T,
153}
154
155impl<T: AsyncWriteExt + Unpin + Send + 'static> Interior<T> {
156    async fn flush(&mut self) -> Result<usize, std::io::Error> {
157        if self.buf.is_none() {
158            Ok(0)
159        } else {
160            let vec_len = self.buf.as_ref().unwrap().0.len();
161            if vec_len > 0 {
162                let mut written: usize = 0;
163                while vec_len > 0 {
164                    let res = self
165                        .output_writer
166                        .write(self.buf.as_ref().unwrap().0.as_slice())
167                        .await;
168
169                    // clear first N elements of vec according to res,
170                    // or empty vec if N == vec_len
171                    if let Ok(delta) = res {
172                        if delta == 0 {
173                            return res;
174                        }
175                        if delta == vec_len {
176                            self.buf.as_mut().unwrap().0 =
177                                Vec::with_capacity(self.buf.as_ref().unwrap().1);
178                            self.backlogged = false;
179                        } else {
180                            self.buf.as_mut().unwrap().0.drain(0..delta);
181                            self.backlogged = true;
182                        }
183                        written += delta;
184                    } else {
185                        self.backlogged = true;
186                        return res;
187                    }
188                }
189                Ok(written)
190            } else {
191                Ok(0)
192            }
193        }
194    }
195
196    fn backlogged(&self) -> bool {
197        self.backlogged
198    }
199}
200
201/// Implements the Default trait.
202pub struct SinkOptions<T: AsyncWriteExt + Unpin + Send + 'static> {
203    pub buffer: Option<BufferOverflowThreshold>,
204    pub flush_timeout_ms: u64,
205    pub tokio_runtime: tokio::runtime::Handle,
206    pub output_writer: T,
207}
208
209impl Default for SinkOptions<Stdout> {
210    /// # Panic
211    ///
212    /// Panics if called outside of a tokio runtime.
213    fn default() -> Self {
214        Self {
215            // unwrap safety: any panic will cause default_options_dont_panic() test to fail
216            buffer: Some(BufferOverflowThreshold::new(64 * 1024).unwrap()),
217            flush_timeout_ms: 100,
218            tokio_runtime: if let Ok(handle) = tokio::runtime::Handle::try_current() {
219                handle
220            } else {
221                panic!("SinkOptions::default() called outside of a tokio runtime")
222            },
223            output_writer: stdout(),
224        }
225    }
226}
227
228/// A size threshold after which the buffer will be flushed. The size of the buffer itself is
229/// unlimited.
230#[derive(Debug, PartialEq, Eq, Copy, Clone, Ord, PartialOrd)]
231pub struct BufferOverflowThreshold(usize);
232
233impl BufferOverflowThreshold {
234    /// Must be greater than 1KB (1024) and less than 1GB (1024 * 1024 * 1024).
235    pub fn new(cap: usize) -> Result<Self, ThresholdError> {
236        const KB: usize = 1024;
237        const GB: usize = 1024 * 1024 * 1024;
238        if cap >= 1 * KB && cap <= 1 * GB {
239            Ok(Self(cap))
240        } else if cap < 1 * KB {
241            Err(ThresholdError::LessThan1KB)
242        } else {
243            Err(ThresholdError::MoreThan1GB)
244        }
245    }
246}
247
248#[derive(Debug, PartialEq, Eq, Copy, Clone, Ord, PartialOrd)]
249pub enum ThresholdError {
250    LessThan1KB,
251    MoreThan1GB,
252}
253
254impl fmt::Display for ThresholdError {
255    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
256        match self {
257            Self::LessThan1KB => {
258                write!(
259                    f,
260                    "buffer overflow threshold can't be less than 1024 bytes (1KB)"
261                )
262            }
263            Self::MoreThan1GB => {
264                write!(
265                    f,
266                    "buffer overflow threshold can't be greater than 1024 * 1024 * 1024 bytes (1GB)",
267                )
268            }
269        }
270    }
271}
272
273struct EmptySignal;
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278
279    #[test]
280    fn overflow_threshold() {
281        assert_eq!(
282            BufferOverflowThreshold::new(1000).err().unwrap(),
283            ThresholdError::LessThan1KB
284        );
285        assert_eq!(
286            BufferOverflowThreshold::new(1024 * 1024 * 1024 + 1)
287                .err()
288                .unwrap(),
289            ThresholdError::MoreThan1GB
290        );
291    }
292
293    #[test]
294    fn default_options_dont_panic() {
295        let rt = tokio::runtime::Runtime::new().unwrap();
296        rt.block_on(async {
297            assert_eq!(100, SinkOptions::default().flush_timeout_ms); // no is panic good enough
298        });
299    }
300
301    #[test]
302    fn no_buffer() {
303        // setup
304        let rt = tokio::runtime::Runtime::new().unwrap();
305        let opts = SinkOptions {
306            buffer: None,
307            flush_timeout_ms: 30,
308            tokio_runtime: rt.handle().clone(),
309            output_writer: Vec::new(),
310        };
311        let mem_sink = Arc::new(BufSink::new(opts));
312        // end setup
313
314        for i in 0..5 {
315            assert!(rt
316                .block_on(async {
317                    mem_sink
318                        .clone()
319                        .sink(String::from(format!("hello world {}\n", i)))
320                        .await
321                })
322                .is_ok());
323        }
324
325        let ref_output =
326            "hello world 0\nhello world 1\nhello world 2\nhello world 3\nhello world 4\n";
327
328        let output =
329            rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
330
331        assert_eq!(ref_output, std::str::from_utf8(output.as_ref()).unwrap());
332    }
333
334    #[test]
335    fn timeout_flush() {
336        // setup
337        let rt = tokio::runtime::Runtime::new().unwrap();
338        let opts = SinkOptions {
339            buffer: Some(BufferOverflowThreshold::new(64 * 1024).unwrap()),
340            flush_timeout_ms: 30,
341            tokio_runtime: rt.handle().clone(),
342            output_writer: Vec::new(),
343        };
344        let mem_sink = Arc::new(BufSink::new(opts));
345        // end setup
346
347        for i in 0..5 {
348            assert!(rt
349                .block_on(async {
350                    mem_sink
351                        .clone()
352                        .sink(String::from(format!("hello world {}\n", i)))
353                        .await
354                })
355                .is_ok());
356        }
357
358        let ref_output =
359            "hello world 0\nhello world 1\nhello world 2\nhello world 3\nhello world 4\n";
360
361        let output_before_flush_timeout =
362            rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
363
364        assert_ne!(
365            ref_output,
366            std::str::from_utf8(output_before_flush_timeout.as_ref()).unwrap()
367        );
368
369        // simulate timeout with sleep
370        rt.block_on(async {
371            sleep(Duration::from_millis(40)).await;
372        });
373
374        let output_after_flush_timeout =
375            rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
376
377        assert_eq!(
378            ref_output,
379            std::str::from_utf8(output_after_flush_timeout.as_ref()).unwrap()
380        );
381    }
382
383    #[test]
384    fn overflow_flush() {
385        // setup
386        let rt = tokio::runtime::Runtime::new().unwrap();
387        let opts = SinkOptions {
388            buffer: Some(BufferOverflowThreshold::new(1 * 1024).unwrap()),
389            flush_timeout_ms: 30,
390            tokio_runtime: rt.handle().clone(),
391            output_writer: Vec::new(),
392        };
393        let mem_sink = Arc::new(BufSink::new(opts));
394        // end setup
395
396        for _ in 0..1024 {
397            assert!(rt
398                .block_on(async { mem_sink.clone().sink(String::from("X")).await })
399                .is_ok());
400        }
401
402        let mut ref_output: String = vec!['X'; 1024].into_iter().collect();
403
404        let output_before_buf_overflow =
405            rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
406
407        assert_ne!(
408            ref_output,
409            std::str::from_utf8(output_before_buf_overflow.as_ref()).unwrap()
410        );
411
412        // trigger overflow
413        assert!(rt
414            .block_on(async { mem_sink.clone().sink(String::from("X")).await })
415            .is_ok());
416        // 1 ms sleep between overflow checks, plus margin
417        rt.block_on(async {
418            sleep(Duration::from_millis(1 + 9)).await;
419        });
420        // ref_output got additional 'X' from overflow trigger
421        ref_output.push('X');
422
423        let output_after_buf_overflow =
424            rt.block_on(async { mem_sink.clone().interior.lock().await.output_writer.clone() });
425
426        assert_eq!(
427            ref_output,
428            std::str::from_utf8(output_after_buf_overflow.as_ref()).unwrap()
429        );
430    }
431
432    #[test]
433    fn flush_err() {
434        // setup
435        use core::task::{Context, Poll};
436        use std::io::{Error, ErrorKind};
437        use std::pin::Pin;
438        use tokio::io::AsyncWrite;
439
440        struct ProblematicWriter;
441        impl AsyncWrite for ProblematicWriter {
442            fn poll_write(
443                self: Pin<&mut Self>,
444                _: &mut Context<'_>,
445                _: &[u8],
446            ) -> Poll<Result<usize, Error>> {
447                Poll::Ready(Err(Error::new(ErrorKind::Other, "kaboom!")))
448            }
449            fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Error>> {
450                Poll::Ready(Err(Error::new(ErrorKind::Other, "kaboom!")))
451            }
452            fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Error>> {
453                Poll::Ready(Err(Error::new(ErrorKind::Other, "kaboom!")))
454            }
455        }
456
457        let rt = tokio::runtime::Runtime::new().unwrap();
458        let opts = SinkOptions {
459            buffer: Some(BufferOverflowThreshold::new(1 * 1024).unwrap()),
460            flush_timeout_ms: 20,
461            tokio_runtime: rt.handle().clone(),
462            output_writer: ProblematicWriter,
463        };
464        let mem_sink = Arc::new(BufSink::new(opts));
465        // end setup
466
467        assert!(rt
468            .block_on(async { mem_sink.clone().sink(String::from("hello world\n")).await })
469            .is_ok());
470
471        assert!(mem_sink.last_flush_err().is_none());
472
473        // wait for flush timeout
474        rt.block_on(async {
475            sleep(Duration::from_millis(20 + 5)).await;
476        });
477
478        assert!(mem_sink.last_flush_err().is_some());
479        assert_eq!(ErrorKind::Other, mem_sink.last_flush_err().unwrap().kind());
480        assert_eq!("kaboom!", format!("{}", mem_sink.last_flush_err().unwrap()));
481    }
482}