Skip to main content

trillium_proxy/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(
3    clippy::dbg_macro,
4    missing_copy_implementations,
5    rustdoc::missing_crate_level_docs,
6    missing_debug_implementations,
7    missing_docs,
8    nonstandard_style,
9    unused_qualifications
10)]
11
12//! http reverse and forward proxy trillium handler
13
14#[cfg(test)]
15#[doc = include_str!("../README.md")]
16mod readme {}
17
18mod body_streamer;
19mod forward_proxy_connect;
20pub mod upstream;
21
22use body_streamer::stream_body;
23pub use forward_proxy_connect::ForwardProxyConnect;
24use full_duplex_async_copy::full_duplex_copy;
25use futures_lite::future::zip;
26use size::{Base, Size};
27use std::{borrow::Cow, fmt::Debug, future::IntoFuture};
28use trillium::{
29    Conn, Handler, KnownHeaderName,
30    Status::{NotFound, SwitchingProtocols},
31    Upgrade,
32};
33use trillium_client::ConnExt as _;
34pub use trillium_client::{Client, Connector};
35use trillium_forwarding::Forwarded;
36use trillium_http::{HeaderName, Headers, HttpContext, Status, Version};
37use upstream::{IntoUpstreamSelector, UpstreamSelector};
38pub use url::Url;
39
40/// constructs a new [`Proxy`]. alias of [`Proxy::new`]
41pub fn proxy<I>(client: impl Into<Client>, upstream: I) -> Proxy<I::UpstreamSelector>
42where
43    I: IntoUpstreamSelector,
44{
45    Proxy::new(client, upstream)
46}
47
48/// the proxy handler
49#[derive(Debug)]
50pub struct Proxy<U> {
51    upstream: U,
52    client: Client,
53    pass_through_not_found: bool,
54    halt: bool,
55    via_pseudonym: Option<Cow<'static, str>>,
56    allow_websocket_upgrade: bool,
57}
58
59impl<U: UpstreamSelector> Proxy<U> {
60    /// construct a new proxy handler that sends all requests to the upstream
61    /// provided
62    ///
63    /// ```
64    /// use trillium_proxy::Proxy;
65    /// use trillium_smol::ClientConfig;
66    ///
67    /// let proxy = Proxy::new(
68    ///     ClientConfig::default(),
69    ///     "http://docs.trillium.rs/trillium_proxy",
70    /// );
71    /// ```
72    pub fn new<I>(client: impl Into<Client>, upstream: I) -> Self
73    where
74        I: IntoUpstreamSelector<UpstreamSelector = U>,
75    {
76        let client = client
77            .into()
78            .without_default_header(KnownHeaderName::UserAgent)
79            .without_default_header(KnownHeaderName::Accept);
80
81        Self {
82            upstream: upstream.into_upstream(),
83            client,
84            pass_through_not_found: true,
85            halt: true,
86            via_pseudonym: None,
87            allow_websocket_upgrade: false,
88        }
89    }
90
91    /// chainable constructor to set the 404 Not Found handling
92    /// behavior. By default, this proxy will pass through the trillium
93    /// Conn unmodified if the proxy response is a 404 not found, allowing
94    /// it to be chained in a tuple handler. To modify this behavior, call
95    /// proxy_not_found, and the full 404 response will be forwarded. The
96    /// Conn will be halted unless [`Proxy::without_halting`] was
97    /// configured
98    ///
99    /// ```
100    /// # use trillium_smol::ClientConfig;
101    /// # use trillium_proxy::Proxy;
102    /// let proxy = Proxy::new(ClientConfig::default(), "http://trillium.rs").proxy_not_found();
103    /// ```
104    pub fn proxy_not_found(mut self) -> Self {
105        self.pass_through_not_found = false;
106        self
107    }
108
109    /// The default behavior for this handler is to halt the conn on any
110    /// response other than a 404. If [`Proxy::proxy_not_found`] has been
111    /// configured, the default behavior for all response statuses is to
112    /// halt the trillium conn. To change this behavior, call
113    /// without_halting when constructing the proxy, and it will not halt
114    /// the conn. This is useful when passing the proxy reply through
115    /// [`trillium_html_rewriter`](https://docs.trillium.rs/trillium_html_rewriter).
116    ///
117    /// ```
118    /// # use trillium_smol::ClientConfig;
119    /// # use trillium_proxy::Proxy;
120    /// let proxy = Proxy::new(ClientConfig::default(), "http://trillium.rs").without_halting();
121    /// ```
122    pub fn without_halting(mut self) -> Self {
123        self.halt = false;
124        self
125    }
126
127    /// populate the pseudonym for a
128    /// [`Via`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Via)
129    /// header. If no pseudonym is provided, no via header will be
130    /// inserted.
131    pub fn with_via_pseudonym(mut self, via_pseudonym: impl Into<Cow<'static, str>>) -> Self {
132        self.via_pseudonym = Some(via_pseudonym.into());
133        self
134    }
135
136    /// Allow websockets to be proxied
137    ///
138    /// This is not currently the default, but that may change at some (semver-minor) point in the
139    /// future
140    pub fn with_websocket_upgrades(mut self) -> Self {
141        self.allow_websocket_upgrade = true;
142        self
143    }
144
145    fn set_via_pseudonym(&self, headers: &mut Headers, version: Version) {
146        if self.via_pseudonym.is_none() {
147            return;
148        }
149
150        use std::fmt::Write;
151        let mut via = String::new();
152        let _ = write!(&mut via, "{version}");
153
154        if let Some(pseudonym) = &self.via_pseudonym {
155            let _ = write!(&mut via, " {pseudonym}");
156        }
157
158        if let Some(old_via) = headers.get_values(KnownHeaderName::Via) {
159            for old_via in old_via {
160                let _ = write!(&mut via, ", {old_via}");
161            }
162        }
163
164        headers.insert(KnownHeaderName::Via, via);
165    }
166}
167
168#[derive(Debug)]
169struct UpstreamUpgrade(Upgrade);
170
171impl<U: UpstreamSelector> Handler for Proxy<U> {
172    async fn init(&mut self, info: &mut trillium::Info) {
173        // this little dance is necessary to set the swansong on the client currently.
174        // this is only necessary because we're not wiring together the client.
175        let old_context = self.client.context();
176        let new_context = HttpContext::default()
177            .with_config(*old_context.config())
178            .with_swansong(info.swansong().clone());
179        self.client.set_context(new_context);
180        log::info!("proxying to {:?}", self.upstream);
181    }
182
183    async fn run(&self, mut conn: Conn) -> Conn {
184        let Some(request_url) = self.upstream.determine_upstream(&mut conn) else {
185            return conn;
186        };
187
188        log::debug!("proxying to {}", request_url.as_str());
189
190        let mut forwarded = Forwarded::from_headers(conn.request_headers())
191            .ok()
192            .flatten()
193            .unwrap_or_default()
194            .into_owned();
195
196        if let Some(peer_ip) = conn.peer_ip() {
197            forwarded.add_for(peer_ip.to_string());
198        };
199
200        if let Some(host) = conn.host() {
201            forwarded.set_host(host);
202        }
203
204        let mut request_headers = conn
205            .request_headers()
206            .clone()
207            .without_headers([
208                KnownHeaderName::Connection,
209                KnownHeaderName::KeepAlive,
210                KnownHeaderName::ProxyAuthenticate,
211                KnownHeaderName::ProxyAuthorization,
212                KnownHeaderName::Te,
213                KnownHeaderName::Trailer,
214                KnownHeaderName::TransferEncoding,
215                KnownHeaderName::Upgrade,
216                KnownHeaderName::Host,
217                KnownHeaderName::XforwardedBy,
218                KnownHeaderName::XforwardedFor,
219                KnownHeaderName::XforwardedHost,
220                KnownHeaderName::XforwardedProto,
221                KnownHeaderName::XforwardedSsl,
222                KnownHeaderName::AltUsed,
223            ])
224            .with_inserted_header(KnownHeaderName::Forwarded, forwarded.to_string());
225
226        let mut connection_is_upgrade = false;
227        for header in conn
228            .request_headers()
229            .get_str(KnownHeaderName::Connection)
230            .unwrap_or_default()
231            .split(',')
232            .map(|h| HeaderName::from(h.trim()))
233        {
234            if header == KnownHeaderName::Upgrade {
235                connection_is_upgrade = true;
236            }
237            request_headers.remove(header);
238        }
239
240        if self.allow_websocket_upgrade
241            && connection_is_upgrade
242            && conn
243                .request_headers()
244                .eq_ignore_ascii_case(KnownHeaderName::Upgrade, "websocket")
245        {
246            request_headers.extend([
247                (KnownHeaderName::Upgrade, "WebSocket"),
248                (KnownHeaderName::Connection, "Upgrade"),
249            ]);
250        }
251
252        self.set_via_pseudonym(&mut request_headers, conn.http_version());
253
254        let content_length = !matches!(
255            conn.request_headers()
256                .get_str(KnownHeaderName::ContentLength),
257            Some("0") | None
258        );
259
260        let chunked = conn
261            .request_headers()
262            .eq_ignore_ascii_case(KnownHeaderName::TransferEncoding, "chunked");
263
264        let method = conn.method();
265        let conn_result = if chunked || content_length {
266            let (body_fut, request_body) = stream_body(&mut conn);
267
268            let client_fut = self
269                .client
270                .build_conn(method, request_url)
271                .with_request_headers(request_headers)
272                .with_body(request_body)
273                .into_future();
274
275            zip(body_fut, client_fut).await.1
276        } else {
277            self.client
278                .build_conn(method, request_url)
279                .with_request_headers(request_headers)
280                .await
281        };
282
283        let mut client_conn = match conn_result {
284            Ok(client_conn) => client_conn,
285            Err(e) => {
286                return conn
287                    .with_status(Status::ServiceUnavailable)
288                    .halt()
289                    .with_state(e);
290            }
291        };
292
293        let client_conn_version = client_conn.http_version();
294
295        let mut conn = match client_conn.status() {
296            Some(SwitchingProtocols) => {
297                conn.response_headers_mut()
298                    .extend(std::mem::take(client_conn.response_headers_mut()));
299
300                conn.with_state(UpstreamUpgrade(
301                    trillium_http::Upgrade::from(client_conn).into(),
302                ))
303                .with_status(SwitchingProtocols)
304            }
305
306            Some(NotFound) if self.pass_through_not_found => {
307                client_conn.recycle().await;
308                return conn;
309            }
310
311            Some(status) => {
312                conn.response_headers_mut().remove(KnownHeaderName::Server);
313                conn.response_headers_mut()
314                    .append_all(client_conn.response_headers().clone());
315                conn.with_body(client_conn).with_status(status)
316            }
317
318            None => return conn.with_status(Status::ServiceUnavailable).halt(),
319        };
320
321        if Some(SwitchingProtocols) != conn.status()
322            || !conn
323                .response_headers()
324                .eq_ignore_ascii_case(KnownHeaderName::Connection, "Upgrade")
325        {
326            let connection = conn
327                .response_headers_mut()
328                .remove(KnownHeaderName::Connection);
329
330            conn.response_headers_mut().remove_all(
331                connection
332                    .iter()
333                    .flatten()
334                    .filter_map(|s| s.as_str())
335                    .flat_map(|s| s.split(','))
336                    .map(|t| HeaderName::from(t.trim()).into_owned()),
337            );
338        }
339
340        conn.response_headers_mut().remove_all([
341            KnownHeaderName::KeepAlive,
342            KnownHeaderName::ProxyAuthenticate,
343            KnownHeaderName::ProxyAuthorization,
344            KnownHeaderName::Te,
345            KnownHeaderName::Trailer,
346            KnownHeaderName::TransferEncoding,
347        ]);
348
349        self.set_via_pseudonym(conn.response_headers_mut(), client_conn_version);
350
351        if self.halt { conn.halt() } else { conn }
352    }
353
354    fn has_upgrade(&self, upgrade: &Upgrade) -> bool {
355        upgrade.state().contains::<UpstreamUpgrade>()
356    }
357
358    async fn upgrade(&self, mut upgrade: Upgrade) {
359        let Some(UpstreamUpgrade(upstream)) = upgrade.state_mut().take() else {
360            return;
361        };
362        let downstream = upgrade;
363        match full_duplex_copy(upstream, downstream).await {
364            Err(e) => log::error!("upgrade stream error: {:?}", e),
365            Ok((up, down)) => {
366                log::debug!("streamed upgrade {} up and {} down", bytes(up), bytes(down))
367            }
368        }
369    }
370}
371
372fn bytes(bytes: u64) -> String {
373    Size::from_bytes(bytes)
374        .format()
375        .with_base(Base::Base10)
376        .to_string()
377}