1use crate::{WebSocketError, WebSocketStream, WsHeartbeatConfig};
4use http::{header, Response, StatusCode};
5use hyper::upgrade::OnUpgrade;
6use hyper_util::rt::TokioIo;
7use rustapi_core::{IntoResponse, ResponseBody};
8use rustapi_openapi::{Operation, ResponseModifier, ResponseSpec};
9use std::future::Future;
10use std::pin::Pin;
11use tokio_tungstenite::tungstenite::protocol::Role;
12
13type UpgradeCallback =
15 Box<dyn FnOnce(WebSocketStream) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>;
16
17use crate::compression::WsCompressionConfig;
22
23pub struct WebSocketUpgrade {
28 response: Response<ResponseBody>,
30 on_upgrade: Option<UpgradeCallback>,
32 #[allow(dead_code)]
34 sec_key: String,
35 client_extensions: Option<String>,
37 compression: Option<WsCompressionConfig>,
39 heartbeat: Option<WsHeartbeatConfig>,
41 on_upgrade_fut: Option<OnUpgrade>,
43}
44
45impl WebSocketUpgrade {
46 pub(crate) fn new(
48 sec_key: String,
49 client_extensions: Option<String>,
50 on_upgrade_fut: Option<OnUpgrade>,
51 ) -> Self {
52 let accept_key = generate_accept_key(&sec_key);
54
55 let response = Response::builder()
57 .status(StatusCode::SWITCHING_PROTOCOLS)
58 .header(header::UPGRADE, "websocket")
59 .header(header::CONNECTION, "Upgrade")
60 .header("Sec-WebSocket-Accept", accept_key)
61 .body(ResponseBody::empty())
62 .unwrap();
63
64 Self {
65 response,
66 on_upgrade: None,
67 sec_key,
68 client_extensions,
69 compression: None,
70 heartbeat: None,
71 on_upgrade_fut,
72 }
73 }
74
75 pub fn heartbeat(mut self, config: WsHeartbeatConfig) -> Self {
77 self.heartbeat = Some(config);
78 self
79 }
80
81 pub fn compress(mut self, config: WsCompressionConfig) -> Self {
83 self.compression = Some(config);
84
85 if let Some(exts) = &self.client_extensions {
87 if exts.contains("permessage-deflate") {
88 let mut header_val = String::from("permessage-deflate");
91
92 header_val.push_str("; server_no_context_takeover");
95 header_val.push_str("; client_no_context_takeover");
96
97 if config.window_bits < 15 {
98 header_val
99 .push_str(&format!("; server_max_window_bits={}", config.window_bits));
100 }
101 if config.client_window_bits < 15 {
102 header_val.push_str(&format!(
103 "; client_max_window_bits={}",
104 config.client_window_bits
105 ));
106 }
107
108 if let Ok(val) = header::HeaderValue::from_str(&header_val) {
109 self.response
110 .headers_mut()
111 .insert("Sec-WebSocket-Extensions", val);
112 }
113 }
114 }
115 self
116 }
117
118 pub fn on_upgrade<F, Fut>(mut self, callback: F) -> Self
131 where
132 F: FnOnce(WebSocketStream) -> Fut + Send + 'static,
133 Fut: Future<Output = ()> + Send + 'static,
134 {
135 self.on_upgrade = Some(Box::new(move |stream| Box::pin(callback(stream))));
136 self
137 }
138
139 pub fn protocol(mut self, protocol: &str) -> Self {
141 self.response.headers_mut().insert(
144 "Sec-WebSocket-Protocol",
145 header::HeaderValue::from_str(protocol).unwrap(),
146 );
147 self
148 }
149
150 #[allow(dead_code)]
152 pub(crate) fn into_response_inner(self) -> Response<ResponseBody> {
153 self.response
154 }
155
156 #[allow(dead_code)]
158 pub(crate) fn take_callback(&mut self) -> Option<UpgradeCallback> {
159 self.on_upgrade.take()
160 }
161}
162
163impl IntoResponse for WebSocketUpgrade {
164 fn into_response(mut self) -> rustapi_core::Response {
165 if let (Some(on_upgrade), Some(callback)) =
167 (self.on_upgrade_fut.take(), self.on_upgrade.take())
168 {
169 let heartbeat = self.heartbeat;
170
171 tokio::spawn(async move {
175 match on_upgrade.await {
176 Ok(upgraded) => {
177 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
178 TokioIo::new(upgraded),
179 Role::Server,
180 None,
181 )
182 .await;
183
184 let socket = if let Some(hb_config) = heartbeat {
185 WebSocketStream::new_managed(ws_stream, hb_config)
186 } else {
187 WebSocketStream::new(ws_stream)
188 };
189
190 callback(socket).await;
191 }
192 Err(e) => {
193 tracing::error!("WebSocket upgrade failed: {:?}", e);
194 if let Some(source) = std::error::Error::source(&e) {
196 tracing::error!("Cause: {:?}", source);
197 }
198 }
199 }
200 });
201 }
202
203 self.response
204 }
205}
206
207impl ResponseModifier for WebSocketUpgrade {
208 fn update_response(op: &mut Operation) {
209 op.responses.insert(
210 "101".to_string(),
211 ResponseSpec {
212 description: "WebSocket upgrade successful".to_string(),
213 content: None,
214 },
215 );
216 }
217}
218
219fn generate_accept_key(key: &str) -> String {
221 use base64::Engine;
222 use sha1::{Digest, Sha1};
223
224 const GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
225
226 let mut hasher = Sha1::new();
227 hasher.update(key.as_bytes());
228 hasher.update(GUID.as_bytes());
229 let hash = hasher.finalize();
230
231 base64::engine::general_purpose::STANDARD.encode(hash)
232}
233
234pub(crate) fn validate_upgrade_request(
236 method: &http::Method,
237 headers: &http::HeaderMap,
238) -> Result<String, WebSocketError> {
239 if method != http::Method::GET {
241 return Err(WebSocketError::invalid_upgrade("Method must be GET"));
242 }
243
244 let upgrade = headers
246 .get(header::UPGRADE)
247 .and_then(|v| v.to_str().ok())
248 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Upgrade header"))?;
249
250 if !upgrade.eq_ignore_ascii_case("websocket") {
251 return Err(WebSocketError::invalid_upgrade(
252 "Upgrade header must be 'websocket'",
253 ));
254 }
255
256 let connection = headers
258 .get(header::CONNECTION)
259 .and_then(|v| v.to_str().ok())
260 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Connection header"))?;
261
262 let has_upgrade = connection
263 .split(',')
264 .any(|s| s.trim().eq_ignore_ascii_case("upgrade"));
265
266 if !has_upgrade {
267 return Err(WebSocketError::invalid_upgrade(
268 "Connection header must contain 'Upgrade'",
269 ));
270 }
271
272 let sec_key = headers
274 .get("Sec-WebSocket-Key")
275 .and_then(|v| v.to_str().ok())
276 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Key header"))?;
277
278 let version = headers
280 .get("Sec-WebSocket-Version")
281 .and_then(|v| v.to_str().ok())
282 .ok_or_else(|| WebSocketError::invalid_upgrade("Missing Sec-WebSocket-Version header"))?;
283
284 if version != "13" {
285 return Err(WebSocketError::invalid_upgrade(
286 "Sec-WebSocket-Version must be 13",
287 ));
288 }
289
290 Ok(sec_key.to_string())
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_accept_key_generation() {
299 let key = "dGhlIHNhbXBsZSBub25jZQ==";
301 let accept = generate_accept_key(key);
302 assert_eq!(accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
303 }
304}