async_io_map/write.rs
1use std::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5
6use futures_lite::{
7 io::{self, Result},
8 ready, AsyncWrite,
9};
10
11use crate::DEFAULT_BUFFER_SIZE;
12
13/// A trait for mapping data written to an underlying writer.
14pub trait MapWriteFn {
15 /// Applies a mapping function to the data before writing it to the underlying writer.
16 /// This function takes a mutable reference to a buffer and modifies it in place.
17 ///
18 /// Be aware that changing the capacity of the buffer will affect any subsequent writes,
19 /// if this is not intended, ensure to reset the capacity of the buffer after processing.
20 ///
21 /// This behavior is intended to allow for a variety of use cases, such as base64 encoding,
22 /// which may require expanding the buffer size to accommodate the transformed data.
23 fn map_write(&mut self, buf: &mut Vec<u8>);
24}
25
26impl<F> MapWriteFn for F
27where
28 F: FnMut(&mut Vec<u8>),
29{
30 fn map_write(&mut self, buf: &mut Vec<u8>) {
31 self(buf)
32 }
33}
34
35pin_project_lite::pin_project! {
36 /// A wrapper around an `AsyncWrite` that allows for data processing
37 /// before the actual I/O operation.
38 ///
39 /// This struct buffers the data written to the underlying writer and applies a mapping function
40 /// to the data before writing it out. It is designed to optimize writes by using a buffer
41 /// of a specified size (default is 8KB).
42 ///
43 /// The buffer size also acts as a threshold for the length of data passed to the mapping function,
44 /// and will be gauranteed to be equal to or less than the specified capacity, unless the
45 /// function modifies the buffer capacity itself.
46 pub struct AsyncMapWriter<'a, W> {
47 #[pin]
48 inner: W,
49 process_fn: Box<dyn MapWriteFn + 'a>,
50 buf: Vec<u8>, // Buffer to hold data before writing
51 written: usize, // Track how much has been written to the buffer
52 transformed: bool, // Add a flag to track if the buffer is already transformed
53 }
54}
55
56impl<'a, W: AsyncWrite> AsyncMapWriter<'a, W> {
57 /// Creates a new `AsyncMapWriter` with a default buffer size of 8KB.
58 ///
59 /// This function initializes the writer with the provided `process_fn` to map the data before writing.
60 pub fn new(writer: W, process_fn: impl MapWriteFn + 'a) -> Self {
61 Self::with_capacity(writer, process_fn, DEFAULT_BUFFER_SIZE)
62 }
63
64 /// Creates a new `AsyncMapWriter` with a specified buffer capacity.
65 ///
66 /// This function initializes the writer with the provided `process_fn` to map the data before writing.
67 pub fn with_capacity(writer: W, process_fn: impl MapWriteFn + 'a, capacity: usize) -> Self {
68 Self {
69 inner: writer,
70 process_fn: Box::new(process_fn),
71 buf: Vec::with_capacity(capacity),
72 written: 0,
73 transformed: false,
74 }
75 }
76
77 /// Consumes the `AsyncMapWriter` and returns the underlying writer.
78 pub fn into_inner(self) -> W {
79 self.inner
80 }
81
82 fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut W> {
83 self.project().inner
84 }
85
86 /// Flushes the internal buffer, applying the mapping function if necessary.
87 /// This function writes the transformed data to the underlying writer.
88 fn poll_flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
89 let mut this = self.project();
90 // If nothing has been written yet and the buffer isn't transformed, apply the transformation
91 if *this.written == 0 && !this.buf.is_empty() && !*this.transformed {
92 (this.process_fn).map_write(this.buf);
93 *this.transformed = true; // Mark as transformed
94 }
95 let len = this.buf.len();
96 let mut ret = Ok(());
97
98 while *this.written < len {
99 match this
100 .inner
101 .as_mut()
102 .poll_write(cx, &this.buf[*this.written..])
103 {
104 Poll::Ready(Ok(0)) => {
105 ret = Err(io::Error::new(io::ErrorKind::WriteZero, "write zero"));
106 break;
107 }
108 Poll::Ready(Ok(n)) => {
109 *this.written += n;
110 }
111 Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::Interrupted => {}
112 Poll::Ready(Err(e)) => {
113 ret = Err(e);
114 break;
115 }
116 Poll::Pending => {
117 return Poll::Pending;
118 }
119 }
120 }
121
122 if *this.written > 0 {
123 this.buf.drain(..*this.written);
124 }
125 *this.written = 0;
126 *this.transformed = false; // Reset transformed flag when buffer is drained
127
128 Poll::Ready(ret)
129 }
130
131 /// Handles large writes by processing the data before writing it to the underlying writer.
132 /// This function ensures that the internal buffer is transformed before writing.
133 ///
134 /// returns the number of bytes written to the internal buffer.
135 fn partial_write(self: Pin<&mut Self>, buf: &[u8]) -> usize {
136 let this = self.project();
137 debug_assert!(
138 !*this.transformed,
139 "large_write should only be called when the buffer is not transformed"
140 );
141 // Determine how many bytes can fit into the unused part of the internal buffer.
142 let available = this.buf.capacity() - this.buf.len();
143 let to_read = available.min(buf.len());
144
145 // Only append if there's space.
146 if to_read > 0 {
147 this.buf.extend_from_slice(&buf[..to_read]);
148 // If not yet transformed, process the accumulated data.
149 if !*this.transformed {
150 (this.process_fn).map_write(this.buf);
151 *this.transformed = true;
152 }
153 }
154 to_read
155 }
156}
157
158impl<W: AsyncWrite> AsyncWrite for AsyncMapWriter<'_, W> {
159 fn poll_write(
160 mut self: Pin<&mut Self>,
161 cx: &mut Context<'_>,
162 buf: &[u8],
163 ) -> Poll<Result<usize>> {
164 // Flush the internal buffer if adding new data would exceed capacity.
165 if self.buf.len() + buf.len() > self.buf.capacity() {
166 ready!(self.as_mut().poll_flush_buf(cx))?;
167 }
168
169 if buf.len() < self.buf.capacity() {
170 // For small writes, write into our internal buffer so that the
171 // mapping function is applied later in poll_flush_buf.
172 return Pin::new(&mut *self.project().buf).poll_write(cx, buf);
173 }
174 // If data is large, process it before writing using the internal buffer.
175 let read = self.as_mut().partial_write(buf);
176
177 // Instead of attempting to write immediately and potentially leaving
178 // data behind, we'll just report however many bytes we've processed
179 // so far.
180 Poll::Ready(Ok(read))
181 }
182
183 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
184 ready!(self.as_mut().poll_flush_buf(cx))?;
185 self.get_pin_mut().poll_flush(cx)
186 }
187
188 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
189 ready!(self.as_mut().poll_flush_buf(cx))?;
190 self.get_pin_mut().poll_close(cx)
191 }
192}
193
194/// A trait for types that can be mapped to an `AsyncMapWriter`.
195pub trait AsyncMapWrite<'a, W> {
196 /// Maps the data written to the writer using the provided function.
197 ///
198 /// This function will apply the mapping function to the data before writing it to the underlying writer.
199 /// This also buffers the data (with a buffer size of 8KB) to optimize writes.
200 fn map(self, process_fn: impl MapWriteFn + 'a) -> AsyncMapWriter<'a, W>
201 where
202 Self: Sized,
203 {
204 self.map_with_capacity(process_fn, DEFAULT_BUFFER_SIZE)
205 }
206
207 /// Maps the data written to the writer using the provided function with a specified buffer capacity.
208 ///
209 /// This function allows you to specify the size of the internal buffer used for writing.
210 /// The default buffer size is 8KB.
211 /// If you need to optimize for larger writes, you can increase this size.
212 fn map_with_capacity(
213 self,
214 process_fn: impl MapWriteFn + 'a,
215 capacity: usize,
216 ) -> AsyncMapWriter<'a, W>;
217}
218
219impl<'a, W: AsyncWrite> AsyncMapWrite<'a, W> for W {
220 fn map_with_capacity(
221 self,
222 process_fn: impl MapWriteFn + 'a,
223 capacity: usize,
224 ) -> AsyncMapWriter<'a, W> {
225 AsyncMapWriter::with_capacity(self, process_fn, capacity)
226 }
227}