fire_http/ws/
util.rs

1use std::net::SocketAddr;
2
3use super::LogWebSocketReturn;
4use crate::error::ClientErrorKind;
5use crate::extractor::ExtractorError;
6use crate::header::{
7	StatusCode, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY,
8	SEC_WEBSOCKET_VERSION, UPGRADE,
9};
10use crate::server::HyperRequest;
11use crate::util::convert_hyper_req_to_fire_header;
12use crate::{Error, Response, Result};
13
14use tracing::error;
15
16use sha1::Digest;
17
18use hyper::upgrade::OnUpgrade;
19
20#[doc(hidden)]
21pub use tokio::task::spawn;
22
23use base64::prelude::{Engine as _, BASE64_STANDARD};
24use types::header::RequestHeader;
25
26/// we need to expose this instead of inlining it in the macro since
27/// tracing logs the crate name and we wan't it to be associated with
28/// fire http instead of the crate that uses the macro
29#[doc(hidden)]
30pub fn upgrade_error(e: hyper::Error) {
31	error!("websocket upgrade error {:?}", e);
32}
33
34/// we need to expose this instead of inlining it in the macro since
35/// tracing logs the crate name and we wan't it to be associated with
36/// fire http instead of the crate that uses the macro
37#[doc(hidden)]
38pub fn log_websocket_return(r: impl LogWebSocketReturn) {
39	if r.should_log_error() {
40		error!("websocket connection closed with error {:?}", r);
41	}
42}
43
44/// we need to expose this instead of inlining it in the macro since
45/// tracing logs the crate name and we wan't it to be associated with
46/// fire http instead of the crate that uses the macro
47#[doc(hidden)]
48pub fn log_extractor_error(r: impl ExtractorError) {
49	let err = r.into_std();
50
51	error!("websocket extractor error: {}", err);
52}
53
54// does the key need to be a specific length?
55#[doc(hidden)]
56pub fn upgrade(req: &mut HyperRequest) -> Result<(OnUpgrade, String)> {
57	// if headers not match for websocket
58	// return bad request
59	let header_upgrade =
60		req.headers().get(UPGRADE).and_then(|v| v.to_str().ok());
61	let header_version = req
62		.headers()
63		.get(SEC_WEBSOCKET_VERSION)
64		.and_then(|v| v.to_str().ok());
65	let websocket_key =
66		req.headers().get(SEC_WEBSOCKET_KEY).map(|v| v.as_bytes());
67
68	if !matches!(
69		(header_upgrade, header_version, websocket_key),
70		(Some("websocket"), Some("13"), Some(_))
71	) {
72		return Err(ClientErrorKind::BadRequest.into());
73	}
74
75	// calculate websocket key stuff
76	// unwrap does not fail because we check above
77	let websocket_key = websocket_key.unwrap();
78	let ws_accept = {
79		let mut sha1 = sha1::Sha1::new();
80		sha1.update(websocket_key);
81		sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
82		// cannot fail because
83		BASE64_STANDARD.encode(sha1.finalize())
84	};
85
86	let on_upgrade = hyper::upgrade::on(req);
87
88	Ok((on_upgrade, ws_accept))
89}
90
91#[doc(hidden)]
92pub fn switching_protocols(ws_accept: String) -> Response {
93	Response::builder()
94		.status_code(StatusCode::SWITCHING_PROTOCOLS)
95		.header(CONNECTION, "upgrade")
96		.header(UPGRADE, "websocket")
97		.header(SEC_WEBSOCKET_ACCEPT, ws_accept)
98		.build()
99}
100
101#[doc(hidden)]
102pub fn hyper_req_to_header(
103	req: &mut HyperRequest,
104	address: SocketAddr,
105) -> Result<RequestHeader> {
106	convert_hyper_req_to_fire_header(req, address).map_err(|e| {
107		Error::new(
108			ClientErrorKind::BadRequest,
109			format!("failed to convert hyper request {:?}", e),
110		)
111	})
112}