use std::io;
use std::net::{TcpListener, TcpStream, ToSocketAddrs};
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use std::time::Duration;
pub struct TcpProxy {
backends: Vec<String>,
counter: Arc<AtomicUsize>,
connect_timeout: Duration,
}
impl TcpProxy {
pub fn new<I, S>(backends: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
TcpProxy {
backends: backends.into_iter().map(|b| b.into()).collect(),
counter: Arc::new(AtomicUsize::new(0)),
connect_timeout: Duration::from_secs(5),
}
}
pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
self.connect_timeout = Duration::from_millis(ms);
self
}
pub fn bind(self, addr: &str) -> Result<(), String> {
if self.backends.is_empty() {
return Err("TcpProxy: no backends configured".to_string());
}
let listener = TcpListener::bind(addr)
.map_err(|e| format!("TcpProxy: bind on {} failed: {}", addr, e))?;
println!("TcpProxy: listening on {}", addr);
let proxy = Arc::new(self);
for incoming in listener.incoming() {
let client = match incoming {
Ok(s) => s,
Err(e) => {
eprintln!("TcpProxy: accept error: {}", e);
continue;
}
};
let p = Arc::clone(&proxy);
std::thread::spawn(move || {
if let Err(e) = p.relay(client) {
eprintln!("TcpProxy: relay error: {}", e);
}
});
}
Ok(())
}
fn pick_backend(&self) -> &str {
let i = self.counter.fetch_add(1, Ordering::Relaxed) % self.backends.len();
&self.backends[i]
}
fn relay(&self, client: TcpStream) -> Result<(), String> {
let addr_str = self.pick_backend().to_string();
let sock_addr = addr_str
.to_socket_addrs()
.map_err(|e| format!("DNS lookup for {} failed: {}", addr_str, e))?
.next()
.ok_or_else(|| format!("no address resolved for {}", addr_str))?;
let backend = TcpStream::connect_timeout(&sock_addr, self.connect_timeout)
.map_err(|e| format!("TcpProxy: connect to {} failed: {}", addr_str, e))?;
let mut client_r = client.try_clone().map_err(|e| e.to_string())?;
let mut backend_r = backend.try_clone().map_err(|e| e.to_string())?;
let mut client_w = client;
let mut backend_w = backend;
let t1 = std::thread::spawn(move || {
io::copy(&mut client_r, &mut backend_w).ok();
let _ = backend_w.shutdown(std::net::Shutdown::Write);
});
let t2 = std::thread::spawn(move || {
io::copy(&mut backend_r, &mut client_w).ok();
let _ = client_w.shutdown(std::net::Shutdown::Write);
});
let _ = t1.join();
let _ = t2.join();
Ok(())
}
}