actori_http/client/
pool.rs

1use std::cell::RefCell;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::pin::Pin;
5use std::rc::Rc;
6use std::task::{Context, Poll};
7use std::time::{Duration, Instant};
8
9use actori_codec::{AsyncRead, AsyncWrite};
10use actori_rt::time::{delay_for, Delay};
11use actori_service::Service;
12use actori_utils::{oneshot, task::LocalWaker};
13use bytes::Bytes;
14use futures_util::future::{poll_fn, FutureExt, LocalBoxFuture};
15use fxhash::FxHashMap;
16use h2::client::{handshake, Connection, SendRequest};
17use http::uri::Authority;
18use indexmap::IndexSet;
19use slab::Slab;
20
21use super::connection::{ConnectionType, IoConnection};
22use super::error::ConnectError;
23use super::Connect;
24
25#[derive(Clone, Copy, PartialEq)]
26/// Protocol version
27pub enum Protocol {
28    Http1,
29    Http2,
30}
31
32#[derive(Hash, Eq, PartialEq, Clone, Debug)]
33pub(crate) struct Key {
34    authority: Authority,
35}
36
37impl From<Authority> for Key {
38    fn from(authority: Authority) -> Key {
39        Key { authority }
40    }
41}
42
43/// Connections pool
44pub(crate) struct ConnectionPool<T, Io: 'static>(Rc<RefCell<T>>, Rc<RefCell<Inner<Io>>>);
45
46impl<T, Io> ConnectionPool<T, Io>
47where
48    Io: AsyncRead + AsyncWrite + Unpin + 'static,
49    T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
50        + 'static,
51{
52    pub(crate) fn new(
53        connector: T,
54        conn_lifetime: Duration,
55        conn_keep_alive: Duration,
56        disconnect_timeout: Option<Duration>,
57        limit: usize,
58    ) -> Self {
59        ConnectionPool(
60            Rc::new(RefCell::new(connector)),
61            Rc::new(RefCell::new(Inner {
62                conn_lifetime,
63                conn_keep_alive,
64                disconnect_timeout,
65                limit,
66                acquired: 0,
67                waiters: Slab::new(),
68                waiters_queue: IndexSet::new(),
69                available: FxHashMap::default(),
70                waker: LocalWaker::new(),
71            })),
72        )
73    }
74}
75
76impl<T, Io> Clone for ConnectionPool<T, Io>
77where
78    Io: 'static,
79{
80    fn clone(&self) -> Self {
81        ConnectionPool(self.0.clone(), self.1.clone())
82    }
83}
84
85impl<T, Io> Service for ConnectionPool<T, Io>
86where
87    Io: AsyncRead + AsyncWrite + Unpin + 'static,
88    T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>
89        + 'static,
90{
91    type Request = Connect;
92    type Response = IoConnection<Io>;
93    type Error = ConnectError;
94    type Future = LocalBoxFuture<'static, Result<IoConnection<Io>, ConnectError>>;
95
96    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
97        self.0.poll_ready(cx)
98    }
99
100    fn call(&mut self, req: Connect) -> Self::Future {
101        // start support future
102        actori_rt::spawn(ConnectorPoolSupport {
103            connector: self.0.clone(),
104            inner: self.1.clone(),
105        });
106
107        let mut connector = self.0.clone();
108        let inner = self.1.clone();
109
110        let fut = async move {
111            let key = if let Some(authority) = req.uri.authority() {
112                authority.clone().into()
113            } else {
114                return Err(ConnectError::Unresolverd);
115            };
116
117            // acquire connection
118            match poll_fn(|cx| Poll::Ready(inner.borrow_mut().acquire(&key, cx))).await {
119                Acquire::Acquired(io, created) => {
120                    // use existing connection
121                    return Ok(IoConnection::new(
122                        io,
123                        created,
124                        Some(Acquired(key, Some(inner))),
125                    ));
126                }
127                Acquire::Available => {
128                    // open tcp connection
129                    let (io, proto) = connector.call(req).await?;
130
131                    let guard = OpenGuard::new(key, inner);
132
133                    if proto == Protocol::Http1 {
134                        Ok(IoConnection::new(
135                            ConnectionType::H1(io),
136                            Instant::now(),
137                            Some(guard.consume()),
138                        ))
139                    } else {
140                        let (snd, connection) = handshake(io).await?;
141                        actori_rt::spawn(connection.map(|_| ()));
142                        Ok(IoConnection::new(
143                            ConnectionType::H2(snd),
144                            Instant::now(),
145                            Some(guard.consume()),
146                        ))
147                    }
148                }
149                _ => {
150                    // connection is not available, wait
151                    let (rx, token) = inner.borrow_mut().wait_for(req);
152
153                    let guard = WaiterGuard::new(key, token, inner);
154                    let res = match rx.await {
155                        Err(_) => Err(ConnectError::Disconnected),
156                        Ok(res) => res,
157                    };
158                    guard.consume();
159                    res
160                }
161            }
162        };
163
164        fut.boxed_local()
165    }
166}
167
168struct WaiterGuard<Io>
169where
170    Io: AsyncRead + AsyncWrite + Unpin + 'static,
171{
172    key: Key,
173    token: usize,
174    inner: Option<Rc<RefCell<Inner<Io>>>>,
175}
176
177impl<Io> WaiterGuard<Io>
178where
179    Io: AsyncRead + AsyncWrite + Unpin + 'static,
180{
181    fn new(key: Key, token: usize, inner: Rc<RefCell<Inner<Io>>>) -> Self {
182        Self {
183            key,
184            token,
185            inner: Some(inner),
186        }
187    }
188
189    fn consume(mut self) {
190        let _ = self.inner.take();
191    }
192}
193
194impl<Io> Drop for WaiterGuard<Io>
195where
196    Io: AsyncRead + AsyncWrite + Unpin + 'static,
197{
198    fn drop(&mut self) {
199        if let Some(i) = self.inner.take() {
200            let mut inner = i.as_ref().borrow_mut();
201            inner.release_waiter(&self.key, self.token);
202            inner.check_availibility();
203        }
204    }
205}
206
207struct OpenGuard<Io>
208where
209    Io: AsyncRead + AsyncWrite + Unpin + 'static,
210{
211    key: Key,
212    inner: Option<Rc<RefCell<Inner<Io>>>>,
213}
214
215impl<Io> OpenGuard<Io>
216where
217    Io: AsyncRead + AsyncWrite + Unpin + 'static,
218{
219    fn new(key: Key, inner: Rc<RefCell<Inner<Io>>>) -> Self {
220        Self {
221            key,
222            inner: Some(inner),
223        }
224    }
225
226    fn consume(mut self) -> Acquired<Io> {
227        Acquired(self.key.clone(), self.inner.take())
228    }
229}
230
231impl<Io> Drop for OpenGuard<Io>
232where
233    Io: AsyncRead + AsyncWrite + Unpin + 'static,
234{
235    fn drop(&mut self) {
236        if let Some(i) = self.inner.take() {
237            let mut inner = i.as_ref().borrow_mut();
238            inner.release();
239            inner.check_availibility();
240        }
241    }
242}
243
244enum Acquire<T> {
245    Acquired(ConnectionType<T>, Instant),
246    Available,
247    NotAvailable,
248}
249
250struct AvailableConnection<Io> {
251    io: ConnectionType<Io>,
252    used: Instant,
253    created: Instant,
254}
255
256pub(crate) struct Inner<Io> {
257    conn_lifetime: Duration,
258    conn_keep_alive: Duration,
259    disconnect_timeout: Option<Duration>,
260    limit: usize,
261    acquired: usize,
262    available: FxHashMap<Key, VecDeque<AvailableConnection<Io>>>,
263    waiters: Slab<
264        Option<(
265            Connect,
266            oneshot::Sender<Result<IoConnection<Io>, ConnectError>>,
267        )>,
268    >,
269    waiters_queue: IndexSet<(Key, usize)>,
270    waker: LocalWaker,
271}
272
273impl<Io> Inner<Io> {
274    fn reserve(&mut self) {
275        self.acquired += 1;
276    }
277
278    fn release(&mut self) {
279        self.acquired -= 1;
280    }
281
282    fn release_waiter(&mut self, key: &Key, token: usize) {
283        self.waiters.remove(token);
284        let _ = self.waiters_queue.shift_remove(&(key.clone(), token));
285    }
286}
287
288impl<Io> Inner<Io>
289where
290    Io: AsyncRead + AsyncWrite + Unpin + 'static,
291{
292    /// connection is not available, wait
293    fn wait_for(
294        &mut self,
295        connect: Connect,
296    ) -> (
297        oneshot::Receiver<Result<IoConnection<Io>, ConnectError>>,
298        usize,
299    ) {
300        let (tx, rx) = oneshot::channel();
301
302        let key: Key = connect.uri.authority().unwrap().clone().into();
303        let entry = self.waiters.vacant_entry();
304        let token = entry.key();
305        entry.insert(Some((connect, tx)));
306        assert!(self.waiters_queue.insert((key, token)));
307
308        (rx, token)
309    }
310
311    fn acquire(&mut self, key: &Key, cx: &mut Context<'_>) -> Acquire<Io> {
312        // check limits
313        if self.limit > 0 && self.acquired >= self.limit {
314            return Acquire::NotAvailable;
315        }
316
317        self.reserve();
318
319        // check if open connection is available
320        // cleanup stale connections at the same time
321        if let Some(ref mut connections) = self.available.get_mut(key) {
322            let now = Instant::now();
323            while let Some(conn) = connections.pop_back() {
324                // check if it still usable
325                if (now - conn.used) > self.conn_keep_alive
326                    || (now - conn.created) > self.conn_lifetime
327                {
328                    if let Some(timeout) = self.disconnect_timeout {
329                        if let ConnectionType::H1(io) = conn.io {
330                            actori_rt::spawn(CloseConnection::new(io, timeout))
331                        }
332                    }
333                } else {
334                    let mut io = conn.io;
335                    let mut buf = [0; 2];
336                    if let ConnectionType::H1(ref mut s) = io {
337                        match Pin::new(s).poll_read(cx, &mut buf) {
338                            Poll::Pending => (),
339                            Poll::Ready(Ok(n)) if n > 0 => {
340                                if let Some(timeout) = self.disconnect_timeout {
341                                    if let ConnectionType::H1(io) = io {
342                                        actori_rt::spawn(CloseConnection::new(
343                                            io, timeout,
344                                        ))
345                                    }
346                                }
347                                continue;
348                            }
349                            _ => continue,
350                        }
351                    }
352                    return Acquire::Acquired(io, conn.created);
353                }
354            }
355        }
356        Acquire::Available
357    }
358
359    fn release_conn(&mut self, key: &Key, io: ConnectionType<Io>, created: Instant) {
360        self.acquired -= 1;
361        self.available
362            .entry(key.clone())
363            .or_insert_with(VecDeque::new)
364            .push_back(AvailableConnection {
365                io,
366                created,
367                used: Instant::now(),
368            });
369        self.check_availibility();
370    }
371
372    fn release_close(&mut self, io: ConnectionType<Io>) {
373        self.acquired -= 1;
374        if let Some(timeout) = self.disconnect_timeout {
375            if let ConnectionType::H1(io) = io {
376                actori_rt::spawn(CloseConnection::new(io, timeout))
377            }
378        }
379        self.check_availibility();
380    }
381
382    fn check_availibility(&self) {
383        if !self.waiters_queue.is_empty() && self.acquired < self.limit {
384            self.waker.wake();
385        }
386    }
387}
388
389struct CloseConnection<T> {
390    io: T,
391    timeout: Delay,
392}
393
394impl<T> CloseConnection<T>
395where
396    T: AsyncWrite + Unpin,
397{
398    fn new(io: T, timeout: Duration) -> Self {
399        CloseConnection {
400            io,
401            timeout: delay_for(timeout),
402        }
403    }
404}
405
406impl<T> Future for CloseConnection<T>
407where
408    T: AsyncWrite + Unpin,
409{
410    type Output = ();
411
412    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
413        let this = self.get_mut();
414
415        match Pin::new(&mut this.timeout).poll(cx) {
416            Poll::Ready(_) => Poll::Ready(()),
417            Poll::Pending => match Pin::new(&mut this.io).poll_shutdown(cx) {
418                Poll::Ready(_) => Poll::Ready(()),
419                Poll::Pending => Poll::Pending,
420            },
421        }
422    }
423}
424
425struct ConnectorPoolSupport<T, Io>
426where
427    Io: AsyncRead + AsyncWrite + Unpin + 'static,
428{
429    connector: T,
430    inner: Rc<RefCell<Inner<Io>>>,
431}
432
433impl<T, Io> Future for ConnectorPoolSupport<T, Io>
434where
435    Io: AsyncRead + AsyncWrite + Unpin + 'static,
436    T: Service<Request = Connect, Response = (Io, Protocol), Error = ConnectError>,
437    T::Future: 'static,
438{
439    type Output = ();
440
441    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
442        let this = unsafe { self.get_unchecked_mut() };
443
444        let mut inner = this.inner.as_ref().borrow_mut();
445        inner.waker.register(cx.waker());
446
447        // check waiters
448        loop {
449            let (key, token) = {
450                if let Some((key, token)) = inner.waiters_queue.get_index(0) {
451                    (key.clone(), *token)
452                } else {
453                    break;
454                }
455            };
456            if inner.waiters.get(token).unwrap().is_none() {
457                continue;
458            }
459
460            match inner.acquire(&key, cx) {
461                Acquire::NotAvailable => break,
462                Acquire::Acquired(io, created) => {
463                    let tx = inner.waiters.get_mut(token).unwrap().take().unwrap().1;
464                    if let Err(conn) = tx.send(Ok(IoConnection::new(
465                        io,
466                        created,
467                        Some(Acquired(key.clone(), Some(this.inner.clone()))),
468                    ))) {
469                        let (io, created) = conn.unwrap().into_inner();
470                        inner.release_conn(&key, io, created);
471                    }
472                }
473                Acquire::Available => {
474                    let (connect, tx) =
475                        inner.waiters.get_mut(token).unwrap().take().unwrap();
476                    OpenWaitingConnection::spawn(
477                        key.clone(),
478                        tx,
479                        this.inner.clone(),
480                        this.connector.call(connect),
481                    );
482                }
483            }
484            let _ = inner.waiters_queue.swap_remove_index(0);
485        }
486
487        Poll::Pending
488    }
489}
490
491struct OpenWaitingConnection<F, Io>
492where
493    Io: AsyncRead + AsyncWrite + Unpin + 'static,
494{
495    fut: F,
496    key: Key,
497    h2: Option<
498        LocalBoxFuture<
499            'static,
500            Result<(SendRequest<Bytes>, Connection<Io, Bytes>), h2::Error>,
501        >,
502    >,
503    rx: Option<oneshot::Sender<Result<IoConnection<Io>, ConnectError>>>,
504    inner: Option<Rc<RefCell<Inner<Io>>>>,
505}
506
507impl<F, Io> OpenWaitingConnection<F, Io>
508where
509    F: Future<Output = Result<(Io, Protocol), ConnectError>> + 'static,
510    Io: AsyncRead + AsyncWrite + Unpin + 'static,
511{
512    fn spawn(
513        key: Key,
514        rx: oneshot::Sender<Result<IoConnection<Io>, ConnectError>>,
515        inner: Rc<RefCell<Inner<Io>>>,
516        fut: F,
517    ) {
518        actori_rt::spawn(OpenWaitingConnection {
519            key,
520            fut,
521            h2: None,
522            rx: Some(rx),
523            inner: Some(inner),
524        })
525    }
526}
527
528impl<F, Io> Drop for OpenWaitingConnection<F, Io>
529where
530    Io: AsyncRead + AsyncWrite + Unpin + 'static,
531{
532    fn drop(&mut self) {
533        if let Some(inner) = self.inner.take() {
534            let mut inner = inner.as_ref().borrow_mut();
535            inner.release();
536            inner.check_availibility();
537        }
538    }
539}
540
541impl<F, Io> Future for OpenWaitingConnection<F, Io>
542where
543    F: Future<Output = Result<(Io, Protocol), ConnectError>>,
544    Io: AsyncRead + AsyncWrite + Unpin,
545{
546    type Output = ();
547
548    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
549        let this = unsafe { self.get_unchecked_mut() };
550
551        if let Some(ref mut h2) = this.h2 {
552            return match Pin::new(h2).poll(cx) {
553                Poll::Ready(Ok((snd, connection))) => {
554                    actori_rt::spawn(connection.map(|_| ()));
555                    let rx = this.rx.take().unwrap();
556                    let _ = rx.send(Ok(IoConnection::new(
557                        ConnectionType::H2(snd),
558                        Instant::now(),
559                        Some(Acquired(this.key.clone(), this.inner.take())),
560                    )));
561                    Poll::Ready(())
562                }
563                Poll::Pending => Poll::Pending,
564                Poll::Ready(Err(err)) => {
565                    let _ = this.inner.take();
566                    if let Some(rx) = this.rx.take() {
567                        let _ = rx.send(Err(ConnectError::H2(err)));
568                    }
569                    Poll::Ready(())
570                }
571            };
572        }
573
574        match unsafe { Pin::new_unchecked(&mut this.fut) }.poll(cx) {
575            Poll::Ready(Err(err)) => {
576                let _ = this.inner.take();
577                if let Some(rx) = this.rx.take() {
578                    let _ = rx.send(Err(err));
579                }
580                Poll::Ready(())
581            }
582            Poll::Ready(Ok((io, proto))) => {
583                if proto == Protocol::Http1 {
584                    let rx = this.rx.take().unwrap();
585                    let _ = rx.send(Ok(IoConnection::new(
586                        ConnectionType::H1(io),
587                        Instant::now(),
588                        Some(Acquired(this.key.clone(), this.inner.take())),
589                    )));
590                    Poll::Ready(())
591                } else {
592                    this.h2 = Some(handshake(io).boxed_local());
593                    unsafe { Pin::new_unchecked(this) }.poll(cx)
594                }
595            }
596            Poll::Pending => Poll::Pending,
597        }
598    }
599}
600
601pub(crate) struct Acquired<T>(Key, Option<Rc<RefCell<Inner<T>>>>);
602
603impl<T> Acquired<T>
604where
605    T: AsyncRead + AsyncWrite + Unpin + 'static,
606{
607    pub(crate) fn close(&mut self, conn: IoConnection<T>) {
608        if let Some(inner) = self.1.take() {
609            let (io, _) = conn.into_inner();
610            inner.as_ref().borrow_mut().release_close(io);
611        }
612    }
613    pub(crate) fn release(&mut self, conn: IoConnection<T>) {
614        if let Some(inner) = self.1.take() {
615            let (io, created) = conn.into_inner();
616            inner
617                .as_ref()
618                .borrow_mut()
619                .release_conn(&self.0, io, created);
620        }
621    }
622}
623
624impl<T> Drop for Acquired<T> {
625    fn drop(&mut self) {
626        if let Some(inner) = self.1.take() {
627            inner.as_ref().borrow_mut().release();
628        }
629    }
630}