1use crate::{WebSocketError, WebSocketStream, WsHeartbeatConfig};
4use http::{header, Response, StatusCode};
5use hyper::upgrade::OnUpgrade;
6use hyper_util::rt::TokioIo;
7use rustapi_core::{IntoResponse, ResponseBody};
8use rustapi_openapi::{Operation, ResponseModifier, ResponseSpec};
9use std::collections::BTreeMap;
10use std::future::Future;
11use std::pin::Pin;
12use tokio_tungstenite::tungstenite::protocol::Role;
13
14type UpgradeCallback =
16 Box<dyn FnOnce(WebSocketStream) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>;
17
18use crate::compression::WsCompressionConfig;
23
24pub struct WebSocketUpgrade {
29 response: Response<ResponseBody>,
31 on_upgrade: Option<UpgradeCallback>,
33 #[allow(dead_code)]
35 sec_key: String,
36 client_extensions: Option<String>,
38 compression: Option<WsCompressionConfig>,
40 heartbeat: Option<WsHeartbeatConfig>,
42 on_upgrade_fut: Option<OnUpgrade>,
44}
45
46impl WebSocketUpgrade {
47 pub(crate) fn new(
49 sec_key: String,
50 client_extensions: Option<String>,
51 on_upgrade_fut: Option<OnUpgrade>,
52 ) -> Self {
53 let accept_key = generate_accept_key(&sec_key);
55
56 let response = Response::builder()
58 .status(StatusCode::SWITCHING_PROTOCOLS)
59 .header(header::UPGRADE, "websocket")
60 .header(header::CONNECTION, "Upgrade")
61 .header("Sec-WebSocket-Accept", accept_key)
62 .body(ResponseBody::empty())
63 .unwrap();
64
65 Self {
66 response,
67 on_upgrade: None,
68 sec_key,
69 client_extensions,
70 compression: None,
71 heartbeat: None,
72 on_upgrade_fut,
73 }
74 }
75
76 pub fn heartbeat(mut self, config: WsHeartbeatConfig) -> Self {
78 self.heartbeat = Some(config);
79 self
80 }
81
82 pub fn compress(mut self, config: WsCompressionConfig) -> Self {
84 self.compression = Some(config);
85
86 if let Some(exts) = &self.client_extensions {
88 if exts.contains("permessage-deflate") {
89 let mut header_val = String::from("permessage-deflate");
92
93 header_val.push_str("; server_no_context_takeover");
96 header_val.push_str("; client_no_context_takeover");
97
98 if config.window_bits < 15 {
99 header_val
100 .push_str(&format!("; server_max_window_bits={}", config.window_bits));
101 }
102 if config.client_window_bits < 15 {
103 header_val.push_str(&format!(
104 "; client_max_window_bits={}",
105 config.client_window_bits
106 ));
107 }
108
109 if let Ok(val) = header::HeaderValue::from_str(&header_val) {
110 self.response
111 .headers_mut()
112 .insert("Sec-WebSocket-Extensions", val);
113 }
114 }
115 }
116 self
117 }
118
119 pub fn on_upgrade<F, Fut>(mut self, callback: F) -> Self
132 where
133 F: FnOnce(WebSocketStream) -> Fut + Send + 'static,
134 Fut: Future<Output = ()> + Send + 'static,
135 {
136 self.on_upgrade = Some(Box::new(move |stream| Box::pin(callback(stream))));
137 self
138 }
139
140 pub fn protocol(mut self, protocol: &str) -> Self {
142 self.response.headers_mut().insert(
145 "Sec-WebSocket-Protocol",
146 header::HeaderValue::from_str(protocol).unwrap(),
147 );
148 self
149 }
150
151 #[allow(dead_code)]
153 pub(crate) fn into_response_inner(self) -> Response<ResponseBody> {
154 self.response
155 }
156
157 #[allow(dead_code)]
159 pub(crate) fn take_callback(&mut self) -> Option<UpgradeCallback> {
160 self.on_upgrade.take()
161 }
162}
163
164impl IntoResponse for WebSocketUpgrade {
165 fn into_response(mut self) -> rustapi_core::Response {
166 if let (Some(on_upgrade), Some(callback)) =
168 (self.on_upgrade_fut.take(), self.on_upgrade.take())
169 {
170 let heartbeat = self.heartbeat;
171
172 tokio::spawn(async move {
176 match on_upgrade.await {
177 Ok(upgraded) => {
178 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
179 TokioIo::new(upgraded),
180 Role::Server,
181 None,
182 )
183 .await;
184
185 let socket = if let Some(hb_config) = heartbeat {
186 WebSocketStream::new_managed(ws_stream, hb_config)
187 } else {
188 WebSocketStream::new(ws_stream)
189 };
190
191 callback(socket).await;
192 }
193 Err(e) => {
194 tracing::error!("WebSocket upgrade failed: {:?}", e);
195 if let Some(source) = std::error::Error::source(&e) {
197 tracing::error!("Cause: {:?}", source);
198 }
199 }
200 }
201 });
202 }
203
204 self.response
205 }
206}
207
208impl ResponseModifier for WebSocketUpgrade {
209 fn update_response(op: &mut Operation) {
210 op.responses.insert(
211 "101".to_string(),
212 ResponseSpec {
213 description: "WebSocket upgrade successful".to_string(),
214 content: BTreeMap::new(),
215 headers: BTreeMap::new(),
216 },
217 );
218 }
219}
220
221fn generate_accept_key(key: &str) -> String {
223 use base64::Engine;
224 use sha1::{Digest, Sha1};
225
226 const GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
227
228 let mut hasher = Sha1::new();
229 hasher.update(key.as_bytes());
230 hasher.update(GUID.as_bytes());
231 let hash = hasher.finalize();
232
233 base64::engine::general_purpose::STANDARD.encode(hash)
234}
235
236pub(crate) fn validate_upgrade_request(
238 method: &http::Method,
239 headers: &http::HeaderMap,
240) -> Result<String, WebSocketError> {
241 if method != http::Method::GET {
243 return Err(WebSocketError::invalid_upgrade("Method must be GET"));
244 }
245
246 let upgrade = headers
248 .get(header::UPGRADE)
249 .and_then(|v| v.to_str().ok())
250 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Upgrade header"))?;
251
252 if !upgrade.eq_ignore_ascii_case("websocket") {
253 return Err(WebSocketError::invalid_upgrade(
254 "Upgrade header must be 'websocket'",
255 ));
256 }
257
258 let connection = headers
260 .get(header::CONNECTION)
261 .and_then(|v| v.to_str().ok())
262 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Connection header"))?;
263
264 let has_upgrade = connection
265 .split(',')
266 .any(|s| s.trim().eq_ignore_ascii_case("upgrade"));
267
268 if !has_upgrade {
269 return Err(WebSocketError::invalid_upgrade(
270 "Connection header must contain 'Upgrade'",
271 ));
272 }
273
274 let sec_key = headers
276 .get("Sec-WebSocket-Key")
277 .and_then(|v| v.to_str().ok())
278 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Key header"))?;
279
280 let version = headers
282 .get("Sec-WebSocket-Version")
283 .and_then(|v| v.to_str().ok())
284 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Version header"))?;
285
286 if version != "13" {
287 return Err(WebSocketError::invalid_upgrade(
288 "Sec-WebSocket-Version must be 13",
289 ));
290 }
291
292 Ok(sec_key.to_string())
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_accept_key_generation() {
301 let key = "dGhlIHNhbXBsZSBub25jZQ==";
303 let accept = generate_accept_key(key);
304 assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
305 }
306}