use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream, ToSocketAddrs};
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use std::time::Duration;
use crate::request::Request;
use crate::websocket::WebSocket;
pub struct WsProxy {
backends: Vec<String>,
counter: Arc<AtomicUsize>,
connect_timeout: Duration,
read_timeout: Duration,
}
impl WsProxy {
pub fn new<I, S>(backends: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
WsProxy {
backends: backends.into_iter().map(|b| b.into()).collect(),
counter: Arc::new(AtomicUsize::new(0)),
connect_timeout: Duration::from_secs(5),
read_timeout: Duration::from_secs(30),
}
}
pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
self.connect_timeout = Duration::from_millis(ms);
self
}
pub fn read_timeout_ms(mut self, ms: u64) -> Self {
self.read_timeout = Duration::from_millis(ms);
self
}
pub fn bind(self, addr: &str) -> Result<(), String> {
if self.backends.is_empty() {
return Err("WsProxy: no backends configured".to_string());
}
let listener = TcpListener::bind(addr)
.map_err(|e| format!("WsProxy: bind on {} failed: {}", addr, e))?;
println!("WsProxy: listening on {}", addr);
let proxy = Arc::new(self);
for incoming in listener.incoming() {
let client = match incoming {
Ok(s) => s,
Err(e) => {
eprintln!("WsProxy: accept error: {}", e);
continue;
}
};
let p = Arc::clone(&proxy);
std::thread::spawn(move || {
if let Err(e) = p.handle(client) {
eprintln!("WsProxy: {}", e);
}
});
}
Ok(())
}
fn pick_backend(&self) -> &str {
let i = self.counter.fetch_add(1, Ordering::Relaxed) % self.backends.len();
&self.backends[i]
}
fn handle(&self, mut client: TcpStream) -> Result<(), String> {
client.set_read_timeout(Some(self.read_timeout)).ok();
let mut buf = vec![0u8; 8192];
let n = client.read(&mut buf).map_err(|e| e.to_string())?;
if n == 0 {
return Ok(());
}
let request = Request::parse(&buf[..n])
.map_err(|e| format!("WsProxy: invalid HTTP request: {}", e))?;
if !WebSocket::is_upgrade_request(&request) {
let _ = client.write_all(b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n");
return Err(format!(
"WsProxy: not a WebSocket upgrade — method={}, uri={}",
request.method, request.request_uri
));
}
let backend_str = self.pick_backend().to_string();
let backend_sock = backend_str
.to_socket_addrs()
.map_err(|e| format!("WsProxy: DNS lookup for {} failed: {}", backend_str, e))?
.next()
.ok_or_else(|| format!("WsProxy: no address for {}", backend_str))?;
let mut backend = TcpStream::connect_timeout(&backend_sock, self.connect_timeout)
.map_err(|e| format!("WsProxy: connect to {} failed: {}", backend_str, e))?;
let upgrade_req = build_upgrade_request(&request, &backend_str);
backend
.write_all(&upgrade_req)
.map_err(|e| format!("WsProxy: write upgrade to backend failed: {}", e))?;
let mut resp_buf = vec![0u8; 4096];
let m = backend
.read(&mut resp_buf)
.map_err(|e| format!("WsProxy: read 101 from backend failed: {}", e))?;
let resp_preview = &resp_buf[..m.min(20)];
if !resp_preview.starts_with(b"HTTP/1.1 101") && !resp_preview.starts_with(b"HTTP/1.0 101") {
return Err(format!(
"WsProxy: backend {} did not send 101 (got {:?})",
backend_str,
std::str::from_utf8(&resp_buf[..m.min(80)]).unwrap_or("?")
));
}
let response_101 = WebSocket::handshake_response(&request)?;
let raw_101 = format_response_head(&response_101);
client
.write_all(&raw_101)
.map_err(|e| format!("WsProxy: write 101 to client failed: {}", e))?;
let mut client_r = client.try_clone().map_err(|e| e.to_string())?;
let mut backend_r = backend.try_clone().map_err(|e| e.to_string())?;
let mut client_w = client;
let mut backend_w = backend;
let t1 = std::thread::spawn(move || {
std::io::copy(&mut client_r, &mut backend_w).ok();
let _ = backend_w.shutdown(std::net::Shutdown::Write);
});
let t2 = std::thread::spawn(move || {
std::io::copy(&mut backend_r, &mut client_w).ok();
let _ = client_w.shutdown(std::net::Shutdown::Write);
});
let _ = t1.join();
let _ = t2.join();
Ok(())
}
}
fn build_upgrade_request(request: &Request, backend_host: &str) -> Vec<u8> {
let mut req = format!(
"{} {} HTTP/1.1\r\nHost: {}\r\n",
request.method, request.request_uri, backend_host
);
for header in &request.headers {
if header.name.to_lowercase() == "host" {
continue;
}
req.push_str(&format!("{}: {}\r\n", header.name, header.value));
}
req.push_str("\r\n");
req.into_bytes()
}
fn format_response_head(response: &crate::response::Response) -> Vec<u8> {
let mut out = format!(
"HTTP/1.1 {} {}\r\n",
response.status_code, response.reason_phrase
)
.into_bytes();
for h in &response.headers {
out.extend_from_slice(h.name.as_bytes());
out.extend_from_slice(b": ");
out.extend_from_slice(h.value.as_bytes());
out.extend_from_slice(b"\r\n");
}
out.extend_from_slice(b"\r\n");
out
}