Skip to main content

trillium_client/conn/
shared.rs

1use super::{Body, Conn, Transport, TypeSet};
2use crate::{ClientHandler, ConnExt, Error, Result, Version};
3use smallvec::SmallVec;
4#[cfg(feature = "hickory")]
5use std::net::IpAddr;
6use std::{
7    borrow::Cow,
8    fmt::{self, Debug, Formatter},
9    future::{Future, IntoFuture},
10    mem,
11    net::SocketAddr,
12    pin::Pin,
13};
14use trillium_http::{ProtocolSession, Upgrade};
15use trillium_server_common::Destination;
16
17/// A wrapper error for [`trillium_http::Error`] or, depending on json serializer feature, either
18/// `sonic_rs::Error` or `serde_json::Error`. Only available when either the `sonic-rs` or
19/// `serde_json` cargo features are enabled.
20#[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
21#[derive(thiserror::Error, Debug)]
22pub enum ClientSerdeError {
23    /// A [`trillium_http::Error`]
24    #[error(transparent)]
25    HttpError(#[from] Error),
26
27    #[cfg(feature = "sonic-rs")]
28    /// A [`sonic_rs::Error`]
29    #[error(transparent)]
30    JsonError(#[from] sonic_rs::Error),
31
32    #[cfg(feature = "serde_json")]
33    /// A [`serde_json::Error`]
34    #[error(transparent)]
35    JsonError(#[from] serde_json::Error),
36}
37
38impl Conn {
39    pub(crate) async fn exec(&mut self) -> Result<()> {
40        // A build-time error (e.g. a malformed url from `build_conn`) is fatal and
41        // unrecoverable: there is nothing to dial, so short-circuit before running
42        // handlers or touching the network.
43        if let Some(error) = self.error.take() {
44            return Err(error);
45        }
46
47        // Arc-clone to dodge conflict with the `&mut self` we pass to `run`.
48        let handler = self.client.arc_handler().clone();
49        handler.run(self).await?;
50
51        if !self.halted {
52            // Stash, don't return: `after_response` runs unconditionally so recovery handlers
53            // (stale-if-error, retry-with-fallback) get a chance to clear it.
54            if let Err(e) = self.exec_network().await {
55                self.error = Some(e);
56            }
57        } else {
58            log::trace!("conn is halted, skipping network round-trip");
59        }
60
61        // Reverse order, regardless of halt/error — mirrors server-side `before_send`.
62        handler.after_response(self).await?;
63
64        if let Some(e) = self.error.take() {
65            Err(e)
66        } else {
67            Ok(())
68        }
69    }
70
71    async fn exec_network(&mut self) -> Result<()> {
72        if self.http_version == Some(Version::Http0_9) {
73            return Err(Error::UnsupportedVersion(Version::Http0_9));
74        }
75
76        // Phase 1 — reuse a live pooled connection, best protocol first. No DNS, no new connect.
77        // A pooled h2 connection is reused in preference to establishing a new h3 connection: we
78        // do not proactively migrate h2→h3, since a general-purpose client can't assume the
79        // request locality that makes eager migration pay off (see the migration-policy backlog
80        // item). A pooled h1 connection, by contrast, does not block establishing h3 below.
81        if self.try_reuse_h3_pool().await? {
82            return Ok(());
83        }
84        if self.try_exec_h2_pooled().await? {
85            return Ok(());
86        }
87
88        // Phase 2/3 — establish a new connection, preferring h3 when the origin is known to speak
89        // it (pinned, Alt-Svc, or SVCB). This runs before the h1 path, so h1→h3 is immediate.
90        if self.try_establish_h3().await? {
91            return Ok(());
92        }
93
94        // Prior-knowledge h2: caller asserted h2, skip h1/ALPN. Useful for TLS connectors
95        // that don't expose `negotiated_alpn` (e.g. native-tls). No fallback — a non-h2
96        // server here surfaces as a plain IO error.
97        if self.http_version == Some(Version::Http2) {
98            return self.exec_h2_prior_knowledge().await;
99        }
100
101        self.exec_h1_or_promote_h2().await
102    }
103
104    pub(crate) fn body_len(&self) -> Option<u64> {
105        if let Some(ref body) = self.request_body {
106            body.len()
107        } else {
108            Some(0)
109        }
110    }
111
112    pub(crate) fn finalize_headers(&mut self) -> Result<()> {
113        match self.http_version() {
114            Version::Http1_0 | Version::Http1_1 => self.finalize_headers_h1(),
115            Version::Http2 => self.finalize_headers_h2(),
116            Version::Http3 if self.client.h3().is_some() => self.finalize_headers_h3(),
117            other => Err(Error::UnsupportedVersion(other)),
118        }
119    }
120
121    /// The [`Destination`] for connecting to this conn's origin over h1/h2: scheme, host, and port
122    /// from the URL, plus any DoH-resolved addresses. A bare-IP origin keeps the address
123    /// [`from_url`](Destination::from_url) derived and is never resolved.
124    ///
125    /// An explicit version pin constrains the connection's ALPN so the pin is honored over TLS: an
126    /// h1 pin advertises only `http/1.1` (a server that would otherwise negotiate `h2` falls back
127    /// to h1), an h2 pin advertises only `h2`. Without a pin the connector's configured default
128    /// ALPN is left in place, so auto-discovery can promote to h2 via ALPN.
129    pub(crate) async fn origin_destination(&self) -> Result<Destination> {
130        let mut destination = Destination::from_url(&self.url)?;
131        let addrs = self.origin_socket_addrs().await?;
132        if !addrs.is_empty() {
133            destination.set_addrs(addrs);
134        }
135        match self.http_version {
136            Some(Version::Http1_0 | Version::Http1_1) => {
137                destination.set_alpn([Cow::Borrowed(b"http/1.1".as_slice())]);
138            }
139            Some(Version::Http2) => {
140                destination.set_alpn([Cow::Borrowed(b"h2".as_slice())]);
141            }
142            _ => {}
143        }
144        Ok(destination)
145    }
146
147    /// Pre-resolved socket addresses for this conn's origin host:port, for the protocols that
148    /// always connect to the origin (h1/h2). Empty when DoH is not configured or the host is an IP
149    /// literal, so the connector falls back to its own (trivial, for an IP) resolution.
150    pub(crate) async fn origin_socket_addrs(&self) -> Result<SmallVec<[SocketAddr; 4]>> {
151        let Some(host) = self.url.host_str() else {
152            return Ok(SmallVec::new());
153        };
154        let port = self.url.port_or_known_default().unwrap_or(443);
155        self.resolve_socket_addrs(host, port).await
156    }
157}
158
159#[cfg(feature = "hickory")]
160impl Conn {
161    /// Resolve `host:port` through the configured DoH resolver, or `None` when DoH is not
162    /// configured (so the caller falls back to the connector's own resolution).
163    ///
164    /// The single place this conn touches DNS. The resolver reads and populates a shared, TTL'd
165    /// cache as a side effect, so repeated calls for the same host — across protocols, and across
166    /// the SVCB decision and the eventual connect — issue at most one set of queries.
167    ///
168    /// Fail-closed: once DoH is configured, a lookup the resolver can't answer fails the request
169    /// rather than falling back to the (possibly plaintext) system resolver.
170    ///
171    /// An IP-literal host is returned as `None` without touching the resolver — there is nothing to
172    /// look up, and no SVCB/HTTPS records exist for a bare address.
173    pub(crate) async fn resolve(
174        &self,
175        host: &str,
176        port: u16,
177    ) -> Result<Option<crate::dns::Resolved>> {
178        if host.parse::<IpAddr>().is_ok() {
179            return Ok(None);
180        }
181        match &self.client.resolver {
182            Some(resolver) => Ok(Some(
183                resolver
184                    .resolve(&self.client, host, port, self.timeout)
185                    .await?,
186            )),
187            None => Ok(None),
188        }
189    }
190
191    pub(crate) async fn resolve_socket_addrs(
192        &self,
193        host: &str,
194        port: u16,
195    ) -> Result<SmallVec<[SocketAddr; 4]>> {
196        Ok(self
197            .resolve(host, port)
198            .await?
199            .map(|resolved| resolved.socket_addrs(port))
200            .unwrap_or_default())
201    }
202}
203
204#[cfg(not(feature = "hickory"))]
205impl Conn {
206    pub(crate) async fn resolve_socket_addrs(
207        &self,
208        _host: &str,
209        _port: u16,
210    ) -> Result<SmallVec<[SocketAddr; 4]>> {
211        Ok(SmallVec::new())
212    }
213}
214
215impl Drop for Conn {
216    fn drop(&mut self) {
217        log::trace!("dropping client conn");
218        drop(self.take_response_body());
219    }
220}
221
222impl From<Conn> for Body {
223    fn from(mut conn: Conn) -> Body {
224        // body_override (e.g. cache hit, set via `set_response_body`) bypasses the transport;
225        // transport pooling is left to `Drop`.
226        if let Some(body) = conn.body_override.take() {
227            return body;
228        }
229
230        match conn.take_received_body(true) {
231            Some(rb) => rb.into(),
232            None => Body::default(),
233        }
234    }
235}
236
237impl From<Conn> for Upgrade<Box<dyn Transport>> {
238    /// Convert a client conn into a [`trillium_http::Upgrade`] after response headers
239    /// arrive, handing off the open transport for direct `AsyncRead` / `AsyncWrite`
240    /// exchange with per-protocol framing applied.
241    ///
242    /// # Panics
243    ///
244    /// Panics if the conn has no live transport (request not yet sent, or transport
245    /// already taken).
246    fn from(mut conn: Conn) -> Self {
247        // `Conn: Drop` rules out destructuring — pull each field with `mem::take` /
248        // `mem::replace`. New fields on `Conn` won't show up here automatically.
249        let path = conn.path.take().unwrap_or_else(|| match conn.url.query() {
250            Some(q) => Cow::Owned(format!("{}?{q}", conn.url.path())),
251            None => Cow::Owned(conn.url.path().to_owned()),
252        });
253        let secure = conn.url.scheme() == "https";
254
255        Upgrade::from_parts(
256            mem::take(&mut conn.response_headers),
257            mem::take(&mut conn.request_headers),
258            path,
259            conn.method,
260            conn.transport
261                .take()
262                .expect("client conn has no transport — request not yet sent"),
263            mem::take(&mut conn.buffer),
264            mem::take(&mut conn.state),
265            conn.context.clone(),
266            None,
267            conn.authority.take(),
268            conn.scheme.take(),
269            mem::replace(&mut conn.protocol_session, ProtocolSession::Http1),
270            conn.protocol.take(),
271            conn.http_version(),
272            conn.status,
273            secure,
274            // Client-side inbound = response body.
275            mem::take(&mut conn.response_body_state),
276            // Carry through any pre-upgrade-decoded trailers so a downstream reader
277            // can observe them.
278            conn.response_trailers.take(),
279        )
280    }
281}
282
283impl IntoFuture for Conn {
284    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'static>>;
285    type Output = Result<Conn>;
286
287    fn into_future(mut self) -> Self::IntoFuture {
288        Box::pin(async move { (&mut self).await.map(|()| self) })
289    }
290}
291
292impl<'conn> IntoFuture for &'conn mut Conn {
293    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'conn>>;
294    type Output = Result<()>;
295
296    fn into_future(self) -> Self::IntoFuture {
297        Box::pin(async move {
298            // Re-issuing handlers (FollowRedirects, retry, auth-refresh) queue a follow-up
299            // via `set_followup` in `after_response`; we recycle, swap, re-exec.
300            loop {
301                let result = if let Some(duration) = self.timeout {
302                    self.client
303                        .connector()
304                        .runtime()
305                        .timeout(duration, self.exec())
306                        .await
307                        .unwrap_or(Err(Error::TimedOut("Conn", duration)))
308                } else {
309                    self.exec().await
310                };
311
312                // `halted` is handler-internal; don't leak it out to the caller.
313                self.halted = false;
314
315                if let Err(e) = result {
316                    // Unrecovered error wins over any queued follow-up. Recovery handlers
317                    // that want the follow-up to run must `take_error()` in `after_response`.
318                    self.followup = None;
319                    return Err(e);
320                }
321
322                let Some(next) = self.take_followup() else {
323                    break;
324                };
325
326                if let Some(body) = self.take_response_body() {
327                    body.recycle().await;
328                }
329
330                let _displaced = mem::replace(self, next);
331            }
332            Ok(())
333        })
334    }
335}
336
337impl Debug for Conn {
338    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
339        f.debug_struct("Conn")
340            .field("authority", &self.authority)
341            .field("buffer", &String::from_utf8_lossy(&self.buffer))
342            .field("client", &self.client)
343            .field("protocol_session", &self.protocol_session)
344            .field("http_version", &self.http_version)
345            .field("method", &self.method)
346            .field("path", &self.path)
347            .field("request_body", &self.request_body)
348            .field("request_headers", &self.request_headers)
349            .field("request_target", &self.request_target)
350            .field("request_trailers", &self.request_trailers)
351            .field("response_body_state", &self.response_body_state)
352            .field("response_headers", &self.response_headers)
353            .field("response_trailers", &self.response_trailers)
354            .field("scheme", &self.scheme)
355            .field("state", &self.state)
356            .field("status", &self.status)
357            .field("url", &self.url)
358            .finish()
359    }
360}
361
362impl AsRef<TypeSet> for Conn {
363    fn as_ref(&self) -> &TypeSet {
364        &self.state
365    }
366}
367
368impl AsMut<TypeSet> for Conn {
369    fn as_mut(&mut self) -> &mut TypeSet {
370        &mut self.state
371    }
372}