hyproxy/
upgrade.rs

1//! Middleware to facilitate upgrades across reverse proxy boundaries.
2
3use core::fmt;
4use std::ops;
5use std::str::FromStr;
6
7use bytes::{BufMut, Bytes, BytesMut};
8use http::HeaderValue;
9use nom::bytes::complete::tag;
10use nom::combinator::{map, opt};
11use nom::multi::separated_list0;
12use nom::sequence::tuple;
13use nom::IResult;
14use thiserror::Error;
15
16use crate::headers::fields::Token;
17use crate::headers::parser::{strip_whitespace, token, NoTail as _};
18/// The `Upgrade` header field allows the sender to specify what protocols they would like to upgrade to.
19pub const UPGRADE: http::HeaderName = http::header::UPGRADE;
20
21/// Errors that can occur when parsing an upgrade protocol.
22#[derive(Debug, Error)]
23#[error("upgrade protocol error")]
24pub struct UpgradeProtocolError(nom::error::Error<Bytes>);
25
26impl From<nom::error::Error<Bytes>> for UpgradeProtocolError {
27    fn from(error: nom::error::Error<Bytes>) -> Self {
28        UpgradeProtocolError(error)
29    }
30}
31
32impl From<nom::error::Error<&[u8]>> for UpgradeProtocolError {
33    fn from(error: nom::error::Error<&[u8]>) -> Self {
34        UpgradeProtocolError(nom::error::Error::new(
35            Bytes::copy_from_slice(error.input),
36            error.code,
37        ))
38    }
39}
40
41fn protocol<'v>() -> impl FnMut(&'v [u8]) -> IResult<&'v [u8], UpgradeProtocol> {
42    let v = tuple((tag(b"/"), token()));
43    let version = opt(map(v, |(_, version)| version));
44
45    map(tuple((token(), version)), |(name, version)| {
46        UpgradeProtocol { name, version }
47    })
48}
49
50fn parse_upgrade_protocols(
51    value: &HeaderValue,
52) -> Result<Vec<UpgradeProtocol>, UpgradeProtocolError> {
53    separated_list0(tag(b","), strip_whitespace(protocol()))(value.as_bytes())
54        .no_tail()
55        .map_err(Into::into)
56}
57
58fn parse_connection_headers(value: &HeaderValue) -> Result<Vec<Token>, UpgradeProtocolError> {
59    separated_list0(tag(b","), strip_whitespace(token()))(value.as_bytes())
60        .no_tail()
61        .map_err(Into::into)
62}
63
64// Get upgrade state for the inbound request
65fn get_upgrade_request(headers: &http::HeaderMap) -> Result<UpgradeRequest, UpgradeProtocolError> {
66    if let Some(connection) = headers.get(http::header::CONNECTION) {
67        let connection_headers = parse_connection_headers(connection)?;
68        if connection_headers.contains(&Token::from_static("upgrade")) {
69            if let Some(upgrade) = headers.get(UPGRADE) {
70                tracing::trace!("Found upgrade header: {:?}", upgrade);
71                return parse_upgrade_protocols(upgrade)
72                    .map(|protocols| UpgradeRequest { protocols });
73            }
74        }
75    }
76
77    Ok(Default::default())
78}
79
80fn get_upgrade_response(headers: &http::HeaderMap) -> Option<UpgradeProtocol> {
81    match get_upgrade_request(headers) {
82        Ok(mut protocols) if protocols.len() == 1 => protocols.pop(),
83        _ => None,
84    }
85}
86
87/// A protocol that can be upgraded.
88#[derive(Clone)]
89pub struct UpgradeProtocol {
90    name: Token,
91    version: Option<Token>,
92}
93
94impl PartialEq for UpgradeProtocol {
95    fn eq(&self, other: &Self) -> bool {
96        if let Some((version, other_version)) = self.version().zip(other.version()) {
97            self.name.eq_ignore_ascii_case(&other.name)
98                && version.eq_ignore_ascii_case(other_version)
99        } else {
100            self.name.eq_ignore_ascii_case(&other.name)
101        }
102    }
103}
104
105impl fmt::Debug for UpgradeProtocol {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        let name = String::from_utf8_lossy(self.name.as_bytes());
108        write!(f, "UpgradeProtocol(")?;
109        match self.version {
110            Some(ref version) => write!(
111                f,
112                "{}/{}",
113                name,
114                String::from_utf8_lossy(version.as_bytes())
115            ),
116            None => write!(f, "{}", name),
117        }?;
118        write!(f, ")")
119    }
120}
121
122impl UpgradeProtocol {
123    /// The name of the protocol.
124    pub fn name(&self) -> &Token {
125        &self.name
126    }
127
128    /// The version of the protocol.
129    pub fn version(&self) -> Option<&Token> {
130        self.version.as_ref()
131    }
132
133    fn extend_buffer(&self, buffer: &mut BytesMut) {
134        buffer.extend_from_slice(self.name.as_bytes());
135        if let Some(version) = &self.version {
136            buffer.put_u8(b'/');
137            buffer.extend_from_slice(version.as_bytes());
138        }
139    }
140}
141
142impl FromStr for UpgradeProtocol {
143    type Err = UpgradeProtocolError;
144
145    fn from_str(value: &str) -> Result<Self, Self::Err> {
146        protocol()(value.as_bytes()).no_tail().map_err(Into::into)
147    }
148}
149
150/// A request to upgrade to one or more protocols.
151#[derive(Debug, Clone, Default)]
152pub struct UpgradeRequest {
153    protocols: Vec<UpgradeProtocol>,
154}
155
156impl UpgradeRequest {
157    /// Check if the request is expecting a particular protocol
158    pub fn matching(&self, protocol: &UpgradeProtocol) -> bool {
159        self.protocols.contains(protocol)
160    }
161
162    /// Add a protocol to the upgrade request
163    pub fn push(&mut self, protocol: UpgradeProtocol) {
164        self.protocols.push(protocol);
165    }
166
167    /// Convert the upgrade request to a header value
168    pub fn to_header_value(&self) -> HeaderValue {
169        let mut buf = BytesMut::new();
170
171        let mut iter = self.protocols.iter();
172        if let Some(protocol) = iter.next() {
173            protocol.extend_buffer(&mut buf);
174        }
175
176        for protocol in iter {
177            buf.put(&b", "[..]);
178            protocol.extend_buffer(&mut buf);
179        }
180
181        HeaderValue::from_bytes(&buf).unwrap()
182    }
183
184    fn pop(&mut self) -> Option<UpgradeProtocol> {
185        self.protocols.pop()
186    }
187}
188
189impl ops::Deref for UpgradeRequest {
190    type Target = [UpgradeProtocol];
191
192    fn deref(&self) -> &Self::Target {
193        &self.protocols
194    }
195}
196
197/// Layer to facilitate upgrades across reverse proxy boundaries.
198#[derive(Clone, Debug)]
199pub struct ProxyUpgradeLayer {
200    _priv: (),
201}
202
203impl Default for ProxyUpgradeLayer {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209impl ProxyUpgradeLayer {
210    /// Create a new `ProxyUpgradeLayer`.
211    pub fn new() -> Self {
212        Self { _priv: () }
213    }
214}
215
216impl<S> tower::layer::Layer<S> for ProxyUpgradeLayer {
217    type Service = ProxyUpgrade<S>;
218
219    fn layer(&self, inner: S) -> Self::Service {
220        ProxyUpgrade::new(inner)
221    }
222}
223
224/// Middleware to facilitate upgrades across reverse proxy boundaries.
225#[derive(Clone, Debug)]
226pub struct ProxyUpgrade<S> {
227    inner: S,
228}
229
230impl<S> ProxyUpgrade<S> {
231    /// Create a new `ProxyUpgrade` middleware.
232    pub fn new(inner: S) -> Self {
233        Self { inner }
234    }
235}
236
237impl<S, BIn, BOut> tower::Service<http::Request<BIn>> for ProxyUpgrade<S>
238where
239    S: tower::Service<http::Request<BIn>, Response = http::Response<BOut>>,
240{
241    type Response = S::Response;
242    type Error = S::Error;
243    type Future = self::future::UpgradableProxyFuture<S::Future>;
244
245    fn call(&mut self, mut request: http::Request<BIn>) -> Self::Future {
246        let upgrade = self::future::Upgrade::new(&mut request);
247        let inner = self.inner.call(request);
248        self::future::UpgradableProxyFuture::new(inner, upgrade)
249    }
250
251    fn poll_ready(
252        &mut self,
253        cx: &mut std::task::Context<'_>,
254    ) -> std::task::Poll<Result<(), Self::Error>> {
255        self.inner.poll_ready(cx)
256    }
257}
258
259mod future {
260
261    use std::task::ready;
262
263    use hyperdriver::bridge::io::TokioIo;
264    use tokio::io::copy_bidirectional;
265
266    use super::*;
267
268    #[derive(Debug)]
269    pub(super) struct Upgrade {
270        protocol: Option<UpgradeRequest>,
271        on: Option<hyper::upgrade::OnUpgrade>,
272    }
273
274    impl Upgrade {
275        pub(super) fn new<B>(request: &mut http::Request<B>) -> Self {
276            let protocol = get_upgrade_request(request.headers())
277                .map(Some)
278                .unwrap_or_else(|error| {
279                    tracing::error!("Unable to parse upgrade protocols from request: {error}");
280                    None
281                });
282
283            if let Some(protocol) = &protocol {
284                request.extensions_mut().insert(protocol.clone());
285            }
286
287            let on = hyper::upgrade::on(request);
288            Self {
289                protocol,
290                on: Some(on),
291            }
292        }
293    }
294
295    pin_project_lite::pin_project! {
296        pub struct UpgradableProxyFuture<F> {
297            #[pin]
298            inner: F,
299            request_upgrade: Upgrade,
300        }
301    }
302
303    impl<F> UpgradableProxyFuture<F> {
304        pub(super) fn new(inner: F, upgrade: Upgrade) -> Self {
305            Self {
306                inner,
307                request_upgrade: upgrade,
308            }
309        }
310    }
311
312    impl<F, BOut, E> std::future::Future for UpgradableProxyFuture<F>
313    where
314        F: std::future::Future<Output = Result<http::Response<BOut>, E>>,
315    {
316        type Output = Result<http::Response<BOut>, E>;
317
318        fn poll(
319            self: std::pin::Pin<&mut Self>,
320            cx: &mut std::task::Context<'_>,
321        ) -> std::task::Poll<Self::Output> {
322            let this = self.project();
323            let mut response = ready!(this.inner.poll(cx));
324
325            if let Ok(response) = &mut response {
326                if response.status() == http::StatusCode::SWITCHING_PROTOCOLS {
327                    let request_protocol = this.request_upgrade.protocol.as_ref();
328                    let response_protocol = get_upgrade_response(response.headers());
329                    if request_protocol
330                        .zip(response_protocol.as_ref())
331                        .is_some_and(|(protocols, response_protocol)| {
332                            protocols.matching(response_protocol)
333                        })
334                    {
335                        let response_upgraded = hyper::upgrade::on(response);
336                        let request_upgraded = this.request_upgrade.on.take().unwrap();
337
338                        tokio::spawn(async move {
339                            let upstream_io = match request_upgraded.await {
340                                Ok(upgraded) => {
341                                    tracing::debug!("Request upgraded");
342                                    upgraded
343                                }
344                                Err(e) => {
345                                    tracing::error!("Request upgrade failed: {:?}", e);
346                                    return;
347                                }
348                            };
349
350                            let downstream_io = match response_upgraded.await {
351                                Ok(upgraded) => {
352                                    tracing::debug!("Response upgraded");
353                                    upgraded
354                                }
355                                Err(e) => {
356                                    tracing::error!("Response upgrade failed: {:?}", e);
357                                    return;
358                                }
359                            };
360
361                            match copy_bidirectional(
362                                &mut TokioIo::new(upstream_io),
363                                &mut TokioIo::new(downstream_io),
364                            )
365                            .await
366                            {
367                                Ok((up, down)) => {
368                                    tracing::debug!(
369                                        "Upgrade complete: {} bytes upstream, {} bytes downstream",
370                                        up,
371                                        down
372                                    );
373                                }
374                                Err(error) => {
375                                    tracing::debug!("Upgrade IO error: {}", error);
376                                }
377                            }
378                        });
379                    } else {
380                        let protocol_options = request_protocol
381                            .map(|p| {
382                                p.iter()
383                                    .map(|p| format!("{p:?}"))
384                                    .collect::<Vec<_>>()
385                                    .join(", ")
386                            })
387                            .unwrap_or_default();
388
389                        tracing::debug!(
390                            requested = %protocol_options,
391                            response = %response_protocol.as_ref().map(|p| format!("{p:?}")).unwrap_or_default(),
392                            "Proxy Upgrade protocol mismatch, refusing to start upgrade"
393                        );
394                    }
395                }
396            }
397
398            std::task::Poll::Ready(response)
399        }
400    }
401}
402
403#[cfg(test)]
404mod tests {
405
406    use super::*;
407
408    #[test]
409    fn parse_protocol() {
410        let protocol = "websocket".parse::<UpgradeProtocol>().unwrap();
411        assert_eq!(protocol.name().as_bytes(), b"websocket");
412    }
413
414    #[test]
415    fn parse_protocol_with_invalid_characters() {
416        let protocol = "websocket/ 2".parse::<UpgradeProtocol>();
417        assert!(protocol.is_err());
418    }
419
420    #[test]
421    fn parse_protocol_requests() {
422        let protocols =
423            parse_upgrade_protocols(&"websocket, http/2".parse::<http::HeaderValue>().unwrap())
424                .unwrap();
425        assert_eq!(protocols.len(), 2);
426
427        let request = UpgradeRequest { protocols };
428
429        assert!(request.matching(&"http/2".parse().unwrap()))
430    }
431
432    #[test]
433    fn parse_headers_without_upgrade_in_connection() {
434        let mut headers = http::HeaderMap::new();
435        headers.insert(http::header::CONNECTION, "close".parse().unwrap());
436        headers.insert(http::header::UPGRADE, "websocket".parse().unwrap());
437
438        let request = get_upgrade_request(&headers).unwrap();
439        assert!(request.is_empty());
440    }
441}