async_io_map/
write.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5
6use futures_lite::{
7    io::{self, Result},
8    ready, AsyncWrite,
9};
10
11use crate::DEFAULT_BUFFER_SIZE;
12
13/// A trait for mapping data written to an underlying writer.
14pub trait MapWriteFn {
15    /// Applies a mapping function to the data before writing it to the underlying writer.
16    /// This function takes a mutable reference to a buffer and modifies it in place.
17    /// 
18    /// Be aware that changing the capacity of the buffer will affect any subsequent writes,
19    /// if this is not intended, ensure to reset the capacity of the buffer after processing.
20    /// 
21    /// This behavior is intended to allow for a variety of use cases, such as base64 encoding,
22    /// which may require expanding the buffer size to accommodate the transformed data.
23    fn map_write(&mut self, buf: &mut Vec<u8>);
24}
25
26impl<F> MapWriteFn for F
27where
28    F: FnMut(&mut Vec<u8>),
29{
30    fn map_write(&mut self, buf: &mut Vec<u8>) {
31        self(buf)
32    }
33}
34
35pin_project_lite::pin_project! {
36  /// A wrapper around an `AsyncWrite` that allows for data processing
37  /// before the actual I/O operation.
38  /// 
39  /// This struct buffers the data written to the underlying writer and applies a mapping function
40  /// to the data before writing it out. It is designed to optimize writes by using a buffer
41  /// of a specified size (default is 8KB).
42  /// 
43  /// The buffer size also acts as a threshold for the length of data passed to the mapping function, 
44  /// and will be gauranteed to be equal to or less than the specified capacity, unless the 
45  /// function modifies the buffer capacity itself.
46  pub struct AsyncMapWriter<'a, W> {
47     #[pin]
48     inner: W,
49     process_fn: Box<dyn MapWriteFn + 'a>,
50     buf: Vec<u8>, // Buffer to hold data before writing
51     written: usize, // Track how much has been written to the buffer
52     transformed: bool, // Add a flag to track if the buffer is already transformed
53  }
54}
55
56impl<'a, W: AsyncWrite> AsyncMapWriter<'a, W> {
57    /// Creates a new `AsyncMapWriter` with a default buffer size of 8KB.
58    /// 
59    /// This function initializes the writer with the provided `process_fn` to map the data before writing.
60    pub fn new(writer: W, process_fn: impl MapWriteFn + 'a) -> Self {
61      Self::with_capacity(writer, process_fn, DEFAULT_BUFFER_SIZE)
62    }
63    
64    /// Creates a new `AsyncMapWriter` with a specified buffer capacity.
65    /// 
66    /// This function initializes the writer with the provided `process_fn` to map the data before writing.
67    pub fn with_capacity(writer: W, process_fn: impl MapWriteFn + 'a, capacity: usize) -> Self {
68        Self {
69            inner: writer,
70            process_fn: Box::new(process_fn),
71            buf: Vec::with_capacity(capacity),
72            written: 0,
73            transformed: false,
74        }
75    }
76
77    /// Consumes the `AsyncMapWriter` and returns the underlying writer.
78    pub fn into_inner(self) -> W {
79        self.inner
80    }
81
82    fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
83        self.project().inner
84    }
85
86    /// Flushes the internal buffer, applying the mapping function if necessary.
87    /// This function writes the transformed data to the underlying writer.
88    fn poll_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
89        let mut this = self.project();
90        // If nothing has been written yet and the buffer isn't transformed, apply the transformation
91        if *this.written == 0 && !this.buf.is_empty() && !*this.transformed {
92            (this.process_fn).map_write(this.buf);
93            *this.transformed = true; // Mark as transformed
94        }
95        let len = this.buf.len();
96        let mut ret = Ok(());
97
98        while *this.written < len {
99            match this
100                .inner
101                .as_mut()
102                .poll_write(cx, &this.buf[*this.written..])
103            {
104                Poll::Ready(Ok(0)) => {
105                    ret = Err(io::Error::new(io::ErrorKind::WriteZero, "write zero"));
106                    break;
107                }
108                Poll::Ready(Ok(n)) => {
109                    *this.written += n;
110                }
111                Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::Interrupted => {}
112                Poll::Ready(Err(e)) => {
113                    ret = Err(e);
114                    break;
115                }
116                Poll::Pending => {
117                    return Poll::Pending;
118                }
119            }
120        }
121
122        if *this.written > 0 {
123            this.buf.drain(..*this.written);
124        }
125        *this.written = 0;
126        *this.transformed = false; // Reset transformed flag when buffer is drained
127
128        Poll::Ready(ret)
129    }
130
131    /// Handles large writes by processing the data before writing it to the underlying writer.
132    /// This function ensures that the internal buffer is transformed before writing.
133    /// 
134    /// returns the number of bytes written to the internal buffer.
135    fn partial_write(self: Pin<&mut Self>, buf: &[u8]) -> usize {
136        let this = self.project();
137        debug_assert!(
138            !*this.transformed,
139            "large_write should only be called when the buffer is not transformed"
140        );
141        // Determine how many bytes can fit into the unused part of the internal buffer.
142        let available = this.buf.capacity() - this.buf.len();
143        let to_read = available.min(buf.len());
144
145        // Only append if there's space.
146        if to_read > 0 {
147            this.buf.extend_from_slice(&buf[..to_read]);
148            // If not yet transformed, process the accumulated data.
149            if !*this.transformed {
150                (this.process_fn).map_write(this.buf);
151                *this.transformed = true;
152            }
153        }
154        to_read
155    }
156}
157
158impl<W: AsyncWrite> AsyncWrite for AsyncMapWriter<'_, W> {
159    fn poll_write(
160        mut self: Pin<&mut Self>,
161        cx: &mut Context<'_>,
162        buf: &[u8],
163    ) -> Poll<Result<usize>> {
164        // Flush the internal buffer if adding new data would exceed capacity.
165        if self.buf.len() + buf.len() > self.buf.capacity() {
166            ready!(self.as_mut().poll_flush_buf(cx))?;
167        }
168
169        if buf.len() < self.buf.capacity() {
170            // For small writes, write into our internal buffer so that the
171            // mapping function is applied later in poll_flush_buf.
172            return Pin::new(&mut *self.project().buf).poll_write(cx, buf);
173        }
174        // If data is large, process it before writing using the internal buffer.
175        let read = self.as_mut().partial_write(buf);
176
177        // Instead of attempting to write immediately and potentially leaving
178        // data behind, we'll just report however many bytes we've processed
179        // so far.
180        Poll::Ready(Ok(read))
181    }
182
183    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
184        ready!(self.as_mut().poll_flush_buf(cx))?;
185        self.get_pin_mut().poll_flush(cx)
186    }
187
188    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
189        ready!(self.as_mut().poll_flush_buf(cx))?;
190        self.get_pin_mut().poll_close(cx)
191    }
192}
193
194/// A trait for types that can be mapped to an `AsyncMapWriter`.
195pub trait AsyncMapWrite<'a, W> {
196    /// Maps the data written to the writer using the provided function.
197    /// 
198    /// This function will apply the mapping function to the data before writing it to the underlying writer.
199    /// This also buffers the data (with a buffer size of 8KB) to optimize writes.
200    fn map(self, process_fn: impl MapWriteFn + 'a) -> AsyncMapWriter<'a, W>
201    where
202        Self: Sized,
203    {
204        self.map_with_capacity(process_fn, DEFAULT_BUFFER_SIZE)
205    }
206
207    /// Maps the data written to the writer using the provided function with a specified buffer capacity.
208    /// 
209    /// This function allows you to specify the size of the internal buffer used for writing.
210    /// The default buffer size is 8KB.
211    /// If you need to optimize for larger writes, you can increase this size.
212    fn map_with_capacity(
213        self,
214        process_fn: impl MapWriteFn + 'a,
215        capacity: usize,
216    ) -> AsyncMapWriter<'a, W>;
217}
218
219impl<'a, W: AsyncWrite> AsyncMapWrite<'a, W> for W {
220    fn map_with_capacity(
221        self,
222        process_fn: impl MapWriteFn + 'a,
223        capacity: usize,
224    ) -> AsyncMapWriter<'a, W> {
225        AsyncMapWriter::with_capacity(self, process_fn, capacity)
226    }
227}