use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use hyper_util::rt::TokioIo;
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::sync::{RwLock, Semaphore};
use tracing::{error, info, warn};
#[cfg(feature = "tls")]
use crate::proxy::tls::{build_tls_acceptor, listen_http_redirect, listen_tls};
use crate::config::{extract_hostname, resolve_listen_addr, tls_redirect_port, Config};
use crate::proxy::handler::proxy;
pub struct Proxy {
config: Arc<RwLock<Config>>,
client: Client<HttpsConnector<HttpConnector>, Incoming>,
max_concurrency: usize,
semaphore: Arc<Semaphore>,
}
impl Proxy {
pub fn new(config: Config) -> Self {
let mut http = HttpConnector::new();
http.set_keepalive(Some(Duration::from_secs(60)));
http.set_nodelay(true);
let https = HttpsConnectorBuilder::new()
.with_native_roots()
.expect("Failed to load native TLS root certificates")
.https_or_http()
.enable_http1()
.wrap_connector(http);
let client = Client::builder(TokioExecutor::new())
.pool_max_idle_per_host(100)
.pool_idle_timeout(Duration::from_secs(90))
.build::<_, Incoming>(https);
let max_concurrency = std::env::var("TINY_PROXY_MAX_CONCURRENCY")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or_else(|| num_cpus::get() * 256);
let semaphore = Arc::new(Semaphore::new(max_concurrency));
info!(
"Proxy initialized with max_concurrency={} (default: {})",
max_concurrency,
num_cpus::get() * 256
);
Self {
config: Arc::new(RwLock::new(config)),
client,
max_concurrency,
semaphore,
}
}
pub fn from_shared(config: Arc<RwLock<Config>>) -> Self {
let mut http = HttpConnector::new();
http.set_keepalive(Some(Duration::from_secs(60)));
http.set_nodelay(true);
let https = HttpsConnectorBuilder::new()
.with_native_roots()
.expect("Failed to load native TLS root certificates")
.https_or_http()
.enable_http1()
.wrap_connector(http);
let client = Client::builder(TokioExecutor::new())
.pool_max_idle_per_host(100)
.pool_idle_timeout(Duration::from_secs(90))
.build::<_, Incoming>(https);
let max_concurrency = std::env::var("TINY_PROXY_MAX_CONCURRENCY")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or_else(|| num_cpus::get() * 256);
let semaphore = Arc::new(Semaphore::new(max_concurrency));
info!(
"Proxy initialized with max_concurrency={} (default: {})",
max_concurrency,
num_cpus::get() * 256
);
Self {
config,
client,
max_concurrency,
semaphore,
}
}
pub async fn start(&self, addr: &str) -> anyhow::Result<()> {
let addr: SocketAddr = addr.parse()?;
self.start_with_addr(addr).await
}
pub async fn start_with_addr(&self, addr: SocketAddr) -> anyhow::Result<()> {
let config_snapshot = self.config.read().await.clone();
let tls_sites: Vec<(String, crate::config::TlsConfig)> = config_snapshot
.sites
.values()
.filter(|site| {
site_addr_matches(&site.address, &addr) && site.tls.is_some()
})
.filter_map(|site| {
let hostname = extract_hostname(&site.address);
site.tls.clone().map(|tls| (hostname.to_string(), tls))
})
.collect();
if !tls_sites.is_empty() {
#[cfg(feature = "tls")]
{
self.start_tls(addr, tls_sites).await
}
#[cfg(not(feature = "tls"))]
{
anyhow::bail!(
"TLS configuration found for {} but 'tls' feature is disabled. \
Refusing to start as plain HTTP (security risk). \
Rebuild with --features tls or remove 'tls' from config.",
addr
);
}
} else {
self.start_http(addr).await
}
}
pub async fn start_all(&self) -> anyhow::Result<()> {
let config_snapshot = self.config.read().await.clone();
let mut socket_groups: HashMap<SocketAddr, Vec<&crate::config::SiteConfig>> =
HashMap::new();
for site in config_snapshot.sites.values() {
let listen_addr = resolve_listen_addr(&site.address)?;
socket_groups.entry(listen_addr).or_default().push(site);
}
let mut http_handles = Vec::new();
let mut tls_redirects: HashSet<(SocketAddr, u16)> = HashSet::new();
for (listen_addr, sites) in socket_groups {
let tls_sites: Vec<_> = sites.iter().copied().filter(|s| s.tls.is_some()).collect();
let has_tls = !tls_sites.is_empty();
let has_plain = tls_sites.len() != sites.len();
if has_tls && has_plain {
anyhow::bail!(
"Mixed TLS and non-TLS sites on the same listen address {} is not supported",
listen_addr
);
}
if has_tls {
#[cfg(feature = "tls")]
{
let tls_entries: Vec<(String, crate::config::TlsConfig)> = tls_sites
.iter()
.filter_map(|s| {
let hostname = extract_hostname(&s.address);
s.tls.clone().map(|tls| (hostname.to_string(), tls))
})
.collect();
let tls_port = listen_addr.port();
let client = self.client.clone();
let config = self.config.clone();
let semaphore = self.semaphore.clone();
let acceptor = build_tls_acceptor(&tls_entries, None)?;
info!(
"Starting HTTPS listener on {} ({} domain(s))",
listen_addr,
tls_entries.len()
);
let handle = tokio::spawn(async move {
if let Err(e) =
listen_tls(listen_addr, acceptor, semaphore, move |req, remote_addr| {
let client = client.clone();
let config = config.clone();
async move {
let config_guard = config.read().await;
let config_snapshot = Arc::new(config_guard.clone());
drop(config_guard);
proxy(req, client, config_snapshot, remote_addr, true).await
}
})
.await
{
error!("TLS listener error: {}", e);
}
});
http_handles.push(handle);
tls_redirects.insert((
SocketAddr::new(listen_addr.ip(), tls_redirect_port(tls_port)),
tls_port,
));
}
#[cfg(not(feature = "tls"))]
{
anyhow::bail!(
"TLS configuration found for {} but 'tls' feature is disabled. \
Refusing to start as plain HTTP (security risk). \
Rebuild with --features tls or remove 'tls' from config.",
listen_addr
);
}
} else {
let client = self.client.clone();
let config = self.config.clone();
let semaphore = self.semaphore.clone();
let max_concurrency = self.max_concurrency;
let handle = tokio::spawn(async move {
if let Err(e) =
Self::run_http_loop(listen_addr, client, config, semaphore, max_concurrency)
.await
{
error!("HTTP listener error: {}", e);
}
});
http_handles.push(handle);
}
}
#[cfg(feature = "tls")]
for (redirect_addr, tls_port) in tls_redirects {
info!(
"Starting HTTP→HTTPS redirect on http://{} → :{}",
redirect_addr, tls_port
);
let handle = tokio::spawn(async move {
match listen_http_redirect(redirect_addr, tls_port).await {
Ok(()) => {}
Err(e) => {
warn!(
"HTTP redirect on port {} failed (HTTPS on :{} still active): {}",
redirect_addr.port(),
tls_port,
e
);
}
}
});
http_handles.push(handle);
}
if http_handles.is_empty() {
warn!("No listeners configured — proxy has no sites");
return Ok(());
}
info!(
"Started {} listener(s), max concurrency: {} ({})",
http_handles.len(),
self.max_concurrency,
if self.max_concurrency == num_cpus::get() * 256 {
"default"
} else {
"custom"
}
);
for handle in http_handles {
if let Err(e) = handle.await {
error!("Listener task panicked: {}", e);
}
}
Ok(())
}
async fn start_http(&self, addr: SocketAddr) -> anyhow::Result<()> {
Self::run_http_loop(
addr,
self.client.clone(),
self.config.clone(),
self.semaphore.clone(),
self.max_concurrency,
)
.await
}
async fn run_http_loop(
addr: SocketAddr,
client: Client<HttpsConnector<HttpConnector>, Incoming>,
config: Arc<RwLock<Config>>,
semaphore: Arc<Semaphore>,
max_concurrency: usize,
) -> anyhow::Result<()> {
let listener = TcpListener::bind(&addr).await?;
info!("Tiny Proxy listening on http://{}", addr);
loop {
let (stream, remote_addr) = listener.accept().await?;
let io = TokioIo::new(stream);
let client = client.clone();
let config = config.clone();
let semaphore = semaphore.clone();
match semaphore.try_acquire_owned() {
Ok(permit) => {
tokio::task::spawn(async move {
let _permit = permit;
let service = service_fn(move |req| {
let client = client.clone();
let config = config.clone();
let config_clone = config.clone();
async move {
let config_guard = config_clone.read().await;
let config_snapshot = Arc::new(config_guard.clone());
drop(config_guard);
proxy(req, client, config_snapshot, remote_addr, false).await
}
});
let mut builder = hyper::server::conn::http1::Builder::new();
builder.keep_alive(true).pipeline_flush(false);
builder.serve_connection(io, service).await
});
}
Err(_) => {
warn!(
"Concurrency limit exceeded ({}), rejecting connection",
max_concurrency
);
}
}
}
}
#[cfg(feature = "tls")]
async fn start_tls(
&self,
addr: SocketAddr,
tls_sites: Vec<(String, crate::config::TlsConfig)>,
) -> anyhow::Result<()> {
let acceptor = build_tls_acceptor(&tls_sites, None)?;
info!(
"Starting HTTPS listener on https://{} ({} domain(s))",
addr,
tls_sites.len()
);
let client = self.client.clone();
let config = self.config.clone();
let semaphore = self.semaphore.clone();
listen_tls(addr, acceptor, semaphore, move |req, remote_addr| {
let client = client.clone();
let config = config.clone();
async move {
let config_guard = config.read().await;
let config_snapshot = Arc::new(config_guard.clone());
drop(config_guard);
proxy(req, client, config_snapshot, remote_addr, true).await
}
})
.await
}
pub fn shared_config(&self) -> Arc<RwLock<Config>> {
self.config.clone()
}
pub async fn config_snapshot(&self) -> Config {
self.config.read().await.clone()
}
pub fn max_concurrency(&self) -> usize {
self.max_concurrency
}
pub fn set_max_concurrency(&mut self, max: usize) {
self.max_concurrency = max;
self.semaphore = Arc::new(Semaphore::new(max));
info!("Max concurrency updated to {}", max);
}
pub async fn update_config(&self, config: Config) {
let mut guard = self.config.write().await;
info!("Configuration updated ({} sites)", config.sites.len());
*guard = config;
}
}
fn site_addr_matches(site_address: &str, listen_addr: &SocketAddr) -> bool {
let mut parts = site_address.rsplitn(2, ':');
let port_str = parts.next().unwrap_or("");
let host_str = parts.next().unwrap_or("");
let site_port: u16 = match port_str.parse() {
Ok(p) => p,
Err(_) => return false,
};
if site_port != listen_addr.port() {
return false;
}
if host_str.is_empty() || host_str == "0.0.0.0" || host_str == "::" {
return true; }
let site_ip = if host_str == "localhost" {
std::net::IpAddr::from(std::net::Ipv4Addr::new(127, 0, 0, 1))
} else if let Ok(ip) = host_str.parse::<std::net::IpAddr>() {
ip
} else {
return true;
};
site_ip == listen_addr.ip()
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_proxy_creation() {
let config = Config {
sites: HashMap::new(),
};
let proxy = Proxy::new(config);
let rt = tokio::runtime::Runtime::new().unwrap();
let snapshot = rt.block_on(proxy.config_snapshot());
assert_eq!(snapshot.sites.len(), 0);
}
#[tokio::test]
async fn test_config_access() {
let mut config = Config {
sites: HashMap::new(),
};
config.sites.insert(
"localhost:8080".to_string(),
crate::config::SiteConfig {
address: "localhost:8080".to_string(),
directives: vec![],
tls: None,
},
);
let proxy = Proxy::new(config);
let snapshot = proxy.config_snapshot().await;
assert_eq!(snapshot.sites.len(), 1);
assert!(snapshot.sites.contains_key("localhost:8080"));
}
#[tokio::test]
async fn test_config_update() {
let config1 = Config {
sites: HashMap::new(),
};
let proxy = Proxy::new(config1);
let snapshot = proxy.config_snapshot().await;
assert_eq!(snapshot.sites.len(), 0);
let mut config2 = Config {
sites: HashMap::new(),
};
config2.sites.insert(
"test.local".to_string(),
crate::config::SiteConfig {
address: "test.local".to_string(),
directives: vec![],
tls: None,
},
);
proxy.update_config(config2).await;
let snapshot = proxy.config_snapshot().await;
assert_eq!(snapshot.sites.len(), 1);
assert!(snapshot.sites.contains_key("test.local"));
}
#[tokio::test]
async fn test_shared_config_handle() {
let config = Config {
sites: HashMap::new(),
};
let proxy = Proxy::new(config);
let handle = proxy.shared_config();
{
let mut guard = handle.write().await;
guard.sites.insert(
"shared.local".to_string(),
crate::config::SiteConfig {
address: "shared.local".to_string(),
directives: vec![],
tls: None,
},
);
}
let snapshot = proxy.config_snapshot().await;
assert_eq!(snapshot.sites.len(), 1);
assert!(snapshot.sites.contains_key("shared.local"));
}
#[test]
fn test_from_shared() {
let config = Config {
sites: HashMap::new(),
};
let shared = Arc::new(RwLock::new(config));
let proxy = Proxy::from_shared(shared.clone());
let rt = tokio::runtime::Runtime::new().unwrap();
{
let mut guard = rt.block_on(shared.write());
guard.sites.insert(
"from-shared.local".to_string(),
crate::config::SiteConfig {
address: "from-shared.local".to_string(),
directives: vec![],
tls: None,
},
);
}
let snapshot = rt.block_on(proxy.config_snapshot());
assert_eq!(snapshot.sites.len(), 1);
assert!(snapshot.sites.contains_key("from-shared.local"));
}
#[test]
fn test_site_addr_matches_localhost() {
let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
assert!(site_addr_matches("localhost:8080", &addr));
}
#[test]
fn test_site_addr_matches_ip() {
let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
assert!(site_addr_matches("0.0.0.0:443", &addr));
}
#[test]
fn test_site_addr_matches_hostname_by_port() {
let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
assert!(site_addr_matches("example.com:443", &addr));
}
#[test]
fn test_site_addr_matches_port_mismatch() {
let addr: SocketAddr = "0.0.0.0:443".parse().unwrap();
assert!(!site_addr_matches("example.com:8443", &addr));
}
#[test]
fn test_site_addr_matches_wildcard_host() {
let addr: SocketAddr = "0.0.0.0:9090".parse().unwrap();
assert!(site_addr_matches(":9090", &addr));
}
}