Skip to main content

reqwest/async_impl/h3_client/
pool.rs

1use bytes::Bytes;
2use std::collections::HashMap;
3use std::future;
4use std::pin::Pin;
5use std::sync::mpsc::{Receiver, TryRecvError};
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll};
8use std::time::Duration;
9use tokio::sync::{oneshot, watch};
10use tokio::time::Instant;
11
12use crate::async_impl::body::ResponseBody;
13use crate::error::{BoxError, Error, Kind};
14use crate::Body;
15use bytes::Buf;
16use h3::client::SendRequest;
17use h3_quinn::{Connection, OpenStreams};
18use http::uri::{Authority, Scheme};
19use http::{Request, Response, Uri};
20use log::{error, trace};
21
22pub(super) type Key = (Scheme, Authority);
23
24#[derive(Clone)]
25pub struct Pool {
26    inner: Arc<Mutex<PoolInner>>,
27}
28
29struct ConnectingLockInner {
30    key: Key,
31    pool: Arc<Mutex<PoolInner>>,
32}
33
34/// A lock that ensures only one HTTP/3 connection is established per host at a
35/// time. The lock is automatically released when dropped.
36pub struct ConnectingLock(Option<ConnectingLockInner>);
37
38/// A waiter that allows subscribers to receive updates when a new connection is
39/// established or when the connection attempt fails. For example, when
40/// connection lock is dropped due to an error.
41pub struct ConnectingWaiter {
42    receiver: watch::Receiver<Option<PoolClient>>,
43}
44
45pub enum Connecting {
46    /// A connection attempt is already in progress.
47    /// You must subscribe to updates instead of initiating a new connection.
48    InProgress(ConnectingWaiter),
49    /// The connection lock has been acquired, allowing you to initiate a
50    /// new connection.
51    Acquired(ConnectingLock),
52}
53
54impl ConnectingLock {
55    fn new(key: Key, pool: Arc<Mutex<PoolInner>>) -> Self {
56        Self(Some(ConnectingLockInner { key, pool }))
57    }
58
59    /// Forget the lock and return corresponding Key
60    fn forget(mut self) -> Key {
61        // Unwrap is safe because the Option can be None only after dropping the
62        // lock
63        self.0.take().unwrap().key
64    }
65}
66
67impl Drop for ConnectingLock {
68    fn drop(&mut self) {
69        if let Some(ConnectingLockInner { key, pool }) = self.0.take() {
70            let mut pool = pool.lock().unwrap();
71            pool.connecting.remove(&key);
72            trace!("HTTP/3 connecting lock for {:?} is dropped", key);
73        }
74    }
75}
76
77impl ConnectingWaiter {
78    pub async fn receive(mut self) -> Option<PoolClient> {
79        match self.receiver.wait_for(Option::is_some).await {
80            // unwrap because we already checked that option is Some
81            Ok(ok) => Some(ok.as_ref().unwrap().to_owned()),
82            Err(_) => None,
83        }
84    }
85}
86
87impl Pool {
88    pub fn new(timeout: Option<Duration>) -> Self {
89        Self {
90            inner: Arc::new(Mutex::new(PoolInner {
91                connecting: HashMap::new(),
92                idle_conns: HashMap::new(),
93                timeout,
94            })),
95        }
96    }
97
98    /// Acquire a connecting lock. This is to ensure that we have only one HTTP3
99    /// connection per host.
100    pub fn connecting(&self, key: &Key) -> Connecting {
101        let mut inner = self.inner.lock().unwrap();
102
103        if let Some(sender) = inner.connecting.get(key) {
104            Connecting::InProgress(ConnectingWaiter {
105                receiver: sender.subscribe(),
106            })
107        } else {
108            let (tx, _) = watch::channel(None);
109            inner.connecting.insert(key.clone(), tx);
110            Connecting::Acquired(ConnectingLock::new(key.clone(), Arc::clone(&self.inner)))
111        }
112    }
113
114    pub fn try_pool(&self, key: &Key) -> Option<PoolClient> {
115        let mut inner = self.inner.lock().unwrap();
116        let timeout = inner.timeout;
117        if let Some(conn) = inner.idle_conns.get(&key) {
118            // We check first if the connection still valid
119            // and if not, we remove it from the pool.
120            if conn.is_invalid() {
121                trace!("pooled HTTP/3 connection is invalid so removing it...");
122                inner.idle_conns.remove(&key);
123                return None;
124            }
125
126            if let Some(duration) = timeout {
127                if Instant::now().saturating_duration_since(conn.idle_timeout) > duration {
128                    trace!("pooled connection expired");
129                    inner.idle_conns.remove(&key);
130                    return None;
131                }
132            }
133        }
134
135        inner
136            .idle_conns
137            .get_mut(&key)
138            .and_then(|conn| Some(conn.pool()))
139    }
140
141    pub fn new_connection(
142        &mut self,
143        lock: ConnectingLock,
144        mut driver: h3::client::Connection<Connection, Bytes>,
145        tx: SendRequest<OpenStreams, Bytes>,
146    ) -> PoolClient {
147        let (close_tx, close_rx) = std::sync::mpsc::channel();
148        tokio::spawn(async move {
149            let e = future::poll_fn(|cx| driver.poll_close(cx)).await;
150            trace!("poll_close returned error {e:?}");
151            close_tx.send(e).ok();
152        });
153
154        let mut inner = self.inner.lock().unwrap();
155
156        // We clean up "connecting" here so we don't have to acquire the lock again.
157        let key = lock.forget();
158        let Some(notifier) = inner.connecting.remove(&key) else {
159            unreachable!("there should be one connecting lock at a time");
160        };
161        let client = PoolClient::new(tx);
162
163        // Send the client to all our awaiters
164        let pool_client = if let Err(watch::error::SendError(Some(unsent_client))) =
165            notifier.send(Some(client.clone()))
166        {
167            // If there are no awaiters, the client is returned to us. As a
168            // micro optimisation, let's reuse it and avoid cloning.
169            unsent_client
170        } else {
171            client.clone()
172        };
173
174        let conn = PoolConnection::new(pool_client, close_rx);
175        inner.insert(key, conn);
176
177        client
178    }
179}
180
181struct PoolInner {
182    connecting: HashMap<Key, watch::Sender<Option<PoolClient>>>,
183    idle_conns: HashMap<Key, PoolConnection>,
184    timeout: Option<Duration>,
185}
186
187impl PoolInner {
188    fn insert(&mut self, key: Key, conn: PoolConnection) {
189        if self.idle_conns.contains_key(&key) {
190            trace!("connection already exists for key {key:?}");
191        }
192
193        self.idle_conns.insert(key, conn);
194    }
195}
196
197#[derive(Clone)]
198pub struct PoolClient {
199    inner: SendRequest<OpenStreams, Bytes>,
200}
201
202impl PoolClient {
203    pub fn new(tx: SendRequest<OpenStreams, Bytes>) -> Self {
204        Self { inner: tx }
205    }
206
207    pub async fn send_request(
208        &mut self,
209        req: Request<Body>,
210    ) -> Result<Response<ResponseBody>, BoxError> {
211        use hyper::body::Body as _;
212
213        let (head, mut req_body) = req.into_parts();
214        let mut req = Request::from_parts(head, ());
215
216        if let Some(n) = req_body.size_hint().exact() {
217            if n > 0 {
218                req.headers_mut()
219                    .insert(http::header::CONTENT_LENGTH, n.into());
220            }
221        }
222
223        let (mut send, mut recv) = self.inner.send_request(req).await?.split();
224
225        let (tx, mut rx) = oneshot::channel::<Result<(), BoxError>>();
226        tokio::spawn(async move {
227            let mut req_body = Pin::new(&mut req_body);
228            loop {
229                match std::future::poll_fn(|cx| req_body.as_mut().poll_frame(cx)).await {
230                    Some(Ok(frame)) => {
231                        if let Ok(b) = frame.into_data() {
232                            if let Err(e) = send.send_data(Bytes::copy_from_slice(&b)).await {
233                                if is_stop_sending(&e) {
234                                    let _ = tx.send(Ok(()));
235                                    return;
236                                }
237                                if let Err(e) = tx.send(Err(e.into())) {
238                                    error!("Failed to communicate send.send_data() error: {e:?}");
239                                }
240                                return;
241                            }
242                        }
243                    }
244                    Some(Err(e)) => {
245                        if let Err(e) = tx.send(Err(e.into())) {
246                            error!("Failed to communicate req_body read error: {e:?}");
247                        }
248                        return;
249                    }
250
251                    None => break,
252                }
253            }
254
255            if let Err(e) = send.finish().await {
256                if !is_stop_sending(&e) {
257                    if let Err(e) = tx.send(Err(e.into())) {
258                        error!("Failed to communicate send.finish read error: {e:?}");
259                    }
260                    return;
261                }
262            }
263
264            let _ = tx.send(Ok(()));
265        });
266
267        tokio::select! {
268            Ok(Err(e)) = &mut rx => Err(e),
269            resp = recv.recv_response() => {
270                let resp = resp?;
271                let resp_body = crate::async_impl::body::boxed(Incoming::new(recv, resp.headers(), rx));
272                Ok(resp.map(|_| resp_body))
273            }
274        }
275    }
276}
277
278pub struct PoolConnection {
279    // This receives errors from polling h3 driver.
280    close_rx: Receiver<h3::error::ConnectionError>,
281    client: PoolClient,
282    idle_timeout: Instant,
283}
284
285impl PoolConnection {
286    pub fn new(client: PoolClient, close_rx: Receiver<h3::error::ConnectionError>) -> Self {
287        Self {
288            close_rx,
289            client,
290            idle_timeout: Instant::now(),
291        }
292    }
293
294    pub fn pool(&mut self) -> PoolClient {
295        self.idle_timeout = Instant::now();
296        self.client.clone()
297    }
298
299    pub fn is_invalid(&self) -> bool {
300        match self.close_rx.try_recv() {
301            Err(TryRecvError::Empty) => false,
302            Err(TryRecvError::Disconnected) => true,
303            Ok(_) => true,
304        }
305    }
306}
307
308struct Incoming<S, B> {
309    inner: h3::client::RequestStream<S, B>,
310    content_length: Option<u64>,
311    send_rx: oneshot::Receiver<Result<(), BoxError>>,
312}
313
314impl<S, B> Incoming<S, B> {
315    fn new(
316        stream: h3::client::RequestStream<S, B>,
317        headers: &http::header::HeaderMap,
318        send_rx: oneshot::Receiver<Result<(), BoxError>>,
319    ) -> Self {
320        Self {
321            inner: stream,
322            content_length: headers
323                .get(http::header::CONTENT_LENGTH)
324                .and_then(|h| h.to_str().ok())
325                .and_then(|v| v.parse().ok()),
326            send_rx,
327        }
328    }
329}
330
331impl<S, B> http_body::Body for Incoming<S, B>
332where
333    S: h3::quic::RecvStream,
334{
335    type Data = Bytes;
336    type Error = crate::error::Error;
337
338    fn poll_frame(
339        mut self: Pin<&mut Self>,
340        cx: &mut Context,
341    ) -> Poll<Option<Result<hyper::body::Frame<Self::Data>, Self::Error>>> {
342        if let Ok(Err(e)) = self.send_rx.try_recv() {
343            return Poll::Ready(Some(Err(crate::error::body(e))));
344        }
345
346        match futures_core::ready!(self.inner.poll_recv_data(cx)) {
347            Ok(Some(mut b)) => Poll::Ready(Some(Ok(hyper::body::Frame::data(
348                b.copy_to_bytes(b.remaining()),
349            )))),
350            Ok(None) => Poll::Ready(None),
351            Err(e) => Poll::Ready(Some(Err(crate::error::body(e)))),
352        }
353    }
354
355    fn size_hint(&self) -> hyper::body::SizeHint {
356        if let Some(content_length) = self.content_length {
357            hyper::body::SizeHint::with_exact(content_length)
358        } else {
359            hyper::body::SizeHint::default()
360        }
361    }
362}
363
364pub(crate) fn extract_domain(uri: &mut Uri) -> Result<Key, Error> {
365    let uri_clone = uri.clone();
366    match (uri_clone.scheme(), uri_clone.authority()) {
367        (Some(scheme), Some(auth)) => {
368            let scheme_str = scheme.as_str();
369            if scheme_str != "https" && scheme_str != "h3" {
370                return Err(Error::new(
371                    Kind::Request,
372                    Some(Box::new(std::io::Error::new(
373                        std::io::ErrorKind::InvalidInput,
374                        format!(
375                            "HTTP/3 only supports 'https' or 'h3' schemes, got: {}",
376                            scheme_str
377                        ),
378                    ))),
379                ));
380            }
381            Ok((scheme.clone(), auth.clone()))
382        }
383        _ => Err(Error::new(Kind::Request, None::<Error>)),
384    }
385}
386
387pub(crate) fn domain_as_uri((scheme, auth): Key) -> Uri {
388    http::uri::Builder::new()
389        .scheme(scheme)
390        .authority(auth)
391        .path_and_query("/")
392        .build()
393        .expect("domain is valid Uri")
394}
395
396/// Indicates the remote requested the peer to stop sending data without error.
397fn is_stop_sending(e: &h3::error::StreamError) -> bool {
398    matches!(
399        e,
400        h3::error::StreamError::RemoteTerminate {
401            code: h3::error::Code::H3_NO_ERROR,
402            ..
403        }
404    )
405}