fastwebsockets/
upgrade.rs1use base64;
22use base64::engine::general_purpose::STANDARD;
23use base64::Engine;
24use http_body_util::Empty;
25use hyper::body::Bytes;
26use hyper::Request;
27use hyper::Response;
28use hyper_util::rt::TokioIo;
29use pin_project::pin_project;
30use sha1::Digest;
31use sha1::Sha1;
32use std::pin::Pin;
33use std::task::Context;
34use std::task::Poll;
35
36use crate::Role;
37use crate::WebSocket;
38use crate::WebSocketError;
39
40fn sec_websocket_protocol(key: &[u8]) -> String {
41 let mut sha1 = Sha1::new();
42 sha1.update(key);
43 sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); let result = sha1.finalize();
45 STANDARD.encode(&result[..])
46}
47
48type Error = WebSocketError;
49
50pub struct IncomingUpgrade {
51 key: String,
52 on_upgrade: hyper::upgrade::OnUpgrade,
53}
54
55impl IncomingUpgrade {
56 pub fn upgrade(self) -> Result<(Response<Empty<Bytes>>, UpgradeFut), Error> {
57 let response = Response::builder()
58 .status(hyper::StatusCode::SWITCHING_PROTOCOLS)
59 .header(hyper::header::CONNECTION, "upgrade")
60 .header(hyper::header::UPGRADE, "websocket")
61 .header("Sec-WebSocket-Accept", self.key)
62 .body(Empty::new())
63 .expect("bug: failed to build response");
64
65 let stream = UpgradeFut {
66 inner: self.on_upgrade,
67 };
68
69 Ok((response, stream))
70 }
71}
72
73#[cfg(feature = "with_axum")]
74impl<S> axum_core::extract::FromRequestParts<S> for IncomingUpgrade
75where
76 S: Send + Sync,
77{
78 type Rejection = hyper::StatusCode;
79
80 async fn from_request_parts(
81 parts: &mut http::request::Parts,
82 _state: &S,
83 ) -> Result<Self, Self::Rejection> {
84 let key = parts
85 .headers
86 .get("Sec-WebSocket-Key")
87 .ok_or(hyper::StatusCode::BAD_REQUEST)?;
88 if parts
89 .headers
90 .get("Sec-WebSocket-Version")
91 .map(|v| v.as_bytes())
92 != Some(b"13")
93 {
94 return Err(hyper::StatusCode::BAD_REQUEST);
95 }
96
97 let on_upgrade = parts
98 .extensions
99 .remove::<hyper::upgrade::OnUpgrade>()
100 .ok_or(hyper::StatusCode::BAD_REQUEST)?;
101 Ok(Self {
102 on_upgrade,
103 key: sec_websocket_protocol(key.as_bytes()),
104 })
105 }
106}
107
108#[pin_project]
110#[derive(Debug)]
111pub struct UpgradeFut {
112 #[pin]
113 inner: hyper::upgrade::OnUpgrade,
114}
115
116pub fn upgrade<B>(
131 mut request: impl std::borrow::BorrowMut<Request<B>>,
132) -> Result<(Response<Empty<Bytes>>, UpgradeFut), Error> {
133 let request = request.borrow_mut();
134
135 let key = request
136 .headers()
137 .get("Sec-WebSocket-Key")
138 .ok_or(WebSocketError::MissingSecWebSocketKey)?;
139 if request
140 .headers()
141 .get("Sec-WebSocket-Version")
142 .map(|v| v.as_bytes())
143 != Some(b"13")
144 {
145 return Err(WebSocketError::InvalidSecWebsocketVersion);
146 }
147
148 let response = Response::builder()
149 .status(hyper::StatusCode::SWITCHING_PROTOCOLS)
150 .header(hyper::header::CONNECTION, "upgrade")
151 .header(hyper::header::UPGRADE, "websocket")
152 .header(
153 "Sec-WebSocket-Accept",
154 &sec_websocket_protocol(key.as_bytes()),
155 )
156 .body(Empty::new())
157 .expect("bug: failed to build response");
158
159 let stream = UpgradeFut {
160 inner: hyper::upgrade::on(request),
161 };
162
163 Ok((response, stream))
164}
165
166pub fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool {
173 header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade")
174 && header_contains_value(
175 request.headers(),
176 hyper::header::UPGRADE,
177 "websocket",
178 )
179}
180
181fn header_contains_value(
183 headers: &hyper::HeaderMap,
184 header: impl hyper::header::AsHeaderName,
185 value: impl AsRef<[u8]>,
186) -> bool {
187 let value = value.as_ref();
188 for header in headers.get_all(header) {
189 if header
190 .as_bytes()
191 .split(|&c| c == b',')
192 .any(|x| trim(x).eq_ignore_ascii_case(value))
193 {
194 return true;
195 }
196 }
197 false
198}
199
200fn trim(data: &[u8]) -> &[u8] {
201 trim_end(trim_start(data))
202}
203
204fn trim_start(data: &[u8]) -> &[u8] {
205 if let Some(start) = data.iter().position(|x| !x.is_ascii_whitespace()) {
206 &data[start..]
207 } else {
208 b""
209 }
210}
211
212fn trim_end(data: &[u8]) -> &[u8] {
213 if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) {
214 &data[..last + 1]
215 } else {
216 b""
217 }
218}
219
220impl std::future::Future for UpgradeFut {
221 type Output = Result<WebSocket<TokioIo<hyper::upgrade::Upgraded>>, Error>;
222
223 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
224 let this = self.project();
225 let upgraded = match this.inner.poll(cx) {
226 Poll::Pending => return Poll::Pending,
227 Poll::Ready(x) => x,
228 };
229 Poll::Ready(Ok(WebSocket::after_handshake(
230 TokioIo::new(upgraded?),
231 Role::Server,
232 )))
233 }
234}