Skip to main content

trillium_client/conn/
shared.rs

1use super::{Body, Conn, Transport, TypeSet};
2use crate::{ClientHandler, ConnExt, Error, Result, Version};
3use std::{
4    borrow::Cow,
5    fmt::{self, Debug, Formatter},
6    future::{Future, IntoFuture},
7    mem,
8    pin::Pin,
9};
10use trillium_http::{ProtocolSession, Upgrade};
11
12/// A wrapper error for [`trillium_http::Error`] or, depending on json serializer feature, either
13/// `sonic_rs::Error` or `serde_json::Error`. Only available when either the `sonic-rs` or
14/// `serde_json` cargo features are enabled.
15#[cfg(any(feature = "serde_json", feature = "sonic-rs"))]
16#[derive(thiserror::Error, Debug)]
17pub enum ClientSerdeError {
18    /// A [`trillium_http::Error`]
19    #[error(transparent)]
20    HttpError(#[from] Error),
21
22    #[cfg(feature = "sonic-rs")]
23    /// A [`sonic_rs::Error`]
24    #[error(transparent)]
25    JsonError(#[from] sonic_rs::Error),
26
27    #[cfg(feature = "serde_json")]
28    /// A [`serde_json::Error`]
29    #[error(transparent)]
30    JsonError(#[from] serde_json::Error),
31}
32
33impl Conn {
34    pub(crate) async fn exec(&mut self) -> Result<()> {
35        // Arc-clone to dodge conflict with the `&mut self` we pass to `run`.
36        let handler = self.client.arc_handler().clone();
37        handler.run(self).await?;
38
39        if !self.halted {
40            // Stash, don't return: `after_response` runs unconditionally so recovery handlers
41            // (stale-if-error, retry-with-fallback) get a chance to clear it.
42            if let Err(e) = self.exec_network().await {
43                self.error = Some(e);
44            }
45        } else {
46            log::trace!("conn is halted, skipping network round-trip");
47        }
48
49        // Reverse order, regardless of halt/error — mirrors server-side `before_send`.
50        handler.after_response(self).await?;
51
52        if let Some(e) = self.error.take() {
53            Err(e)
54        } else {
55            Ok(())
56        }
57    }
58
59    async fn exec_network(&mut self) -> Result<()> {
60        if matches!(self.http_version, Version::Http0_9) {
61            return Err(Error::UnsupportedVersion(self.http_version));
62        }
63
64        if self.try_exec_h3().await? {
65            return Ok(());
66        }
67        if self.try_exec_h2_pooled().await? {
68            return Ok(());
69        }
70
71        // Prior-knowledge h2: caller asserted h2, skip h1/ALPN. Useful for TLS connectors
72        // that don't expose `negotiated_alpn` (e.g. native-tls). No fallback — a non-h2
73        // server here surfaces as a plain IO error.
74        if self.http_version == Version::Http2 {
75            return self.exec_h2_prior_knowledge().await;
76        }
77
78        self.exec_h1_or_promote_h2().await
79    }
80
81    pub(crate) fn body_len(&self) -> Option<u64> {
82        if let Some(ref body) = self.request_body {
83            body.len()
84        } else {
85            Some(0)
86        }
87    }
88
89    pub(crate) fn finalize_headers(&mut self) -> Result<()> {
90        match self.http_version {
91            Version::Http1_0 | Version::Http1_1 => self.finalize_headers_h1(),
92            Version::Http2 => self.finalize_headers_h2(),
93            Version::Http3 if self.client.h3().is_some() => self.finalize_headers_h3(),
94            other => Err(Error::UnsupportedVersion(other)),
95        }
96    }
97}
98
99impl Drop for Conn {
100    fn drop(&mut self) {
101        log::trace!("dropping client conn");
102        drop(self.take_response_body());
103    }
104}
105
106impl From<Conn> for Body {
107    fn from(mut conn: Conn) -> Body {
108        // body_override (e.g. cache hit, set via `set_response_body`) bypasses the transport;
109        // transport pooling is left to `Drop`.
110        if let Some(body) = conn.body_override.take() {
111            return body;
112        }
113
114        match conn.take_received_body(true) {
115            Some(rb) => rb.into(),
116            None => Body::default(),
117        }
118    }
119}
120
121impl From<Conn> for Upgrade<Box<dyn Transport>> {
122    /// Convert a client conn into a [`trillium_http::Upgrade`] after response headers
123    /// arrive, handing off the open transport for direct `AsyncRead` / `AsyncWrite`
124    /// exchange with per-protocol framing applied.
125    ///
126    /// # Panics
127    ///
128    /// Panics if the conn has no live transport (request not yet sent, or transport
129    /// already taken).
130    fn from(mut conn: Conn) -> Self {
131        // `Conn: Drop` rules out destructuring — pull each field with `mem::take` /
132        // `mem::replace`. New fields on `Conn` won't show up here automatically.
133        let path = conn.path.take().unwrap_or_else(|| match conn.url.query() {
134            Some(q) => Cow::Owned(format!("{}?{q}", conn.url.path())),
135            None => Cow::Owned(conn.url.path().to_owned()),
136        });
137        let secure = conn.url.scheme() == "https";
138
139        Upgrade::from_parts(
140            mem::take(&mut conn.response_headers),
141            mem::take(&mut conn.request_headers),
142            path,
143            conn.method,
144            conn.transport
145                .take()
146                .expect("client conn has no transport — request not yet sent"),
147            mem::take(&mut conn.buffer),
148            mem::take(&mut conn.state),
149            conn.context.clone(),
150            None,
151            conn.authority.take(),
152            conn.scheme.take(),
153            mem::replace(&mut conn.protocol_session, ProtocolSession::Http1),
154            conn.protocol.take(),
155            conn.http_version,
156            conn.status,
157            secure,
158            // Client-side inbound = response body.
159            mem::take(&mut conn.response_body_state),
160            // Carry through any pre-upgrade-decoded trailers so a downstream reader
161            // can observe them.
162            conn.response_trailers.take(),
163        )
164    }
165}
166
167impl IntoFuture for Conn {
168    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'static>>;
169    type Output = Result<Conn>;
170
171    fn into_future(mut self) -> Self::IntoFuture {
172        Box::pin(async move { (&mut self).await.map(|()| self) })
173    }
174}
175
176impl<'conn> IntoFuture for &'conn mut Conn {
177    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'conn>>;
178    type Output = Result<()>;
179
180    fn into_future(self) -> Self::IntoFuture {
181        Box::pin(async move {
182            // Re-issuing handlers (FollowRedirects, retry, auth-refresh) queue a follow-up
183            // via `set_followup` in `after_response`; we recycle, swap, re-exec.
184            loop {
185                let result = if let Some(duration) = self.timeout {
186                    self.client
187                        .connector()
188                        .runtime()
189                        .timeout(duration, self.exec())
190                        .await
191                        .unwrap_or(Err(Error::TimedOut("Conn", duration)))
192                } else {
193                    self.exec().await
194                };
195
196                // `halted` is handler-internal; don't leak it out to the caller.
197                self.halted = false;
198
199                if let Err(e) = result {
200                    // Unrecovered error wins over any queued follow-up. Recovery handlers
201                    // that want the follow-up to run must `take_error()` in `after_response`.
202                    self.followup = None;
203                    return Err(e);
204                }
205
206                let Some(next) = self.take_followup() else {
207                    break;
208                };
209
210                if let Some(body) = self.take_response_body() {
211                    body.recycle().await;
212                }
213
214                let _displaced = mem::replace(self, next);
215            }
216            Ok(())
217        })
218    }
219}
220
221impl Debug for Conn {
222    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
223        f.debug_struct("Conn")
224            .field("authority", &self.authority)
225            .field("buffer", &String::from_utf8_lossy(&self.buffer))
226            .field("client", &self.client)
227            .field("protocol_session", &self.protocol_session)
228            .field("http_version", &self.http_version)
229            .field("method", &self.method)
230            .field("path", &self.path)
231            .field("request_body", &self.request_body)
232            .field("request_headers", &self.request_headers)
233            .field("request_target", &self.request_target)
234            .field("request_trailers", &self.request_trailers)
235            .field("response_body_state", &self.response_body_state)
236            .field("response_headers", &self.response_headers)
237            .field("response_trailers", &self.response_trailers)
238            .field("scheme", &self.scheme)
239            .field("state", &self.state)
240            .field("status", &self.status)
241            .field("url", &self.url)
242            .finish()
243    }
244}
245
246impl AsRef<TypeSet> for Conn {
247    fn as_ref(&self) -> &TypeSet {
248        &self.state
249    }
250}
251
252impl AsMut<TypeSet> for Conn {
253    fn as_mut(&mut self) -> &mut TypeSet {
254        &mut self.state
255    }
256}