Skip to main content

trillium_http/
conn.rs

1use crate::{
2    Body, Buffer, Headers, HttpContext,
3    KnownHeaderName::Host,
4    Method, ProtocolSession, ReceivedBody, Status, Swansong, TypeSet, Version,
5    after_send::{AfterSend, SendStatus},
6    h2::H2Connection,
7    h3::H3Connection,
8    liveness::{CancelOnDisconnect, LivenessFut},
9    received_body::ReceivedBodyState,
10    util::encoding,
11};
12use encoding_rs::Encoding;
13use futures_lite::{
14    future,
15    io::{AsyncRead, AsyncWrite},
16};
17use std::{
18    borrow::Cow,
19    fmt::{self, Debug, Formatter},
20    future::Future,
21    net::IpAddr,
22    pin::pin,
23    str,
24    sync::Arc,
25    time::Instant,
26};
27mod h1;
28mod h2;
29mod h3;
30mod shared;
31pub(crate) use h1::{HeadError, write_headers_or_trailers};
32pub(crate) use h3::{H3FirstFrame, encode_field_section_h3};
33pub(crate) use shared::ConnParts;
34
35/// An HTTP connection.
36///
37/// This struct represents both the request and the response, and holds the
38/// transport over which the response will be sent.
39#[derive(fieldwork::Fieldwork)]
40pub struct Conn<Transport> {
41    #[field(get)]
42    /// the shared [`HttpContext`]
43    pub(crate) context: Arc<HttpContext>,
44
45    /// request [headers](Headers)
46    #[field(get, get_mut)]
47    pub(crate) request_headers: Headers,
48
49    /// response [headers](Headers)
50    #[field(get, get_mut)]
51    pub(crate) response_headers: Headers,
52
53    pub(crate) path: Cow<'static, str>,
54
55    /// the http method for this conn's request
56    ///
57    /// ```
58    /// # use trillium_http::{Conn, Method};
59    /// let mut conn = Conn::new_synthetic(Method::Get, "/some/path?and&a=query", ());
60    /// assert_eq!(conn.method(), Method::Get);
61    /// ```
62    #[field(get, set, copy)]
63    pub(crate) method: Method,
64
65    /// the http status for this conn, if set
66    #[field(get, copy)]
67    pub(crate) status: Option<Status>,
68
69    /// The HTTP protocol version in use on this connection.
70    ///
71    /// ```
72    /// # use trillium_http::{Conn, Method, Version};
73    /// let conn = Conn::new_synthetic(Method::Get, "/", ());
74    /// assert_eq!(conn.http_version(), Version::Http1_1);
75    /// ```
76    #[field(get = http_version, copy)]
77    pub(crate) version: Version,
78
79    /// the [state typemap](TypeSet) for this conn
80    #[field(get, get_mut)]
81    pub(crate) state: TypeSet,
82
83    /// the response [body](Body)
84    ///
85    /// ```
86    /// # use trillium_testing::HttpTest;
87    /// HttpTest::new(|conn| async move { conn.with_response_body("hello") })
88    ///     .get("/")
89    ///     .block()
90    ///     .assert_body("hello");
91    ///
92    /// HttpTest::new(|conn| async move { conn.with_response_body(String::from("world")) })
93    ///     .get("/")
94    ///     .block()
95    ///     .assert_body("world");
96    ///
97    /// HttpTest::new(|conn| async move { conn.with_response_body(vec![99, 97, 116]) })
98    ///     .get("/")
99    ///     .block()
100    ///     .assert_body("cat");
101    /// ```
102    #[field(get, set, into, option_set_some, take, with)]
103    pub(crate) response_body: Option<Body>,
104
105    /// the transport
106    ///
107    /// This should only be used to call your own custom methods on the transport that do not read
108    /// or write any data. Calling any method that reads from or writes to the transport will
109    /// disrupt the HTTP protocol. If you're looking to transition from HTTP to another protocol,
110    /// use an HTTP upgrade.
111    #[field(get, get_mut)]
112    pub(crate) transport: Transport,
113
114    pub(crate) buffer: Buffer,
115
116    pub(crate) request_body_state: ReceivedBodyState,
117
118    pub(crate) after_send: AfterSend,
119
120    /// whether the connection is secure
121    ///
122    /// note that this does not necessarily indicate that the transport itself is secure, as it may
123    /// indicate that `trillium_http` is behind a trusted reverse proxy that has terminated tls and
124    /// provided appropriate headers to indicate this.
125    #[field(get, set, rename_predicates)]
126    pub(crate) secure: bool,
127
128    /// The [`Instant`] that the first header bytes for this conn were
129    /// received, before any processing or parsing has been performed.
130    #[field(get, copy)]
131    pub(crate) start_time: Instant,
132
133    /// The IP Address for the connection, if available
134    #[field(set, get, copy, into)]
135    pub(crate) peer_ip: Option<IpAddr>,
136
137    /// the `:authority` pseudo-header
138    #[field(set, get, into)]
139    pub(crate) authority: Option<Cow<'static, str>>,
140
141    /// the `:scheme` pseudo-header
142    #[field(set, get, into)]
143    pub(crate) scheme: Option<Cow<'static, str>>,
144
145    /// the [`ProtocolSession`] for this conn — the per-protocol session state
146    /// (h2/h3 connection driver and stream id) bundled into a single enum so the
147    /// "set together" invariant is enforced at the type level. `Http1` for
148    /// h1 / synthetic conns.
149    pub(crate) protocol_session: ProtocolSession,
150
151    /// the `:protocol` pseudo-header (extended CONNECT)
152    #[field(set, get, into)]
153    pub(crate) protocol: Option<Cow<'static, str>>,
154
155    /// request trailers, populated after the request body has been fully read
156    #[field(get, get_mut)]
157    pub(crate) request_trailers: Option<Headers>,
158
159    /// Marker set via [`Conn::upgrade`].
160    pub(crate) upgrade: bool,
161}
162
163impl<Transport> Debug for Conn<Transport> {
164    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
165        f.debug_struct("Conn")
166            .field("context", &self.context)
167            .field("request_headers", &self.request_headers)
168            .field("response_headers", &self.response_headers)
169            .field("path", &self.path)
170            .field("method", &self.method)
171            .field("status", &self.status)
172            .field("version", &self.version)
173            .field("state", &self.state)
174            .field("response_body", &self.response_body)
175            .field("transport", &format_args!(".."))
176            .field("buffer", &format_args!(".."))
177            .field("request_body_state", &self.request_body_state)
178            .field("secure", &self.secure)
179            .field("after_send", &format_args!(".."))
180            .field("start_time", &self.start_time)
181            .field("peer_ip", &self.peer_ip)
182            .field("authority", &self.authority)
183            .field("scheme", &self.scheme)
184            .field("protocol", &self.protocol)
185            .field("protocol_session", &self.protocol_session)
186            .field("request_trailers", &self.request_trailers)
187            .field("upgrade", &self.upgrade)
188            .finish()
189    }
190}
191
192impl<Transport> Conn<Transport>
193where
194    Transport: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
195{
196    /// Returns the shared state typemap for this conn.
197    pub fn shared_state(&self) -> &TypeSet {
198        &self.context.shared_state
199    }
200
201    /// sets the http status code from any `TryInto<Status>`.
202    ///
203    /// ```
204    /// # use trillium_http::Status;
205    /// # trillium_testing::HttpTest::new(|mut conn| async move {
206    /// assert!(conn.status().is_none());
207    ///
208    /// conn.set_status(200); // a status can be set as a u16
209    /// assert_eq!(conn.status().unwrap(), Status::Ok);
210    ///
211    /// conn.set_status(Status::ImATeapot); // or as a Status
212    /// assert_eq!(conn.status().unwrap(), Status::ImATeapot);
213    /// conn
214    /// # }).get("/").block().assert_status(Status::ImATeapot);
215    /// ```
216    pub fn set_status(&mut self, status: impl TryInto<Status>) -> &mut Self {
217        self.status = Some(status.try_into().unwrap_or_else(|_| {
218            log::error!("attempted to set an invalid status code");
219            Status::InternalServerError
220        }));
221        self
222    }
223
224    /// sets the http status code from any `TryInto<Status>`, returning Conn
225    #[must_use]
226    pub fn with_status(mut self, status: impl TryInto<Status>) -> Self {
227        self.set_status(status);
228        self
229    }
230
231    /// The status to send on the wire: the explicitly-set status, or a
232    /// method-appropriate default when a handler left it unset. Unhandled
233    /// requests default to `404 Not Found`, except CONNECT, which defaults to
234    /// `501 Not Implemented`: an origin server implements no tunnel, and 404's
235    /// resource model does not apply to CONNECT's authority-form target.
236    pub(crate) fn response_status(&self) -> Status {
237        self.status.unwrap_or(match self.method {
238            Method::Connect => Status::NotImplemented,
239            _ => Status::NotFound,
240        })
241    }
242
243    /// retrieves the path part of the request url, up to and excluding any query component
244    /// ```
245    /// # use trillium_testing::HttpTest;
246    /// HttpTest::new(|mut conn| async move {
247    ///     assert_eq!(conn.path(), "/some/path");
248    ///     conn.with_status(200)
249    /// })
250    /// .get("/some/path?and&a=query")
251    /// .block()
252    /// .assert_ok();
253    /// ```
254    pub fn path(&self) -> &str {
255        match self.path.split_once('?') {
256            Some((path, _)) => path,
257            None => &self.path,
258        }
259    }
260
261    /// retrieves the combined path and any query
262    pub fn path_and_query(&self) -> &str {
263        &self.path
264    }
265
266    /// retrieves the query component of the path, or an empty &str
267    ///
268    /// ```
269    /// # use trillium_testing::HttpTest;
270    /// let server = HttpTest::new(|conn| async move {
271    ///     let querystring = conn.querystring().to_string();
272    ///     conn.with_response_body(querystring).with_status(200)
273    /// });
274    ///
275    /// server
276    ///     .get("/some/path?and&a=query")
277    ///     .block()
278    ///     .assert_body("and&a=query");
279    ///
280    /// server.get("/some/path").block().assert_body("");
281    /// ```
282    pub fn querystring(&self) -> &str {
283        self.path
284            .split_once('?')
285            .map(|(_, query)| query)
286            .unwrap_or_default()
287    }
288
289    /// get the host for this conn, if it exists
290    pub fn host(&self) -> Option<&str> {
291        self.request_headers.get_str(Host)
292    }
293
294    /// set the host for this conn
295    pub fn set_host(&mut self, host: String) -> &mut Self {
296        self.request_headers.insert(Host, host);
297        self
298    }
299
300    /// Cancels and drops the future if reading from the transport results in an error or empty read
301    ///
302    /// The use of this method is not advised if your connected http client employs pipelining
303    /// (rarely seen in the wild), as it will buffer an unbounded number of requests one byte at a
304    /// time
305    ///
306    /// If the client disconnects from the conn's transport, this function will return None. If the
307    /// future completes without disconnection, this future will return Some containing the output
308    /// of the future.
309    ///
310    /// Note that the inner future cannot borrow conn, so you will need to clone or take any
311    /// information needed to execute the future prior to executing this method.
312    ///
313    /// # Example
314    ///
315    /// ```rust
316    /// # use futures_lite::{AsyncRead, AsyncWrite};
317    /// # use trillium_http::{Conn, Method};
318    /// async fn something_slow_and_cancel_safe() -> String {
319    ///     String::from("this was not actually slow")
320    /// }
321    /// async fn handler<T>(mut conn: Conn<T>) -> Conn<T>
322    /// where
323    ///     T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
324    /// {
325    ///     let Some(returned_body) = conn
326    ///         .cancel_on_disconnect(async { something_slow_and_cancel_safe().await })
327    ///         .await
328    ///     else {
329    ///         return conn;
330    ///     };
331    ///     conn.with_response_body(returned_body).with_status(200)
332    /// }
333    /// ```
334    pub async fn cancel_on_disconnect<'a, Fut>(&'a mut self, fut: Fut) -> Option<Fut::Output>
335    where
336        Fut: Future + Send + 'a,
337    {
338        CancelOnDisconnect(self, pin!(fut)).await
339    }
340
341    /// Check if the transport is connected by attempting to read from the transport
342    ///
343    /// # Example
344    ///
345    /// This is best to use at appropriate points in a long-running handler, like:
346    ///
347    /// ```rust
348    /// # use futures_lite::{AsyncRead, AsyncWrite};
349    /// # use trillium_http::{Conn, Method};
350    /// # async fn something_slow_but_not_cancel_safe() {}
351    /// async fn handler<T>(mut conn: Conn<T>) -> Conn<T>
352    /// where
353    ///     T: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
354    /// {
355    ///     for _ in 0..100 {
356    ///         if conn.is_disconnected().await {
357    ///             return conn;
358    ///         }
359    ///         something_slow_but_not_cancel_safe().await;
360    ///     }
361    ///     conn.with_status(200)
362    /// }
363    /// ```
364    pub async fn is_disconnected(&mut self) -> bool {
365        future::poll_once(LivenessFut::new(self)).await.is_some()
366    }
367
368    /// returns the [`encoding_rs::Encoding`] for this request, as determined from the mime-type
369    /// charset, if available
370    ///
371    /// ```
372    /// # use trillium_testing::HttpTest;
373    /// HttpTest::new(|mut conn| async move {
374    ///     assert_eq!(conn.request_encoding(), encoding_rs::WINDOWS_1252); // the default
375    ///
376    ///     conn.request_headers_mut()
377    ///         .insert("content-type", "text/plain;charset=utf-16");
378    ///     assert_eq!(conn.request_encoding(), encoding_rs::UTF_16LE);
379    ///
380    ///     conn.with_status(200)
381    /// })
382    /// .get("/")
383    /// .block()
384    /// .assert_ok();
385    /// ```
386    pub fn request_encoding(&self) -> &'static Encoding {
387        encoding(&self.request_headers)
388    }
389
390    /// returns the [`encoding_rs::Encoding`] for this response, as
391    /// determined from the mime-type charset, if available
392    ///
393    /// ```
394    /// # use trillium_testing::HttpTest;
395    /// HttpTest::new(|mut conn| async move {
396    ///     assert_eq!(conn.response_encoding(), encoding_rs::WINDOWS_1252); // the default
397    ///     conn.response_headers_mut()
398    ///         .insert("content-type", "text/plain;charset=utf-16");
399    ///
400    ///     assert_eq!(conn.response_encoding(), encoding_rs::UTF_16LE);
401    ///
402    ///     conn.with_status(200)
403    /// })
404    /// .get("/")
405    /// .block()
406    /// .assert_ok();
407    /// ```
408    pub fn response_encoding(&self) -> &'static Encoding {
409        encoding(&self.response_headers)
410    }
411
412    /// returns a [`ReceivedBody`] that references this conn. the conn
413    /// retains all data and holds the singular transport, but the
414    /// `ReceivedBody` provides an interface to read body content.
415    ///
416    /// If the request included an `Expect: 100-continue` header, the 100 Continue response is sent
417    /// lazily on the first read from the returned [`ReceivedBody`].
418    /// ```
419    /// # use trillium_testing::HttpTest;
420    /// let server = HttpTest::new(|mut conn| async move {
421    ///     let request_body = conn.request_body();
422    ///     assert_eq!(request_body.content_length(), Some(5));
423    ///     assert_eq!(request_body.read_string().await.unwrap(), "hello");
424    ///     conn.with_status(200)
425    /// });
426    ///
427    /// server.post("/").with_body("hello").block().assert_ok();
428    /// ```
429    pub fn request_body(&mut self) -> ReceivedBody<'_, Transport> {
430        let needs_100_continue = self.needs_100_continue();
431        let body = self.build_request_body();
432        if needs_100_continue {
433            body.with_send_100_continue()
434        } else {
435            body
436        }
437    }
438
439    /// returns a clone of the [`swansong::Swansong`] for this Conn. use
440    /// this to gracefully stop long-running futures and streams
441    /// inside of handler functions
442    pub fn swansong(&self) -> Swansong {
443        self.protocol_session
444            .h3_connection()
445            .map_or_else(|| self.context.swansong.clone(), |h| h.swansong().clone())
446    }
447
448    /// Registers a function to call after the http response has been
449    /// completely transferred.
450    ///
451    /// The callback is guaranteed to fire **exactly once** before the conn is
452    /// dropped. Either the codec's send path invokes it with the real outcome,
453    /// or — if the conn is dropped before send completes (handler panic,
454    /// transport error, mid-write disconnect) — the drop fallback invokes it
455    /// with a `SendStatus` whose `is_success()` returns false. Multiple
456    /// registrations on the same conn chain in registration order.
457    ///
458    /// Because firing is ordered by send-completion rather than handler return,
459    /// this is the right hook for instrumentation that wants to report what the
460    /// peer actually observed.
461    ///
462    /// This is a sync function and should be computationally lightweight. If
463    /// your _application_ needs additional async processing, use your runtime's
464    /// task spawn within this hook. If your _library_ needs additional async
465    /// processing in an `after_send` hook, please open an issue.
466    pub fn after_send<F>(&mut self, after_send: F)
467    where
468        F: FnOnce(SendStatus) + Send + Sync + 'static,
469    {
470        self.after_send.append(after_send);
471    }
472
473    /// applies a mapping function from one transport to another. This
474    /// is particularly useful for boxing the transport. unless you're
475    /// sure this is what you're looking for, you probably don't want
476    /// to be using this
477    pub fn map_transport<NewTransport>(
478        self,
479        f: impl Fn(Transport) -> NewTransport,
480    ) -> Conn<NewTransport>
481    where
482        NewTransport: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
483    {
484        // Manual respread: rustc treats `Conn<Transport>` and `Conn<NewTransport>` as
485        // disjoint types and rejects `..self` without the unstable
486        // `type_changing_struct_update` feature. If a new field is added to `Conn`,
487        // update this respread, `Upgrade::map_transport`, and `From<Conn> for Upgrade`
488        // (`upgrade.rs`) — they share this drift hazard.
489        Conn {
490            context: self.context,
491            request_headers: self.request_headers,
492            response_headers: self.response_headers,
493            method: self.method,
494            response_body: self.response_body,
495            path: self.path,
496            status: self.status,
497            version: self.version,
498            state: self.state,
499            transport: f(self.transport),
500            buffer: self.buffer,
501            request_body_state: self.request_body_state,
502            secure: self.secure,
503            after_send: self.after_send,
504            start_time: self.start_time,
505            peer_ip: self.peer_ip,
506            authority: self.authority,
507            scheme: self.scheme,
508            protocol: self.protocol,
509            protocol_session: self.protocol_session,
510            request_trailers: self.request_trailers,
511            upgrade: self.upgrade,
512        }
513    }
514
515    /// whether this conn is suitable for an http upgrade to another protocol
516    pub fn should_upgrade(&self) -> bool {
517        self.upgrade
518            || (self.method() == Method::Connect && self.status == Some(Status::Ok))
519            || self.status == Some(Status::SwitchingProtocols)
520    }
521
522    /// Mark this conn to be handed off as an upgrade once the response headers are sent.
523    /// Set the response status (typically `200`) and any headers describing the upgraded
524    /// byte stream before calling; the handler's `upgrade` method receives an [`Upgrade`]
525    /// with per-protocol framing applied on its `AsyncRead`/`AsyncWrite`.
526    #[doc(hidden)]
527    #[must_use]
528    pub fn upgrade(mut self) -> Self {
529        self.upgrade = true;
530        self
531    }
532
533    #[doc(hidden)]
534    pub fn finalize_headers(&mut self) {
535        if self.version == Version::Http3 {
536            self.finalize_response_headers_h3();
537        } else {
538            self.finalize_response_headers_1x();
539        }
540    }
541
542    /// the [`H2Connection`] driver for this conn, if this is an HTTP/2 request
543    pub fn h2_connection(&self) -> Option<&Arc<H2Connection>> {
544        self.protocol_session.h2_connection()
545    }
546
547    /// the h2 stream id for this conn, if this is an HTTP/2 request
548    pub fn h2_stream_id(&self) -> Option<u32> {
549        self.protocol_session.h2_stream_id()
550    }
551
552    /// the [`H3Connection`] driver for this conn, if this is an HTTP/3 request
553    pub fn h3_connection(&self) -> Option<&Arc<H3Connection>> {
554        self.protocol_session.h3_connection()
555    }
556
557    /// the h3 stream id for this conn, if this is an HTTP/3 request
558    pub fn h3_stream_id(&self) -> Option<u64> {
559        self.protocol_session.h3_stream_id()
560    }
561}