fastwebsockets/
upgrade.rs

1// Port of hyper_tunstenite for fastwebsockets.
2// https://github.com/de-vri-es/hyper-tungstenite-rs
3//
4// Copyright 2021, Maarten de Vries maarten@de-vri.es
5// BSD 2-Clause "Simplified" License
6//
7// Copyright 2023 Divy Srivastava <dj.srivastava23@gmail.com>
8//
9// Licensed under the Apache License, Version 2.0 (the "License");
10// you may not use this file except in compliance with the License.
11// You may obtain a copy of the License at
12//
13// http://www.apache.org/licenses/LICENSE-2.0
14//
15// Unless required by applicable law or agreed to in writing, software
16// distributed under the License is distributed on an "AS IS" BASIS,
17// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18// See the License for the specific language governing permissions and
19// limitations under the License.
20
21use base64;
22use base64::engine::general_purpose::STANDARD;
23use base64::Engine;
24use http_body_util::Empty;
25use hyper::body::Bytes;
26use hyper::Request;
27use hyper::Response;
28use hyper_util::rt::TokioIo;
29use pin_project::pin_project;
30use sha1::Digest;
31use sha1::Sha1;
32use std::pin::Pin;
33use std::task::Context;
34use std::task::Poll;
35
36use crate::Role;
37use crate::WebSocket;
38use crate::WebSocketError;
39
40fn sec_websocket_protocol(key: &[u8]) -> String {
41  let mut sha1 = Sha1::new();
42  sha1.update(key);
43  sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); // magic string
44  let result = sha1.finalize();
45  STANDARD.encode(&result[..])
46}
47
48type Error = WebSocketError;
49
50pub struct IncomingUpgrade {
51  key: String,
52  on_upgrade: hyper::upgrade::OnUpgrade,
53}
54
55impl IncomingUpgrade {
56  pub fn upgrade(self) -> Result<(Response<Empty<Bytes>>, UpgradeFut), Error> {
57    let response = Response::builder()
58      .status(hyper::StatusCode::SWITCHING_PROTOCOLS)
59      .header(hyper::header::CONNECTION, "upgrade")
60      .header(hyper::header::UPGRADE, "websocket")
61      .header("Sec-WebSocket-Accept", self.key)
62      .body(Empty::new())
63      .expect("bug: failed to build response");
64
65    let stream = UpgradeFut {
66      inner: self.on_upgrade,
67    };
68
69    Ok((response, stream))
70  }
71}
72
73#[cfg(feature = "with_axum")]
74impl<S> axum_core::extract::FromRequestParts<S> for IncomingUpgrade
75where
76  S: Send + Sync,
77{
78  type Rejection = hyper::StatusCode;
79
80  async fn from_request_parts(
81    parts: &mut http::request::Parts,
82    _state: &S,
83  ) -> Result<Self, Self::Rejection> {
84    let key = parts
85      .headers
86      .get("Sec-WebSocket-Key")
87      .ok_or(hyper::StatusCode::BAD_REQUEST)?;
88    if parts
89      .headers
90      .get("Sec-WebSocket-Version")
91      .map(|v| v.as_bytes())
92      != Some(b"13")
93    {
94      return Err(hyper::StatusCode::BAD_REQUEST);
95    }
96
97    let on_upgrade = parts
98      .extensions
99      .remove::<hyper::upgrade::OnUpgrade>()
100      .ok_or(hyper::StatusCode::BAD_REQUEST)?;
101    Ok(Self {
102      on_upgrade,
103      key: sec_websocket_protocol(key.as_bytes()),
104    })
105  }
106}
107
108/// A future that resolves to a websocket stream when the associated HTTP upgrade completes.
109#[pin_project]
110#[derive(Debug)]
111pub struct UpgradeFut {
112  #[pin]
113  inner: hyper::upgrade::OnUpgrade,
114}
115
116/// Try to upgrade a received `hyper::Request` to a websocket connection.
117///
118/// The function returns a HTTP response and a future that resolves to the websocket stream.
119/// The response body *MUST* be sent to the client before the future can be resolved.
120///
121/// This functions checks `Sec-WebSocket-Key` and `Sec-WebSocket-Version` headers.
122/// It does not inspect the `Origin`, `Sec-WebSocket-Protocol` or `Sec-WebSocket-Extensions` headers.
123/// You can inspect the headers manually before calling this function,
124/// and modify the response headers appropriately.
125///
126/// This function also does not look at the `Connection` or `Upgrade` headers.
127/// To check if a request is a websocket upgrade request, you can use [`is_upgrade_request`].
128/// Alternatively you can inspect the `Connection` and `Upgrade` headers manually.
129///
130pub fn upgrade<B>(
131  mut request: impl std::borrow::BorrowMut<Request<B>>,
132) -> Result<(Response<Empty<Bytes>>, UpgradeFut), Error> {
133  let request = request.borrow_mut();
134
135  let key = request
136    .headers()
137    .get("Sec-WebSocket-Key")
138    .ok_or(WebSocketError::MissingSecWebSocketKey)?;
139  if request
140    .headers()
141    .get("Sec-WebSocket-Version")
142    .map(|v| v.as_bytes())
143    != Some(b"13")
144  {
145    return Err(WebSocketError::InvalidSecWebsocketVersion);
146  }
147
148  let response = Response::builder()
149    .status(hyper::StatusCode::SWITCHING_PROTOCOLS)
150    .header(hyper::header::CONNECTION, "upgrade")
151    .header(hyper::header::UPGRADE, "websocket")
152    .header(
153      "Sec-WebSocket-Accept",
154      &sec_websocket_protocol(key.as_bytes()),
155    )
156    .body(Empty::new())
157    .expect("bug: failed to build response");
158
159  let stream = UpgradeFut {
160    inner: hyper::upgrade::on(request),
161  };
162
163  Ok((response, stream))
164}
165
166/// Check if a request is a websocket upgrade request.
167///
168/// If the `Upgrade` header lists multiple protocols,
169/// this function returns true if of them are `"websocket"`,
170/// If the server supports multiple upgrade protocols,
171/// it would be more appropriate to try each listed protocol in order.
172pub fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool {
173  header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade")
174    && header_contains_value(
175      request.headers(),
176      hyper::header::UPGRADE,
177      "websocket",
178    )
179}
180
181/// Check if there is a header of the given name containing the wanted value.
182fn header_contains_value(
183  headers: &hyper::HeaderMap,
184  header: impl hyper::header::AsHeaderName,
185  value: impl AsRef<[u8]>,
186) -> bool {
187  let value = value.as_ref();
188  for header in headers.get_all(header) {
189    if header
190      .as_bytes()
191      .split(|&c| c == b',')
192      .any(|x| trim(x).eq_ignore_ascii_case(value))
193    {
194      return true;
195    }
196  }
197  false
198}
199
200fn trim(data: &[u8]) -> &[u8] {
201  trim_end(trim_start(data))
202}
203
204fn trim_start(data: &[u8]) -> &[u8] {
205  if let Some(start) = data.iter().position(|x| !x.is_ascii_whitespace()) {
206    &data[start..]
207  } else {
208    b""
209  }
210}
211
212fn trim_end(data: &[u8]) -> &[u8] {
213  if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) {
214    &data[..last + 1]
215  } else {
216    b""
217  }
218}
219
220impl std::future::Future for UpgradeFut {
221  type Output = Result<WebSocket<TokioIo<hyper::upgrade::Upgraded>>, Error>;
222
223  fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
224    let this = self.project();
225    let upgraded = match this.inner.poll(cx) {
226      Poll::Pending => return Poll::Pending,
227      Poll::Ready(x) => x,
228    };
229    Poll::Ready(Ok(WebSocket::after_handshake(
230      TokioIo::new(upgraded?),
231      Role::Server,
232    )))
233  }
234}