rocket_community/data/transform.rs
1use std::io;
2use std::ops::{Deref, DerefMut};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use tokio::io::ReadBuf;
7
8/// Chainable, in-place, streaming data transformer.
9///
10/// [`Transform`] operates on [`TransformBuf`]s similar to how [`AsyncRead`]
11/// operats on [`ReadBuf`]. A [`Transform`] sits somewhere in a chain of
12/// transforming readers. The head (most upstream part) of the chain is _always_
13/// an [`AsyncRead`]: the data source. The tail (all downstream parts) is
14/// composed _only_ of other [`Transform`]s:
15///
16/// ```text
17/// downstream --->
18/// AsyncRead | Transform | .. | Transform
19/// <---- upstream
20/// ```
21///
22/// When the upstream source makes data available, the
23/// [`Transform::transform()`] method is called. [`Transform`]s may obtain the
24/// subset of the filled section added by an upstream data source with
25/// [`TransformBuf::fresh()`]. They may modify this data at will, potentially
26/// changing the size of the filled section. For example,
27/// [`TransformBuf::spoil()`] "removes" all of the fresh data, and
28/// [`TransformBuf::fresh_mut()`] can be used to modify the data in-place.
29///
30/// Additionally, new data may be added in-place via the traditional approach:
31/// write to (or overwrite) the initialized section of the buffer and mark it as
32/// filled. All of the remaining filled data will be passed to downstream
33/// transforms as "fresh" data. To add data to the end of the (potentially
34/// rewritten) stream, the [`Transform::poll_finish()`] method can be
35/// implemented.
36///
37/// [`AsyncRead`]: tokio::io::AsyncRead
38pub trait Transform {
39 /// Called when data is read from the upstream source. For any given fresh
40 /// data, this method is called only once. [`TransformBuf::fresh()`] is
41 /// guaranteed to contain at least one byte.
42 ///
43 /// While this method is not _async_ (it does not return [`Poll`]), it is
44 /// nevertheless executed in an async context and should respect all such
45 /// restrictions including not blocking.
46 fn transform(self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()>;
47
48 /// Called when the upstream is finished, that is, it has no more data to
49 /// fill. At this point, the transform becomes an async reader. This method
50 /// thus has identical semantics to [`AsyncRead::poll_read()`]. This method
51 /// may never be called if the upstream does not finish.
52 ///
53 /// The default implementation returns `Poll::Ready(Ok(()))`.
54 ///
55 /// [`AsyncRead::poll_read()`]: tokio::io::AsyncRead::poll_read()
56 fn poll_finish(
57 self: Pin<&mut Self>,
58 cx: &mut Context<'_>,
59 buf: &mut ReadBuf<'_>,
60 ) -> Poll<io::Result<()>> {
61 let (_, _) = (cx, buf);
62 Poll::Ready(Ok(()))
63 }
64}
65
66/// A buffer of transformable streaming data.
67///
68/// # Overview
69///
70/// A byte buffer, similar to a [`ReadBuf`], with a "fresh" dimension. Fresh
71/// data is always a subset of the filled data, filled data is always a subset
72/// of initialized data, and initialized data is always a subset of the buffer
73/// itself. Both the filled and initialized data sections are guaranteed to be
74/// at the start of the buffer, but the fresh subset is likely to begin
75/// somewhere inside the filled section.
76///
77/// To visualize this, the diagram below represents a possible state for the
78/// byte buffer being tracked. The square `[ ]` brackets represent the complete
79/// buffer, while the curly `{ }` represent the named subset.
80///
81/// ```text
82/// [ { !! fresh !! } ]
83/// { +++ filled +++ } unfilled ]
84/// { ----- initialized ------ } uninitialized ]
85/// [ capacity ]
86/// ```
87///
88/// The same buffer represented in its true single dimension is below:
89///
90/// ```text
91/// [ ++!!!!!!!!!!!!!!---------xxxxxxxxxxxxxxxxxxxxxxxx]
92/// ```
93///
94/// * `+`: filled (implies initialized)
95/// * `!`: fresh (implies filled)
96/// * `-`: unfilled / initialized (implies initialized)
97/// * `x`: uninitialized (implies unfilled)
98///
99/// As with [`ReadBuf`], [`AsyncRead`] readers fill the initialized portion of a
100/// [`TransformBuf`] to indicate that data is available. _Filling_ initialized
101/// portions of the byte buffers is what increases the size of the _filled_
102/// section. Because a [`ReadBuf`] may already be partially filled when a reader
103/// adds bytes to it, a mechanism to track where the _newly_ filled portion
104/// exists is needed. This is exactly what the "fresh" section tracks.
105///
106/// [`AsyncRead`]: tokio::io::AsyncRead
107pub struct TransformBuf<'a, 'b> {
108 pub(crate) buf: &'a mut ReadBuf<'b>,
109 pub(crate) cursor: usize,
110}
111
112impl TransformBuf<'_, '_> {
113 /// Returns a borrow to the fresh data: data filled by the upstream source.
114 pub fn fresh(&self) -> &[u8] {
115 &self.filled()[self.cursor..]
116 }
117
118 /// Returns a mutable borrow to the fresh data: data filled by the upstream
119 /// source.
120 pub fn fresh_mut(&mut self) -> &mut [u8] {
121 let cursor = self.cursor;
122 &mut self.filled_mut()[cursor..]
123 }
124
125 /// Spoils the fresh data by resetting the filled section to its value
126 /// before any new data was added. As a result, the data will never be seen
127 /// by any downstream consumer unless it is returned via another mechanism.
128 pub fn spoil(&mut self) {
129 let cursor = self.cursor;
130 self.set_filled(cursor);
131 }
132}
133
134pub struct Inspect(pub(crate) Box<dyn FnMut(&[u8]) + Send + Sync + 'static>);
135
136impl Transform for Inspect {
137 fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()> {
138 (self.0)(buf.fresh());
139 Ok(())
140 }
141}
142
143pub struct InPlaceMap(
144 pub(crate) Box<dyn FnMut(&mut TransformBuf<'_, '_>) -> io::Result<()> + Send + Sync + 'static>,
145);
146
147impl Transform for InPlaceMap {
148 fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()> {
149 (self.0)(buf)
150 }
151}
152
153impl<'a, 'b> Deref for TransformBuf<'a, 'b> {
154 type Target = ReadBuf<'b>;
155
156 fn deref(&self) -> &Self::Target {
157 self.buf
158 }
159}
160
161impl<'a, 'b> DerefMut for TransformBuf<'a, 'b> {
162 fn deref_mut(&mut self) -> &mut Self::Target {
163 self.buf
164 }
165}
166
167// TODO: Test chaining various transform combinations:
168// * consume | consume
169// * add | consume
170// * consume | add
171// * add | add
172// Where `add` is a transformer that adds data to the stream, and `consume` is
173// one that removes data.
174#[cfg(test)]
175#[allow(deprecated)]
176mod tests {
177 use std::hash::SipHasher;
178 use std::sync::{
179 atomic::{AtomicU64, AtomicU8, Ordering},
180 Arc,
181 };
182
183 use parking_lot::Mutex;
184 use ubyte::ToByteUnit;
185
186 use crate::fairing::AdHoc;
187 use crate::http::Method;
188 use crate::local::blocking::Client;
189 use crate::{route, Data, Request, Response, Route};
190
191 mod hash_transform {
192 use std::hash::Hasher;
193 use std::io::Cursor;
194
195 use tokio::io::AsyncRead;
196
197 use super::super::*;
198
199 pub struct HashTransform<H: Hasher> {
200 pub(crate) hasher: H,
201 pub(crate) hash: Option<Cursor<[u8; 8]>>,
202 }
203
204 impl<H: Hasher + Unpin> Transform for HashTransform<H> {
205 fn transform(
206 mut self: Pin<&mut Self>,
207 buf: &mut TransformBuf<'_, '_>,
208 ) -> io::Result<()> {
209 self.hasher.write(buf.fresh());
210 buf.spoil();
211 Ok(())
212 }
213
214 fn poll_finish(
215 mut self: Pin<&mut Self>,
216 cx: &mut Context<'_>,
217 buf: &mut ReadBuf<'_>,
218 ) -> Poll<io::Result<()>> {
219 if self.hash.is_none() {
220 let hash = self.hasher.finish();
221 self.hash = Some(Cursor::new(hash.to_be_bytes()));
222 }
223
224 let cursor = self.hash.as_mut().unwrap();
225 Pin::new(cursor).poll_read(cx, buf)
226 }
227 }
228
229 impl crate::Data<'_> {
230 /// Chain an in-place hash [`Transform`] to `self`.
231 pub fn chain_hash_transform<H: std::hash::Hasher>(&mut self, hasher: H) -> &mut Self
232 where
233 H: Unpin + Send + Sync + 'static,
234 {
235 self.chain_transform(HashTransform { hasher, hash: None })
236 }
237 }
238 }
239
240 #[test]
241 fn test_transform_series() {
242 fn handler<'r>(_: &'r Request<'_>, data: Data<'r>) -> route::BoxFuture<'r> {
243 Box::pin(async move {
244 data.open(128.bytes())
245 .stream_to(tokio::io::sink())
246 .await
247 .expect("read ok");
248 route::Outcome::Success(Response::new())
249 })
250 }
251
252 let inspect2: Arc<AtomicU8> = Arc::new(AtomicU8::new(0));
253 let raw_data: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(Vec::new()));
254 let hash: Arc<AtomicU64> = Arc::new(AtomicU64::new(0));
255 let rocket = crate::build()
256 .manage(hash.clone())
257 .manage(raw_data.clone())
258 .manage(inspect2.clone())
259 .mount("/", vec![Route::new(Method::Post, "/", handler)])
260 .attach(AdHoc::on_request("transforms", |req, data| {
261 Box::pin(async {
262 let hash1 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
263 let hash2 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
264 let raw_data = req
265 .rocket()
266 .state::<Arc<Mutex<Vec<u8>>>>()
267 .cloned()
268 .unwrap();
269 let inspect2 = req.rocket().state::<Arc<AtomicU8>>().cloned().unwrap();
270 data.chain_inspect(move |bytes| {
271 *raw_data.lock() = bytes.to_vec();
272 })
273 .chain_hash_transform(SipHasher::new())
274 .chain_inspect(move |bytes| {
275 assert_eq!(bytes.len(), 8);
276 let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
277 let value = u64::from_be_bytes(bytes);
278 hash1.store(value, Ordering::Release);
279 })
280 .chain_inspect(move |bytes| {
281 assert_eq!(bytes.len(), 8);
282 let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
283 let value = u64::from_be_bytes(bytes);
284 let prev = hash2.load(Ordering::Acquire);
285 assert_eq!(prev, value);
286 inspect2.fetch_add(1, Ordering::Release);
287 });
288 })
289 }));
290
291 // Make sure nothing has happened yet.
292 assert!(raw_data.lock().is_empty());
293 assert_eq!(hash.load(Ordering::Acquire), 0);
294 assert_eq!(inspect2.load(Ordering::Acquire), 0);
295
296 // Check that nothing happens if the data isn't read.
297 let client = Client::debug(rocket).unwrap();
298 client.get("/").body("Hello, world!").dispatch();
299 assert!(raw_data.lock().is_empty());
300 assert_eq!(hash.load(Ordering::Acquire), 0);
301 assert_eq!(inspect2.load(Ordering::Acquire), 0);
302
303 // Check inspect + hash + inspect + inspect.
304 client.post("/").body("Hello, world!").dispatch();
305 assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes());
306 assert_eq!(hash.load(Ordering::Acquire), 0xae5020d7cf49d14f);
307 assert_eq!(inspect2.load(Ordering::Acquire), 1);
308
309 // Check inspect + hash + inspect + inspect, round 2.
310 let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!";
311 client.post("/").body(string).dispatch();
312 assert_eq!(raw_data.lock().as_slice(), string.as_bytes());
313 assert_eq!(hash.load(Ordering::Acquire), 0x323f9aa98f907faf);
314 assert_eq!(inspect2.load(Ordering::Acquire), 2);
315 }
316}