1use std::pin::Pin;
2use std::task::{Context, Poll};
3use std::{fmt, io, mem, time};
4
5use requiem_codec::{AsyncRead, AsyncWrite, Framed};
6use bytes::{Buf, Bytes};
7use futures_util::future::{err, Either, Future, FutureExt, LocalBoxFuture, Ready};
8use h2::client::SendRequest;
9use pin_project::{pin_project, project};
10
11use crate::body::MessageBody;
12use crate::h1::ClientCodec;
13use crate::message::{RequestHeadType, ResponseHead};
14use crate::payload::Payload;
15
16use super::error::SendRequestError;
17use super::pool::{Acquired, Protocol};
18use super::{h1proto, h2proto};
19
20pub(crate) enum ConnectionType<Io> {
21 H1(Io),
22 H2(SendRequest<Bytes>),
23}
24
25pub trait Connection {
26 type Io: AsyncRead + AsyncWrite + Unpin;
27 type Future: Future<Output = Result<(ResponseHead, Payload), SendRequestError>>;
28
29 fn protocol(&self) -> Protocol;
30
31 fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>(
33 self,
34 head: H,
35 body: B,
36 ) -> Self::Future;
37
38 type TunnelFuture: Future<
39 Output = Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
40 >;
41
42 fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture;
44}
45
46pub(crate) trait ConnectionLifetime: AsyncRead + AsyncWrite + 'static {
47 fn close(&mut self);
49
50 fn release(&mut self);
52}
53
54#[doc(hidden)]
55pub struct IoConnection<T> {
57 io: Option<ConnectionType<T>>,
58 created: time::Instant,
59 pool: Option<Acquired<T>>,
60}
61
62impl<T> fmt::Debug for IoConnection<T>
63where
64 T: fmt::Debug,
65{
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 match self.io {
68 Some(ConnectionType::H1(ref io)) => write!(f, "H1Connection({:?})", io),
69 Some(ConnectionType::H2(_)) => write!(f, "H2Connection"),
70 None => write!(f, "Connection(Empty)"),
71 }
72 }
73}
74
75impl<T: AsyncRead + AsyncWrite + Unpin> IoConnection<T> {
76 pub(crate) fn new(
77 io: ConnectionType<T>,
78 created: time::Instant,
79 pool: Option<Acquired<T>>,
80 ) -> Self {
81 IoConnection {
82 pool,
83 created,
84 io: Some(io),
85 }
86 }
87
88 pub(crate) fn into_inner(self) -> (ConnectionType<T>, time::Instant) {
89 (self.io.unwrap(), self.created)
90 }
91}
92
93impl<T> Connection for IoConnection<T>
94where
95 T: AsyncRead + AsyncWrite + Unpin + 'static,
96{
97 type Io = T;
98 type Future =
99 LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>;
100
101 fn protocol(&self) -> Protocol {
102 match self.io {
103 Some(ConnectionType::H1(_)) => Protocol::Http1,
104 Some(ConnectionType::H2(_)) => Protocol::Http2,
105 None => Protocol::Http1,
106 }
107 }
108
109 fn send_request<B: MessageBody + 'static, H: Into<RequestHeadType>>(
110 mut self,
111 head: H,
112 body: B,
113 ) -> Self::Future {
114 match self.io.take().unwrap() {
115 ConnectionType::H1(io) => {
116 h1proto::send_request(io, head.into(), body, self.created, self.pool)
117 .boxed_local()
118 }
119 ConnectionType::H2(io) => {
120 h2proto::send_request(io, head.into(), body, self.created, self.pool)
121 .boxed_local()
122 }
123 }
124 }
125
126 type TunnelFuture = Either<
127 LocalBoxFuture<
128 'static,
129 Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
130 >,
131 Ready<Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>>,
132 >;
133
134 fn open_tunnel<H: Into<RequestHeadType>>(mut self, head: H) -> Self::TunnelFuture {
136 match self.io.take().unwrap() {
137 ConnectionType::H1(io) => {
138 Either::Left(h1proto::open_tunnel(io, head.into()).boxed_local())
139 }
140 ConnectionType::H2(io) => {
141 if let Some(mut pool) = self.pool.take() {
142 pool.release(IoConnection::new(
143 ConnectionType::H2(io),
144 self.created,
145 None,
146 ));
147 }
148 Either::Right(err(SendRequestError::TunnelNotSupported))
149 }
150 }
151 }
152}
153
154#[allow(dead_code)]
155pub(crate) enum EitherConnection<A, B> {
156 A(IoConnection<A>),
157 B(IoConnection<B>),
158}
159
160impl<A, B> Connection for EitherConnection<A, B>
161where
162 A: AsyncRead + AsyncWrite + Unpin + 'static,
163 B: AsyncRead + AsyncWrite + Unpin + 'static,
164{
165 type Io = EitherIo<A, B>;
166 type Future =
167 LocalBoxFuture<'static, Result<(ResponseHead, Payload), SendRequestError>>;
168
169 fn protocol(&self) -> Protocol {
170 match self {
171 EitherConnection::A(con) => con.protocol(),
172 EitherConnection::B(con) => con.protocol(),
173 }
174 }
175
176 fn send_request<RB: MessageBody + 'static, H: Into<RequestHeadType>>(
177 self,
178 head: H,
179 body: RB,
180 ) -> Self::Future {
181 match self {
182 EitherConnection::A(con) => con.send_request(head, body),
183 EitherConnection::B(con) => con.send_request(head, body),
184 }
185 }
186
187 type TunnelFuture = LocalBoxFuture<
188 'static,
189 Result<(ResponseHead, Framed<Self::Io, ClientCodec>), SendRequestError>,
190 >;
191
192 fn open_tunnel<H: Into<RequestHeadType>>(self, head: H) -> Self::TunnelFuture {
194 match self {
195 EitherConnection::A(con) => con
196 .open_tunnel(head)
197 .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::A))))
198 .boxed_local(),
199 EitherConnection::B(con) => con
200 .open_tunnel(head)
201 .map(|res| res.map(|(head, framed)| (head, framed.map_io(EitherIo::B))))
202 .boxed_local(),
203 }
204 }
205}
206
207#[pin_project]
208pub enum EitherIo<A, B> {
209 A(#[pin] A),
210 B(#[pin] B),
211}
212
213impl<A, B> AsyncRead for EitherIo<A, B>
214where
215 A: AsyncRead,
216 B: AsyncRead,
217{
218 #[project]
219 fn poll_read(
220 self: Pin<&mut Self>,
221 cx: &mut Context<'_>,
222 buf: &mut [u8],
223 ) -> Poll<io::Result<usize>> {
224 #[project]
225 match self.project() {
226 EitherIo::A(val) => val.poll_read(cx, buf),
227 EitherIo::B(val) => val.poll_read(cx, buf),
228 }
229 }
230
231 unsafe fn prepare_uninitialized_buffer(
232 &self,
233 buf: &mut [mem::MaybeUninit<u8>],
234 ) -> bool {
235 match self {
236 EitherIo::A(ref val) => val.prepare_uninitialized_buffer(buf),
237 EitherIo::B(ref val) => val.prepare_uninitialized_buffer(buf),
238 }
239 }
240}
241
242impl<A, B> AsyncWrite for EitherIo<A, B>
243where
244 A: AsyncWrite,
245 B: AsyncWrite,
246{
247 #[project]
248 fn poll_write(
249 self: Pin<&mut Self>,
250 cx: &mut Context<'_>,
251 buf: &[u8],
252 ) -> Poll<io::Result<usize>> {
253 #[project]
254 match self.project() {
255 EitherIo::A(val) => val.poll_write(cx, buf),
256 EitherIo::B(val) => val.poll_write(cx, buf),
257 }
258 }
259
260 #[project]
261 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
262 #[project]
263 match self.project() {
264 EitherIo::A(val) => val.poll_flush(cx),
265 EitherIo::B(val) => val.poll_flush(cx),
266 }
267 }
268
269 #[project]
270 fn poll_shutdown(
271 self: Pin<&mut Self>,
272 cx: &mut Context<'_>,
273 ) -> Poll<io::Result<()>> {
274 #[project]
275 match self.project() {
276 EitherIo::A(val) => val.poll_shutdown(cx),
277 EitherIo::B(val) => val.poll_shutdown(cx),
278 }
279 }
280
281 #[project]
282 fn poll_write_buf<U: Buf>(
283 self: Pin<&mut Self>,
284 cx: &mut Context<'_>,
285 buf: &mut U,
286 ) -> Poll<Result<usize, io::Error>>
287 where
288 Self: Sized,
289 {
290 #[project]
291 match self.project() {
292 EitherIo::A(val) => val.poll_write_buf(cx, buf),
293 EitherIo::B(val) => val.poll_write_buf(cx, buf),
294 }
295 }
296}