1use futures_util::ready;
4use hyper::service::HttpService;
5use std::future::Future;
6use std::io::{Error as IoError, ErrorKind, Result as IoResult};
7use std::marker::PhantomPinned;
8use std::mem::MaybeUninit;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::{error::Error as StdError, marker::Unpin, time::Duration};
12
13use bytes::Bytes;
14use http::{Request, Response};
15use http_body::Body;
16use hyper::{
17 body::Incoming,
18 rt::{bounds::Http2ServerConnExec, Read, ReadBuf, Timer, Write},
19 server::conn::{http1, http2},
20 service::Service,
21};
22use pin_project_lite::pin_project;
23
24use crate::common::rewind::Rewind;
25
26type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
27
28const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
29
30#[derive(Clone, Debug)]
32pub struct Builder<E> {
33 http1: http1::Builder,
34 http2: http2::Builder<E>,
35}
36
37impl<E> Builder<E> {
38 pub fn new(executor: E) -> Self {
54 Self {
55 http1: http1::Builder::new(),
56 http2: http2::Builder::new(executor),
57 }
58 }
59
60 pub fn http1(&mut self) -> Http1Builder<'_, E> {
62 Http1Builder { inner: self }
63 }
64
65 pub fn http2(&mut self) -> Http2Builder<'_, E> {
67 Http2Builder { inner: self }
68 }
69
70 pub fn serve_connection<I, S, B>(&self, io: I, service: S) -> Connection<'_, I, S, E>
72 where
73 S: Service<Request<Incoming>, Response = Response<B>>,
74 S::Future: 'static,
75 S::Error: Into<Box<dyn StdError + Send + Sync>>,
76 B: Body + 'static,
77 B::Error: Into<Box<dyn StdError + Send + Sync>>,
78 I: Read + Write + Unpin + 'static,
79 E: Http2ServerConnExec<S::Future, B>,
80 {
81 Connection {
82 state: ConnState::ReadVersion {
83 read_version: read_version(io),
84 builder: self,
85 service: Some(service),
86 },
87 }
88 }
89
90 pub fn serve_connection_with_upgrades<I, S, B>(
94 &self,
95 io: I,
96 service: S,
97 ) -> UpgradeableConnection<'_, I, S, E>
98 where
99 S: Service<Request<Incoming>, Response = Response<B>>,
100 S::Future: 'static,
101 S::Error: Into<Box<dyn StdError + Send + Sync>>,
102 B: Body + 'static,
103 B::Error: Into<Box<dyn StdError + Send + Sync>>,
104 I: Read + Write + Unpin + Send + 'static,
105 E: Http2ServerConnExec<S::Future, B>,
106 {
107 UpgradeableConnection {
108 state: UpgradeableConnState::ReadVersion {
109 read_version: read_version(io),
110 builder: self,
111 service: Some(service),
112 },
113 }
114 }
115}
116#[derive(Copy, Clone)]
117enum Version {
118 H1,
119 H2,
120}
121
122fn read_version<I>(io: I) -> ReadVersion<I>
123where
124 I: Read + Unpin,
125{
126 ReadVersion {
127 io: Some(io),
128 buf: [MaybeUninit::uninit(); 24],
129 filled: 0,
130 version: Version::H1,
131 _pin: PhantomPinned,
132 }
133}
134
135pin_project! {
136 struct ReadVersion<I> {
137 io: Option<I>,
138 buf: [MaybeUninit<u8>; 24],
139 filled: usize,
141 version: Version,
142 #[pin]
144 _pin: PhantomPinned,
145 }
146}
147
148impl<I> Future for ReadVersion<I>
149where
150 I: Read + Unpin,
151{
152 type Output = IoResult<(Version, Rewind<I>)>;
153
154 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
155 let this = self.project();
156
157 let mut buf = ReadBuf::uninit(&mut *this.buf);
158 unsafe {
161 buf.unfilled().advance(*this.filled);
162 };
163
164 while buf.filled().len() < H2_PREFACE.len() {
165 if buf.filled() != &H2_PREFACE[0..buf.filled().len()] {
166 let io = this.io.take().unwrap();
167 let buf = buf.filled().to_vec();
168 return Poll::Ready(Ok((
169 *this.version,
170 Rewind::new_buffered(io, Bytes::from(buf)),
171 )));
172 } else {
173 let len = buf.filled().len();
175 ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, buf.unfilled()))?;
176 *this.filled = buf.filled().len();
177 if buf.filled().len() == len {
178 return Err(IoError::new(ErrorKind::UnexpectedEof, "early eof")).into();
179 }
180 }
181 }
182 if buf.filled() == H2_PREFACE {
183 *this.version = Version::H2;
184 }
185 let io = this.io.take().unwrap();
186 let buf = buf.filled().to_vec();
187 Poll::Ready(Ok((
188 *this.version,
189 Rewind::new_buffered(io, Bytes::from(buf)),
190 )))
191 }
192}
193
194pin_project! {
195 pub struct Connection<'a, I, S, E>
197 where
198 S: HttpService<Incoming>,
199 {
200 #[pin]
201 state: ConnState<'a, I, S, E>,
202 }
203}
204
205pin_project! {
206 #[project = ConnStateProj]
207 enum ConnState<'a, I, S, E>
208 where
209 S: HttpService<Incoming>,
210 {
211 ReadVersion {
212 #[pin]
213 read_version: ReadVersion<I>,
214 builder: &'a Builder<E>,
215 service: Option<S>,
216 },
217 H1 {
218 #[pin]
219 conn: hyper::server::conn::http1::Connection<Rewind<I>, S>,
220 },
221 H2 {
222 #[pin]
223 conn: hyper::server::conn::http2::Connection<Rewind<I>, S, E>,
224 },
225 }
226}
227
228impl<I, S, E, B> Connection<'_, I, S, E>
229where
230 S: HttpService<Incoming, ResBody = B>,
231 S::Error: Into<Box<dyn StdError + Send + Sync>>,
232 I: Read + Write + Unpin,
233 B: Body + 'static,
234 B::Error: Into<Box<dyn StdError + Send + Sync>>,
235 E: Http2ServerConnExec<S::Future, B>,
236{
237 pub fn graceful_shutdown(self: Pin<&mut Self>) {
246 match self.project().state.project() {
247 ConnStateProj::ReadVersion { .. } => {}
248 ConnStateProj::H1 { conn } => conn.graceful_shutdown(),
249 ConnStateProj::H2 { conn } => conn.graceful_shutdown(),
250 }
251 }
252}
253
254impl<I, S, E, B> Future for Connection<'_, I, S, E>
255where
256 S: Service<Request<Incoming>, Response = Response<B>>,
257 S::Future: 'static,
258 S::Error: Into<Box<dyn StdError + Send + Sync>>,
259 B: Body + 'static,
260 B::Error: Into<Box<dyn StdError + Send + Sync>>,
261 I: Read + Write + Unpin + 'static,
262 E: Http2ServerConnExec<S::Future, B>,
263{
264 type Output = Result<()>;
265
266 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
267 loop {
268 let mut this = self.as_mut().project();
269
270 match this.state.as_mut().project() {
271 ConnStateProj::ReadVersion {
272 read_version,
273 builder,
274 service,
275 } => {
276 let (version, io) = ready!(read_version.poll(cx))?;
277 let service = service.take().unwrap();
278 match version {
279 Version::H1 => {
280 let conn = builder.http1.serve_connection(io, service);
281 this.state.set(ConnState::H1 { conn });
282 }
283 Version::H2 => {
284 let conn = builder.http2.serve_connection(io, service);
285 this.state.set(ConnState::H2 { conn });
286 }
287 }
288 }
289 ConnStateProj::H1 { conn } => {
290 return conn.poll(cx).map_err(Into::into);
291 }
292 ConnStateProj::H2 { conn } => {
293 return conn.poll(cx).map_err(Into::into);
294 }
295 }
296 }
297 }
298}
299
300pin_project! {
301 pub struct UpgradeableConnection<'a, I, S, E>
303 where
304 S: HttpService<Incoming>,
305 {
306 #[pin]
307 state: UpgradeableConnState<'a, I, S, E>,
308 }
309}
310
311pin_project! {
312 #[project = UpgradeableConnStateProj]
313 enum UpgradeableConnState<'a, I, S, E>
314 where
315 S: HttpService<Incoming>,
316 {
317 ReadVersion {
318 #[pin]
319 read_version: ReadVersion<I>,
320 builder: &'a Builder<E>,
321 service: Option<S>,
322 },
323 H1 {
324 #[pin]
325 conn: hyper::server::conn::http1::UpgradeableConnection<Rewind<I>, S>,
326 },
327 H2 {
328 #[pin]
329 conn: hyper::server::conn::http2::Connection<Rewind<I>, S, E>,
330 },
331 }
332}
333
334impl<I, S, E, B> UpgradeableConnection<'_, I, S, E>
335where
336 S: HttpService<Incoming, ResBody = B>,
337 S::Error: Into<Box<dyn StdError + Send + Sync>>,
338 I: Read + Write + Unpin,
339 B: Body + 'static,
340 B::Error: Into<Box<dyn StdError + Send + Sync>>,
341 E: Http2ServerConnExec<S::Future, B>,
342{
343 pub fn graceful_shutdown(self: Pin<&mut Self>) {
352 match self.project().state.project() {
353 UpgradeableConnStateProj::ReadVersion { .. } => {}
354 UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(),
355 UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(),
356 }
357 }
358}
359
360impl<I, S, E, B> Future for UpgradeableConnection<'_, I, S, E>
361where
362 S: Service<Request<Incoming>, Response = Response<B>>,
363 S::Future: 'static,
364 S::Error: Into<Box<dyn StdError + Send + Sync>>,
365 B: Body + 'static,
366 B::Error: Into<Box<dyn StdError + Send + Sync>>,
367 I: Read + Write + Unpin + Send + 'static,
368 E: Http2ServerConnExec<S::Future, B>,
369{
370 type Output = Result<()>;
371
372 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
373 loop {
374 let mut this = self.as_mut().project();
375
376 match this.state.as_mut().project() {
377 UpgradeableConnStateProj::ReadVersion {
378 read_version,
379 builder,
380 service,
381 } => {
382 let (version, io) = ready!(read_version.poll(cx))?;
383 let service = service.take().unwrap();
384 match version {
385 Version::H1 => {
386 let conn = builder.http1.serve_connection(io, service).with_upgrades();
387 this.state.set(UpgradeableConnState::H1 { conn });
388 }
389 Version::H2 => {
390 let conn = builder.http2.serve_connection(io, service);
391 this.state.set(UpgradeableConnState::H2 { conn });
392 }
393 }
394 }
395 UpgradeableConnStateProj::H1 { conn } => {
396 return conn.poll(cx).map_err(Into::into);
397 }
398 UpgradeableConnStateProj::H2 { conn } => {
399 return conn.poll(cx).map_err(Into::into);
400 }
401 }
402 }
403 }
404}
405
406pub struct Http1Builder<'a, E> {
408 inner: &'a mut Builder<E>,
409}
410
411impl<E> Http1Builder<'_, E> {
412 pub fn http2(&mut self) -> Http2Builder<'_, E> {
414 Http2Builder {
415 inner: &mut self.inner,
416 }
417 }
418
419 pub fn half_close(&mut self, val: bool) -> &mut Self {
428 self.inner.http1.half_close(val);
429 self
430 }
431
432 pub fn keep_alive(&mut self, val: bool) -> &mut Self {
436 self.inner.http1.keep_alive(val);
437 self
438 }
439
440 pub fn title_case_headers(&mut self, enabled: bool) -> &mut Self {
447 self.inner.http1.title_case_headers(enabled);
448 self
449 }
450
451 pub fn preserve_header_case(&mut self, enabled: bool) -> &mut Self {
465 self.inner.http1.preserve_header_case(enabled);
466 self
467 }
468
469 pub fn header_read_timeout(&mut self, read_timeout: Duration) -> &mut Self {
474 self.inner.http1.header_read_timeout(read_timeout);
475 self
476 }
477
478 pub fn writev(&mut self, val: bool) -> &mut Self {
491 self.inner.http1.writev(val);
492 self
493 }
494
495 pub fn max_buf_size(&mut self, max: usize) -> &mut Self {
503 self.inner.http1.max_buf_size(max);
504 self
505 }
506
507 pub fn pipeline_flush(&mut self, enabled: bool) -> &mut Self {
513 self.inner.http1.pipeline_flush(enabled);
514 self
515 }
516
517 pub fn timer<M>(&mut self, timer: M) -> &mut Self
519 where
520 M: Timer + Send + Sync + 'static,
521 {
522 self.inner.http1.timer(timer);
523 self
524 }
525
526 pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()>
528 where
529 S: Service<Request<Incoming>, Response = Response<B>>,
530 S::Future: 'static,
531 S::Error: Into<Box<dyn StdError + Send + Sync>>,
532 B: Body + 'static,
533 B::Error: Into<Box<dyn StdError + Send + Sync>>,
534 I: Read + Write + Unpin + 'static,
535 E: Http2ServerConnExec<S::Future, B>,
536 {
537 self.inner.serve_connection(io, service).await
538 }
539}
540
541pub struct Http2Builder<'a, E> {
543 inner: &'a mut Builder<E>,
544}
545
546impl<E> Http2Builder<'_, E> {
547 pub fn http1(&mut self) -> Http1Builder<'_, E> {
549 Http1Builder {
550 inner: &mut self.inner,
551 }
552 }
553
554 pub fn initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
563 self.inner.http2.initial_stream_window_size(sz);
564 self
565 }
566
567 pub fn initial_connection_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
573 self.inner.http2.initial_connection_window_size(sz);
574 self
575 }
576
577 pub fn adaptive_window(&mut self, enabled: bool) -> &mut Self {
583 self.inner.http2.adaptive_window(enabled);
584 self
585 }
586
587 pub fn max_frame_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
593 self.inner.http2.max_frame_size(sz);
594 self
595 }
596
597 pub fn max_concurrent_streams(&mut self, max: impl Into<Option<u32>>) -> &mut Self {
604 self.inner.http2.max_concurrent_streams(max);
605 self
606 }
607
608 pub fn keep_alive_interval(&mut self, interval: impl Into<Option<Duration>>) -> &mut Self {
618 self.inner.http2.keep_alive_interval(interval);
619 self
620 }
621
622 pub fn keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self {
632 self.inner.http2.keep_alive_timeout(timeout);
633 self
634 }
635
636 pub fn max_send_buf_size(&mut self, max: usize) -> &mut Self {
644 self.inner.http2.max_send_buf_size(max);
645 self
646 }
647
648 pub fn enable_connect_protocol(&mut self) -> &mut Self {
652 self.inner.http2.enable_connect_protocol();
653 self
654 }
655
656 pub fn max_header_list_size(&mut self, max: u32) -> &mut Self {
660 self.inner.http2.max_header_list_size(max);
661 self
662 }
663
664 pub fn timer<M>(&mut self, timer: M) -> &mut Self
666 where
667 M: Timer + Send + Sync + 'static,
668 {
669 self.inner.http2.timer(timer);
670 self
671 }
672
673 pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()>
675 where
676 S: Service<Request<Incoming>, Response = Response<B>>,
677 S::Future: 'static,
678 S::Error: Into<Box<dyn StdError + Send + Sync>>,
679 B: Body + 'static,
680 B::Error: Into<Box<dyn StdError + Send + Sync>>,
681 I: Read + Write + Unpin + 'static,
682 E: Http2ServerConnExec<S::Future, B>,
683 {
684 self.inner.serve_connection(io, service).await
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use crate::{
691 rt::{TokioExecutor, TokioIo},
692 server::conn::auto,
693 };
694 use http::{Request, Response};
695 use http_body::Body;
696 use http_body_util::{BodyExt, Empty, Full};
697 use hyper::{body, body::Bytes, client, service::service_fn};
698 use std::{convert::Infallible, error::Error as StdError, net::SocketAddr};
699 use tokio::net::{TcpListener, TcpStream};
700
701 const BODY: &[u8] = b"Hello, world!";
702
703 #[test]
704 fn configuration() {
705 auto::Builder::new(TokioExecutor::new())
707 .http1()
708 .keep_alive(true)
709 .http2()
710 .keep_alive_interval(None);
711 let mut builder = auto::Builder::new(TokioExecutor::new());
715
716 builder.http1().keep_alive(true);
717 builder.http2().keep_alive_interval(None);
718 }
720
721 #[cfg(not(miri))]
722 #[tokio::test]
723 async fn http1() {
724 let addr = start_server().await;
725 let mut sender = connect_h1(addr).await;
726
727 let response = sender
728 .send_request(Request::new(Empty::<Bytes>::new()))
729 .await
730 .unwrap();
731
732 let body = response.into_body().collect().await.unwrap().to_bytes();
733
734 assert_eq!(body, BODY);
735 }
736
737 #[cfg(not(miri))]
738 #[tokio::test]
739 async fn http2() {
740 let addr = start_server().await;
741 let mut sender = connect_h2(addr).await;
742
743 let response = sender
744 .send_request(Request::new(Empty::<Bytes>::new()))
745 .await
746 .unwrap();
747
748 let body = response.into_body().collect().await.unwrap().to_bytes();
749
750 assert_eq!(body, BODY);
751 }
752
753 async fn connect_h1<B>(addr: SocketAddr) -> client::conn::http1::SendRequest<B>
754 where
755 B: Body + Send + 'static,
756 B::Data: Send,
757 B::Error: Into<Box<dyn StdError + Send + Sync>>,
758 {
759 let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
760 let (sender, connection) = client::conn::http1::handshake(stream).await.unwrap();
761
762 tokio::spawn(connection);
763
764 sender
765 }
766
767 async fn connect_h2<B>(addr: SocketAddr) -> client::conn::http2::SendRequest<B>
768 where
769 B: Body + Unpin + Send + 'static,
770 B::Data: Send,
771 B::Error: Into<Box<dyn StdError + Send + Sync>>,
772 {
773 let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
774 let (sender, connection) = client::conn::http2::Builder::new(TokioExecutor::new())
775 .handshake(stream)
776 .await
777 .unwrap();
778
779 tokio::spawn(connection);
780
781 sender
782 }
783
784 async fn start_server() -> SocketAddr {
785 let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
786 let listener = TcpListener::bind(addr).await.unwrap();
787
788 let local_addr = listener.local_addr().unwrap();
789
790 tokio::spawn(async move {
791 loop {
792 let (stream, _) = listener.accept().await.unwrap();
793 let stream = TokioIo::new(stream);
794 tokio::task::spawn(async move {
795 let _ = auto::Builder::new(TokioExecutor::new())
796 .serve_connection(stream, service_fn(hello))
797 .await;
798 });
799 }
800 });
801
802 local_addr
803 }
804
805 async fn hello(_req: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, Infallible> {
806 Ok(Response::new(Full::new(Bytes::from(BODY))))
807 }
808}