1use core::fmt;
4use std::ops;
5use std::str::FromStr;
6
7use bytes::{BufMut, Bytes, BytesMut};
8use http::HeaderValue;
9use nom::bytes::complete::tag;
10use nom::combinator::{map, opt};
11use nom::multi::separated_list0;
12use nom::sequence::tuple;
13use nom::IResult;
14use thiserror::Error;
15
16use crate::headers::fields::Token;
17use crate::headers::parser::{strip_whitespace, token, NoTail as _};
18pub const UPGRADE: http::HeaderName = http::header::UPGRADE;
20
21#[derive(Debug, Error)]
23#[error("upgrade protocol error")]
24pub struct UpgradeProtocolError(nom::error::Error<Bytes>);
25
26impl From<nom::error::Error<Bytes>> for UpgradeProtocolError {
27 fn from(error: nom::error::Error<Bytes>) -> Self {
28 UpgradeProtocolError(error)
29 }
30}
31
32impl From<nom::error::Error<&[u8]>> for UpgradeProtocolError {
33 fn from(error: nom::error::Error<&[u8]>) -> Self {
34 UpgradeProtocolError(nom::error::Error::new(
35 Bytes::copy_from_slice(error.input),
36 error.code,
37 ))
38 }
39}
40
41fn protocol<'v>() -> impl FnMut(&'v [u8]) -> IResult<&'v [u8], UpgradeProtocol> {
42 let v = tuple((tag(b"/"), token()));
43 let version = opt(map(v, |(_, version)| version));
44
45 map(tuple((token(), version)), |(name, version)| {
46 UpgradeProtocol { name, version }
47 })
48}
49
50fn parse_upgrade_protocols(
51 value: &HeaderValue,
52) -> Result<Vec<UpgradeProtocol>, UpgradeProtocolError> {
53 separated_list0(tag(b","), strip_whitespace(protocol()))(value.as_bytes())
54 .no_tail()
55 .map_err(Into::into)
56}
57
58fn parse_connection_headers(value: &HeaderValue) -> Result<Vec<Token>, UpgradeProtocolError> {
59 separated_list0(tag(b","), strip_whitespace(token()))(value.as_bytes())
60 .no_tail()
61 .map_err(Into::into)
62}
63
64fn get_upgrade_request(headers: &http::HeaderMap) -> Result<UpgradeRequest, UpgradeProtocolError> {
66 if let Some(connection) = headers.get(http::header::CONNECTION) {
67 let connection_headers = parse_connection_headers(connection)?;
68 if connection_headers.contains(&Token::from_static("upgrade")) {
69 if let Some(upgrade) = headers.get(UPGRADE) {
70 tracing::trace!("Found upgrade header: {:?}", upgrade);
71 return parse_upgrade_protocols(upgrade)
72 .map(|protocols| UpgradeRequest { protocols });
73 }
74 }
75 }
76
77 Ok(Default::default())
78}
79
80fn get_upgrade_response(headers: &http::HeaderMap) -> Option<UpgradeProtocol> {
81 match get_upgrade_request(headers) {
82 Ok(mut protocols) if protocols.len() == 1 => protocols.pop(),
83 _ => None,
84 }
85}
86
87#[derive(Clone)]
89pub struct UpgradeProtocol {
90 name: Token,
91 version: Option<Token>,
92}
93
94impl PartialEq for UpgradeProtocol {
95 fn eq(&self, other: &Self) -> bool {
96 if let Some((version, other_version)) = self.version().zip(other.version()) {
97 self.name.eq_ignore_ascii_case(&other.name)
98 && version.eq_ignore_ascii_case(other_version)
99 } else {
100 self.name.eq_ignore_ascii_case(&other.name)
101 }
102 }
103}
104
105impl fmt::Debug for UpgradeProtocol {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 let name = String::from_utf8_lossy(self.name.as_bytes());
108 write!(f, "UpgradeProtocol(")?;
109 match self.version {
110 Some(ref version) => write!(
111 f,
112 "{}/{}",
113 name,
114 String::from_utf8_lossy(version.as_bytes())
115 ),
116 None => write!(f, "{}", name),
117 }?;
118 write!(f, ")")
119 }
120}
121
122impl UpgradeProtocol {
123 pub fn name(&self) -> &Token {
125 &self.name
126 }
127
128 pub fn version(&self) -> Option<&Token> {
130 self.version.as_ref()
131 }
132
133 fn extend_buffer(&self, buffer: &mut BytesMut) {
134 buffer.extend_from_slice(self.name.as_bytes());
135 if let Some(version) = &self.version {
136 buffer.put_u8(b'/');
137 buffer.extend_from_slice(version.as_bytes());
138 }
139 }
140}
141
142impl FromStr for UpgradeProtocol {
143 type Err = UpgradeProtocolError;
144
145 fn from_str(value: &str) -> Result<Self, Self::Err> {
146 protocol()(value.as_bytes()).no_tail().map_err(Into::into)
147 }
148}
149
150#[derive(Debug, Clone, Default)]
152pub struct UpgradeRequest {
153 protocols: Vec<UpgradeProtocol>,
154}
155
156impl UpgradeRequest {
157 pub fn matching(&self, protocol: &UpgradeProtocol) -> bool {
159 self.protocols.contains(protocol)
160 }
161
162 pub fn push(&mut self, protocol: UpgradeProtocol) {
164 self.protocols.push(protocol);
165 }
166
167 pub fn to_header_value(&self) -> HeaderValue {
169 let mut buf = BytesMut::new();
170
171 let mut iter = self.protocols.iter();
172 if let Some(protocol) = iter.next() {
173 protocol.extend_buffer(&mut buf);
174 }
175
176 for protocol in iter {
177 buf.put(&b", "[..]);
178 protocol.extend_buffer(&mut buf);
179 }
180
181 HeaderValue::from_bytes(&buf).unwrap()
182 }
183
184 fn pop(&mut self) -> Option<UpgradeProtocol> {
185 self.protocols.pop()
186 }
187}
188
189impl ops::Deref for UpgradeRequest {
190 type Target = [UpgradeProtocol];
191
192 fn deref(&self) -> &Self::Target {
193 &self.protocols
194 }
195}
196
197#[derive(Clone, Debug)]
199pub struct ProxyUpgradeLayer {
200 _priv: (),
201}
202
203impl Default for ProxyUpgradeLayer {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209impl ProxyUpgradeLayer {
210 pub fn new() -> Self {
212 Self { _priv: () }
213 }
214}
215
216impl<S> tower::layer::Layer<S> for ProxyUpgradeLayer {
217 type Service = ProxyUpgrade<S>;
218
219 fn layer(&self, inner: S) -> Self::Service {
220 ProxyUpgrade::new(inner)
221 }
222}
223
224#[derive(Clone, Debug)]
226pub struct ProxyUpgrade<S> {
227 inner: S,
228}
229
230impl<S> ProxyUpgrade<S> {
231 pub fn new(inner: S) -> Self {
233 Self { inner }
234 }
235}
236
237impl<S, BIn, BOut> tower::Service<http::Request<BIn>> for ProxyUpgrade<S>
238where
239 S: tower::Service<http::Request<BIn>, Response = http::Response<BOut>>,
240{
241 type Response = S::Response;
242 type Error = S::Error;
243 type Future = self::future::UpgradableProxyFuture<S::Future>;
244
245 fn call(&mut self, mut request: http::Request<BIn>) -> Self::Future {
246 let upgrade = self::future::Upgrade::new(&mut request);
247 let inner = self.inner.call(request);
248 self::future::UpgradableProxyFuture::new(inner, upgrade)
249 }
250
251 fn poll_ready(
252 &mut self,
253 cx: &mut std::task::Context<'_>,
254 ) -> std::task::Poll<Result<(), Self::Error>> {
255 self.inner.poll_ready(cx)
256 }
257}
258
259mod future {
260
261 use std::task::ready;
262
263 use hyperdriver::bridge::io::TokioIo;
264 use tokio::io::copy_bidirectional;
265
266 use super::*;
267
268 #[derive(Debug)]
269 pub(super) struct Upgrade {
270 protocol: Option<UpgradeRequest>,
271 on: Option<hyper::upgrade::OnUpgrade>,
272 }
273
274 impl Upgrade {
275 pub(super) fn new<B>(request: &mut http::Request<B>) -> Self {
276 let protocol = get_upgrade_request(request.headers())
277 .map(Some)
278 .unwrap_or_else(|error| {
279 tracing::error!("Unable to parse upgrade protocols from request: {error}");
280 None
281 });
282
283 if let Some(protocol) = &protocol {
284 request.extensions_mut().insert(protocol.clone());
285 }
286
287 let on = hyper::upgrade::on(request);
288 Self {
289 protocol,
290 on: Some(on),
291 }
292 }
293 }
294
295 pin_project_lite::pin_project! {
296 pub struct UpgradableProxyFuture<F> {
297 #[pin]
298 inner: F,
299 request_upgrade: Upgrade,
300 }
301 }
302
303 impl<F> UpgradableProxyFuture<F> {
304 pub(super) fn new(inner: F, upgrade: Upgrade) -> Self {
305 Self {
306 inner,
307 request_upgrade: upgrade,
308 }
309 }
310 }
311
312 impl<F, BOut, E> std::future::Future for UpgradableProxyFuture<F>
313 where
314 F: std::future::Future<Output = Result<http::Response<BOut>, E>>,
315 {
316 type Output = Result<http::Response<BOut>, E>;
317
318 fn poll(
319 self: std::pin::Pin<&mut Self>,
320 cx: &mut std::task::Context<'_>,
321 ) -> std::task::Poll<Self::Output> {
322 let this = self.project();
323 let mut response = ready!(this.inner.poll(cx));
324
325 if let Ok(response) = &mut response {
326 if response.status() == http::StatusCode::SWITCHING_PROTOCOLS {
327 let request_protocol = this.request_upgrade.protocol.as_ref();
328 let response_protocol = get_upgrade_response(response.headers());
329 if request_protocol
330 .zip(response_protocol.as_ref())
331 .is_some_and(|(protocols, response_protocol)| {
332 protocols.matching(response_protocol)
333 })
334 {
335 let response_upgraded = hyper::upgrade::on(response);
336 let request_upgraded = this.request_upgrade.on.take().unwrap();
337
338 tokio::spawn(async move {
339 let upstream_io = match request_upgraded.await {
340 Ok(upgraded) => {
341 tracing::debug!("Request upgraded");
342 upgraded
343 }
344 Err(e) => {
345 tracing::error!("Request upgrade failed: {:?}", e);
346 return;
347 }
348 };
349
350 let downstream_io = match response_upgraded.await {
351 Ok(upgraded) => {
352 tracing::debug!("Response upgraded");
353 upgraded
354 }
355 Err(e) => {
356 tracing::error!("Response upgrade failed: {:?}", e);
357 return;
358 }
359 };
360
361 match copy_bidirectional(
362 &mut TokioIo::new(upstream_io),
363 &mut TokioIo::new(downstream_io),
364 )
365 .await
366 {
367 Ok((up, down)) => {
368 tracing::debug!(
369 "Upgrade complete: {} bytes upstream, {} bytes downstream",
370 up,
371 down
372 );
373 }
374 Err(error) => {
375 tracing::debug!("Upgrade IO error: {}", error);
376 }
377 }
378 });
379 } else {
380 let protocol_options = request_protocol
381 .map(|p| {
382 p.iter()
383 .map(|p| format!("{p:?}"))
384 .collect::<Vec<_>>()
385 .join(", ")
386 })
387 .unwrap_or_default();
388
389 tracing::debug!(
390 requested = %protocol_options,
391 response = %response_protocol.as_ref().map(|p| format!("{p:?}")).unwrap_or_default(),
392 "Proxy Upgrade protocol mismatch, refusing to start upgrade"
393 );
394 }
395 }
396 }
397
398 std::task::Poll::Ready(response)
399 }
400 }
401}
402
403#[cfg(test)]
404mod tests {
405
406 use super::*;
407
408 #[test]
409 fn parse_protocol() {
410 let protocol = "websocket".parse::<UpgradeProtocol>().unwrap();
411 assert_eq!(protocol.name().as_bytes(), b"websocket");
412 }
413
414 #[test]
415 fn parse_protocol_with_invalid_characters() {
416 let protocol = "websocket/ 2".parse::<UpgradeProtocol>();
417 assert!(protocol.is_err());
418 }
419
420 #[test]
421 fn parse_protocol_requests() {
422 let protocols =
423 parse_upgrade_protocols(&"websocket, http/2".parse::<http::HeaderValue>().unwrap())
424 .unwrap();
425 assert_eq!(protocols.len(), 2);
426
427 let request = UpgradeRequest { protocols };
428
429 assert!(request.matching(&"http/2".parse().unwrap()))
430 }
431
432 #[test]
433 fn parse_headers_without_upgrade_in_connection() {
434 let mut headers = http::HeaderMap::new();
435 headers.insert(http::header::CONNECTION, "close".parse().unwrap());
436 headers.insert(http::header::UPGRADE, "websocket".parse().unwrap());
437
438 let request = get_upgrade_request(&headers).unwrap();
439 assert!(request.is_empty());
440 }
441}