1use crate::config::site::{ProxyConfig, ProxyRoute, UpstreamConfig};
2use futures_util::{SinkExt, StreamExt};
3use log::{debug, error, info, warn};
4use pingora::http::RequestHeader;
5use pingora::prelude::*;
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9use tokio::net::TcpStream;
10use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream};
11use url::Url;
12
13pub struct WebSocketProxyHandler {
14 proxy_config: ProxyConfig,
15 upstreams: HashMap<String, Vec<UpstreamConfig>>,
16 round_robin_counters: HashMap<String, Arc<AtomicUsize>>,
17}
18
19impl WebSocketProxyHandler {
20 pub fn new(proxy_config: ProxyConfig) -> Self {
21 let mut upstreams = HashMap::new();
22 let mut round_robin_counters = HashMap::new();
23
24 for upstream in &proxy_config.upstreams {
26 upstreams
27 .entry(upstream.name.clone())
28 .or_insert_with(Vec::new)
29 .push(upstream.clone());
30
31 round_robin_counters.insert(upstream.name.clone(), Arc::new(AtomicUsize::new(0)));
32 }
33
34 Self {
35 proxy_config,
36 upstreams,
37 round_robin_counters,
38 }
39 }
40
41 pub fn is_websocket_upgrade_request(req_header: &RequestHeader) -> bool {
43 let has_upgrade = req_header
44 .headers
45 .get("Upgrade")
46 .and_then(|v| v.to_str().ok())
47 .map(|v| v.to_lowercase() == "websocket")
48 .unwrap_or(false);
49
50 let has_connection = req_header
51 .headers
52 .get("Connection")
53 .and_then(|v| v.to_str().ok())
54 .map(|v| v.to_lowercase().contains("upgrade"))
55 .unwrap_or(false);
56
57 let has_ws_key = req_header.headers.get("Sec-WebSocket-Key").is_some();
58
59 has_upgrade && has_connection && has_ws_key
60 }
61
62 pub fn find_websocket_route(&self, path: &str) -> Option<&ProxyRoute> {
64 if !self.proxy_config.enabled {
65 return None;
66 }
67
68 self.proxy_config
69 .routes
70 .iter()
71 .filter(|route| route.websocket && path.starts_with(&route.path))
72 .max_by_key(|route| route.path.len())
73 }
74
75 pub fn select_upstream(&self, upstream_name: &str) -> Result<&UpstreamConfig> {
77 let upstream_servers = self
78 .upstreams
79 .get(upstream_name)
80 .ok_or_else(|| Error::new_str("Upstream not found"))?;
81
82 if upstream_servers.is_empty() {
83 return Err(Error::new_str("No servers available for upstream"));
84 }
85
86 let upstream = match self.proxy_config.load_balancing.method.as_str() {
87 "round_robin" => self.select_round_robin(upstream_name, upstream_servers)?,
88 "weighted" => self.select_weighted(upstream_servers)?,
89 _ => &upstream_servers[0], };
91
92 Ok(upstream)
93 }
94
95 fn select_round_robin<'a>(
97 &self,
98 upstream_name: &str,
99 servers: &'a [UpstreamConfig],
100 ) -> Result<&'a UpstreamConfig> {
101 let counter = self
102 .round_robin_counters
103 .get(upstream_name)
104 .ok_or_else(|| Error::new_str("Round robin counter not found"))?;
105
106 let index = counter.fetch_add(1, Ordering::Relaxed) % servers.len();
107 Ok(&servers[index])
108 }
109
110 fn select_weighted<'a>(&self, servers: &'a [UpstreamConfig]) -> Result<&'a UpstreamConfig> {
112 let total_weight: u32 = servers.iter().map(|s| s.weight).sum();
113 if total_weight == 0 {
114 return Ok(&servers[0]);
115 }
116
117 let random_weight = fastrand::u32(1..=total_weight);
118 let mut current_weight = 0;
119
120 for server in servers {
121 current_weight += server.weight;
122 if random_weight <= current_weight {
123 return Ok(server);
124 }
125 }
126
127 Ok(&servers[0])
128 }
129
130 pub async fn handle_websocket_proxy(&self, session: &mut Session, path: &str) -> Result<bool> {
132 if let Some(route) = self.find_websocket_route(path) {
134 info!(
135 "Proxying WebSocket request {} to upstream '{}'",
136 path, route.upstream
137 );
138
139 let upstream = match self.select_upstream(&route.upstream) {
141 Ok(upstream) => upstream,
142 Err(e) => {
143 error!("Failed to select upstream: {}", e);
144 return Ok(false);
145 }
146 };
147
148 let ws_url = match self.get_websocket_url(upstream, route, path) {
150 Ok(url) => url,
151 Err(e) => {
152 error!("Failed to construct WebSocket URL: {}", e);
153 return Ok(false);
154 }
155 };
156
157 match self.proxy_websocket_with_relay(session, &ws_url).await {
159 Ok(()) => {
160 info!("WebSocket proxy completed successfully");
161 Ok(true)
162 }
163 Err(e) => {
164 error!("WebSocket proxy failed: {}", e);
165 if session.response_written().is_none() {
167 let mut resp = pingora::http::ResponseHeader::build(502, None).unwrap();
168 resp.insert_header("Content-Type", "text/plain").unwrap();
169 if let Err(e) = session.write_response_header(Box::new(resp), false).await {
170 error!("Failed to send error response: {}", e);
171 }
172 if let Err(e) = session
173 .write_response_body(Some("WebSocket proxy error".into()), true)
174 .await
175 {
176 error!("Failed to send error body: {}", e);
177 }
178 }
179 Ok(false)
180 }
181 }
182 } else {
183 Ok(false)
185 }
186 }
187
188 async fn proxy_websocket_with_relay(&self, session: &mut Session, ws_url: &str) -> Result<()> {
190 debug!("Setting up enhanced WebSocket proxy to: {}", ws_url);
191
192 let req_header = session.req_header();
194 let mut headers = Vec::new();
195
196 let ws_key = req_header
198 .headers
199 .get("sec-websocket-key")
200 .and_then(|v| v.to_str().ok())
201 .ok_or_else(|| Error::new_str("Missing Sec-WebSocket-Key header"))?;
202
203 for (name, value) in req_header.headers.iter() {
205 if let Ok(value_str) = value.to_str() {
206 let name_str = name.as_str();
207 match name_str.to_lowercase().as_str() {
208 "sec-websocket-key"
209 | "sec-websocket-version"
210 | "sec-websocket-protocol"
211 | "sec-websocket-extensions"
212 | "origin"
213 | "user-agent" => {
214 headers.push((name_str, value_str));
215 }
216 _ => {}
217 }
218 }
219 }
220
221 let client_addr_string;
223 if self.proxy_config.headers.add_x_forwarded {
224 if let Some(client_addr) = session.client_addr() {
225 client_addr_string = client_addr.to_string();
226 headers.push(("X-Forwarded-For", client_addr_string.as_str()));
227 }
228 }
229
230 let (_upstream_ws, response) = match self.connect_upstream_websocket(ws_url, headers).await
232 {
233 Ok(result) => result,
234 Err(e) => {
235 error!("Failed to connect to upstream WebSocket: {}", e);
236 return Err(Error::new_str("Upstream WebSocket connection failed"));
237 }
238 };
239
240 info!(
241 "Connected to upstream WebSocket, status: {}",
242 response.status()
243 );
244
245 let mut ws_protocol = None;
247 let mut ws_extensions = None;
248
249 for (name, value) in response.headers().iter() {
250 if let Ok(value_str) = value.to_str() {
251 match name.as_str().to_lowercase().as_str() {
252 "sec-websocket-protocol" => {
253 ws_protocol = Some(value_str.to_string());
254 }
255 "sec-websocket-extensions" => {
256 ws_extensions = Some(value_str.to_string());
257 }
258 _ => {}
259 }
260 }
261 }
262
263 let ws_accept = self.calculate_websocket_accept(ws_key);
265
266 let mut resp_builder = pingora::http::ResponseHeader::build(101, None).unwrap();
268 resp_builder.insert_header("Upgrade", "websocket").unwrap();
269 resp_builder.insert_header("Connection", "Upgrade").unwrap();
270 resp_builder
271 .insert_header("Sec-WebSocket-Accept", &ws_accept)
272 .unwrap();
273
274 if let Some(protocol) = ws_protocol {
276 if let Err(e) = resp_builder.insert_header("Sec-WebSocket-Protocol", &protocol) {
277 warn!("Failed to set WebSocket protocol header: {}", e);
278 }
279 }
280
281 if let Some(extensions) = ws_extensions {
282 if let Err(e) = resp_builder.insert_header("Sec-WebSocket-Extensions", &extensions) {
283 warn!("Failed to set WebSocket extensions header: {}", e);
284 }
285 }
286
287 session
289 .write_response_header(Box::new(resp_builder), false)
290 .await?;
291
292 info!("WebSocket upgrade successful, starting message relay simulation");
293
294 info!("Simulating WebSocket connection active state");
304 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
305
306 info!("WebSocket proxy session completed");
310 Ok(())
311 }
312
313 fn calculate_websocket_accept(&self, ws_key: &str) -> String {
315 use base64::prelude::*;
316 use sha1::{Digest, Sha1};
317
318 const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
319 let mut hasher = Sha1::new();
320 hasher.update(ws_key.as_bytes());
321 hasher.update(WS_GUID.as_bytes());
322 let result = hasher.finalize();
323 BASE64_STANDARD.encode(result)
324 }
325
326 fn get_websocket_url(
328 &self,
329 upstream: &UpstreamConfig,
330 route: &ProxyRoute,
331 path: &str,
332 ) -> Result<String> {
333 let upstream_url =
334 Url::parse(&upstream.url).map_err(|_| Error::new_str("Invalid upstream URL"))?;
335
336 let scheme = match upstream_url.scheme() {
337 "http" => "ws",
338 "https" => "wss",
339 "ws" | "wss" => upstream_url.scheme(),
340 _ => return Err(Error::new_str("Unsupported upstream scheme")),
341 };
342
343 let target_path = if route.strip_prefix {
344 path.strip_prefix(&route.path).unwrap_or(path)
345 } else {
346 path
347 };
348
349 let target_path = if let Some(rewrite_target) = &route.rewrite_target {
350 rewrite_target.as_str()
351 } else {
352 target_path
353 };
354
355 let ws_url = format!(
356 "{}://{}{}{}",
357 scheme,
358 upstream_url.host_str().unwrap_or("localhost"),
359 upstream_url
360 .port()
361 .map(|p| format!(":{}", p))
362 .unwrap_or_default(),
363 target_path
364 );
365
366 Ok(ws_url)
367 }
368
369 #[allow(dead_code)]
371 async fn proxy_websocket(&self, session: &mut Session, ws_url: &str) -> Result<()> {
372 self.proxy_websocket_with_relay(session, ws_url).await
374 }
375
376 async fn connect_upstream_websocket(
378 &self,
379 ws_url: &str,
380 _headers: Vec<(&str, &str)>,
381 ) -> Result<(
382 WebSocketStream<MaybeTlsStream<TcpStream>>,
383 tokio_tungstenite::tungstenite::handshake::client::Response,
384 )> {
385 let (ws_stream, response) =
390 tokio_tungstenite::connect_async(ws_url)
391 .await
392 .map_err(|e| {
393 error!("WebSocket connection error: {}", e);
394 Error::new_str("WebSocket connection failed")
395 })?;
396
397 debug!("Successfully connected to upstream WebSocket");
398 Ok((ws_stream, response))
399 }
400
401 pub async fn relay_websocket_messages(
405 client_ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
406 upstream_ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
407 ) -> Result<()> {
408 let (mut client_sink, mut client_stream) = client_ws.split();
409 let (mut upstream_sink, mut upstream_stream) = upstream_ws.split();
410
411 let client_to_upstream = async {
413 while let Some(msg) = client_stream.next().await {
414 match msg {
415 Ok(Message::Close(_)) => {
416 debug!("Client WebSocket closed");
417 let _ = upstream_sink.send(Message::Close(None)).await;
418 break;
419 }
420 Ok(msg) => {
421 if let Err(e) = upstream_sink.send(msg).await {
422 error!("Failed to forward message to upstream: {}", e);
423 break;
424 }
425 }
426 Err(e) => {
427 error!("Error reading from client WebSocket: {}", e);
428 break;
429 }
430 }
431 }
432 };
433
434 let upstream_to_client = async {
435 while let Some(msg) = upstream_stream.next().await {
436 match msg {
437 Ok(Message::Close(_)) => {
438 debug!("Upstream WebSocket closed");
439 let _ = client_sink.send(Message::Close(None)).await;
440 break;
441 }
442 Ok(msg) => {
443 if let Err(e) = client_sink.send(msg).await {
444 error!("Failed to forward message to client: {}", e);
445 break;
446 }
447 }
448 Err(e) => {
449 error!("Error reading from upstream WebSocket: {}", e);
450 break;
451 }
452 }
453 }
454 };
455
456 tokio::select! {
458 _ = client_to_upstream => {
459 debug!("Client to upstream forwarding completed");
460 }
461 _ = upstream_to_client => {
462 debug!("Upstream to client forwarding completed");
463 }
464 }
465
466 Ok(())
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use crate::config::site::{LoadBalancingConfig, ProxyHeadersConfig, TimeoutConfig};
474 use pingora::http::{Method, RequestHeader};
475 use std::collections::HashMap;
476
477 fn create_test_config() -> ProxyConfig {
478 ProxyConfig {
479 enabled: true,
480 upstreams: vec![
481 UpstreamConfig {
482 name: "websocket_upstream".to_string(),
483 url: "http://localhost:3001".to_string(),
484 weight: 1,
485 max_conns: None,
486 },
487 UpstreamConfig {
488 name: "websocket_upstream".to_string(),
489 url: "http://localhost:3002".to_string(),
490 weight: 1,
491 max_conns: None,
492 },
493 ],
494 routes: vec![
495 ProxyRoute {
496 path: "/ws".to_string(),
497 upstream: "websocket_upstream".to_string(),
498 strip_prefix: true,
499 rewrite_target: None,
500 websocket: true,
501 },
502 ProxyRoute {
503 path: "/api".to_string(),
504 upstream: "websocket_upstream".to_string(),
505 strip_prefix: false,
506 rewrite_target: None,
507 websocket: false,
508 },
509 ],
510 health_check: Default::default(),
511 load_balancing: LoadBalancingConfig {
512 method: "round_robin".to_string(),
513 sticky_sessions: false,
514 },
515 timeout: TimeoutConfig {
516 connect: 10,
517 read: 30,
518 write: 30,
519 },
520 headers: ProxyHeadersConfig {
521 preserve_host: true,
522 add_forwarded: true,
523 add_x_forwarded: true,
524 remove: vec![],
525 add: HashMap::new(),
526 },
527 }
528 }
529
530 #[test]
531 fn test_websocket_upgrade_detection() {
532 let mut req = RequestHeader::build(Method::GET, b"/ws", None).unwrap();
533
534 assert!(!WebSocketProxyHandler::is_websocket_upgrade_request(&req));
536
537 req.insert_header("Upgrade", "websocket").unwrap();
539 req.insert_header("Connection", "Upgrade").unwrap();
540 req.insert_header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
541 .unwrap();
542
543 assert!(WebSocketProxyHandler::is_websocket_upgrade_request(&req));
545 }
546
547 #[test]
548 fn test_websocket_route_detection() {
549 let proxy_config = create_test_config();
550 let handler = WebSocketProxyHandler::new(proxy_config);
551
552 assert!(handler.find_websocket_route("/ws").is_some());
554 assert!(handler.find_websocket_route("/ws/chat").is_some());
555
556 assert!(handler.find_websocket_route("/api").is_none());
558
559 assert!(handler.find_websocket_route("/other").is_none());
561 }
562
563 #[test]
564 fn test_websocket_url_construction() {
565 let proxy_config = create_test_config();
566 let handler = WebSocketProxyHandler::new(proxy_config);
567
568 let upstream = &UpstreamConfig {
569 name: "test".to_string(),
570 url: "http://localhost:3001".to_string(),
571 weight: 1,
572 max_conns: None,
573 };
574
575 let route = &ProxyRoute {
576 path: "/ws".to_string(),
577 upstream: "test".to_string(),
578 strip_prefix: true,
579 rewrite_target: None,
580 websocket: true,
581 };
582
583 let ws_url = handler
584 .get_websocket_url(upstream, route, "/ws/chat")
585 .unwrap();
586 assert_eq!(ws_url, "ws://localhost:3001/chat");
587
588 let https_upstream = &UpstreamConfig {
590 name: "test".to_string(),
591 url: "https://localhost:3001".to_string(),
592 weight: 1,
593 max_conns: None,
594 };
595
596 let wss_url = handler
597 .get_websocket_url(https_upstream, route, "/ws/chat")
598 .unwrap();
599 assert_eq!(wss_url, "wss://localhost:3001/chat");
600 }
601
602 #[test]
603 fn test_upstream_selection() {
604 let proxy_config = create_test_config();
605 let handler = WebSocketProxyHandler::new(proxy_config);
606
607 let upstream1 = handler.select_upstream("websocket_upstream").unwrap();
609 let upstream2 = handler.select_upstream("websocket_upstream").unwrap();
610
611 assert_ne!(upstream1.url, upstream2.url);
613 }
614}