1extern crate alloc;
37
38use http::{
39 header::{
40 HeaderMap, HeaderName, HeaderValue, ALLOW, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY,
41 SEC_WEBSOCKET_VERSION, UPGRADE,
42 },
43 request::Request,
44 response::{Builder, Response},
45 uri::Uri,
46 Method, StatusCode, Version,
47};
48
49mod codec;
50mod error;
51mod frame;
52mod mask;
53mod proto;
54
55pub use self::{
56 codec::{Codec, Item, Message},
57 error::{HandshakeError, ProtocolError},
58 proto::{hash_key, CloseCode, CloseReason, OpCode},
59};
60
61#[allow(clippy::declare_interior_mutable_const)]
62mod const_header {
63 use super::{HeaderName, HeaderValue};
64
65 pub(super) const PROTOCOL: HeaderName = HeaderName::from_static("protocol");
66
67 pub(super) const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
68 pub(super) const UPGRADE_VALUE: HeaderValue = HeaderValue::from_static("upgrade");
69 pub(super) const SEC_WEBSOCKET_VERSION_VALUE: HeaderValue = HeaderValue::from_static("13");
70}
71
72use const_header::*;
73
74impl From<HandshakeError> for Builder {
75 fn from(e: HandshakeError) -> Self {
76 match e {
77 HandshakeError::GetMethodRequired => Response::builder()
78 .status(StatusCode::METHOD_NOT_ALLOWED)
79 .header(ALLOW, "GET"),
80
81 _ => Response::builder().status(StatusCode::BAD_REQUEST),
82 }
83 }
84}
85
86pub fn client_request_from_uri<U, E>(uri: U, version: Version) -> Result<Request<()>, E>
90where
91 Uri: TryFrom<U, Error = E>,
92{
93 let uri = uri.try_into()?;
94
95 let mut req = Request::new(());
96 *req.uri_mut() = uri;
97 *req.version_mut() = version;
98
99 match version {
100 Version::HTTP_11 => {
101 req.headers_mut().insert(UPGRADE, WEBSOCKET);
102 req.headers_mut().insert(CONNECTION, UPGRADE_VALUE);
103
104 let input = rand::random::<[u8; 16]>();
106 let mut output = [0u8; 24];
107
108 #[allow(clippy::needless_borrow)] let n =
110 base64::engine::Engine::encode_slice(&base64::engine::general_purpose::STANDARD, input, &mut output)
111 .unwrap();
112 assert_eq!(n, output.len());
113
114 req.headers_mut()
115 .insert(SEC_WEBSOCKET_KEY, HeaderValue::from_bytes(&output).unwrap());
116 }
117 Version::HTTP_2 => {
118 *req.method_mut() = Method::CONNECT;
119 req.headers_mut().insert(PROTOCOL, WEBSOCKET);
120 }
121 _ => {}
122 }
123
124 req.headers_mut()
125 .insert(SEC_WEBSOCKET_VERSION, SEC_WEBSOCKET_VERSION_VALUE);
126
127 Ok(req)
128}
129
130pub fn handshake(method: &Method, headers: &HeaderMap) -> Result<Builder, HandshakeError> {
132 let key = verify_handshake(method, headers)?;
133 let builder = handshake_response(key);
134 Ok(builder)
135}
136
137pub fn handshake_h2(method: &Method, headers: &HeaderMap) -> Result<Builder, HandshakeError> {
139 if method != Method::CONNECT {
141 return Err(HandshakeError::ConnectMethodRequired);
142 }
143
144 ws_version_check(headers)?;
145
146 Ok(Response::builder().status(StatusCode::OK))
147}
148
149fn verify_handshake<'a>(method: &'a Method, headers: &'a HeaderMap) -> Result<&'a [u8], HandshakeError> {
151 if method != Method::GET {
153 return Err(HandshakeError::GetMethodRequired);
154 }
155
156 let has_upgrade_hd = headers
158 .get(UPGRADE)
159 .and_then(|hdr| hdr.to_str().ok())
160 .filter(|s| s.to_ascii_lowercase().contains("websocket"))
161 .is_some();
162
163 if !has_upgrade_hd {
164 return Err(HandshakeError::NoWebsocketUpgrade);
165 }
166
167 let has_connection_hd = headers
169 .get(CONNECTION)
170 .and_then(|hdr| hdr.to_str().ok())
171 .filter(|s| s.to_ascii_lowercase().contains("upgrade"))
172 .is_some();
173
174 if !has_connection_hd {
175 return Err(HandshakeError::NoConnectionUpgrade);
176 }
177
178 ws_version_check(headers)?;
179
180 let value = headers.get(SEC_WEBSOCKET_KEY).ok_or(HandshakeError::BadWebsocketKey)?;
182
183 Ok(value.as_bytes())
184}
185
186fn handshake_response(key: &[u8]) -> Builder {
190 let key = hash_key(key);
191
192 Response::builder()
193 .status(StatusCode::SWITCHING_PROTOCOLS)
194 .header(UPGRADE, WEBSOCKET)
195 .header(CONNECTION, UPGRADE_VALUE)
196 .header(
197 SEC_WEBSOCKET_ACCEPT,
198 HeaderValue::from_bytes(&key).unwrap(),
200 )
201}
202
203fn ws_version_check(headers: &HeaderMap) -> Result<(), HandshakeError> {
205 let value = headers
206 .get(SEC_WEBSOCKET_VERSION)
207 .ok_or(HandshakeError::NoVersionHeader)?;
208
209 if value != "13" && value != "8" && value != "7" {
210 Err(HandshakeError::UnsupportedVersion)
211 } else {
212 Ok(())
213 }
214}
215
216#[cfg(feature = "stream")]
217pub mod stream;
218
219#[cfg(feature = "stream")]
220pub use self::stream::{RequestStream, ResponseSender, ResponseStream, ResponseWeakSender, WsError};
221
222#[cfg(feature = "stream")]
223pub type WsOutput<B> = (RequestStream<B>, Response<ResponseStream>, ResponseSender);
224
225#[cfg(feature = "stream")]
226pub fn ws<ReqB, B, T, E>(req: &Request<ReqB>, body: B) -> Result<WsOutput<B>, HandshakeError>
274where
275 B: futures_core::Stream<Item = Result<T, E>>,
276 T: AsRef<[u8]>,
277{
278 let builder = match req.version() {
279 Version::HTTP_2 => handshake_h2(req.method(), req.headers())?,
280 _ => handshake(req.method(), req.headers())?,
281 };
282
283 let decode = RequestStream::new(body);
284 let (res, tx) = decode.response_stream();
285
286 let res = builder
287 .body(res)
288 .expect("handshake function failed to generate correct Response Builder");
289
290 Ok((decode, res, tx))
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_handshake() {
299 let req = Request::builder().method(Method::POST).body(()).unwrap();
300 assert_eq!(
301 HandshakeError::GetMethodRequired,
302 verify_handshake(req.method(), req.headers()).unwrap_err(),
303 );
304
305 let req = Request::builder().body(()).unwrap();
306 assert_eq!(
307 HandshakeError::NoWebsocketUpgrade,
308 verify_handshake(req.method(), req.headers()).unwrap_err(),
309 );
310
311 let req = Request::builder()
312 .header(UPGRADE, HeaderValue::from_static("test"))
313 .body(())
314 .unwrap();
315 assert_eq!(
316 HandshakeError::NoWebsocketUpgrade,
317 verify_handshake(req.method(), req.headers()).unwrap_err(),
318 );
319
320 let req = Request::builder().header(UPGRADE, WEBSOCKET).body(()).unwrap();
321 assert_eq!(
322 HandshakeError::NoConnectionUpgrade,
323 verify_handshake(req.method(), req.headers()).unwrap_err(),
324 );
325
326 let req = Request::builder()
327 .header(UPGRADE, WEBSOCKET)
328 .header(CONNECTION, UPGRADE_VALUE)
329 .body(())
330 .unwrap();
331 assert_eq!(
332 HandshakeError::NoVersionHeader,
333 verify_handshake(req.method(), req.headers()).unwrap_err(),
334 );
335
336 let req = Request::builder()
337 .header(UPGRADE, WEBSOCKET)
338 .header(CONNECTION, UPGRADE_VALUE)
339 .header(SEC_WEBSOCKET_VERSION, HeaderValue::from_static("5"))
340 .body(())
341 .unwrap();
342 assert_eq!(
343 HandshakeError::UnsupportedVersion,
344 verify_handshake(req.method(), req.headers()).unwrap_err(),
345 );
346
347 let builder = || {
348 Request::builder()
349 .header(UPGRADE, WEBSOCKET)
350 .header(CONNECTION, UPGRADE_VALUE)
351 .header(SEC_WEBSOCKET_VERSION, SEC_WEBSOCKET_VERSION_VALUE)
352 };
353
354 let req = builder().body(()).unwrap();
355 assert_eq!(
356 HandshakeError::BadWebsocketKey,
357 verify_handshake(req.method(), req.headers()).unwrap_err(),
358 );
359
360 let req = builder()
361 .header(SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION_VALUE)
362 .body(())
363 .unwrap();
364 let key = verify_handshake(req.method(), req.headers()).unwrap();
365 assert_eq!(
366 StatusCode::SWITCHING_PROTOCOLS,
367 handshake_response(key).body(()).unwrap().status()
368 );
369 }
370
371 #[test]
372 fn test_ws_error_http_response() {
373 let res = Builder::from(HandshakeError::GetMethodRequired).body(()).unwrap();
374 assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
375 let res = Builder::from(HandshakeError::NoWebsocketUpgrade).body(()).unwrap();
376 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
377 let res = Builder::from(HandshakeError::NoConnectionUpgrade).body(()).unwrap();
378 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
379 let res = Builder::from(HandshakeError::NoVersionHeader).body(()).unwrap();
380 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
381 let res = Builder::from(HandshakeError::UnsupportedVersion).body(()).unwrap();
382 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
383 let res = Builder::from(HandshakeError::BadWebsocketKey).body(()).unwrap();
384 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
385 }
386}