1use std::net::SocketAddr;
2
3use super::LogWebSocketReturn;
4use crate::error::ClientErrorKind;
5use crate::extractor::ExtractorError;
6use crate::header::{
7 StatusCode, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY,
8 SEC_WEBSOCKET_VERSION, UPGRADE,
9};
10use crate::server::HyperRequest;
11use crate::util::convert_hyper_req_to_fire_header;
12use crate::{Error, Response, Result};
13
14use tracing::error;
15
16use sha1::Digest;
17
18use hyper::upgrade::OnUpgrade;
19
20#[doc(hidden)]
21pub use tokio::task::spawn;
22
23use base64::prelude::{Engine as _, BASE64_STANDARD};
24use types::header::RequestHeader;
25
26#[doc(hidden)]
30pub fn upgrade_error(e: hyper::Error) {
31 error!("websocket upgrade error {:?}", e);
32}
33
34#[doc(hidden)]
38pub fn log_websocket_return(r: impl LogWebSocketReturn) {
39 if r.should_log_error() {
40 error!("websocket connection closed with error {:?}", r);
41 }
42}
43
44#[doc(hidden)]
48pub fn log_extractor_error(r: impl ExtractorError) {
49 let err = r.into_std();
50
51 error!("websocket extractor error: {}", err);
52}
53
54#[doc(hidden)]
56pub fn upgrade(req: &mut HyperRequest) -> Result<(OnUpgrade, String)> {
57 let header_upgrade =
60 req.headers().get(UPGRADE).and_then(|v| v.to_str().ok());
61 let header_version = req
62 .headers()
63 .get(SEC_WEBSOCKET_VERSION)
64 .and_then(|v| v.to_str().ok());
65 let websocket_key =
66 req.headers().get(SEC_WEBSOCKET_KEY).map(|v| v.as_bytes());
67
68 if !matches!(
69 (header_upgrade, header_version, websocket_key),
70 (Some("websocket"), Some("13"), Some(_))
71 ) {
72 return Err(ClientErrorKind::BadRequest.into());
73 }
74
75 let websocket_key = websocket_key.unwrap();
78 let ws_accept = {
79 let mut sha1 = sha1::Sha1::new();
80 sha1.update(websocket_key);
81 sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
82 BASE64_STANDARD.encode(sha1.finalize())
84 };
85
86 let on_upgrade = hyper::upgrade::on(req);
87
88 Ok((on_upgrade, ws_accept))
89}
90
91#[doc(hidden)]
92pub fn switching_protocols(ws_accept: String) -> Response {
93 Response::builder()
94 .status_code(StatusCode::SWITCHING_PROTOCOLS)
95 .header(CONNECTION, "upgrade")
96 .header(UPGRADE, "websocket")
97 .header(SEC_WEBSOCKET_ACCEPT, ws_accept)
98 .build()
99}
100
101#[doc(hidden)]
102pub fn hyper_req_to_header(
103 req: &mut HyperRequest,
104 address: SocketAddr,
105) -> Result<RequestHeader> {
106 convert_hyper_req_to_fire_header(req, address).map_err(|e| {
107 Error::new(
108 ClientErrorKind::BadRequest,
109 format!("failed to convert hyper request {:?}", e),
110 )
111 })
112}