rama_http_core/server/conn/
auto.rs

1//! Http1 or Http2 connection.
2
3use futures_util::ready;
4use std::marker::PhantomPinned;
5use std::mem::MaybeUninit;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::{io, time::Duration};
9use tokio::io::AsyncRead;
10use tokio::io::AsyncWrite;
11use tokio::io::ReadBuf;
12
13use bytes::Bytes;
14use pin_project_lite::pin_project;
15
16use crate::body::Incoming;
17use crate::common::io::Rewind;
18use crate::service::HttpService;
19use rama_core::error::BoxError;
20use rama_core::rt::Executor;
21
22use super::{http1, http2};
23
24type Result<T> = std::result::Result<T, BoxError>;
25
26const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
27
28/// Http1 or Http2 connection builder.
29#[derive(Clone, Debug)]
30pub struct Builder {
31    http1: http1::Builder,
32    http2: http2::Builder,
33    version: Option<Version>,
34}
35
36impl Builder {
37    /// Create a new auto connection builder.
38    pub fn new(executor: Executor) -> Self {
39        Self {
40            http1: http1::Builder::new(),
41            http2: http2::Builder::new(executor),
42            version: None,
43        }
44    }
45
46    /// Http1 configuration.
47    pub fn http1(&mut self) -> Http1Builder<'_> {
48        Http1Builder { inner: self }
49    }
50
51    /// Http2 configuration.
52    pub fn http2(&mut self) -> Http2Builder<'_> {
53        Http2Builder { inner: self }
54    }
55
56    /// Only accepts HTTP/2
57    ///
58    /// Does not do anything if used with [`serve_connection_with_upgrades`]
59    ///
60    /// [`serve_connection_with_upgrades`]: Builder::serve_connection_with_upgrades
61    pub fn http2_only(mut self) -> Self {
62        assert!(self.version.is_none());
63        self.version = Some(Version::H2);
64        self
65    }
66
67    /// Only accepts HTTP/1
68    ///
69    /// Does not do anything if used with [`serve_connection_with_upgrades`]
70    ///
71    /// [`serve_connection_with_upgrades`]: Builder::serve_connection_with_upgrades
72    pub fn http1_only(mut self) -> Self {
73        assert!(self.version.is_none());
74        self.version = Some(Version::H1);
75        self
76    }
77
78    /// Returns `true` if this builder can serve an HTTP/1.1-based connection.
79    pub fn is_http1_available(&self) -> bool {
80        match self.version {
81            Some(Version::H1) => true,
82            Some(Version::H2) => false,
83            _ => true,
84        }
85    }
86
87    /// Returns `true` if this builder can serve an HTTP/2-based connection.
88    pub fn is_http2_available(&self) -> bool {
89        match self.version {
90            Some(Version::H1) => false,
91            Some(Version::H2) => true,
92            _ => true,
93        }
94    }
95
96    /// Bind a connection together with a [`Service`].
97    pub fn serve_connection<I, S>(&self, io: I, service: S) -> Connection<'_, I, S>
98    where
99        S: HttpService<Incoming>,
100        I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
101    {
102        let state = match self.version {
103            Some(Version::H1) => {
104                let io = Rewind::new_buffered(io, Bytes::new());
105                let conn = self.http1.serve_connection(io, service);
106                ConnState::H1 { conn }
107            }
108            Some(Version::H2) => {
109                let io = Rewind::new_buffered(io, Bytes::new());
110                let conn = self.http2.serve_connection(io, service);
111                ConnState::H2 { conn }
112            }
113            _ => ConnState::ReadVersion {
114                read_version: read_version(io),
115                builder: Cow::Borrowed(self),
116                service: Some(service),
117            },
118        };
119
120        Connection { state }
121    }
122
123    /// Bind a connection together with a [`Service`], with the ability to
124    /// handle HTTP upgrades. This requires that the IO object implements
125    /// `Send`.
126    pub fn serve_connection_with_upgrades<I, S>(
127        &self,
128        io: I,
129        service: S,
130    ) -> UpgradeableConnection<'_, I, S>
131    where
132        S: HttpService<Incoming>,
133        I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
134    {
135        UpgradeableConnection {
136            state: UpgradeableConnState::ReadVersion {
137                read_version: read_version(io),
138                builder: Cow::Borrowed(self),
139                service: Some(service),
140            },
141        }
142    }
143}
144
145#[derive(Copy, Clone, Debug)]
146enum Version {
147    H1,
148    H2,
149}
150
151fn read_version<I>(io: I) -> ReadVersion<I>
152where
153    I: AsyncRead + Unpin,
154{
155    ReadVersion {
156        io: Some(io),
157        buf: [MaybeUninit::uninit(); 24],
158        filled: 0,
159        version: Version::H2,
160        cancelled: false,
161        _pin: PhantomPinned,
162    }
163}
164
165pin_project! {
166    struct ReadVersion<I> {
167        io: Option<I>,
168        buf: [MaybeUninit<u8>; 24],
169        // the amount of `buf` thats been filled
170        filled: usize,
171        version: Version,
172        cancelled: bool,
173        // Make this future `!Unpin` for compatibility with async trait methods.
174        #[pin]
175        _pin: PhantomPinned,
176    }
177}
178
179impl<I> ReadVersion<I> {
180    pub fn cancel(self: Pin<&mut Self>) {
181        *self.project().cancelled = true;
182    }
183}
184
185impl<I> Future for ReadVersion<I>
186where
187    I: AsyncRead + Unpin,
188{
189    type Output = io::Result<(Version, Rewind<I>)>;
190
191    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
192        let this = self.project();
193        if *this.cancelled {
194            return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "Cancelled")));
195        }
196
197        let mut buf = ReadBuf::uninit(&mut *this.buf);
198        buf.advance(*this.filled);
199
200        // We start as H2 and switch to H1 as soon as we don't have the preface.
201        while buf.filled().len() < H2_PREFACE.len() {
202            let len = buf.filled().len();
203            ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, &mut buf))?;
204            *this.filled = buf.filled().len();
205
206            // We starts as H2 and switch to H1 when we don't get the preface.
207            if buf.filled().len() == len
208                || buf.filled()[len..] != H2_PREFACE[len..buf.filled().len()]
209            {
210                *this.version = Version::H1;
211                break;
212            }
213        }
214
215        let io = this.io.take().unwrap();
216        let buf = buf.filled().to_vec();
217        Poll::Ready(Ok((
218            *this.version,
219            Rewind::new_buffered(io, Bytes::from(buf)),
220        )))
221    }
222}
223
224pin_project! {
225    /// A [`Future`](core::future::Future) representing an HTTP/1 connection, returned from
226    /// [`Builder::serve_connection`](struct.Builder.html#method.serve_connection).
227    ///
228    /// To drive HTTP on this connection this future **must be polled**, typically with
229    /// `.await`. If it isn't polled, no progress will be made on this connection.
230    #[must_use = "futures do nothing unless polled"]
231    pub struct Connection<'a, I, S>
232    where
233        S: HttpService<Incoming>,
234    {
235        #[pin]
236        state: ConnState<'a, I, S>,
237    }
238}
239
240// A custom COW, since the libstd is has ToOwned bounds that are too eager.
241enum Cow<'a, T> {
242    Borrowed(&'a T),
243    Owned(T),
244}
245
246impl<T> std::ops::Deref for Cow<'_, T> {
247    type Target = T;
248    fn deref(&self) -> &T {
249        match self {
250            Cow::Borrowed(t) => t,
251            Cow::Owned(t) => t,
252        }
253    }
254}
255
256type Http1Connection<I, S> = http1::Connection<Rewind<I>, S>;
257
258type Http2Connection<I, S> = http2::Connection<Rewind<I>, S>;
259
260pin_project! {
261    #[project = ConnStateProj]
262    enum ConnState<'a, I, S>
263    where
264        S: HttpService<Incoming>,
265    {
266        ReadVersion {
267            #[pin]
268            read_version: ReadVersion<I>,
269            builder: Cow<'a, Builder>,
270            service: Option<S>,
271        },
272        H1 {
273            #[pin]
274            conn: Http1Connection<I, S>,
275        },
276        H2 {
277            #[pin]
278            conn: Http2Connection<I, S>,
279        },
280    }
281}
282
283impl<I, S> Connection<'_, I, S>
284where
285    S: HttpService<Incoming>,
286    I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
287{
288    /// Start a graceful shutdown process for this connection.
289    ///
290    /// This `Connection` should continue to be polled until shutdown can finish.
291    ///
292    /// # Note
293    ///
294    /// This should only be called while the `Connection` future is still pending. If called after
295    /// `Connection::poll` has resolved, this does nothing.
296    pub fn graceful_shutdown(self: Pin<&mut Self>) {
297        match self.project().state.project() {
298            ConnStateProj::ReadVersion { read_version, .. } => read_version.cancel(),
299            ConnStateProj::H1 { conn } => conn.graceful_shutdown(),
300            ConnStateProj::H2 { conn } => conn.graceful_shutdown(),
301        }
302    }
303
304    /// Make this Connection static, instead of borrowing from Builder.
305    pub fn into_owned(self) -> Connection<'static, I, S>
306    where
307        Builder: Clone,
308    {
309        Connection {
310            state: match self.state {
311                ConnState::ReadVersion {
312                    read_version,
313                    builder,
314                    service,
315                } => ConnState::ReadVersion {
316                    read_version,
317                    service,
318                    builder: Cow::Owned(builder.clone()),
319                },
320                ConnState::H1 { conn } => ConnState::H1 { conn },
321                ConnState::H2 { conn } => ConnState::H2 { conn },
322            },
323        }
324    }
325}
326
327impl<I, S> Future for Connection<'_, I, S>
328where
329    S: HttpService<Incoming>,
330    I: AsyncRead + AsyncWrite + Send + Unpin + 'static + 'static,
331{
332    type Output = Result<()>;
333
334    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
335        loop {
336            let mut this = self.as_mut().project();
337
338            match this.state.as_mut().project() {
339                ConnStateProj::ReadVersion {
340                    read_version,
341                    builder,
342                    service,
343                } => {
344                    let (version, io) = ready!(read_version.poll(cx))?;
345                    let service = service.take().unwrap();
346                    match version {
347                        Version::H1 => {
348                            let conn = builder.http1.serve_connection(io, service);
349                            this.state.set(ConnState::H1 { conn });
350                        }
351                        Version::H2 => {
352                            let conn = builder.http2.serve_connection(io, service);
353                            this.state.set(ConnState::H2 { conn });
354                        }
355                    }
356                }
357                ConnStateProj::H1 { conn } => {
358                    return conn.poll(cx).map_err(Into::into);
359                }
360                ConnStateProj::H2 { conn } => {
361                    return conn.poll(cx).map_err(Into::into);
362                }
363            }
364        }
365    }
366}
367
368pin_project! {
369    /// An upgradable [`Connection`], returned by
370    /// [`Builder::serve_upgradable_connection`](struct.Builder.html#method.serve_connection_with_upgrades).
371    ///
372    /// To drive HTTP on this connection this future **must be polled**, typically with
373    /// `.await`. If it isn't polled, no progress will be made on this connection.
374    #[must_use = "futures do nothing unless polled"]
375    pub struct UpgradeableConnection<'a, I, S>
376    where
377        S: HttpService<Incoming>,
378    {
379        #[pin]
380        state: UpgradeableConnState<'a, I, S>,
381    }
382}
383
384type Http1UpgradeableConnection<I, S> = http1::UpgradeableConnection<I, S>;
385
386pin_project! {
387    #[project = UpgradeableConnStateProj]
388    enum UpgradeableConnState<'a, I, S>
389    where
390        S: HttpService<Incoming>,
391    {
392        ReadVersion {
393            #[pin]
394            read_version: ReadVersion<I>,
395            builder: Cow<'a, Builder>,
396            service: Option<S>,
397        },
398        H1 {
399            #[pin]
400            conn: Http1UpgradeableConnection<Rewind<I>, S>,
401        },
402        H2 {
403            #[pin]
404            conn: Http2Connection<I, S>,
405        },
406    }
407}
408
409impl<I, S> UpgradeableConnection<'_, I, S>
410where
411    S: HttpService<Incoming>,
412    I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
413{
414    /// Start a graceful shutdown process for this connection.
415    ///
416    /// This `UpgradeableConnection` should continue to be polled until shutdown can finish.
417    ///
418    /// # Note
419    ///
420    /// This should only be called while the `Connection` future is still nothing. pending. If
421    /// called after `UpgradeableConnection::poll` has resolved, this does nothing.
422    pub fn graceful_shutdown(self: Pin<&mut Self>) {
423        match self.project().state.project() {
424            UpgradeableConnStateProj::ReadVersion { read_version, .. } => read_version.cancel(),
425            UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(),
426            UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(),
427        }
428    }
429
430    /// Make this Connection static, instead of borrowing from Builder.
431    pub fn into_owned(self) -> UpgradeableConnection<'static, I, S>
432    where
433        Builder: Clone,
434    {
435        UpgradeableConnection {
436            state: match self.state {
437                UpgradeableConnState::ReadVersion {
438                    read_version,
439                    builder,
440                    service,
441                } => UpgradeableConnState::ReadVersion {
442                    read_version,
443                    service,
444                    builder: Cow::Owned(builder.clone()),
445                },
446                UpgradeableConnState::H1 { conn } => UpgradeableConnState::H1 { conn },
447                UpgradeableConnState::H2 { conn } => UpgradeableConnState::H2 { conn },
448            },
449        }
450    }
451}
452
453impl<I, S> Future for UpgradeableConnection<'_, I, S>
454where
455    S: HttpService<Incoming>,
456    I: AsyncRead + AsyncWrite + Send + Unpin + 'static + Send + 'static,
457{
458    type Output = Result<()>;
459
460    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
461        loop {
462            let mut this = self.as_mut().project();
463
464            match this.state.as_mut().project() {
465                UpgradeableConnStateProj::ReadVersion {
466                    read_version,
467                    builder,
468                    service,
469                } => {
470                    let (version, io) = ready!(read_version.poll(cx))?;
471                    let service = service.take().unwrap();
472                    match version {
473                        Version::H1 => {
474                            let conn = builder.http1.serve_connection(io, service).with_upgrades();
475                            this.state.set(UpgradeableConnState::H1 { conn });
476                        }
477                        Version::H2 => {
478                            let conn = builder.http2.serve_connection(io, service);
479                            this.state.set(UpgradeableConnState::H2 { conn });
480                        }
481                    }
482                }
483                UpgradeableConnStateProj::H1 { conn } => {
484                    return conn.poll(cx).map_err(Into::into);
485                }
486                UpgradeableConnStateProj::H2 { conn } => {
487                    return conn.poll(cx).map_err(Into::into);
488                }
489            }
490        }
491    }
492}
493
494/// Http1 part of builder.
495pub struct Http1Builder<'a> {
496    inner: &'a mut Builder,
497}
498
499impl Http1Builder<'_> {
500    /// Http2 configuration.
501    pub fn http2(&mut self) -> Http2Builder<'_> {
502        Http2Builder { inner: self.inner }
503    }
504
505    /// Set whether the `date` header should be included in HTTP responses.
506    ///
507    /// Note that including the `date` header is recommended by RFC 7231.
508    ///
509    /// Default is true.
510    pub fn auto_date_header(&mut self, enabled: bool) -> &mut Self {
511        self.inner.http1.auto_date_header(enabled);
512        self
513    }
514
515    /// Set whether HTTP/1 connections should support half-closures.
516    ///
517    /// Clients can chose to shutdown their write-side while waiting
518    /// for the server to respond. Setting this to `true` will
519    /// prevent closing the connection immediately if `read`
520    /// detects an EOF in the middle of a request.
521    ///
522    /// Default is `false`.
523    pub fn half_close(&mut self, val: bool) -> &mut Self {
524        self.inner.http1.half_close(val);
525        self
526    }
527
528    /// Enables or disables HTTP/1 keep-alive.
529    ///
530    /// Default is true.
531    pub fn keep_alive(&mut self, val: bool) -> &mut Self {
532        self.inner.http1.keep_alive(val);
533        self
534    }
535
536    /// Set whether HTTP/1 connections will write header names as title case at
537    /// the socket level.
538    ///
539    /// Note that this setting does not affect HTTP/2.
540    ///
541    /// Default is false.
542    pub fn title_case_headers(&mut self, enabled: bool) -> &mut Self {
543        self.inner.http1.title_case_headers(enabled);
544        self
545    }
546
547    /// Set whether HTTP/1 connections will silently ignored malformed header lines.
548    ///
549    /// If this is enabled and a header line does not start with a valid header
550    /// name, or does not include a colon at all, the line will be silently ignored
551    /// and no error will be reported.
552    ///
553    /// Default is false.
554    pub fn ignore_invalid_headers(&mut self, enabled: bool) -> &mut Self {
555        self.inner.http1.ignore_invalid_headers(enabled);
556        self
557    }
558
559    /// Set the maximum number of headers.
560    ///
561    /// When a request is received, the parser will reserve a buffer to store headers for optimal
562    /// performance.
563    ///
564    /// If server receives more headers than the buffer size, it responds to the client with
565    /// "431 Request Header Fields Too Large".
566    ///
567    /// The headers is allocated on the stack by default, which has higher performance. After
568    /// setting this value, headers will be allocated in heap memory, that is, heap memory
569    /// allocation will occur for each request, and there will be a performance drop of about 5%.
570    ///
571    /// Note that this setting does not affect HTTP/2.
572    ///
573    /// Default is 100.
574    pub fn max_headers(&mut self, val: usize) -> &mut Self {
575        self.inner.http1.max_headers(val);
576        self
577    }
578
579    /// Set a timeout for reading client request headers. If a client does not
580    /// transmit the entire header within this time, the connection is closed.
581    ///
582    /// Default is currently 30 seconds, but do not depend on that.
583    pub fn header_read_timeout(&mut self, read_timeout: Duration) -> &mut Self {
584        self.inner.http1.header_read_timeout(read_timeout);
585        self
586    }
587
588    /// Set whether HTTP/1 connections should try to use vectored writes,
589    /// or always flatten into a single buffer.
590    ///
591    /// Note that setting this to false may mean more copies of body data,
592    /// but may also improve performance when an IO transport doesn't
593    /// support vectored writes well, such as most TLS implementations.
594    ///
595    /// Setting this to true will force hyper to use queued strategy
596    /// which may eliminate unnecessary cloning on some TLS backends
597    ///
598    /// Default is `auto`. In this mode rama-http-core will try to guess which
599    /// mode to use
600    pub fn writev(&mut self, val: bool) -> &mut Self {
601        self.inner.http1.writev(val);
602        self
603    }
604
605    /// Set the maximum buffer size for the connection.
606    ///
607    /// Default is ~400kb.
608    ///
609    /// # Panics
610    ///
611    /// The minimum value allowed is 8192. This method panics if the passed `max` is less than the minimum.
612    pub fn max_buf_size(&mut self, max: usize) -> &mut Self {
613        self.inner.http1.max_buf_size(max);
614        self
615    }
616
617    /// Aggregates flushes to better support pipelined responses.
618    ///
619    /// Experimental, may have bugs.
620    ///
621    /// Default is false.
622    pub fn pipeline_flush(&mut self, enabled: bool) -> &mut Self {
623        self.inner.http1.pipeline_flush(enabled);
624        self
625    }
626
627    /// Bind a connection together with a [`Service`].
628    pub async fn serve_connection<I, S>(&self, io: I, service: S) -> Result<()>
629    where
630        S: HttpService<Incoming>,
631        I: AsyncRead + AsyncWrite + Send + Unpin + 'static + Send + 'static,
632    {
633        self.inner.serve_connection(io, service).await
634    }
635
636    /// Bind a connection together with a [`Service`], with the ability to
637    /// handle HTTP upgrades. This requires that the IO object implements
638    /// `Send`.
639    pub fn serve_connection_with_upgrades<I, S>(
640        &self,
641        io: I,
642        service: S,
643    ) -> UpgradeableConnection<'_, I, S>
644    where
645        S: HttpService<Incoming>,
646        I: AsyncRead + AsyncWrite + Send + Unpin + 'static + Send + 'static,
647    {
648        self.inner.serve_connection_with_upgrades(io, service)
649    }
650}
651
652/// Http2 part of builder.
653pub struct Http2Builder<'a> {
654    inner: &'a mut Builder,
655}
656
657impl Http2Builder<'_> {
658    /// Http1 configuration.
659    pub fn http1(&mut self) -> Http1Builder<'_> {
660        Http1Builder { inner: self.inner }
661    }
662
663    /// Configures the maximum number of pending reset streams allowed before a GOAWAY will be sent.
664    ///
665    /// This will default to the default value set by the [`h2` crate](https://crates.io/crates/h2).
666    /// As of v0.4.0, it is 20.
667    ///
668    /// See <https://github.com/hyperium/hyper/issues/2877> for more information.
669    pub fn max_pending_accept_reset_streams(&mut self, max: impl Into<Option<usize>>) -> &mut Self {
670        self.inner.http2.max_pending_accept_reset_streams(max);
671        self
672    }
673
674    /// Configures the maximum number of local reset streams allowed before a GOAWAY will be sent.
675    ///
676    /// If not set, rama-http-core will use a default, currently of 1024.
677    ///
678    /// If `None` is supplied, rama-http-core will not apply any limit.
679    /// This is not advised, as it can potentially expose servers to DOS vulnerabilities.
680    ///
681    /// See <https://rustsec.org/advisories/RUSTSEC-2024-0003.html> for more information.
682    pub fn max_local_error_reset_streams(&mut self, max: impl Into<Option<usize>>) -> &mut Self {
683        self.inner.http2.max_local_error_reset_streams(max);
684        self
685    }
686
687    /// Sets the [`SETTINGS_INITIAL_WINDOW_SIZE`][spec] option for HTTP2
688    /// stream-level flow control.
689    ///
690    /// Passing `None` will do nothing.
691    ///
692    /// If not set, rama-http-core will use a default.
693    ///
694    /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_INITIAL_WINDOW_SIZE
695    pub fn initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
696        self.inner.http2.initial_stream_window_size(sz);
697        self
698    }
699
700    /// Sets the max connection-level flow control for HTTP2.
701    ///
702    /// Passing `None` will do nothing.
703    ///
704    /// If not set, rama-http-core will use a default.
705    pub fn initial_connection_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
706        self.inner.http2.initial_connection_window_size(sz);
707        self
708    }
709
710    /// Sets whether to use an adaptive flow control.
711    ///
712    /// Enabling this will override the limits set in
713    /// `http2_initial_stream_window_size` and
714    /// `http2_initial_connection_window_size`.
715    pub fn adaptive_window(&mut self, enabled: bool) -> &mut Self {
716        self.inner.http2.adaptive_window(enabled);
717        self
718    }
719
720    /// Sets the maximum frame size to use for HTTP2.
721    ///
722    /// Passing `None` will do nothing.
723    ///
724    /// If not set, rama-http-core will use a default.
725    pub fn max_frame_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
726        self.inner.http2.max_frame_size(sz);
727        self
728    }
729
730    /// Sets the [`SETTINGS_MAX_CONCURRENT_STREAMS`][spec] option for HTTP2
731    /// connections.
732    ///
733    /// Default is 200. Passing `None` will remove any limit.
734    ///
735    /// [spec]: https://http2.github.io/http2-spec/#SETTINGS_MAX_CONCURRENT_STREAMS
736    pub fn max_concurrent_streams(&mut self, max: impl Into<Option<u32>>) -> &mut Self {
737        self.inner.http2.max_concurrent_streams(max);
738        self
739    }
740
741    /// Sets an interval for HTTP2 Ping frames should be sent to keep a
742    /// connection alive.
743    ///
744    /// Pass `None` to disable HTTP2 keep-alive.
745    ///
746    /// Default is currently disabled.
747    ///
748    /// # Cargo Feature
749    ///
750    pub fn keep_alive_interval(&mut self, interval: impl Into<Option<Duration>>) -> &mut Self {
751        self.inner.http2.keep_alive_interval(interval);
752        self
753    }
754
755    /// Sets a timeout for receiving an acknowledgement of the keep-alive ping.
756    ///
757    /// If the ping is not acknowledged within the timeout, the connection will
758    /// be closed. Does nothing if `http2_keep_alive_interval` is disabled.
759    ///
760    /// Default is 20 seconds.
761    ///
762    /// # Cargo Feature
763    ///
764    pub fn keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self {
765        self.inner.http2.keep_alive_timeout(timeout);
766        self
767    }
768
769    /// Set the maximum write buffer size for each HTTP/2 stream.
770    ///
771    /// Default is currently ~400KB, but may change.
772    ///
773    /// # Panics
774    ///
775    /// The value must be no larger than `u32::MAX`.
776    pub fn max_send_buf_size(&mut self, max: usize) -> &mut Self {
777        self.inner.http2.max_send_buf_size(max);
778        self
779    }
780
781    /// Enables the [extended CONNECT protocol].
782    ///
783    /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4
784    pub fn enable_connect_protocol(&mut self) -> &mut Self {
785        self.inner.http2.enable_connect_protocol();
786        self
787    }
788
789    /// Sets the max size of received header frames.
790    ///
791    /// Default is currently ~16MB, but may change.
792    pub fn max_header_list_size(&mut self, max: u32) -> &mut Self {
793        self.inner.http2.max_header_list_size(max);
794        self
795    }
796
797    /// Set whether the `date` header should be included in HTTP responses.
798    ///
799    /// Note that including the `date` header is recommended by RFC 7231.
800    ///
801    /// Default is true.
802    pub fn auto_date_header(&mut self, enabled: bool) -> &mut Self {
803        self.inner.http2.auto_date_header(enabled);
804        self
805    }
806
807    /// Bind a connection together with a [`Service`].
808    pub async fn serve_connection<I, S>(&self, io: I, service: S) -> Result<()>
809    where
810        S: HttpService<Incoming>,
811        I: AsyncRead + AsyncWrite + Send + Unpin + 'static + Send + 'static,
812    {
813        self.inner.serve_connection(io, service).await
814    }
815
816    /// Bind a connection together with a [`Service`], with the ability to
817    /// handle HTTP upgrades. This requires that the IO object implements
818    /// `Send`.
819    pub fn serve_connection_with_upgrades<I, S>(
820        &self,
821        io: I,
822        service: S,
823    ) -> UpgradeableConnection<'_, I, S>
824    where
825        S: HttpService<Incoming>,
826        I: AsyncRead + AsyncWrite + Send + Unpin + 'static + Send + 'static,
827    {
828        self.inner.serve_connection_with_upgrades(io, service)
829    }
830}
831
832#[cfg(test)]
833mod tests {
834    use crate::client::conn::http1;
835    use crate::server::conn::auto;
836    use crate::service::RamaHttpService;
837    use crate::{body::Bytes, client};
838    use rama_core::Context;
839    use rama_core::error::BoxError;
840    use rama_core::rt::Executor;
841    use rama_core::service::service_fn;
842    use rama_http_types::dep::http_body::Body;
843    use rama_http_types::dep::http_body_util::{BodyExt, Empty};
844    use rama_http_types::{Request, Response};
845    use std::{convert::Infallible, net::SocketAddr, time::Duration};
846    use tokio::{
847        net::{TcpListener, TcpStream},
848        pin,
849    };
850
851    const BODY: &[u8] = b"Hello, world!";
852
853    #[test]
854    fn configuration() {
855        // One liner.
856        auto::Builder::new(Executor::new())
857            .http1()
858            .keep_alive(true)
859            .http2()
860            .keep_alive_interval(None);
861        //  .serve_connection(io, service);
862
863        // Using variable.
864        let mut builder = auto::Builder::new(Executor::new());
865
866        builder.http1().keep_alive(true);
867        builder.http2().keep_alive_interval(None);
868        // builder.serve_connection(io, service);
869    }
870
871    #[cfg(not(miri))]
872    #[tokio::test]
873    async fn http1() {
874        let addr = start_server(false, false).await;
875        let mut sender = connect_h1(addr).await;
876
877        let response = sender
878            .send_request(Request::new(Empty::<Bytes>::new()))
879            .await
880            .unwrap();
881
882        let body = response.into_body().collect().await.unwrap().to_bytes();
883
884        assert_eq!(body, BODY);
885    }
886
887    #[cfg(not(miri))]
888    #[tokio::test]
889    async fn http2() {
890        let addr = start_server(false, false).await;
891        let mut sender = connect_h2(addr).await;
892
893        let response = sender
894            .send_request(Request::new(Empty::<Bytes>::new()))
895            .await
896            .unwrap();
897
898        let body = response.into_body().collect().await.unwrap().to_bytes();
899
900        assert_eq!(body, BODY);
901    }
902
903    #[cfg(not(miri))]
904    #[tokio::test]
905    async fn http2_only() {
906        let addr = start_server(false, true).await;
907        let mut sender = connect_h2(addr).await;
908
909        let response = sender
910            .send_request(Request::new(Empty::<Bytes>::new()))
911            .await
912            .unwrap();
913
914        let body = response.into_body().collect().await.unwrap().to_bytes();
915
916        assert_eq!(body, BODY);
917    }
918
919    #[cfg(not(miri))]
920    #[tokio::test]
921    async fn http2_only_fail_if_client_is_http1() {
922        let addr = start_server(false, true).await;
923        let mut sender = connect_h1(addr).await;
924
925        let _ = sender
926            .send_request(Request::new(Empty::<Bytes>::new()))
927            .await
928            .expect_err("should fail");
929    }
930
931    #[cfg(not(miri))]
932    #[tokio::test]
933    async fn http1_only() {
934        let addr = start_server(true, false).await;
935        let mut sender = connect_h1(addr).await;
936
937        let response = sender
938            .send_request(Request::new(Empty::<Bytes>::new()))
939            .await
940            .unwrap();
941
942        let body = response.into_body().collect().await.unwrap().to_bytes();
943
944        assert_eq!(body, BODY);
945    }
946
947    #[cfg(not(miri))]
948    #[tokio::test]
949    async fn http1_only_fail_if_client_is_http2() {
950        let addr = start_server(true, false).await;
951        let mut sender = connect_h2(addr).await;
952
953        let _ = sender
954            .send_request(Request::new(Empty::<Bytes>::new()))
955            .await
956            .expect_err("should fail");
957    }
958
959    #[cfg(not(miri))]
960    #[tokio::test]
961    async fn graceful_shutdown() {
962        use rama_core::{Context, service::service_fn};
963
964        use crate::service::RamaHttpService;
965
966        let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
967            .await
968            .unwrap();
969
970        let listener_addr = listener.local_addr().unwrap();
971
972        // Spawn the task in background so that we can connect there
973        let listen_task = tokio::spawn(async move { listener.accept().await.unwrap() });
974        // Only connect a stream, do not send headers or anything
975        let _stream = TcpStream::connect(listener_addr).await.unwrap();
976
977        let (stream, _) = listen_task.await.unwrap();
978        let builder = auto::Builder::new(Executor::new());
979        let connection = builder.serve_connection(
980            stream,
981            RamaHttpService::new(Context::default(), service_fn(hello)),
982        );
983
984        pin!(connection);
985
986        connection.as_mut().graceful_shutdown();
987
988        let connection_error = tokio::time::timeout(Duration::from_millis(200), connection)
989            .await
990            .expect("Connection should have finished in a timely manner after graceful shutdown.")
991            .expect_err("Connection should have been interrupted.");
992
993        let connection_error = connection_error
994            .downcast_ref::<std::io::Error>()
995            .expect("The error should have been `std::io::Error`.");
996        assert_eq!(connection_error.kind(), std::io::ErrorKind::Interrupted);
997    }
998
999    async fn connect_h1<B>(addr: SocketAddr) -> client::conn::http1::SendRequest<B>
1000    where
1001        B: Body<Data: Send + 'static, Error: Into<BoxError>> + Send + 'static + Unpin,
1002    {
1003        let stream = TcpStream::connect(addr).await.unwrap();
1004        let (sender, connection) = http1::handshake(stream).await.unwrap();
1005
1006        tokio::spawn(connection);
1007
1008        sender
1009    }
1010
1011    async fn connect_h2<B>(addr: SocketAddr) -> client::conn::http2::SendRequest<B>
1012    where
1013        B: Body<Data: Send + 'static, Error: Into<BoxError>> + Send + 'static + Unpin,
1014    {
1015        let stream = TcpStream::connect(addr).await.unwrap();
1016        let (sender, connection) = client::conn::http2::Builder::new(Executor::new())
1017            .handshake(stream)
1018            .await
1019            .unwrap();
1020
1021        tokio::spawn(connection);
1022
1023        sender
1024    }
1025
1026    async fn start_server(h1_only: bool, h2_only: bool) -> SocketAddr {
1027        let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
1028        let listener = TcpListener::bind(addr).await.unwrap();
1029
1030        let local_addr = listener.local_addr().unwrap();
1031
1032        tokio::spawn(async move {
1033            loop {
1034                let (stream, _) = listener.accept().await.unwrap();
1035                tokio::spawn(async move {
1036                    let mut builder = auto::Builder::new(Executor::new());
1037                    if h1_only {
1038                        builder = builder.http1_only();
1039                        builder
1040                            .serve_connection(
1041                                stream,
1042                                RamaHttpService::new(Context::default(), service_fn(hello)),
1043                            )
1044                            .await
1045                    } else if h2_only {
1046                        builder = builder.http2_only();
1047                        builder
1048                            .serve_connection(
1049                                stream,
1050                                RamaHttpService::new(Context::default(), service_fn(hello)),
1051                            )
1052                            .await
1053                    } else {
1054                        builder
1055                            .http2()
1056                            .max_header_list_size(4096)
1057                            .serve_connection_with_upgrades(
1058                                stream,
1059                                RamaHttpService::new(Context::default(), service_fn(hello)),
1060                            )
1061                            .await
1062                    }
1063                    .unwrap();
1064                });
1065            }
1066        });
1067
1068        local_addr
1069    }
1070
1071    async fn hello(_req: Request) -> Result<Response, Infallible> {
1072        Ok(Response::new(rama_http_types::Body::from(BODY)))
1073    }
1074}