rocket_community/data/
transform.rs

1use std::io;
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use tokio::io::ReadBuf;
7
8/// Chainable, in-place, streaming data transformer.
9///
10/// [`Transform`] operates on [`TransformBuf`]s similar to how [`AsyncRead`]
11/// operats on [`ReadBuf`]. A [`Transform`] sits somewhere in a chain of
12/// transforming readers. The head (most upstream part) of the chain is _always_
13/// an [`AsyncRead`]: the data source. The tail (all downstream parts) is
14/// composed _only_ of other [`Transform`]s:
15///
16/// ```text
17///                          downstream --->
18///  AsyncRead | Transform | .. | Transform
19/// <---- upstream
20/// ```
21///
22/// When the upstream source makes data available, the
23/// [`Transform::transform()`] method is called. [`Transform`]s may obtain the
24/// subset of the filled section added by an upstream data source with
25/// [`TransformBuf::fresh()`]. They may modify this data at will, potentially
26/// changing the size of the filled section. For example,
27/// [`TransformBuf::spoil()`] "removes" all of the fresh data, and
28/// [`TransformBuf::fresh_mut()`] can be used to modify the data in-place.
29///
30/// Additionally, new data may be added in-place via the traditional approach:
31/// write to (or overwrite) the initialized section of the buffer and mark it as
32/// filled. All of the remaining filled data will be passed to downstream
33/// transforms as "fresh" data. To add data to the end of the (potentially
34/// rewritten) stream, the [`Transform::poll_finish()`] method can be
35/// implemented.
36///
37/// [`AsyncRead`]: tokio::io::AsyncRead
38pub trait Transform {
39    /// Called when data is read from the upstream source. For any given fresh
40    /// data, this method is called only once. [`TransformBuf::fresh()`] is
41    /// guaranteed to contain at least one byte.
42    ///
43    /// While this method is not _async_ (it does not return [`Poll`]), it is
44    /// nevertheless executed in an async context and should respect all such
45    /// restrictions including not blocking.
46    fn transform(self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()>;
47
48    /// Called when the upstream is finished, that is, it has no more data to
49    /// fill. At this point, the transform becomes an async reader. This method
50    /// thus has identical semantics to [`AsyncRead::poll_read()`]. This method
51    /// may never be called if the upstream does not finish.
52    ///
53    /// The default implementation returns `Poll::Ready(Ok(()))`.
54    ///
55    /// [`AsyncRead::poll_read()`]: tokio::io::AsyncRead::poll_read()
56    fn poll_finish(
57        self: Pin<&mut Self>,
58        cx: &mut Context<'_>,
59        buf: &mut ReadBuf<'_>,
60    ) -> Poll<io::Result<()>> {
61        let (_, _) = (cx, buf);
62        Poll::Ready(Ok(()))
63    }
64}
65
66/// A buffer of transformable streaming data.
67///
68/// # Overview
69///
70/// A byte buffer, similar to a [`ReadBuf`], with a "fresh" dimension. Fresh
71/// data is always a subset of the filled data, filled data is always a subset
72/// of initialized data, and initialized data is always a subset of the buffer
73/// itself. Both the filled and initialized data sections are guaranteed to be
74/// at the start of the buffer, but the fresh subset is likely to begin
75/// somewhere inside the filled section.
76///
77/// To visualize this, the diagram below represents a possible state for the
78/// byte buffer being tracked. The square `[ ]` brackets represent the complete
79/// buffer, while the curly `{ }` represent the named subset.
80///
81/// ```text
82/// [  { !! fresh !! }                                 ]
83/// { +++ filled +++ }          unfilled               ]
84/// { ----- initialized ------ }     uninitialized     ]
85/// [                    capacity                      ]
86/// ```
87///
88/// The same buffer represented in its true single dimension is below:
89///
90/// ```text
91/// [ ++!!!!!!!!!!!!!!---------xxxxxxxxxxxxxxxxxxxxxxxx]
92/// ```
93///
94/// * `+`: filled (implies initialized)
95/// * `!`: fresh (implies filled)
96/// * `-`: unfilled / initialized (implies initialized)
97/// * `x`: uninitialized (implies unfilled)
98///
99/// As with [`ReadBuf`], [`AsyncRead`] readers fill the initialized portion of a
100/// [`TransformBuf`] to indicate that data is available. _Filling_ initialized
101/// portions of the byte buffers is what increases the size of the _filled_
102/// section. Because a [`ReadBuf`] may already be partially filled when a reader
103/// adds bytes to it, a mechanism to track where the _newly_ filled portion
104/// exists is needed. This is exactly what the "fresh" section tracks.
105///
106/// [`AsyncRead`]: tokio::io::AsyncRead
107pub struct TransformBuf<'a, 'b> {
108    pub(crate) buf: &'a mut ReadBuf<'b>,
109    pub(crate) cursor: usize,
110}
111
112impl TransformBuf<'_, '_> {
113    /// Returns a borrow to the fresh data: data filled by the upstream source.
114    pub fn fresh(&self) -> &[u8] {
115        &self.filled()[self.cursor..]
116    }
117
118    /// Returns a mutable borrow to the fresh data: data filled by the upstream
119    /// source.
120    pub fn fresh_mut(&mut self) -> &mut [u8] {
121        let cursor = self.cursor;
122        &mut self.filled_mut()[cursor..]
123    }
124
125    /// Spoils the fresh data by resetting the filled section to its value
126    /// before any new data was added. As a result, the data will never be seen
127    /// by any downstream consumer unless it is returned via another mechanism.
128    pub fn spoil(&mut self) {
129        let cursor = self.cursor;
130        self.set_filled(cursor);
131    }
132}
133
134pub struct Inspect(pub(crate) Box<dyn FnMut(&[u8]) + Send + Sync + 'static>);
135
136impl Transform for Inspect {
137    fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()> {
138        (self.0)(buf.fresh());
139        Ok(())
140    }
141}
142
143pub struct InPlaceMap(
144    pub(crate) Box<dyn FnMut(&mut TransformBuf<'_, '_>) -> io::Result<()> + Send + Sync + 'static>,
145);
146
147impl Transform for InPlaceMap {
148    fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()> {
149        (self.0)(buf)
150    }
151}
152
153impl<'a, 'b> Deref for TransformBuf<'a, 'b> {
154    type Target = ReadBuf<'b>;
155
156    fn deref(&self) -> &Self::Target {
157        self.buf
158    }
159}
160
161impl<'a, 'b> DerefMut for TransformBuf<'a, 'b> {
162    fn deref_mut(&mut self) -> &mut Self::Target {
163        self.buf
164    }
165}
166
167// TODO: Test chaining various transform combinations:
168//  * consume | consume
169//  * add | consume
170//  * consume | add
171//  * add | add
172// Where `add` is a transformer that adds data to the stream, and `consume` is
173// one that removes data.
174#[cfg(test)]
175#[allow(deprecated)]
176mod tests {
177    use std::hash::SipHasher;
178    use std::sync::{
179        atomic::{AtomicU64, AtomicU8, Ordering},
180        Arc,
181    };
182
183    use parking_lot::Mutex;
184    use ubyte::ToByteUnit;
185
186    use crate::fairing::AdHoc;
187    use crate::http::Method;
188    use crate::local::blocking::Client;
189    use crate::{route, Data, Request, Response, Route};
190
191    mod hash_transform {
192        use std::hash::Hasher;
193        use std::io::Cursor;
194
195        use tokio::io::AsyncRead;
196
197        use super::super::*;
198
199        pub struct HashTransform<H: Hasher> {
200            pub(crate) hasher: H,
201            pub(crate) hash: Option<Cursor<[u8; 8]>>,
202        }
203
204        impl<H: Hasher + Unpin> Transform for HashTransform<H> {
205            fn transform(
206                mut self: Pin<&mut Self>,
207                buf: &mut TransformBuf<'_, '_>,
208            ) -> io::Result<()> {
209                self.hasher.write(buf.fresh());
210                buf.spoil();
211                Ok(())
212            }
213
214            fn poll_finish(
215                mut self: Pin<&mut Self>,
216                cx: &mut Context<'_>,
217                buf: &mut ReadBuf<'_>,
218            ) -> Poll<io::Result<()>> {
219                if self.hash.is_none() {
220                    let hash = self.hasher.finish();
221                    self.hash = Some(Cursor::new(hash.to_be_bytes()));
222                }
223
224                let cursor = self.hash.as_mut().unwrap();
225                Pin::new(cursor).poll_read(cx, buf)
226            }
227        }
228
229        impl crate::Data<'_> {
230            /// Chain an in-place hash [`Transform`] to `self`.
231            pub fn chain_hash_transform<H: std::hash::Hasher>(&mut self, hasher: H) -> &mut Self
232            where
233                H: Unpin + Send + Sync + 'static,
234            {
235                self.chain_transform(HashTransform { hasher, hash: None })
236            }
237        }
238    }
239
240    #[test]
241    fn test_transform_series() {
242        fn handler<'r>(_: &'r Request<'_>, data: Data<'r>) -> route::BoxFuture<'r> {
243            Box::pin(async move {
244                data.open(128.bytes())
245                    .stream_to(tokio::io::sink())
246                    .await
247                    .expect("read ok");
248                route::Outcome::Success(Response::new())
249            })
250        }
251
252        let inspect2: Arc<AtomicU8> = Arc::new(AtomicU8::new(0));
253        let raw_data: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(Vec::new()));
254        let hash: Arc<AtomicU64> = Arc::new(AtomicU64::new(0));
255        let rocket = crate::build()
256            .manage(hash.clone())
257            .manage(raw_data.clone())
258            .manage(inspect2.clone())
259            .mount("/", vec![Route::new(Method::Post, "/", handler)])
260            .attach(AdHoc::on_request("transforms", |req, data| {
261                Box::pin(async {
262                    let hash1 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
263                    let hash2 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
264                    let raw_data = req
265                        .rocket()
266                        .state::<Arc<Mutex<Vec<u8>>>>()
267                        .cloned()
268                        .unwrap();
269                    let inspect2 = req.rocket().state::<Arc<AtomicU8>>().cloned().unwrap();
270                    data.chain_inspect(move |bytes| {
271                        *raw_data.lock() = bytes.to_vec();
272                    })
273                    .chain_hash_transform(SipHasher::new())
274                    .chain_inspect(move |bytes| {
275                        assert_eq!(bytes.len(), 8);
276                        let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
277                        let value = u64::from_be_bytes(bytes);
278                        hash1.store(value, Ordering::Release);
279                    })
280                    .chain_inspect(move |bytes| {
281                        assert_eq!(bytes.len(), 8);
282                        let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
283                        let value = u64::from_be_bytes(bytes);
284                        let prev = hash2.load(Ordering::Acquire);
285                        assert_eq!(prev, value);
286                        inspect2.fetch_add(1, Ordering::Release);
287                    });
288                })
289            }));
290
291        // Make sure nothing has happened yet.
292        assert!(raw_data.lock().is_empty());
293        assert_eq!(hash.load(Ordering::Acquire), 0);
294        assert_eq!(inspect2.load(Ordering::Acquire), 0);
295
296        // Check that nothing happens if the data isn't read.
297        let client = Client::debug(rocket).unwrap();
298        client.get("/").body("Hello, world!").dispatch();
299        assert!(raw_data.lock().is_empty());
300        assert_eq!(hash.load(Ordering::Acquire), 0);
301        assert_eq!(inspect2.load(Ordering::Acquire), 0);
302
303        // Check inspect + hash + inspect + inspect.
304        client.post("/").body("Hello, world!").dispatch();
305        assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes());
306        assert_eq!(hash.load(Ordering::Acquire), 0xae5020d7cf49d14f);
307        assert_eq!(inspect2.load(Ordering::Acquire), 1);
308
309        // Check inspect + hash + inspect + inspect, round 2.
310        let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!";
311        client.post("/").body(string).dispatch();
312        assert_eq!(raw_data.lock().as_slice(), string.as_bytes());
313        assert_eq!(hash.load(Ordering::Acquire), 0x323f9aa98f907faf);
314        assert_eq!(inspect2.load(Ordering::Acquire), 2);
315    }
316}