1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use super::{WebSocket, LogWebSocketReturn};
use crate::{Result, Data, Response};
use crate::error::ClientErrorKind;
use crate::header::{
StatusCode, UPGRADE, SEC_WEBSOCKET_VERSION, SEC_WEBSOCKET_KEY, CONNECTION,
SEC_WEBSOCKET_ACCEPT
};
use crate::server::HyperRequest;
use std::mem::ManuallyDrop;
use std::any::{Any, TypeId};
use tracing::error;
use sha1::Digest;
use hyper::upgrade::OnUpgrade;
#[doc(hidden)]
pub use tokio::task::spawn;
fn is_ws<T: Any>() -> bool {
TypeId::of::<T>() == TypeId::of::<WebSocket>()
}
fn is_data<T: Any>() -> bool {
TypeId::of::<T>() == TypeId::of::<Data>()
}
#[inline]
pub fn valid_ws_data_as_ref<T: Any>(data: &Data) -> bool {
is_ws::<T>() || is_data::<T>() || data.exists::<T>()
}
#[inline]
pub fn valid_ws_data_as_owned<T: Any>(_: &Data) -> bool {
is_ws::<T>()
}
#[inline]
pub fn get_ws_data_as_ref<'a, T: Any>(
data: &'a Data,
ws: &'a mut Option<WebSocket>
) -> &'a T {
if is_ws::<T>() {
let ws = ws.as_ref().unwrap();
<dyn Any>::downcast_ref(ws).unwrap()
} else if is_data::<T>() {
<dyn Any>::downcast_ref(data).unwrap()
} else {
data.get::<T>().unwrap()
}
}
#[inline]
pub fn get_ws_data_as_owned<T: Any>(
_data: &Data,
ws: &mut Option<WebSocket>
) -> T {
if is_ws::<T>() {
let ws = ws.take().unwrap();
unsafe {
transform_websocket(ws)
}
} else {
unreachable!()
}
}
unsafe fn transform_websocket<T: Any>(ws: WebSocket) -> T {
let mut ws = ManuallyDrop::new(ws);
(&mut ws as *mut ManuallyDrop<WebSocket> as *mut T).read()
}
#[doc(hidden)]
pub fn upgrade_error(e: hyper::Error) {
error!("websocket upgrade error {:?}", e);
}
#[doc(hidden)]
pub fn log_websocket_return(r: impl LogWebSocketReturn) {
if r.should_log_error() {
error!("websocket connection closed with error {:?}", r);
}
}
#[doc(hidden)]
pub fn upgrade(req: &mut HyperRequest) -> Result<(OnUpgrade, String)> {
let header_upgrade = req.headers()
.get(UPGRADE)
.and_then(|v| v.to_str().ok());
let header_version = req.headers()
.get(SEC_WEBSOCKET_VERSION)
.and_then(|v| v.to_str().ok());
let websocket_key = req.headers()
.get(SEC_WEBSOCKET_KEY)
.map(|v| v.as_bytes());
if !matches!(
(header_upgrade, header_version, websocket_key),
(Some("websocket"), Some("13"), Some(_))
) {
return Err(ClientErrorKind::BadRequest.into())
}
let websocket_key = websocket_key.unwrap();
let ws_accept = {
let mut sha1 = sha1::Sha1::new();
sha1.update(websocket_key);
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
base64::encode(sha1.finalize())
};
let on_upgrade = hyper::upgrade::on(req);
Ok((on_upgrade, ws_accept))
}
#[doc(hidden)]
pub fn switching_protocols(ws_accept: String) -> Response {
Response::builder()
.status_code(StatusCode::SWITCHING_PROTOCOLS)
.header(CONNECTION, "upgrade")
.header(UPGRADE, "websocket")
.header(SEC_WEBSOCKET_ACCEPT, ws_accept)
.build()
}