requiem_http/client/
connection.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3use std::{fmt, io, mem, time};
4
5use requiem_codec::{AsyncRead, AsyncWrite, Framed};
6use bytes::{Buf, Bytes};
7use futures_util::future::{err, Either, Future, FutureExt, LocalBoxFuture, Ready};
8use h2::client::SendRequest;
9use pin_project::{pin_project, project};
10
11use crate::body::MessageBody;
12use crate::h1::ClientCodec;
13use crate::message::{RequestHeadType, ResponseHead};
14use crate::payload::Payload;
15
16use super::error::SendRequestError;
17use super::pool::{Acquired, Protocol};
18use super::{h1proto, h2proto};
19
20pub(crate) enum ConnectionType<Io> {
21    H1(Io),
22    H2(SendRequest<Bytes>),
23}
24
25pub trait Connection {
26    type Io: AsyncRead + AsyncWrite + Unpin;
27    type Future: Future<Output = Result<(ResponseHead, Payload), SendRequestError>>;
28
29    fn protocol(&self) -> Protocol;
30
31    /// Send request and body
32    fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>(
33        self,
34        head: H,
35        body: B,
36    ) -> Self::Future;
37
38    type TunnelFuture: Future<
39        Output = Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
40    >;
41
42    /// Send request, returns Response and Framed
43    fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture;
44}
45
46pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static {
47    /// Close connection
48    fn close(&mut self);
49
50    /// Release connection to the connection pool
51    fn release(&mut self);
52}
53
54#[doc(hidden)]
55/// HTTP client connection
56pub struct IoConnection<T> {
57    io: Option<ConnectionType<T>>,
58    created: time::Instant,
59    pool: Option<Acquired<T>>,
60}
61
62impl<T> fmt::Debug for IoConnection<T>
63where
64    T: fmt::Debug,
65{
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        match self.io {
68            Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io),
69            Some(ConnectionType::H2(_)) => write!(f, "H2Connection"),
70            None => write!(f, "Connection(Empty)"),
71        }
72    }
73}
74
75impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
76    pub(crate) fn new(
77        io: ConnectionType<T>,
78        created: time::Instant,
79        pool: Option<Acquired<T>>,
80    ) -> Self {
81        IoConnection {
82            pool,
83            created,
84            io: Some(io),
85        }
86    }
87
88    pub(crate) fn into_inner(self) -> (ConnectionType<T>, time::Instant) {
89        (self.io.unwrap(), self.created)
90    }
91}
92
93impl<T> Connection for IoConnection<T>
94where
95    T: AsyncRead + AsyncWrite + Unpin + 'static,
96{
97    type Io = T;
98    type Future =
99        LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>;
100
101    fn protocol(&self) -> Protocol {
102        match self.io {
103            Some(ConnectionType::H1(_)) => Protocol::Http1,
104            Some(ConnectionType::H2(_)) => Protocol::Http2,
105            None => Protocol::Http1,
106        }
107    }
108
109    fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>(
110        mut self,
111        head: H,
112        body: B,
113    ) -> Self::Future {
114        match self.io.take().unwrap() {
115            ConnectionType::H1(io) => {
116                h1proto::send_request(io, head.into(), body, self.created, self.pool)
117                    .boxed_local()
118            }
119            ConnectionType::H2(io) => {
120                h2proto::send_request(io, head.into(), body, self.created, self.pool)
121                    .boxed_local()
122            }
123        }
124    }
125
126    type TunnelFuture = Either<
127        LocalBoxFuture<
128            'static,
129            Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
130        >,
131        Ready<Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>>,
132    >;
133
134    /// Send request, returns Response and Framed
135    fn open_tunnel<H: Into<RequestHeadType>>(mut self, head: H) -> Self::TunnelFuture {
136        match self.io.take().unwrap() {
137            ConnectionType::H1(io) => {
138                Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local())
139            }
140            ConnectionType::H2(io) => {
141                if let Some(mut pool) = self.pool.take() {
142                    pool.release(IoConnection::new(
143                        ConnectionType::H2(io),
144                        self.created,
145                        None,
146                    ));
147                }
148                Either::Right(err(SendRequestError::TunnelNotSupported))
149            }
150        }
151    }
152}
153
154#[allow(dead_code)]
155pub(crate) enum EitherConnection<A, B> {
156    A(IoConnection<A>),
157    B(IoConnection<B>),
158}
159
160impl<A, B> Connection for EitherConnection<A, B>
161where
162    A: AsyncRead + AsyncWrite + Unpin + 'static,
163    B: AsyncRead + AsyncWrite + Unpin + 'static,
164{
165    type Io = EitherIo<A, B>;
166    type Future =
167        LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>;
168
169    fn protocol(&self) -> Protocol {
170        match self {
171            EitherConnection::A(con) => con.protocol(),
172            EitherConnection::B(con) => con.protocol(),
173        }
174    }
175
176    fn send_request<RB: MessageBody + 'static, H: Into<RequestHeadType>>(
177        self,
178        head: H,
179        body: RB,
180    ) -> Self::Future {
181        match self {
182            EitherConnection::A(con) => con.send_request(head, body),
183            EitherConnection::B(con) => con.send_request(head, body),
184        }
185    }
186
187    type TunnelFuture = LocalBoxFuture<
188        'static,
189        Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
190    >;
191
192    /// Send request, returns Response and Framed
193    fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture {
194        match self {
195            EitherConnection::A(con) => con
196                .open_tunnel(head)
197                .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::A))))
198                .boxed_local(),
199            EitherConnection::B(con) => con
200                .open_tunnel(head)
201                .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::B))))
202                .boxed_local(),
203        }
204    }
205}
206
207#[pin_project]
208pub enum EitherIo<A, B> {
209    A(#[pin] A),
210    B(#[pin] B),
211}
212
213impl<A, B> AsyncRead for EitherIo<A, B>
214where
215    A: AsyncRead,
216    B: AsyncRead,
217{
218    #[project]
219    fn poll_read(
220        self: Pin<&mut Self>,
221        cx: &mut Context<'_>,
222        buf: &mut [u8],
223    ) -> Poll<io::Result<usize>> {
224        #[project]
225        match self.project() {
226            EitherIo::A(val) => val.poll_read(cx, buf),
227            EitherIo::B(val) => val.poll_read(cx, buf),
228        }
229    }
230
231    unsafe fn prepare_uninitialized_buffer(
232        &self,
233        buf: &mut [mem::MaybeUninit<u8>],
234    ) -> bool {
235        match self {
236            EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf),
237            EitherIo::B(ref val) => val.prepare_uninitialized_buffer(buf),
238        }
239    }
240}
241
242impl<A, B> AsyncWrite for EitherIo<A, B>
243where
244    A: AsyncWrite,
245    B: AsyncWrite,
246{
247    #[project]
248    fn poll_write(
249        self: Pin<&mut Self>,
250        cx: &mut Context<'_>,
251        buf: &[u8],
252    ) -> Poll<io::Result<usize>> {
253        #[project]
254        match self.project() {
255            EitherIo::A(val) => val.poll_write(cx, buf),
256            EitherIo::B(val) => val.poll_write(cx, buf),
257        }
258    }
259
260    #[project]
261    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
262        #[project]
263        match self.project() {
264            EitherIo::A(val) => val.poll_flush(cx),
265            EitherIo::B(val) => val.poll_flush(cx),
266        }
267    }
268
269    #[project]
270    fn poll_shutdown(
271        self: Pin<&mut Self>,
272        cx: &mut Context<'_>,
273    ) -> Poll<io::Result<()>> {
274        #[project]
275        match self.project() {
276            EitherIo::A(val) => val.poll_shutdown(cx),
277            EitherIo::B(val) => val.poll_shutdown(cx),
278        }
279    }
280
281    #[project]
282    fn poll_write_buf<U: Buf>(
283        self: Pin<&mut Self>,
284        cx: &mut Context<'_>,
285        buf: &mut U,
286    ) -> Poll<Result<usize, io::Error>>
287    where
288        Self: Sized,
289    {
290        #[project]
291        match self.project() {
292            EitherIo::A(val) => val.poll_write_buf(cx, buf),
293            EitherIo::B(val) => val.poll_write_buf(cx, buf),
294        }
295    }
296}