reqwest/async_impl/h3_client/
pool.rs1use bytes::Bytes;
2use std::collections::HashMap;
3use std::future;
4use std::pin::Pin;
5use std::sync::mpsc::{Receiver, TryRecvError};
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll};
8use std::time::Duration;
9use tokio::sync::{oneshot, watch};
10use tokio::time::Instant;
11
12use crate::async_impl::body::ResponseBody;
13use crate::error::{BoxError, Error, Kind};
14use crate::Body;
15use bytes::Buf;
16use h3::client::SendRequest;
17use h3_quinn::{Connection, OpenStreams};
18use http::uri::{Authority, Scheme};
19use http::{Request, Response, Uri};
20use log::{error, trace};
21
22pub(super) type Key = (Scheme, Authority);
23
24#[derive(Clone)]
25pub struct Pool {
26 inner: Arc<Mutex<PoolInner>>,
27}
28
29struct ConnectingLockInner {
30 key: Key,
31 pool: Arc<Mutex<PoolInner>>,
32}
33
34pub struct ConnectingLock(Option<ConnectingLockInner>);
37
38pub struct ConnectingWaiter {
42 receiver: watch::Receiver<Option<PoolClient>>,
43}
44
45pub enum Connecting {
46 InProgress(ConnectingWaiter),
49 Acquired(ConnectingLock),
52}
53
54impl ConnectingLock {
55 fn new(key: Key, pool: Arc<Mutex<PoolInner>>) -> Self {
56 Self(Some(ConnectingLockInner { key, pool }))
57 }
58
59 fn forget(mut self) -> Key {
61 self.0.take().unwrap().key
64 }
65}
66
67impl Drop for ConnectingLock {
68 fn drop(&mut self) {
69 if let Some(ConnectingLockInner { key, pool }) = self.0.take() {
70 let mut pool = pool.lock().unwrap();
71 pool.connecting.remove(&key);
72 trace!("HTTP/3 connecting lock for {:?} is dropped", key);
73 }
74 }
75}
76
77impl ConnectingWaiter {
78 pub async fn receive(mut self) -> Option<PoolClient> {
79 match self.receiver.wait_for(Option::is_some).await {
80 Ok(ok) => Some(ok.as_ref().unwrap().to_owned()),
82 Err(_) => None,
83 }
84 }
85}
86
87impl Pool {
88 pub fn new(timeout: Option<Duration>) -> Self {
89 Self {
90 inner: Arc::new(Mutex::new(PoolInner {
91 connecting: HashMap::new(),
92 idle_conns: HashMap::new(),
93 timeout,
94 })),
95 }
96 }
97
98 pub fn connecting(&self, key: &Key) -> Connecting {
101 let mut inner = self.inner.lock().unwrap();
102
103 if let Some(sender) = inner.connecting.get(key) {
104 Connecting::InProgress(ConnectingWaiter {
105 receiver: sender.subscribe(),
106 })
107 } else {
108 let (tx, _) = watch::channel(None);
109 inner.connecting.insert(key.clone(), tx);
110 Connecting::Acquired(ConnectingLock::new(key.clone(), Arc::clone(&self.inner)))
111 }
112 }
113
114 pub fn try_pool(&self, key: &Key) -> Option<PoolClient> {
115 let mut inner = self.inner.lock().unwrap();
116 let timeout = inner.timeout;
117 if let Some(conn) = inner.idle_conns.get(&key) {
118 if conn.is_invalid() {
121 trace!("pooled HTTP/3 connection is invalid so removing it...");
122 inner.idle_conns.remove(&key);
123 return None;
124 }
125
126 if let Some(duration) = timeout {
127 if Instant::now().saturating_duration_since(conn.idle_timeout) > duration {
128 trace!("pooled connection expired");
129 inner.idle_conns.remove(&key);
130 return None;
131 }
132 }
133 }
134
135 inner
136 .idle_conns
137 .get_mut(&key)
138 .and_then(|conn| Some(conn.pool()))
139 }
140
141 pub fn new_connection(
142 &mut self,
143 lock: ConnectingLock,
144 mut driver: h3::client::Connection<Connection, Bytes>,
145 tx: SendRequest<OpenStreams, Bytes>,
146 ) -> PoolClient {
147 let (close_tx, close_rx) = std::sync::mpsc::channel();
148 tokio::spawn(async move {
149 let e = future::poll_fn(|cx| driver.poll_close(cx)).await;
150 trace!("poll_close returned error {e:?}");
151 close_tx.send(e).ok();
152 });
153
154 let mut inner = self.inner.lock().unwrap();
155
156 let key = lock.forget();
158 let Some(notifier) = inner.connecting.remove(&key) else {
159 unreachable!("there should be one connecting lock at a time");
160 };
161 let client = PoolClient::new(tx);
162
163 let pool_client = if let Err(watch::error::SendError(Some(unsent_client))) =
165 notifier.send(Some(client.clone()))
166 {
167 unsent_client
170 } else {
171 client.clone()
172 };
173
174 let conn = PoolConnection::new(pool_client, close_rx);
175 inner.insert(key, conn);
176
177 client
178 }
179}
180
181struct PoolInner {
182 connecting: HashMap<Key, watch::Sender<Option<PoolClient>>>,
183 idle_conns: HashMap<Key, PoolConnection>,
184 timeout: Option<Duration>,
185}
186
187impl PoolInner {
188 fn insert(&mut self, key: Key, conn: PoolConnection) {
189 if self.idle_conns.contains_key(&key) {
190 trace!("connection already exists for key {key:?}");
191 }
192
193 self.idle_conns.insert(key, conn);
194 }
195}
196
197#[derive(Clone)]
198pub struct PoolClient {
199 inner: SendRequest<OpenStreams, Bytes>,
200}
201
202impl PoolClient {
203 pub fn new(tx: SendRequest<OpenStreams, Bytes>) -> Self {
204 Self { inner: tx }
205 }
206
207 pub async fn send_request(
208 &mut self,
209 req: Request<Body>,
210 ) -> Result<Response<ResponseBody>, BoxError> {
211 use hyper::body::Body as _;
212
213 let (head, mut req_body) = req.into_parts();
214 let mut req = Request::from_parts(head, ());
215
216 if let Some(n) = req_body.size_hint().exact() {
217 if n > 0 {
218 req.headers_mut()
219 .insert(http::header::CONTENT_LENGTH, n.into());
220 }
221 }
222
223 let (mut send, mut recv) = self.inner.send_request(req).await?.split();
224
225 let (tx, mut rx) = oneshot::channel::<Result<(), BoxError>>();
226 tokio::spawn(async move {
227 let mut req_body = Pin::new(&mut req_body);
228 loop {
229 match std::future::poll_fn(|cx| req_body.as_mut().poll_frame(cx)).await {
230 Some(Ok(frame)) => {
231 if let Ok(b) = frame.into_data() {
232 if let Err(e) = send.send_data(Bytes::copy_from_slice(&b)).await {
233 if is_stop_sending(&e) {
234 let _ = tx.send(Ok(()));
235 return;
236 }
237 if let Err(e) = tx.send(Err(e.into())) {
238 error!("Failed to communicate send.send_data() error: {e:?}");
239 }
240 return;
241 }
242 }
243 }
244 Some(Err(e)) => {
245 if let Err(e) = tx.send(Err(e.into())) {
246 error!("Failed to communicate req_body read error: {e:?}");
247 }
248 return;
249 }
250
251 None => break,
252 }
253 }
254
255 if let Err(e) = send.finish().await {
256 if !is_stop_sending(&e) {
257 if let Err(e) = tx.send(Err(e.into())) {
258 error!("Failed to communicate send.finish read error: {e:?}");
259 }
260 return;
261 }
262 }
263
264 let _ = tx.send(Ok(()));
265 });
266
267 tokio::select! {
268 Ok(Err(e)) = &mut rx => Err(e),
269 resp = recv.recv_response() => {
270 let resp = resp?;
271 let resp_body = crate::async_impl::body::boxed(Incoming::new(recv, resp.headers(), rx));
272 Ok(resp.map(|_| resp_body))
273 }
274 }
275 }
276}
277
278pub struct PoolConnection {
279 close_rx: Receiver<h3::error::ConnectionError>,
281 client: PoolClient,
282 idle_timeout: Instant,
283}
284
285impl PoolConnection {
286 pub fn new(client: PoolClient, close_rx: Receiver<h3::error::ConnectionError>) -> Self {
287 Self {
288 close_rx,
289 client,
290 idle_timeout: Instant::now(),
291 }
292 }
293
294 pub fn pool(&mut self) -> PoolClient {
295 self.idle_timeout = Instant::now();
296 self.client.clone()
297 }
298
299 pub fn is_invalid(&self) -> bool {
300 match self.close_rx.try_recv() {
301 Err(TryRecvError::Empty) => false,
302 Err(TryRecvError::Disconnected) => true,
303 Ok(_) => true,
304 }
305 }
306}
307
308struct Incoming<S, B> {
309 inner: h3::client::RequestStream<S, B>,
310 content_length: Option<u64>,
311 send_rx: oneshot::Receiver<Result<(), BoxError>>,
312}
313
314impl<S, B> Incoming<S, B> {
315 fn new(
316 stream: h3::client::RequestStream<S, B>,
317 headers: &http::header::HeaderMap,
318 send_rx: oneshot::Receiver<Result<(), BoxError>>,
319 ) -> Self {
320 Self {
321 inner: stream,
322 content_length: headers
323 .get(http::header::CONTENT_LENGTH)
324 .and_then(|h| h.to_str().ok())
325 .and_then(|v| v.parse().ok()),
326 send_rx,
327 }
328 }
329}
330
331impl<S, B> http_body::Body for Incoming<S, B>
332where
333 S: h3::quic::RecvStream,
334{
335 type Data = Bytes;
336 type Error = crate::error::Error;
337
338 fn poll_frame(
339 mut self: Pin<&mut Self>,
340 cx: &mut Context,
341 ) -> Poll<Option<Result<hyper::body::Frame<Self::Data>, Self::Error>>> {
342 if let Ok(Err(e)) = self.send_rx.try_recv() {
343 return Poll::Ready(Some(Err(crate::error::body(e))));
344 }
345
346 match futures_core::ready!(self.inner.poll_recv_data(cx)) {
347 Ok(Some(mut b)) => Poll::Ready(Some(Ok(hyper::body::Frame::data(
348 b.copy_to_bytes(b.remaining()),
349 )))),
350 Ok(None) => Poll::Ready(None),
351 Err(e) => Poll::Ready(Some(Err(crate::error::body(e)))),
352 }
353 }
354
355 fn size_hint(&self) -> hyper::body::SizeHint {
356 if let Some(content_length) = self.content_length {
357 hyper::body::SizeHint::with_exact(content_length)
358 } else {
359 hyper::body::SizeHint::default()
360 }
361 }
362}
363
364pub(crate) fn extract_domain(uri: &mut Uri) -> Result<Key, Error> {
365 let uri_clone = uri.clone();
366 match (uri_clone.scheme(), uri_clone.authority()) {
367 (Some(scheme), Some(auth)) => {
368 let scheme_str = scheme.as_str();
369 if scheme_str != "https" && scheme_str != "h3" {
370 return Err(Error::new(
371 Kind::Request,
372 Some(Box::new(std::io::Error::new(
373 std::io::ErrorKind::InvalidInput,
374 format!(
375 "HTTP/3 only supports 'https' or 'h3' schemes, got: {}",
376 scheme_str
377 ),
378 ))),
379 ));
380 }
381 Ok((scheme.clone(), auth.clone()))
382 }
383 _ => Err(Error::new(Kind::Request, None::<Error>)),
384 }
385}
386
387pub(crate) fn domain_as_uri((scheme, auth): Key) -> Uri {
388 http::uri::Builder::new()
389 .scheme(scheme)
390 .authority(auth)
391 .path_and_query("/")
392 .build()
393 .expect("domain is valid Uri")
394}
395
396fn is_stop_sending(e: &h3::error::StreamError) -> bool {
398 matches!(
399 e,
400 h3::error::StreamError::RemoteTerminate {
401 code: h3::error::Code::H3_NO_ERROR,
402 ..
403 }
404 )
405}