flussab/
deferred_writer.rs

1use std::io::{self, Write};
2
3/// A buffered writer with deferred error checking.
4///
5/// This can be used like [`std::io::BufWriter`], but like [`DeferredReader`][crate::DeferredReader]
6/// this performs deferred error checking. This means that any call to [`write`][Write::write], will
7/// always succeed. IO errors that occur during writing will be reported during the next call to
8/// [`flush`][Write::flush] or [`check_io_error`][Self::check_io_error]. Any data written after an
9/// IO error occured, before it is eventually reported, will be discarded.
10///
11/// Deferring error checks like this can result in a significant speed up for some usage patterns.
12pub struct DeferredWriter<'a> {
13    write: Box<dyn Write + 'a>,
14    buf: Vec<u8>,
15    io_error: Option<io::Error>,
16    panicked: bool,
17}
18
19impl<'a> DeferredWriter<'a> {
20    const DEFAULT_CHUNK_SIZE: usize = 16 << 10;
21
22    /// Creates a [`DeferredWriter`] writing data to a [`Write`] instance.
23    pub fn from_write(write: impl Write + 'a) -> Self {
24        Self::from_boxed_dyn_write(Box::new(write))
25    }
26
27    /// Creates a [`DeferredWriter`] writing data to a boxed [`Write`] instance.
28    #[inline(never)]
29    pub fn from_boxed_dyn_write(write: Box<dyn Write + 'a>) -> Self {
30        DeferredWriter {
31            write,
32            buf: Vec::with_capacity(Self::DEFAULT_CHUNK_SIZE),
33            io_error: None,
34            panicked: false,
35        }
36    }
37
38    /// Flush the buffered data to the underlying [`Write`] instance, deferring IO errors.
39    pub fn flush_defer_err(&mut self) {
40        // Silently discard data if we errored before but haven't reported it yet
41        if self.io_error.is_none() {
42            self.panicked = true;
43            if let Err(err) = self.write.write_all(&self.buf) {
44                self.io_error = Some(err);
45            }
46            self.panicked = false;
47        }
48        self.buf.clear();
49    }
50
51    /// Write a slice of bytes, deferring IO errors.
52    ///
53    /// Both, [`write`][Write::write] and [`write_all`][Write::write_all] directly call this method.
54    /// Unlike them, this does not return a `#[must_use]` value, making it clear that this cannot
55    /// return an error.
56    #[inline]
57    pub fn write_all_defer_err(&mut self, buf: &[u8]) {
58        let old_len = self.buf.len();
59        // SAFETY add cannot overflow as both are at most `isize::MAX`.
60        let new_len = old_len + buf.len();
61        if new_len <= self.buf.capacity() {
62            unsafe {
63                // SAFETY this writes to `old_len..new_len` which we just checked to be within the
64                // capacity of `self.buf`
65                self.buf
66                    .as_mut_ptr()
67                    .add(old_len)
68                    .copy_from_nonoverlapping(buf.as_ptr(), buf.len());
69                self.buf.set_len(new_len)
70            }
71        } else {
72            self.write_all_defer_err_cold(buf);
73        }
74    }
75
76    #[inline(never)]
77    #[cold]
78    fn write_all_defer_err_cold(&mut self, mut buf: &[u8]) {
79        // If the passed `buf` is small enough that we don't need an individual write for it alone,
80        // fill our internal `buf` up to capacity.
81        if buf.len() < self.buf.capacity() {
82            // This assumes that we bailed out of the fast path in `write_all_defer_err`, otherwise
83            // the index passed to split_at could be out of bounds and this would panic.
84            let (buf_first, buf_second) = buf.split_at(self.buf.capacity() - self.buf.len());
85            self.buf.extend_from_slice(buf_first);
86            buf = buf_second;
87        }
88
89        // This will leaves us an empty buffer, even if an IO error occured.
90        self.flush_defer_err();
91
92        if buf.len() < self.buf.capacity() {
93            self.buf.extend_from_slice(buf);
94        } else {
95            // If the passed `buf` does not fit into our internal `buf`, we directly write it, again
96            // deferring any IO errors.
97
98            // Silently discard data if we errored before but haven't reported it yet
99            if self.io_error.is_none() {
100                self.panicked = true;
101                if let Err(err) = self.write.write_all(buf) {
102                    self.io_error = Some(err);
103                }
104                self.panicked = false;
105            }
106        }
107    }
108
109    /// Returns a pointer to the current write pointer within the internal buffer if sufficient
110    /// space is available.
111    ///
112    /// If only fewer than `len` bytes are available, this returns a null pointer.
113    ///
114    /// Can be used in conjunction with [`advance_unchecked`][Self::advance_unchecked] to construct
115    /// output data directly within the output buffer, potentially avoiding a redundant copy.
116    #[inline]
117    pub fn buf_write_ptr(&mut self, len: usize) -> *mut u8 {
118        let old_len = self.buf.len();
119        // SAFETY add cannot overflow as both are at most `isize::MAX`.
120        let new_len = old_len + len;
121        if new_len <= self.buf.capacity() {
122            // SAFETY this returns the offset to `old_len` which is always in range
123            unsafe { self.buf.as_mut_ptr().add(old_len) }
124        } else {
125            std::ptr::null_mut()
126        }
127    }
128
129    /// Advances the write pointer within the internal buffer.
130    ///
131    /// # Safety
132    ///
133    /// This assumes that a) there is sufficient space left in the buffer and b) that the bytes the
134    /// pointer is advanced over were initialized prior to calling this (via
135    /// [`buf_write_ptr`](Self::buf_write_ptr)).
136    ///
137    /// If either assumption does not hold calling this results in undefined behavior.
138    #[inline]
139    pub unsafe fn advance_unchecked(&mut self, len: usize) {
140        let old_len = self.buf.len();
141        // SAFETY add cannot overflow as both are at most `isize::MAX`.
142        let new_len = old_len + len;
143        debug_assert!(new_len <= self.buf.capacity());
144        self.buf.set_len(new_len)
145    }
146
147    /// Returns an encountered IO errors as `Err(io_err)`.
148    ///
149    /// This resets the stored IO error and returns `Ok(())` if no IO error is stored.
150    #[inline]
151    pub fn check_io_error(&mut self) -> io::Result<()> {
152        if let Some(err) = self.io_error.take() {
153            Err(err)
154        } else {
155            Ok(())
156        }
157    }
158}
159
160impl<'a> Write for DeferredWriter<'a> {
161    #[inline]
162    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
163        self.write_all_defer_err(buf);
164        Ok(buf.len())
165    }
166
167    #[inline]
168    fn flush(&mut self) -> io::Result<()> {
169        self.flush_defer_err();
170        self.check_io_error()
171    }
172
173    #[inline]
174    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
175        self.write_all_defer_err(buf);
176        Ok(())
177    }
178}
179
180impl<'a> Drop for DeferredWriter<'a> {
181    fn drop(&mut self) {
182        if !self.panicked {
183            self.flush_defer_err();
184        }
185    }
186}