use std::collections::HashMap;
use std::net::{SocketAddr, ToSocketAddrs};
use std::path::Path;
use std::sync::Arc;
use futures_util::{SinkExt, StreamExt};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::RwLock;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::handshake::server::{Request, Response};
use crate::config::{ConfigChange, ConfigWatcher, ResolvedConfig, ServerFileConfig};
use crate::error::{Error, Result};
pub enum Bindable {
Address(SocketAddr),
Listener(TcpListener),
}
impl Bindable {
pub fn new<T: ToSocketAddrs>(addr: T) -> Result<Self> {
let socket_addr = addr
.to_socket_addrs()?
.next()
.ok_or_else(|| Error::config("could not resolve address"))?;
Ok(Bindable::Address(socket_addr))
}
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
match self {
Bindable::Address(addr) => Ok(*addr),
Bindable::Listener(listener) => listener.local_addr(),
}
}
}
impl std::fmt::Debug for Bindable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Bindable::Address(addr) => f.debug_tuple("Bindable::Address").field(addr).finish(),
Bindable::Listener(listener) => f
.debug_tuple("Bindable::Listener")
.field(&listener.local_addr())
.finish(),
}
}
}
impl From<TcpListener> for Bindable {
fn from(listener: TcpListener) -> Self {
Bindable::Listener(listener)
}
}
impl From<SocketAddr> for Bindable {
fn from(addr: SocketAddr) -> Self {
Bindable::Address(addr)
}
}
pub trait IntoBindable {
fn into_bindable(self) -> Result<Bindable>;
}
impl IntoBindable for TcpListener {
fn into_bindable(self) -> Result<Bindable> {
Ok(Bindable::Listener(self))
}
}
impl IntoBindable for SocketAddr {
fn into_bindable(self) -> Result<Bindable> {
Ok(Bindable::Address(self))
}
}
impl IntoBindable for &str {
fn into_bindable(self) -> Result<Bindable> {
Bindable::new(self)
}
}
impl IntoBindable for String {
fn into_bindable(self) -> Result<Bindable> {
Bindable::new(self.as_str())
}
}
impl IntoBindable for (&str, u16) {
fn into_bindable(self) -> Result<Bindable> {
Bindable::new(self)
}
}
impl IntoBindable for (std::net::IpAddr, u16) {
fn into_bindable(self) -> Result<Bindable> {
Ok(Bindable::Address(self.into()))
}
}
impl IntoBindable for (std::net::Ipv4Addr, u16) {
fn into_bindable(self) -> Result<Bindable> {
Ok(Bindable::Address(self.into()))
}
}
impl IntoBindable for (std::net::Ipv6Addr, u16) {
fn into_bindable(self) -> Result<Bindable> {
Ok(Bindable::Address(self.into()))
}
}
#[derive(Clone)]
pub struct Address {
resolver: Arc<dyn Fn() -> std::io::Result<SocketAddr> + Send + Sync>,
}
impl Address {
pub fn new<T>(addr: T) -> Self
where
T: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Address {
resolver: Arc::new(move || {
addr.clone().to_socket_addrs()?.next().ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::NotFound, "could not resolve address")
})
}),
}
}
pub fn resolve(&self) -> std::io::Result<SocketAddr> {
(self.resolver)()
}
}
impl std::fmt::Debug for Address {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Address").finish_non_exhaustive()
}
}
impl From<String> for Address {
fn from(s: String) -> Self {
Address::new(s)
}
}
impl From<&'static str> for Address {
fn from(s: &'static str) -> Self {
Address::new(s.to_string())
}
}
impl From<SocketAddr> for Address {
fn from(addr: SocketAddr) -> Self {
Address::new(addr)
}
}
#[derive(Debug, Clone)]
pub enum TlsConfig {
Files {
cert_path: String,
key_path: String,
},
SelfSigned,
}
#[derive(Debug, Clone)]
pub enum TlsMode<'a> {
None,
Files { cert: &'a str, key: &'a str },
SelfSigned,
}
pub async fn run(
listen: &str,
routes: &[String],
default_target: Option<&str>,
tls: TlsMode<'_>,
) -> Result<()> {
let mut builder = ProxyServer::builder();
for r in routes {
let (path, target) = r.split_once('=').ok_or_else(|| {
Error::config(format!(
"Invalid route format '{}', expected 'path=target'",
r
))
})?;
builder = builder.route(path.to_string(), target.to_string());
}
if let Some(target) = default_target {
builder = builder.default_target(target.to_string());
}
let is_tls = !matches!(tls, TlsMode::None);
match tls {
TlsMode::None => {}
TlsMode::Files { cert, key } => {
builder = builder.tls(cert, key);
}
TlsMode::SelfSigned => {
builder = builder.tls_self_signed();
}
}
let server = builder.bind(listen)?;
if is_tls {
eprintln!("Proxy server listening on {} (WSS)", listen);
} else {
eprintln!("Proxy server listening on {}", listen);
}
server.run().await
}
pub async fn run_with_config(config_path: impl AsRef<Path>) -> Result<()> {
let config_path = config_path.as_ref();
loop {
let config = ServerFileConfig::load(config_path)?;
let resolved = ResolvedConfig::from_file_config(&config)?;
let tls_acceptor = build_tls_acceptor(&config)?;
let is_tls = config.has_tls();
if is_tls {
eprintln!(
"Proxy server listening on {} (WSS) - config: {}",
config.listen,
config_path.display()
);
} else {
eprintln!(
"Proxy server listening on {} - config: {}",
config.listen,
config_path.display()
);
}
let shared_config = Arc::new(RwLock::new(resolved));
let mut watcher = ConfigWatcher::new(config_path, config.clone())?;
let listener = TcpListener::bind(&config.listen).await?;
let restart_needed =
run_server_loop(listener, tls_acceptor, shared_config, &mut watcher).await?;
if !restart_needed {
break;
}
eprintln!("Configuration changed, restarting server...");
}
Ok(())
}
fn build_tls_acceptor(config: &ServerFileConfig) -> Result<Option<tokio_rustls::TlsAcceptor>> {
if !config.has_tls() {
return Ok(None);
}
let (certs, key) = if config.tls.self_signed {
generate_self_signed_cert()?
} else if let (Some(cert), Some(key)) = (&config.tls.cert, &config.tls.key) {
load_certs_from_files(cert, key)?
} else {
return Err(Error::config("invalid TLS configuration"));
};
let tls_config = tokio_rustls::rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| Error::config(format!("failed to create TLS config: {}", e)))?;
Ok(Some(tokio_rustls::TlsAcceptor::from(Arc::new(tls_config))))
}
async fn run_server_loop(
listener: TcpListener,
tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
shared_config: Arc<RwLock<ResolvedConfig>>,
watcher: &mut ConfigWatcher,
) -> Result<bool> {
loop {
tokio::select! {
accept_result = listener.accept() => {
let (stream, peer_addr) = accept_result?;
let config = Arc::clone(&shared_config);
let tls = tls_acceptor.clone();
tokio::spawn(async move {
let result = if let Some(ref tls_acceptor) = tls {
match tls_acceptor.accept(stream).await {
Ok(tls_stream) => handle_ws_connection_shared(tls_stream, config).await,
Err(e) => {
eprintln!("TLS handshake failed from {}: {}", peer_addr, e);
return;
}
}
} else {
handle_ws_connection_shared(stream, config).await
};
if let Err(e) = result {
eprintln!("Error handling connection from {}: {}", peer_addr, e);
}
});
}
change = watcher.recv() => {
match change {
Some(ConfigChange::RoutingOnly(new_config)) => {
match ResolvedConfig::from_file_config(&new_config) {
Ok(resolved) => {
let mut config = shared_config.write().await;
*config = resolved;
eprintln!("Configuration reloaded (routing updated)");
}
Err(e) => {
eprintln!("Failed to apply config: {}", e);
}
}
}
Some(ConfigChange::FullRestart(_)) => {
return Ok(true);
}
Some(ConfigChange::Error(e)) => {
eprintln!("Config reload error: {}", e);
}
None => {
return Ok(false);
}
}
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ProxyServerBuilder {
routes: HashMap<String, Address>,
default_target: Option<Address>,
tls_config: Option<TlsConfig>,
}
impl ProxyServerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn route<T>(mut self, path: impl Into<String>, target: T) -> Self
where
T: ToSocketAddrs + Clone + Send + Sync + 'static,
{
self.routes.insert(path.into(), Address::new(target));
self
}
pub fn default_target<T>(mut self, target: T) -> Self
where
T: ToSocketAddrs + Clone + Send + Sync + 'static,
{
self.default_target = Some(Address::new(target));
self
}
pub fn tls(mut self, cert_path: impl Into<String>, key_path: impl Into<String>) -> Self {
self.tls_config = Some(TlsConfig::Files {
cert_path: cert_path.into(),
key_path: key_path.into(),
});
self
}
pub fn tls_self_signed(mut self) -> Self {
self.tls_config = Some(TlsConfig::SelfSigned);
self
}
pub fn bind(self, bindable: impl IntoBindable) -> Result<ProxyServer> {
let bindable = bindable.into_bindable()?;
let listen_addr = bindable.local_addr()?;
if self.routes.is_empty() && self.default_target.is_none() {
return Err(Error::config(
"at least one route or a default_target is required",
));
}
let tls_acceptor = if let Some(tls_config) = &self.tls_config {
let (certs, key) = match tls_config {
TlsConfig::Files {
cert_path,
key_path,
} => load_certs_from_files(cert_path, key_path)?,
TlsConfig::SelfSigned => generate_self_signed_cert()?,
};
let config = tokio_rustls::rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| Error::config(format!("failed to create TLS config: {}", e)))?;
Some(tokio_rustls::TlsAcceptor::from(Arc::new(config)))
} else {
None
};
Ok(ProxyServer {
bindable,
inner: Arc::new(ProxyServerInner {
listen_addr,
routes: self.routes,
default_target: self.default_target,
tls_acceptor,
}),
})
}
}
fn load_certs_from_files(
cert_path: &str,
key_path: &str,
) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
use std::io::BufReader;
let cert_file = std::fs::File::open(cert_path).map_err(|e| {
Error::config(format!(
"failed to open TLS certificate '{}': {}",
cert_path, e
))
})?;
let key_file = std::fs::File::open(key_path)
.map_err(|e| Error::config(format!("failed to open TLS key '{}': {}", key_path, e)))?;
let certs: Vec<_> = rustls_pemfile::certs(&mut BufReader::new(cert_file))
.collect::<std::result::Result<_, _>>()
.map_err(|e| Error::config(format!("failed to parse TLS certificate: {}", e)))?;
let key = rustls_pemfile::private_key(&mut BufReader::new(key_file))
.map_err(|e| Error::config(format!("failed to parse TLS key: {}", e)))?
.ok_or_else(|| Error::config("no private key found in key file"))?;
Ok((certs, key))
}
fn generate_self_signed_cert() -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
use rcgen::{CertificateParams, DnType, ExtendedKeyUsagePurpose, KeyUsagePurpose, SanType};
let mut params = CertificateParams::default();
params
.distinguished_name
.push(DnType::CommonName, "localhost");
params.subject_alt_names = vec![
SanType::DnsName(
"localhost"
.try_into()
.map_err(|e| Error::config(format!("failed to create SAN: {}", e)))?,
),
SanType::IpAddress(std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))),
];
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
params.key_usages = vec![KeyUsagePurpose::DigitalSignature];
let key_pair = rcgen::KeyPair::generate()
.map_err(|e| Error::config(format!("failed to generate key pair: {}", e)))?;
let cert = params
.self_signed(&key_pair)
.map_err(|e| Error::config(format!("failed to generate self-signed certificate: {}", e)))?;
let cert_der = CertificateDer::from(cert.der().to_vec());
let key_der = PrivateKeyDer::try_from(key_pair.serialize_der())
.map_err(|e| Error::config(format!("failed to serialize private key: {}", e)))?;
Ok((vec![cert_der], key_der))
}
struct ProxyServerInner {
listen_addr: SocketAddr,
routes: HashMap<String, Address>,
default_target: Option<Address>,
tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
}
pub struct ProxyServer {
bindable: Bindable,
inner: Arc<ProxyServerInner>,
}
impl ProxyServer {
pub fn builder() -> ProxyServerBuilder {
ProxyServerBuilder::new()
}
pub fn local_addr(&self) -> SocketAddr {
self.inner.listen_addr
}
pub async fn run(self) -> Result<()> {
let listener = match self.bindable {
Bindable::Address(addr) => TcpListener::bind(addr).await?,
Bindable::Listener(l) => l,
};
loop {
let (stream, peer_addr) = listener.accept().await?;
let inner = Arc::clone(&self.inner);
tokio::spawn(async move {
let result = if let Some(ref tls_acceptor) = inner.tls_acceptor {
match tls_acceptor.accept(stream).await {
Ok(tls_stream) => handle_ws_connection(tls_stream, &inner).await,
Err(e) => {
eprintln!("TLS handshake failed from {}: {}", peer_addr, e);
return;
}
}
} else {
handle_ws_connection(stream, &inner).await
};
if let Err(e) = result {
eprintln!("Error handling connection from {}: {}", peer_addr, e);
}
});
}
}
}
async fn handle_ws_connection<S>(stream: S, inner: &ProxyServerInner) -> Result<()>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let path = Arc::new(std::sync::Mutex::new(String::new()));
let path_clone = Arc::clone(&path);
#[allow(clippy::result_large_err)] let callback = move |req: &Request, response: Response| {
let uri_path = req.uri().path().to_string();
*path_clone.lock().unwrap() = uri_path;
Ok(response)
};
let ws_stream = tokio_tungstenite::accept_hdr_async(stream, callback).await?;
let request_path = path.lock().unwrap().clone();
let target = inner
.routes
.get(&request_path)
.or_else(|| {
let normalized = request_path.trim_end_matches('/');
inner.routes.get(normalized)
})
.or(inner.default_target.as_ref())
.ok_or_else(|| Error::no_route_found(request_path.clone()))?;
let (mut ws_write, mut ws_read) = ws_stream.split();
let target_addr = target.resolve()?;
let tcp_stream = TcpStream::connect(target_addr).await?;
let (mut tcp_read, mut tcp_write) = tcp_stream.into_split();
let ws_to_tcp = async {
while let Some(msg) = ws_read.next().await {
match msg {
Ok(Message::Binary(data)) => {
tcp_write.write_all(&data).await?;
}
Ok(Message::Text(text)) => {
tcp_write.write_all(text.as_bytes()).await?;
}
Ok(Message::Close(_)) => {
break;
}
Ok(Message::Ping(data)) => {
let _ = data;
}
Ok(Message::Pong(_)) => {
}
Ok(Message::Frame(_)) => {
}
Err(e) => {
return Err(e.into());
}
}
}
Ok::<_, Error>(())
};
let tcp_to_ws = async {
let mut buf = vec![0u8; 8192];
loop {
let n = tcp_read.read(&mut buf).await?;
if n == 0 {
break;
}
ws_write
.send(Message::Binary(buf[..n].to_vec().into()))
.await?;
}
Ok::<_, Error>(())
};
tokio::select! {
result = ws_to_tcp => result?,
result = tcp_to_ws => result?,
}
Ok(())
}
async fn handle_ws_connection_shared<S>(
stream: S,
config: Arc<RwLock<ResolvedConfig>>,
) -> Result<()>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let path = Arc::new(std::sync::Mutex::new(String::new()));
let path_clone = Arc::clone(&path);
#[allow(clippy::result_large_err)]
let callback = move |req: &Request, response: Response| {
let uri_path = req.uri().path().to_string();
*path_clone.lock().unwrap() = uri_path;
Ok(response)
};
let ws_stream = tokio_tungstenite::accept_hdr_async(stream, callback).await?;
let request_path = path.lock().unwrap().clone();
let target = {
let cfg = config.read().await;
cfg.routes
.get(&request_path)
.or_else(|| {
let normalized = request_path.trim_end_matches('/');
cfg.routes.get(normalized)
})
.or(cfg.default_target.as_ref())
.cloned()
.ok_or_else(|| Error::no_route_found(request_path.clone()))?
};
let (mut ws_write, mut ws_read) = ws_stream.split();
let target_addr = target
.to_socket_addrs()?
.next()
.ok_or_else(|| Error::config("could not resolve address"))?;
let tcp_stream = TcpStream::connect(target_addr).await?;
let (mut tcp_read, mut tcp_write) = tcp_stream.into_split();
let ws_to_tcp = async {
while let Some(msg) = ws_read.next().await {
match msg {
Ok(Message::Binary(data)) => {
tcp_write.write_all(&data).await?;
}
Ok(Message::Text(text)) => {
tcp_write.write_all(text.as_bytes()).await?;
}
Ok(Message::Close(_)) => {
break;
}
Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_)) => {}
Err(e) => {
return Err(e.into());
}
}
}
Ok::<_, Error>(())
};
let tcp_to_ws = async {
let mut buf = vec![0u8; 8192];
loop {
let n = tcp_read.read(&mut buf).await?;
if n == 0 {
break;
}
ws_write
.send(Message::Binary(buf[..n].to_vec().into()))
.await?;
}
Ok::<_, Error>(())
};
tokio::select! {
result = ws_to_tcp => result?,
result = tcp_to_ws => result?,
}
Ok(())
}