use std::net::SocketAddr;
use tokio::io::copy_bidirectional;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::watch;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use crate::config::ProxyEntry;
use crate::safety::parse_backend;
use crate::tls::{self, CertSource};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ListenerState {
Starting,
Running,
Error,
Stopped,
}
impl ListenerState {
pub fn as_str(self) -> &'static str {
match self {
ListenerState::Starting => "starting",
ListenerState::Running => "running",
ListenerState::Error => "error",
ListenerState::Stopped => "stopped",
}
}
}
#[derive(Debug, Clone)]
pub struct ListenerStatus {
pub state: ListenerState,
pub error: Option<String>,
pub cert_source: CertSource,
}
impl ListenerStatus {
fn starting() -> Self {
Self {
state: ListenerState::Starting,
error: None,
cert_source: CertSource::SelfSigned,
}
}
fn error(message: String, cert_source: CertSource) -> Self {
Self {
state: ListenerState::Error,
error: Some(message),
cert_source,
}
}
}
pub fn spawn_listener(
entry: ProxyEntry,
cancel: CancellationToken,
) -> watch::Receiver<ListenerStatus> {
let (tx, rx) = watch::channel(ListenerStatus::starting());
tokio::spawn(async move {
run_listener(entry, cancel, tx).await;
});
rx
}
async fn run_listener(
entry: ProxyEntry,
cancel: CancellationToken,
tx: watch::Sender<ListenerStatus>,
) {
let setup = match tls::build_tls(&entry) {
Ok(setup) => setup,
Err(e) => {
let _ = tx.send(ListenerStatus::error(
format!("tls setup: {e}"),
CertSource::SelfSigned,
));
tracing::warn!(name = %entry.name, error = %e, "Proxy TLS setup failed");
return;
}
};
let cert_source = setup.cert_source;
let acceptor = TlsAcceptor::from(setup.config);
let _watcher = tls::spawn_cert_watcher(entry.clone(), setup.resolver, cancel.clone());
let addr = SocketAddr::from(([0, 0, 0, 0], entry.listen_port));
let listener = match TcpListener::bind(addr).await {
Ok(listener) => listener,
Err(e) => {
let _ = tx.send(ListenerStatus::error(bind_error_message(&e), cert_source));
tracing::warn!(
name = %entry.name, port = entry.listen_port, error = %e,
"Proxy listener bind failed"
);
return;
}
};
let _ = tx.send(ListenerStatus {
state: ListenerState::Running,
error: None,
cert_source,
});
tracing::info!(
name = %entry.name, port = entry.listen_port, backend = %entry.backend,
cert = cert_source.as_str(), "Proxy listener running"
);
loop {
tokio::select! {
_ = cancel.cancelled() => break,
accept = listener.accept() => match accept {
Ok((tcp, peer)) => {
let acceptor = acceptor.clone();
let backend = entry.backend.clone();
let name = entry.name.clone();
tokio::spawn(async move {
handle_conn(acceptor, tcp, peer, &backend, &name).await;
});
}
Err(e) => {
tracing::warn!(name = %entry.name, error = %e, "Proxy accept error");
}
}
}
}
let _ = tx.send(ListenerStatus {
state: ListenerState::Stopped,
error: None,
cert_source,
});
}
async fn handle_conn(
acceptor: TlsAcceptor,
tcp: TcpStream,
peer: SocketAddr,
backend: &str,
name: &str,
) {
let mut tls = match acceptor.accept(tcp).await {
Ok(stream) => stream,
Err(e) => {
tracing::debug!(name, %peer, error = %e, "Proxy TLS handshake failed");
return;
}
};
let (host, port) = match parse_backend(backend) {
Ok(hostport) => hostport,
Err(e) => {
tracing::warn!(name, backend, error = %e, "Proxy backend parse failed");
return;
}
};
let mut upstream = match TcpStream::connect((host.as_str(), port)).await {
Ok(stream) => stream,
Err(e) => {
tracing::warn!(name, backend, error = %e, "Proxy backend connect failed");
return;
}
};
if let Err(e) = copy_bidirectional(&mut tls, &mut upstream).await {
tracing::debug!(name, %peer, error = %e, "Proxy passthrough ended");
}
}
fn bind_error_message(e: &std::io::Error) -> String {
match e.kind() {
std::io::ErrorKind::AddrInUse => "address in use".to_string(),
std::io::ErrorKind::PermissionDenied => "permission denied".to_string(),
std::io::ErrorKind::AddrNotAvailable => "address not available".to_string(),
_ => e.to_string(),
}
}