hyper_caching_body/lib.rs
1//! Cacheable HTTP body
2//!
3//! See [`CachingBody`] for details
4
5//! # Example
6//!
7//!
8//!
9//! ```no_run
10//!
11//! # {
12//!use hyper::{server::conn::http1, service::service_fn};
13//!use hyper_caching_body::CachingBody;
14//!use hyper_util::rt::TokioIo;
15//!use std::net::SocketAddr;
16//!use tokio::net::{TcpListener, TcpStream};
17//!
18//!#[tokio::main]
19//!async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! let in_addr: SocketAddr = ([127, 0, 0, 1], 3001).into();
21//! let out_addr: SocketAddr = ([127, 0, 0, 1], 8080).into();
22//!
23//! let out_addr_clone = out_addr;
24//!
25//! let listener = TcpListener::bind(in_addr).await?;
26//!
27//! println!("Listening on http://{}", in_addr);
28//! println!("Proxying on http://{}", out_addr);
29//!
30//! loop {
31//! let (stream, _) = listener.accept().await?;
32//! let io = TokioIo::new(stream);
33//!
34//! let service = service_fn(move |mut req| {
35//! let uri_string = format!(
36//! "http://{}{}",
37//! out_addr_clone,
38//! req.uri()
39//! .path_and_query()
40//! .map(|x| x.as_str())
41//! .unwrap_or("/")
42//! );
43//! let uri = uri_string.parse().unwrap();
44//! *req.uri_mut() = uri;
45//!
46//! let host = req.uri().host().expect("uri has no host");
47//! let port = req.uri().port_u16().unwrap_or(80);
48//! let addr = format!("{}:{}", host, port);
49//!
50//! async move {
51//! let client_stream = TcpStream::connect(addr).await.unwrap();
52//! let io = TokioIo::new(client_stream);
53//!
54//! let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
55//! tokio::task::spawn(async move {
56//! if let Err(err) = conn.await {
57//! println!("Connection failed: {:?}", err);
58//! }
59//! });
60//!
61//! // Here we create a channel to receive buffer contents in another task
62//!
63//! let (tx, mut rx) = tokio::sync::mpsc::channel(1);
64//! let res = sender
65//! .send_request(req)
66//! .await
67//! .map(|r| r.map(|b| CachingBody::new(b, tx))); // Wrap the body
68//!
69//! // Spawn a task to receive buffe`r contents in another task
70//! tokio::task::spawn(async move {
71//! if let Some(body) = rx.recv().await {
72//! dbg!(body);
73//! }
74//! });
75//! res
76//! }
77//! });
78//!
79//! tokio::task::spawn(async move {
80//! if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
81//! println!("Failed to serve the connection: {:?}", err);
82//! }
83//! });
84//! }
85//!}
86//! # }
87//! ```
88use bytes::{Bytes, BytesMut};
89use std::sync::Mutex;
90use std::{
91 pin::Pin,
92 sync::Arc,
93 task::{Context, Poll},
94};
95
96/// A wrapper for [`hyper::body::Incoming`] that caches the contents in-flight
97/// On each
98#[derive(Debug)]
99pub struct CachingBody {
100 /// Inner body
101 body: hyper::body::Incoming,
102 /// Zero-copy buffer for received content
103 buf: Arc<Mutex<BytesMut>>,
104 /// Sender for buffered body
105 tx: tokio::sync::mpsc::Sender<Bytes>,
106}
107
108impl CachingBody {
109 /// Create a [`CachingBody`] wrapping [`hyper::body::Incoming`]
110 ///
111 /// On each successful poll with [`Some(Ok(http_body::Frame<bytes::Bytes>))`] writes to the underlying buf
112 ///
113 /// When the underlying [`hyper::body::Incoming`] reaches the end of stream it is dropped and the buf is sent over a channel
114 pub fn new(body: hyper::body::Incoming, tx: tokio::sync::mpsc::Sender<Bytes>) -> Self {
115 let buf = Arc::new(Mutex::new(BytesMut::new()));
116 Self { body, buf, tx }
117 }
118}
119
120impl Drop for CachingBody {
121 // This is a hack?
122 //
123 // So for some reason poll_frame on Incoming body never returns a Poll::Ready(None) e.g. the end is reached (maybe i'm dumb)
124 fn drop(&mut self) {
125 let guard_res = self.buf.lock();
126 if let Ok(guard) = guard_res {
127 let bytes = guard.clone().freeze();
128 let _ = self.tx.try_send(bytes);
129 }
130 }
131}
132
133impl hyper::body::Body for CachingBody {
134 type Data = Bytes;
135 type Error = hyper::Error;
136
137 fn poll_frame(
138 mut self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 ) -> Poll<Option<core::result::Result<http_body::Frame<bytes::Bytes>, Self::Error>>> {
141 let incoming = &mut self.body;
142 let poll = Pin::new(incoming).poll_frame(cx);
143
144 match poll {
145 Poll::Ready(Some(Ok(frame))) => {
146 if let Some(data) = frame.data_ref()
147 && let Ok(mut guard) = self.buf.lock()
148 {
149 guard.extend_from_slice(data);
150 }
151 Poll::Ready(Some(Ok(frame)))
152 }
153
154 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
155
156 Poll::Ready(None) => Poll::Ready(None),
157
158 Poll::Pending => Poll::Pending,
159 }
160 }
161
162 fn is_end_stream(&self) -> bool {
163 self.body.is_end_stream()
164 }
165
166 fn size_hint(&self) -> hyper::body::SizeHint {
167 self.body.size_hint()
168 }
169}