hyper_caching_body/
lib.rs

1//! Cacheable HTTP body
2//!
3//! See [`CachingBody`] for details
4
5//! # Example
6//!
7//!
8//!
9//! ```no_run
10//!
11//! # {
12//!use hyper::{server::conn::http1, service::service_fn};
13//!use hyper_caching_body::CachingBody;
14//!use hyper_util::rt::TokioIo;
15//!use std::net::SocketAddr;
16//!use tokio::net::{TcpListener, TcpStream};
17//!
18//!#[tokio::main]
19//!async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//!    let in_addr: SocketAddr = ([127, 0, 0, 1], 3001).into();
21//!    let out_addr: SocketAddr = ([127, 0, 0, 1], 8080).into();
22//!
23//!    let out_addr_clone = out_addr;
24//!
25//!    let listener = TcpListener::bind(in_addr).await?;
26//!
27//!    println!("Listening on http://{}", in_addr);
28//!    println!("Proxying on http://{}", out_addr);
29//!
30//!    loop {
31//!        let (stream, _) = listener.accept().await?;
32//!        let io = TokioIo::new(stream);
33//!
34//!        let service = service_fn(move |mut req| {
35//!            let uri_string = format!(
36//!                "http://{}{}",
37//!                out_addr_clone,
38//!                req.uri()
39//!                    .path_and_query()
40//!                    .map(|x| x.as_str())
41//!                    .unwrap_or("/")
42//!            );
43//!            let uri = uri_string.parse().unwrap();
44//!            *req.uri_mut() = uri;
45//!
46//!            let host = req.uri().host().expect("uri has no host");
47//!            let port = req.uri().port_u16().unwrap_or(80);
48//!            let addr = format!("{}:{}", host, port);
49//!
50//!            async move {
51//!                let client_stream = TcpStream::connect(addr).await.unwrap();
52//!                let io = TokioIo::new(client_stream);
53//!
54//!                let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
55//!                tokio::task::spawn(async move {
56//!                    if let Err(err) = conn.await {
57//!                        println!("Connection failed: {:?}", err);
58//!                    }
59//!                });
60//!
61//!                // Here we create a channel to receive buffer contents in another task
62//!
63//!                let (tx, mut rx) = tokio::sync::mpsc::channel(1);
64//!                let res = sender
65//!                    .send_request(req)
66//!                    .await
67//!                    .map(|r| r.map(|b| CachingBody::new(b, tx))); // Wrap the body
68//!
69//!                // Spawn a task to receive buffe`r contents in another task
70//!                tokio::task::spawn(async move {
71//!                    if let Some(body) = rx.recv().await {
72//!                        dbg!(body);
73//!                    }
74//!                });
75//!                res
76//!            }
77//!        });
78//!
79//!        tokio::task::spawn(async move {
80//!            if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
81//!                println!("Failed to serve the connection: {:?}", err);
82//!            }
83//!        });
84//!    }
85//!}
86//! # }
87//! ```
88use bytes::{Bytes, BytesMut};
89use std::sync::Mutex;
90use std::{
91    pin::Pin,
92    sync::Arc,
93    task::{Context, Poll},
94};
95
96/// A wrapper for [`hyper::body::Incoming`] that caches the contents in-flight
97/// On each
98#[derive(Debug)]
99pub struct CachingBody {
100    /// Inner body
101    body: hyper::body::Incoming,
102    /// Zero-copy buffer for received content
103    buf: Arc<Mutex<BytesMut>>,
104    /// Sender for buffered body
105    tx: tokio::sync::mpsc::Sender<Bytes>,
106}
107
108impl CachingBody {
109    /// Create a [`CachingBody`] wrapping [`hyper::body::Incoming`]
110    ///
111    /// On each successful poll with [`Some(Ok(http_body::Frame<bytes::Bytes>))`] writes to the underlying buf
112    ///
113    /// When the underlying [`hyper::body::Incoming`] reaches the end of stream it is dropped and the buf is sent over a channel
114    pub fn new(body: hyper::body::Incoming, tx: tokio::sync::mpsc::Sender<Bytes>) -> Self {
115        let buf = Arc::new(Mutex::new(BytesMut::new()));
116        Self { body, buf, tx }
117    }
118}
119
120impl Drop for CachingBody {
121    // This is a hack?
122    //
123    // So for some reason poll_frame on Incoming body never returns a Poll::Ready(None) e.g. the end is reached (maybe i'm dumb)
124    fn drop(&mut self) {
125        let guard_res = self.buf.lock();
126        if let Ok(guard) = guard_res {
127            let bytes = guard.clone().freeze();
128            let _ = self.tx.try_send(bytes);
129        }
130    }
131}
132
133impl hyper::body::Body for CachingBody {
134    type Data = Bytes;
135    type Error = hyper::Error;
136
137    fn poll_frame(
138        mut self: Pin<&mut Self>,
139        cx: &mut Context<'_>,
140    ) -> Poll<Option<core::result::Result<http_body::Frame<bytes::Bytes>, Self::Error>>> {
141        let incoming = &mut self.body;
142        let poll = Pin::new(incoming).poll_frame(cx);
143
144        match poll {
145            Poll::Ready(Some(Ok(frame))) => {
146                if let Some(data) = frame.data_ref()
147                    && let Ok(mut guard) = self.buf.lock()
148                {
149                    guard.extend_from_slice(data);
150                }
151                Poll::Ready(Some(Ok(frame)))
152            }
153
154            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
155
156            Poll::Ready(None) => Poll::Ready(None),
157
158            Poll::Pending => Poll::Pending,
159        }
160    }
161
162    fn is_end_stream(&self) -> bool {
163        self.body.is_end_stream()
164    }
165
166    fn size_hint(&self) -> hyper::body::SizeHint {
167        self.body.size_hint()
168    }
169}