use std::io;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::ReadBuf;
pub trait Transform {
fn transform(self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()>;
fn poll_finish(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let (_, _) = (cx, buf);
Poll::Ready(Ok(()))
}
}
pub struct TransformBuf<'a, 'b> {
pub(crate) buf: &'a mut ReadBuf<'b>,
pub(crate) cursor: usize,
}
impl TransformBuf<'_, '_> {
pub fn fresh(&self) -> &[u8] {
&self.filled()[self.cursor..]
}
pub fn fresh_mut(&mut self) -> &mut [u8] {
let cursor = self.cursor;
&mut self.filled_mut()[cursor..]
}
pub fn spoil(&mut self) {
let cursor = self.cursor;
self.set_filled(cursor);
}
}
pub struct Inspect(pub(crate) Box<dyn FnMut(&[u8]) + Send + Sync + 'static>);
impl Transform for Inspect {
fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()> {
(self.0)(buf.fresh());
Ok(())
}
}
pub struct InPlaceMap(
pub(crate) Box<dyn FnMut(&mut TransformBuf<'_, '_>) -> io::Result<()> + Send + Sync + 'static>,
);
impl Transform for InPlaceMap {
fn transform(mut self: Pin<&mut Self>, buf: &mut TransformBuf<'_, '_>) -> io::Result<()> {
(self.0)(buf)
}
}
impl<'a, 'b> Deref for TransformBuf<'a, 'b> {
type Target = ReadBuf<'b>;
fn deref(&self) -> &Self::Target {
self.buf
}
}
impl<'a, 'b> DerefMut for TransformBuf<'a, 'b> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.buf
}
}
#[cfg(test)]
#[allow(deprecated)]
mod tests {
use std::hash::SipHasher;
use std::sync::{
atomic::{AtomicU64, AtomicU8, Ordering},
Arc,
};
use parking_lot::Mutex;
use ubyte::ToByteUnit;
use crate::fairing::AdHoc;
use crate::http::Method;
use crate::local::blocking::Client;
use crate::{route, Data, Request, Response, Route};
mod hash_transform {
use std::hash::Hasher;
use std::io::Cursor;
use tokio::io::AsyncRead;
use super::super::*;
pub struct HashTransform<H: Hasher> {
pub(crate) hasher: H,
pub(crate) hash: Option<Cursor<[u8; 8]>>,
}
impl<H: Hasher + Unpin> Transform for HashTransform<H> {
fn transform(
mut self: Pin<&mut Self>,
buf: &mut TransformBuf<'_, '_>,
) -> io::Result<()> {
self.hasher.write(buf.fresh());
buf.spoil();
Ok(())
}
fn poll_finish(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.hash.is_none() {
let hash = self.hasher.finish();
self.hash = Some(Cursor::new(hash.to_be_bytes()));
}
let cursor = self.hash.as_mut().unwrap();
Pin::new(cursor).poll_read(cx, buf)
}
}
impl crate::Data<'_> {
pub fn chain_hash_transform<H: std::hash::Hasher>(&mut self, hasher: H) -> &mut Self
where
H: Unpin + Send + Sync + 'static,
{
self.chain_transform(HashTransform { hasher, hash: None })
}
}
}
#[test]
fn test_transform_series() {
fn handler<'r>(_: &'r Request<'_>, data: Data<'r>) -> route::BoxFuture<'r> {
Box::pin(async move {
data.open(128.bytes())
.stream_to(tokio::io::sink())
.await
.expect("read ok");
route::Outcome::Success(Response::new())
})
}
let inspect2: Arc<AtomicU8> = Arc::new(AtomicU8::new(0));
let raw_data: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(Vec::new()));
let hash: Arc<AtomicU64> = Arc::new(AtomicU64::new(0));
let rocket = crate::build()
.manage(hash.clone())
.manage(raw_data.clone())
.manage(inspect2.clone())
.mount("/", vec![Route::new(Method::Post, "/", handler)])
.attach(AdHoc::on_request("transforms", |req, data| {
Box::pin(async {
let hash1 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
let hash2 = req.rocket().state::<Arc<AtomicU64>>().cloned().unwrap();
let raw_data = req
.rocket()
.state::<Arc<Mutex<Vec<u8>>>>()
.cloned()
.unwrap();
let inspect2 = req.rocket().state::<Arc<AtomicU8>>().cloned().unwrap();
data.chain_inspect(move |bytes| {
*raw_data.lock() = bytes.to_vec();
})
.chain_hash_transform(SipHasher::new())
.chain_inspect(move |bytes| {
assert_eq!(bytes.len(), 8);
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
let value = u64::from_be_bytes(bytes);
hash1.store(value, Ordering::Release);
})
.chain_inspect(move |bytes| {
assert_eq!(bytes.len(), 8);
let bytes: [u8; 8] = bytes.try_into().expect("[u8; 8]");
let value = u64::from_be_bytes(bytes);
let prev = hash2.load(Ordering::Acquire);
assert_eq!(prev, value);
inspect2.fetch_add(1, Ordering::Release);
});
})
}));
assert!(raw_data.lock().is_empty());
assert_eq!(hash.load(Ordering::Acquire), 0);
assert_eq!(inspect2.load(Ordering::Acquire), 0);
let client = Client::debug(rocket).unwrap();
client.get("/").body("Hello, world!").dispatch();
assert!(raw_data.lock().is_empty());
assert_eq!(hash.load(Ordering::Acquire), 0);
assert_eq!(inspect2.load(Ordering::Acquire), 0);
client.post("/").body("Hello, world!").dispatch();
assert_eq!(raw_data.lock().as_slice(), "Hello, world!".as_bytes());
assert_eq!(hash.load(Ordering::Acquire), 0xae5020d7cf49d14f);
assert_eq!(inspect2.load(Ordering::Acquire), 1);
let string = "Rocket, Rocket, where art thee? Oh, tis in the sky, I see!";
client.post("/").body(string).dispatch();
assert_eq!(raw_data.lock().as_slice(), string.as_bytes());
assert_eq!(hash.load(Ordering::Acquire), 0x323f9aa98f907faf);
assert_eq!(inspect2.load(Ordering::Acquire), 2);
}
}