use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder};
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use hyper_util::rt::TokioIo;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::sync::{RwLock, Semaphore};
use tracing::{info, warn};
use crate::config::Config;
use crate::proxy::handler::proxy;
pub struct Proxy {
config: Arc<RwLock<Config>>,
client: Client<HttpsConnector<HttpConnector>, Incoming>,
max_concurrency: usize,
semaphore: Arc<Semaphore>,
}
impl Proxy {
pub fn new(config: Config) -> Self {
let mut http = HttpConnector::new();
http.set_keepalive(Some(Duration::from_secs(60)));
http.set_nodelay(true);
let https = HttpsConnectorBuilder::new()
.with_native_roots()
.unwrap()
.https_or_http()
.enable_http1()
.wrap_connector(http);
let client = Client::builder(TokioExecutor::new())
.pool_max_idle_per_host(100)
.pool_idle_timeout(Duration::from_secs(90))
.build::<_, Incoming>(https);
let max_concurrency = std::env::var("TINY_PROXY_MAX_CONCURRENCY")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or_else(|| num_cpus::get() * 256);
let semaphore = Arc::new(Semaphore::new(max_concurrency));
info!(
"Proxy initialized with max_concurrency={} (default: {})",
max_concurrency,
num_cpus::get() * 256
);
Self {
config: Arc::new(RwLock::new(config)),
client,
max_concurrency,
semaphore,
}
}
pub fn from_shared(config: Arc<RwLock<Config>>) -> Self {
let mut http = HttpConnector::new();
http.set_keepalive(Some(Duration::from_secs(60)));
http.set_nodelay(true);
let https = HttpsConnectorBuilder::new()
.with_native_roots()
.unwrap()
.https_or_http()
.enable_http1()
.wrap_connector(http);
let client = Client::builder(TokioExecutor::new())
.pool_max_idle_per_host(100)
.pool_idle_timeout(Duration::from_secs(90))
.build::<_, Incoming>(https);
let max_concurrency = std::env::var("TINY_PROXY_MAX_CONCURRENCY")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or_else(|| num_cpus::get() * 256);
let semaphore = Arc::new(Semaphore::new(max_concurrency));
info!(
"Proxy initialized with max_concurrency={} (default: {})",
max_concurrency,
num_cpus::get() * 256
);
Self {
config,
client,
max_concurrency,
semaphore,
}
}
pub async fn start(&self, addr: &str) -> anyhow::Result<()> {
let addr: SocketAddr = addr.parse()?;
self.start_with_addr(addr).await
}
pub async fn start_with_addr(&self, addr: SocketAddr) -> anyhow::Result<()> {
let listener = TcpListener::bind(&addr).await?;
info!("Tiny Proxy listening on http://{}", addr);
info!(
"Max concurrency: {} ({})",
self.max_concurrency,
if self.max_concurrency == num_cpus::get() * 256 {
"default"
} else {
"custom"
}
);
loop {
let (stream, remote_addr) = listener.accept().await?;
let io = TokioIo::new(stream);
let client = self.client.clone();
let config = self.config.clone();
let semaphore = self.semaphore.clone();
match semaphore.try_acquire_owned() {
Ok(permit) => {
tokio::task::spawn(async move {
let _permit = permit;
let service = service_fn(move |req| {
let client = client.clone();
let config = config.clone();
let config_clone = config.clone();
async move {
let config_guard = config_clone.read().await;
let config_snapshot = Arc::new(config_guard.clone());
drop(config_guard);
proxy(req, client, config_snapshot, remote_addr).await
}
});
let mut builder = hyper::server::conn::http1::Builder::new();
builder.keep_alive(true).pipeline_flush(false);
builder.serve_connection(io, service).await
});
}
Err(_) => {
warn!(
"Concurrency limit exceeded ({}), rejecting connection",
self.max_concurrency
);
}
}
}
}
pub fn shared_config(&self) -> Arc<RwLock<Config>> {
self.config.clone()
}
pub async fn config_snapshot(&self) -> Config {
self.config.read().await.clone()
}
pub fn max_concurrency(&self) -> usize {
self.max_concurrency
}
pub fn set_max_concurrency(&mut self, max: usize) {
self.max_concurrency = max;
self.semaphore = Arc::new(Semaphore::new(max));
info!("Max concurrency updated to {}", max);
}
pub async fn update_config(&self, config: Config) {
let mut guard = self.config.write().await;
info!("Configuration updated ({} sites)", config.sites.len());
*guard = config;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_proxy_creation() {
let config = Config {
sites: HashMap::new(),
};
let proxy = Proxy::new(config);
let rt = tokio::runtime::Runtime::new().unwrap();
let snapshot = rt.block_on(proxy.config_snapshot());
assert_eq!(snapshot.sites.len(), 0);
}
#[tokio::test]
async fn test_config_access() {
let mut config = Config {
sites: HashMap::new(),
};
config.sites.insert(
"localhost:8080".to_string(),
crate::config::SiteConfig {
address: "localhost:8080".to_string(),
directives: vec![],
},
);
let proxy = Proxy::new(config);
let snapshot = proxy.config_snapshot().await;
assert_eq!(snapshot.sites.len(), 1);
assert!(snapshot.sites.contains_key("localhost:8080"));
}
#[tokio::test]
async fn test_config_update() {
let config1 = Config {
sites: HashMap::new(),
};
let proxy = Proxy::new(config1);
let snapshot = proxy.config_snapshot().await;
assert_eq!(snapshot.sites.len(), 0);
let mut config2 = Config {
sites: HashMap::new(),
};
config2.sites.insert(
"test.local".to_string(),
crate::config::SiteConfig {
address: "test.local".to_string(),
directives: vec![],
},
);
proxy.update_config(config2).await;
let snapshot = proxy.config_snapshot().await;
assert_eq!(snapshot.sites.len(), 1);
assert!(snapshot.sites.contains_key("test.local"));
}
#[tokio::test]
async fn test_shared_config_handle() {
let config = Config {
sites: HashMap::new(),
};
let proxy = Proxy::new(config);
let handle = proxy.shared_config();
{
let mut guard = handle.write().await;
guard.sites.insert(
"shared.local".to_string(),
crate::config::SiteConfig {
address: "shared.local".to_string(),
directives: vec![],
},
);
}
let snapshot = proxy.config_snapshot().await;
assert_eq!(snapshot.sites.len(), 1);
assert!(snapshot.sites.contains_key("shared.local"));
}
#[test]
fn test_from_shared() {
let config = Config {
sites: HashMap::new(),
};
let shared = Arc::new(RwLock::new(config));
let proxy = Proxy::from_shared(shared.clone());
let rt = tokio::runtime::Runtime::new().unwrap();
{
let mut guard = rt.block_on(shared.write());
guard.sites.insert(
"from-shared.local".to_string(),
crate::config::SiteConfig {
address: "from-shared.local".to_string(),
directives: vec![],
},
);
}
let snapshot = rt.block_on(proxy.config_snapshot());
assert_eq!(snapshot.sites.len(), 1);
assert!(snapshot.sites.contains_key("from-shared.local"));
}
}