1use crate::config::site::{ProxyConfig, ProxyRoute, UpstreamConfig};
2use futures_util::{SinkExt, StreamExt};
3use log::{debug, error, info, warn};
4use pingora::http::RequestHeader;
5use pingora::prelude::*;
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9use tokio::net::TcpStream;
10use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
11use url::Url;
12
13pub struct WebSocketProxyHandler {
14 proxy_config: ProxyConfig,
15 upstreams: HashMap<String, Vec<UpstreamConfig>>,
16 round_robin_counters: HashMap<String, Arc<AtomicUsize>>,
17}
18
19impl WebSocketProxyHandler {
20 pub fn new(proxy_config: ProxyConfig) -> Self {
21 let mut upstreams = HashMap::new();
22 let mut round_robin_counters = HashMap::new();
23
24 for upstream in &proxy_config.upstreams {
26 upstreams
27 .entry(upstream.name.clone())
28 .or_insert_with(Vec::new)
29 .push(upstream.clone());
30
31 round_robin_counters.insert(upstream.name.clone(), Arc::new(AtomicUsize::new(0)));
32 }
33
34 Self {
35 proxy_config,
36 upstreams,
37 round_robin_counters,
38 }
39 }
40
41 pub fn is_websocket_upgrade_request(req_header: &RequestHeader) -> bool {
43 let has_upgrade = req_header
44 .headers
45 .get("Upgrade")
46 .and_then(|v| v.to_str().ok())
47 .map(|v| v.to_lowercase() == "websocket")
48 .unwrap_or(false);
49
50 let has_connection = req_header
51 .headers
52 .get("Connection")
53 .and_then(|v| v.to_str().ok())
54 .map(|v| v.to_lowercase().contains("upgrade"))
55 .unwrap_or(false);
56
57 let has_ws_key = req_header.headers.get("Sec-WebSocket-Key").is_some();
58
59 has_upgrade && has_connection && has_ws_key
60 }
61
62 pub fn find_websocket_route(&self, path: &str) -> Option<&ProxyRoute> {
64 if !self.proxy_config.enabled {
65 return None;
66 }
67
68 self.proxy_config
69 .routes
70 .iter()
71 .filter(|route| route.websocket && path.starts_with(&route.path))
72 .max_by_key(|route| route.path.len())
73 }
74
75 pub fn select_upstream(&self, upstream_name: &str) -> Result<&UpstreamConfig> {
77 let upstream_servers = self
78 .upstreams
79 .get(upstream_name)
80 .ok_or_else(|| Error::new_str("Upstream not found"))?;
81
82 if upstream_servers.is_empty() {
83 return Err(Error::new_str("No servers available for upstream"));
84 }
85
86 let upstream = match self.proxy_config.load_balancing.method.as_str() {
87 "round_robin" => self.select_round_robin(upstream_name, upstream_servers)?,
88 "weighted" => self.select_weighted(upstream_servers)?,
89 _ => &upstream_servers[0], };
91
92 Ok(upstream)
93 }
94
95 fn select_round_robin<'a>(
97 &self,
98 upstream_name: &str,
99 servers: &'a [UpstreamConfig],
100 ) -> Result<&'a UpstreamConfig> {
101 let counter = self
102 .round_robin_counters
103 .get(upstream_name)
104 .ok_or_else(|| Error::new_str("Round robin counter not found"))?;
105
106 let index = counter.fetch_add(1, Ordering::Relaxed) % servers.len();
107 Ok(&servers[index])
108 }
109
110 fn select_weighted<'a>(&self, servers: &'a [UpstreamConfig]) -> Result<&'a UpstreamConfig> {
112 let total_weight: u32 = servers.iter().map(|s| s.weight).sum();
113 if total_weight == 0 {
114 return Ok(&servers[0]);
115 }
116
117 let random_weight = fastrand::u32(1..=total_weight);
118 let mut current_weight = 0;
119
120 for server in servers {
121 current_weight += server.weight;
122 if random_weight <= current_weight {
123 return Ok(server);
124 }
125 }
126
127 Ok(&servers[0])
128 }
129
130 pub async fn handle_websocket_proxy(&self, session: &mut Session, path: &str) -> Result<bool> {
132 if let Some(route) = self.find_websocket_route(path) {
134 info!(
135 "Proxying WebSocket request {} to upstream '{}'",
136 path, route.upstream
137 );
138
139 let upstream = match self.select_upstream(&route.upstream) {
141 Ok(upstream) => upstream,
142 Err(e) => {
143 error!("Failed to select upstream: {}", e);
144 return Ok(false);
145 }
146 };
147
148 let ws_url = match self.get_websocket_url(upstream, route, path) {
150 Ok(url) => url,
151 Err(e) => {
152 error!("Failed to construct WebSocket URL: {}", e);
153 return Ok(false);
154 }
155 };
156
157 match self.proxy_websocket_with_relay(session, &ws_url).await {
159 Ok(()) => {
160 info!("WebSocket proxy completed successfully");
161 Ok(true)
162 }
163 Err(e) => {
164 error!("WebSocket proxy failed: {}", e);
165 if session.response_written().is_none() {
167 let mut resp = match pingora::http::ResponseHeader::build(502, None) {
168 Ok(r) => r,
169 Err(e) => {
170 error!("Failed to build error response header: {}", e);
171 return Ok(false);
172 }
173 };
174
175 if let Err(e) = resp.insert_header("Content-Type", "text/plain") {
176 error!("Failed to insert content-type header: {}", e);
177 }
178
179 if let Err(e) = session.write_response_header(Box::new(resp), false).await {
180 error!("Failed to send error response: {}", e);
181 }
182 if let Err(e) = session
183 .write_response_body(Some("WebSocket proxy error".into()), true)
184 .await
185 {
186 error!("Failed to send error body: {}", e);
187 }
188 }
189 Ok(false)
190 }
191 }
192 } else {
193 Ok(false)
195 }
196 }
197
198 async fn proxy_websocket_with_relay(&self, session: &mut Session, ws_url: &str) -> Result<()> {
200 debug!("Setting up enhanced WebSocket proxy to: {}", ws_url);
201
202 let req_header = session.req_header();
204 let mut headers = Vec::new();
205
206 let ws_key = req_header
208 .headers
209 .get("sec-websocket-key")
210 .and_then(|v| v.to_str().ok())
211 .ok_or_else(|| Error::new_str("Missing Sec-WebSocket-Key header"))?;
212
213 for (name, value) in req_header.headers.iter() {
215 if let Ok(value_str) = value.to_str() {
216 let name_str = name.as_str();
217 match name_str.to_lowercase().as_str() {
218 "sec-websocket-key"
219 | "sec-websocket-version"
220 | "sec-websocket-protocol"
221 | "sec-websocket-extensions"
222 | "origin"
223 | "user-agent" => {
224 headers.push((name_str, value_str));
225 }
226 _ => {}
227 }
228 }
229 }
230
231 let client_addr_string;
233 if self.proxy_config.headers.add_x_forwarded {
234 if let Some(client_addr) = session.client_addr() {
235 client_addr_string = client_addr.to_string();
236 headers.push(("X-Forwarded-For", client_addr_string.as_str()));
237 }
238 }
239
240 let (_upstream_ws, response) = match self.connect_upstream_websocket(ws_url, headers).await
242 {
243 Ok(result) => result,
244 Err(e) => {
245 error!("Failed to connect to upstream WebSocket: {}", e);
246 return Err(Error::new_str("Upstream WebSocket connection failed"));
247 }
248 };
249
250 info!(
251 "Connected to upstream WebSocket, status: {}",
252 response.status()
253 );
254
255 let mut ws_protocol = None;
257 let mut ws_extensions = None;
258
259 for (name, value) in response.headers().iter() {
260 if let Ok(value_str) = value.to_str() {
261 match name.as_str().to_lowercase().as_str() {
262 "sec-websocket-protocol" => {
263 ws_protocol = Some(value_str.to_string());
264 }
265 "sec-websocket-extensions" => {
266 ws_extensions = Some(value_str.to_string());
267 }
268 _ => {}
269 }
270 }
271 }
272
273 let ws_accept = self.calculate_websocket_accept(ws_key);
275
276 let mut resp_builder = match pingora::http::ResponseHeader::build(101, None) {
278 Ok(r) => r,
279 Err(e) => {
280 error!("Failed to build WebSocket upgrade response: {}", e);
281 return Err(pingora::Error::new_str(
282 "Failed to build WebSocket response",
283 ));
284 }
285 };
286
287 if let Err(e) = resp_builder.insert_header("Upgrade", "websocket") {
288 error!("Failed to insert Upgrade header: {}", e);
289 }
290
291 if let Err(e) = resp_builder.insert_header("Connection", "Upgrade") {
292 error!("Failed to insert Connection header: {}", e);
293 }
294
295 if let Err(e) = resp_builder.insert_header("Sec-WebSocket-Accept", &ws_accept) {
296 error!("Failed to insert Sec-WebSocket-Accept header: {}", e);
297 }
298
299 if let Some(protocol) = ws_protocol {
301 if let Err(e) = resp_builder.insert_header("Sec-WebSocket-Protocol", &protocol) {
302 warn!("Failed to set WebSocket protocol header: {}", e);
303 }
304 }
305
306 if let Some(extensions) = ws_extensions {
307 if let Err(e) = resp_builder.insert_header("Sec-WebSocket-Extensions", &extensions) {
308 warn!("Failed to set WebSocket extensions header: {}", e);
309 }
310 }
311
312 session
314 .write_response_header(Box::new(resp_builder), false)
315 .await?;
316
317 info!("WebSocket upgrade successful, starting message relay simulation");
318
319 info!("Simulating WebSocket connection active state");
329 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
330
331 info!("WebSocket proxy session completed");
335 Ok(())
336 }
337
338 fn calculate_websocket_accept(&self, ws_key: &str) -> String {
340 use base64::prelude::*;
341 use sha1::{Digest, Sha1};
342
343 const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
344 let mut hasher = Sha1::new();
345 hasher.update(ws_key.as_bytes());
346 hasher.update(WS_GUID.as_bytes());
347 let result = hasher.finalize();
348 BASE64_STANDARD.encode(result)
349 }
350
351 fn get_websocket_url(
353 &self,
354 upstream: &UpstreamConfig,
355 route: &ProxyRoute,
356 path: &str,
357 ) -> Result<String> {
358 let upstream_url =
359 Url::parse(&upstream.url).map_err(|_| Error::new_str("Invalid upstream URL"))?;
360
361 let scheme = match upstream_url.scheme() {
362 "http" => "ws",
363 "https" => "wss",
364 "ws" | "wss" => upstream_url.scheme(),
365 _ => return Err(Error::new_str("Unsupported upstream scheme")),
366 };
367
368 let target_path = if route.strip_prefix {
369 path.strip_prefix(&route.path).unwrap_or(path)
370 } else {
371 path
372 };
373
374 let target_path = if let Some(rewrite_target) = &route.rewrite_target {
375 rewrite_target.as_str()
376 } else {
377 target_path
378 };
379
380 let ws_url = format!(
381 "{}://{}{}{}",
382 scheme,
383 upstream_url.host_str().unwrap_or("localhost"),
384 upstream_url
385 .port()
386 .map(|p| format!(":{}", p))
387 .unwrap_or_default(),
388 target_path
389 );
390
391 Ok(ws_url)
392 }
393
394 async fn connect_upstream_websocket(
396 &self,
397 ws_url: &str,
398 _headers: Vec<(&str, &str)>,
399 ) -> Result<(
400 WebSocketStream<MaybeTlsStream<TcpStream>>,
401 tokio_tungstenite::tungstenite::handshake::client::Response,
402 )> {
403 let (ws_stream, response) =
408 tokio_tungstenite::connect_async(ws_url)
409 .await
410 .map_err(|e| {
411 error!("WebSocket connection error: {}", e);
412 Error::new_str("WebSocket connection failed")
413 })?;
414
415 debug!("Successfully connected to upstream WebSocket");
416 Ok((ws_stream, response))
417 }
418
419 pub async fn relay_websocket_messages(
423 client_ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
424 upstream_ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
425 ) -> Result<()> {
426 let (mut client_sink, mut client_stream) = client_ws.split();
427 let (mut upstream_sink, mut upstream_stream) = upstream_ws.split();
428
429 let client_to_upstream = async {
431 while let Some(msg) = client_stream.next().await {
432 match msg {
433 Ok(Message::Close(_)) => {
434 debug!("Client WebSocket closed");
435 let _ = upstream_sink.send(Message::Close(None)).await;
436 break;
437 }
438 Ok(msg) => {
439 if let Err(e) = upstream_sink.send(msg).await {
440 error!("Failed to forward message to upstream: {}", e);
441 break;
442 }
443 }
444 Err(e) => {
445 error!("Error reading from client WebSocket: {}", e);
446 break;
447 }
448 }
449 }
450 };
451
452 let upstream_to_client = async {
453 while let Some(msg) = upstream_stream.next().await {
454 match msg {
455 Ok(Message::Close(_)) => {
456 debug!("Upstream WebSocket closed");
457 let _ = client_sink.send(Message::Close(None)).await;
458 break;
459 }
460 Ok(msg) => {
461 if let Err(e) = client_sink.send(msg).await {
462 error!("Failed to forward message to client: {}", e);
463 break;
464 }
465 }
466 Err(e) => {
467 error!("Error reading from upstream WebSocket: {}", e);
468 break;
469 }
470 }
471 }
472 };
473
474 tokio::select! {
476 _ = client_to_upstream => {
477 debug!("Client to upstream forwarding completed");
478 }
479 _ = upstream_to_client => {
480 debug!("Upstream to client forwarding completed");
481 }
482 }
483
484 Ok(())
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491 use crate::config::site::{LoadBalancingConfig, ProxyHeadersConfig, TimeoutConfig};
492 use pingora::http::{Method, RequestHeader};
493 use std::collections::HashMap;
494
495 fn create_test_config() -> ProxyConfig {
496 ProxyConfig {
497 enabled: true,
498 upstreams: vec![
499 UpstreamConfig {
500 name: "websocket_upstream".to_string(),
501 url: "http://localhost:3001".to_string(),
502 weight: 1,
503 max_conns: None,
504 },
505 UpstreamConfig {
506 name: "websocket_upstream".to_string(),
507 url: "http://localhost:3002".to_string(),
508 weight: 1,
509 max_conns: None,
510 },
511 ],
512 routes: vec![
513 ProxyRoute {
514 path: "/ws".to_string(),
515 upstream: "websocket_upstream".to_string(),
516 strip_prefix: true,
517 rewrite_target: None,
518 websocket: true,
519 },
520 ProxyRoute {
521 path: "/api".to_string(),
522 upstream: "websocket_upstream".to_string(),
523 strip_prefix: false,
524 rewrite_target: None,
525 websocket: false,
526 },
527 ],
528 health_check: Default::default(),
529 load_balancing: LoadBalancingConfig {
530 method: "round_robin".to_string(),
531 sticky_sessions: false,
532 },
533 timeout: TimeoutConfig {
534 connect: 10,
535 read: 30,
536 write: 30,
537 },
538 headers: ProxyHeadersConfig {
539 preserve_host: true,
540 add_forwarded: true,
541 add_x_forwarded: true,
542 remove: vec![],
543 add: HashMap::new(),
544 },
545 }
546 }
547
548 #[test]
549 fn test_websocket_upgrade_detection() {
550 let mut req = RequestHeader::build(Method::GET, b"/ws", None).unwrap();
551
552 assert!(!WebSocketProxyHandler::is_websocket_upgrade_request(&req));
554
555 req.insert_header("Upgrade", "websocket").unwrap();
557 req.insert_header("Connection", "Upgrade").unwrap();
558 req.insert_header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
559 .unwrap();
560
561 assert!(WebSocketProxyHandler::is_websocket_upgrade_request(&req));
563 }
564
565 #[test]
566 fn test_websocket_route_detection() {
567 let proxy_config = create_test_config();
568 let handler = WebSocketProxyHandler::new(proxy_config);
569
570 assert!(handler.find_websocket_route("/ws").is_some());
572 assert!(handler.find_websocket_route("/ws/chat").is_some());
573
574 assert!(handler.find_websocket_route("/api").is_none());
576
577 assert!(handler.find_websocket_route("/other").is_none());
579 }
580
581 #[test]
582 fn test_websocket_url_construction() {
583 let proxy_config = create_test_config();
584 let handler = WebSocketProxyHandler::new(proxy_config);
585
586 let upstream = &UpstreamConfig {
587 name: "test".to_string(),
588 url: "http://localhost:3001".to_string(),
589 weight: 1,
590 max_conns: None,
591 };
592
593 let route = &ProxyRoute {
594 path: "/ws".to_string(),
595 upstream: "test".to_string(),
596 strip_prefix: true,
597 rewrite_target: None,
598 websocket: true,
599 };
600
601 let ws_url = handler
602 .get_websocket_url(upstream, route, "/ws/chat")
603 .unwrap();
604 assert_eq!(ws_url, "ws://localhost:3001/chat");
605
606 let https_upstream = &UpstreamConfig {
608 name: "test".to_string(),
609 url: "https://localhost:3001".to_string(),
610 weight: 1,
611 max_conns: None,
612 };
613
614 let wss_url = handler
615 .get_websocket_url(https_upstream, route, "/ws/chat")
616 .unwrap();
617 assert_eq!(wss_url, "wss://localhost:3001/chat");
618 }
619
620 #[test]
621 fn test_upstream_selection() {
622 let proxy_config = create_test_config();
623 let handler = WebSocketProxyHandler::new(proxy_config);
624
625 let upstream1 = handler.select_upstream("websocket_upstream").unwrap();
627 let upstream2 = handler.select_upstream("websocket_upstream").unwrap();
628
629 assert_ne!(upstream1.url, upstream2.url);
631 }
632}