1extern crate alloc;
37
38use core::ops::Deref;
39
40use http::{
41 header::{
42 HeaderMap, HeaderValue, ALLOW, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION,
43 UPGRADE,
44 },
45 request::Request,
46 response::{Builder, Response},
47 uri::Uri,
48 Method, StatusCode, Version,
49};
50
51mod codec;
52mod error;
53mod frame;
54mod mask;
55mod proto;
56
57pub use self::{
58 codec::{Codec, Item, Message},
59 error::{HandshakeError, ProtocolError},
60 proto::{hash_key, CloseCode, CloseReason, OpCode},
61};
62
63#[allow(clippy::declare_interior_mutable_const)]
64mod const_header {
65 use super::HeaderValue;
66
67 pub(super) const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
68 pub(super) const UPGRADE_VALUE: HeaderValue = HeaderValue::from_static("upgrade");
69 pub(super) const SEC_WEBSOCKET_VERSION_VALUE: HeaderValue = HeaderValue::from_static("13");
70}
71
72use const_header::*;
73
74impl From<HandshakeError> for Builder {
75 fn from(e: HandshakeError) -> Self {
76 match e {
77 HandshakeError::GetMethodRequired => Response::builder()
78 .status(StatusCode::METHOD_NOT_ALLOWED)
79 .header(ALLOW, "GET"),
80
81 _ => Response::builder().status(StatusCode::BAD_REQUEST),
82 }
83 }
84}
85
86pub fn client_request_from_uri(uri: Uri, version: Version) -> Request<()> {
91 let mut req = Request::new(());
92 *req.uri_mut() = uri;
93 *req.version_mut() = version;
94
95 client_request_extend(&mut req);
96
97 req
98}
99
100pub fn client_request_extend<B>(req: &mut Request<B>) {
110 match req.version() {
111 Version::HTTP_11 => {
112 req.headers_mut().insert(UPGRADE, WEBSOCKET);
113 req.headers_mut().insert(CONNECTION, UPGRADE_VALUE);
114
115 let input = rand::random::<[u8; 16]>();
117 let mut output = [0u8; 24];
118
119 #[allow(clippy::needless_borrow)] let n =
121 base64::engine::Engine::encode_slice(&base64::engine::general_purpose::STANDARD, input, &mut output)
122 .unwrap();
123 assert_eq!(n, output.len());
124
125 req.headers_mut()
126 .insert(SEC_WEBSOCKET_KEY, HeaderValue::from_bytes(&output).unwrap());
127 }
128 Version::HTTP_2 => {
129 *req.method_mut() = Method::CONNECT;
130 req.extensions_mut().insert(Http2WsProtocol::new());
131 }
132 _ => {}
133 }
134
135 req.headers_mut()
136 .insert(SEC_WEBSOCKET_VERSION, SEC_WEBSOCKET_VERSION_VALUE);
137}
138
139#[derive(Clone)]
140pub struct Http2WsProtocol(&'static str);
141
142impl AsRef<str> for Http2WsProtocol {
143 fn as_ref(&self) -> &str {
144 self.0
145 }
146}
147
148impl Deref for Http2WsProtocol {
149 type Target = str;
150
151 fn deref(&self) -> &Self::Target {
152 self.0
153 }
154}
155
156impl Http2WsProtocol {
157 const fn new() -> Self {
158 Self("websocket")
159 }
160}
161
162pub fn handshake(method: &Method, headers: &HeaderMap) -> Result<Builder, HandshakeError> {
164 let key = verify_handshake(method, headers)?;
165 let builder = handshake_response(key);
166 Ok(builder)
167}
168
169pub fn handshake_h2(method: &Method, headers: &HeaderMap) -> Result<Builder, HandshakeError> {
179 if method != Method::CONNECT {
181 return Err(HandshakeError::ConnectMethodRequired);
182 }
183
184 ws_version_check(headers)?;
185
186 Ok(Response::builder().status(StatusCode::OK))
187}
188
189fn verify_handshake<'a>(method: &'a Method, headers: &'a HeaderMap) -> Result<&'a [u8], HandshakeError> {
191 if method != Method::GET {
193 return Err(HandshakeError::GetMethodRequired);
194 }
195
196 let has_upgrade_hd = headers
198 .get(UPGRADE)
199 .and_then(|hdr| hdr.to_str().ok())
200 .filter(|s| s.to_ascii_lowercase().contains("websocket"))
201 .is_some();
202
203 if !has_upgrade_hd {
204 return Err(HandshakeError::NoWebsocketUpgrade);
205 }
206
207 let has_connection_hd = headers
209 .get(CONNECTION)
210 .and_then(|hdr| hdr.to_str().ok())
211 .filter(|s| s.to_ascii_lowercase().contains("upgrade"))
212 .is_some();
213
214 if !has_connection_hd {
215 return Err(HandshakeError::NoConnectionUpgrade);
216 }
217
218 ws_version_check(headers)?;
219
220 let value = headers.get(SEC_WEBSOCKET_KEY).ok_or(HandshakeError::BadWebsocketKey)?;
222
223 Ok(value.as_bytes())
224}
225
226fn handshake_response(key: &[u8]) -> Builder {
230 let key = hash_key(key);
231
232 Response::builder()
233 .status(StatusCode::SWITCHING_PROTOCOLS)
234 .header(UPGRADE, WEBSOCKET)
235 .header(CONNECTION, UPGRADE_VALUE)
236 .header(
237 SEC_WEBSOCKET_ACCEPT,
238 HeaderValue::from_bytes(&key).unwrap(),
240 )
241}
242
243fn ws_version_check(headers: &HeaderMap) -> Result<(), HandshakeError> {
245 let value = headers
246 .get(SEC_WEBSOCKET_VERSION)
247 .ok_or(HandshakeError::NoVersionHeader)?;
248
249 if value != "13" && value != "8" && value != "7" {
250 Err(HandshakeError::UnsupportedVersion)
251 } else {
252 Ok(())
253 }
254}
255
256#[cfg(feature = "stream")]
257pub mod stream;
258
259#[cfg(feature = "stream")]
260pub use self::stream::{RequestStream, ResponseSender, ResponseStream, ResponseWeakSender, WsError};
261
262#[cfg(feature = "stream")]
263pub type WsOutput<B> = (RequestStream<B>, Response<ResponseStream>, ResponseSender);
264
265#[cfg(feature = "stream")]
266pub fn ws<ReqB, B, T, E>(req: &Request<ReqB>, body: B) -> Result<WsOutput<B>, HandshakeError>
319where
320 B: futures_core::Stream<Item = Result<T, E>>,
321 T: AsRef<[u8]>,
322{
323 let builder = match req.version() {
324 Version::HTTP_2 => handshake_h2(req.method(), req.headers())?,
325 _ => handshake(req.method(), req.headers())?,
326 };
327
328 let decode = RequestStream::new(body);
329 let (res, tx) = decode.response_stream();
330
331 let res = builder
332 .body(res)
333 .expect("handshake function failed to generate correct Response Builder");
334
335 Ok((decode, res, tx))
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_handshake() {
344 let req = Request::builder().method(Method::POST).body(()).unwrap();
345 assert_eq!(
346 HandshakeError::GetMethodRequired,
347 verify_handshake(req.method(), req.headers()).unwrap_err(),
348 );
349
350 let req = Request::builder().body(()).unwrap();
351 assert_eq!(
352 HandshakeError::NoWebsocketUpgrade,
353 verify_handshake(req.method(), req.headers()).unwrap_err(),
354 );
355
356 let req = Request::builder()
357 .header(UPGRADE, HeaderValue::from_static("test"))
358 .body(())
359 .unwrap();
360 assert_eq!(
361 HandshakeError::NoWebsocketUpgrade,
362 verify_handshake(req.method(), req.headers()).unwrap_err(),
363 );
364
365 let req = Request::builder().header(UPGRADE, WEBSOCKET).body(()).unwrap();
366 assert_eq!(
367 HandshakeError::NoConnectionUpgrade,
368 verify_handshake(req.method(), req.headers()).unwrap_err(),
369 );
370
371 let req = Request::builder()
372 .header(UPGRADE, WEBSOCKET)
373 .header(CONNECTION, UPGRADE_VALUE)
374 .body(())
375 .unwrap();
376 assert_eq!(
377 HandshakeError::NoVersionHeader,
378 verify_handshake(req.method(), req.headers()).unwrap_err(),
379 );
380
381 let req = Request::builder()
382 .header(UPGRADE, WEBSOCKET)
383 .header(CONNECTION, UPGRADE_VALUE)
384 .header(SEC_WEBSOCKET_VERSION, HeaderValue::from_static("5"))
385 .body(())
386 .unwrap();
387 assert_eq!(
388 HandshakeError::UnsupportedVersion,
389 verify_handshake(req.method(), req.headers()).unwrap_err(),
390 );
391
392 let builder = || {
393 Request::builder()
394 .header(UPGRADE, WEBSOCKET)
395 .header(CONNECTION, UPGRADE_VALUE)
396 .header(SEC_WEBSOCKET_VERSION, SEC_WEBSOCKET_VERSION_VALUE)
397 };
398
399 let req = builder().body(()).unwrap();
400 assert_eq!(
401 HandshakeError::BadWebsocketKey,
402 verify_handshake(req.method(), req.headers()).unwrap_err(),
403 );
404
405 let req = builder()
406 .header(SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION_VALUE)
407 .body(())
408 .unwrap();
409 let key = verify_handshake(req.method(), req.headers()).unwrap();
410 assert_eq!(
411 StatusCode::SWITCHING_PROTOCOLS,
412 handshake_response(key).body(()).unwrap().status()
413 );
414 }
415
416 #[test]
417 fn test_ws_error_http_response() {
418 let res = Builder::from(HandshakeError::GetMethodRequired).body(()).unwrap();
419 assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED);
420 let res = Builder::from(HandshakeError::NoWebsocketUpgrade).body(()).unwrap();
421 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
422 let res = Builder::from(HandshakeError::NoConnectionUpgrade).body(()).unwrap();
423 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
424 let res = Builder::from(HandshakeError::NoVersionHeader).body(()).unwrap();
425 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
426 let res = Builder::from(HandshakeError::UnsupportedVersion).body(()).unwrap();
427 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
428 let res = Builder::from(HandshakeError::BadWebsocketKey).body(()).unwrap();
429 assert_eq!(res.status(), StatusCode::BAD_REQUEST);
430 }
431}