use crate::{
ratelimit::RateLimit,
server_tls::ServerTls,
tls::{Tls, parse_url},
};
use clap::{Parser, ValueEnum};
use std::{fmt::Debug, time::Duration};
use trillium::{Conn, Method, Status};
use trillium_cache::{InMemoryStorage, client::Cache};
use trillium_logger::{
Logger,
client::{ClientLogger, dev_formatter as client_dev_formatter},
dev_formatter,
};
use trillium_proxy::{
ForwardProxyConnect, Proxy, Url,
upstream::{
ConnectionCounting, ForwardProxy, IntoUpstreamSelector, RandomSelector, RoundRobin,
UpstreamSelector,
},
};
use trillium_smol::ClientConfig;
#[derive(Clone, Copy, Debug, ValueEnum, Default, PartialEq, Eq)]
enum UpstreamSelectorStrategy {
#[default]
RoundRobin,
ConnectionCounting,
Random,
Forward,
}
#[derive(Parser, Debug)]
pub struct ProxyCli {
#[arg(env, value_parser = parse_url)]
upstream: Vec<Url>,
#[arg(short, long, env, default_value_t, value_enum)]
strategy: UpstreamSelectorStrategy,
#[arg(short = 'o', long, env, default_value = "localhost")]
host: String,
#[arg(short, long, env, default_value = "8080")]
port: u16,
#[command(flatten)]
server_tls: ServerTls,
#[arg(short, long, verbatim_doc_comment, default_value_t, value_enum)]
client_tls: Tls,
#[arg(short = 'k', long, verbatim_doc_comment)]
insecure: bool,
#[arg(long)]
no_compress: bool,
#[arg(long, help_heading = "Cache")]
no_cache: bool,
#[arg(long, value_parser = parse_size, default_value = "256MiB", conflicts_with = "no_cache", help_heading = "Cache")]
cache_capacity: u64,
#[arg(long, value_parser = parse_size, default_value = "16MiB", conflicts_with = "no_cache", help_heading = "Cache")]
cache_max_body: u64,
#[arg(long, value_parser = humantime::parse_duration, conflicts_with = "no_cache", help_heading = "Cache")]
cache_time_to_idle: Option<Duration>,
#[arg(long, value_parser = humantime::parse_duration, conflicts_with = "no_cache", help_heading = "Cache")]
cache_time_to_live: Option<Duration>,
#[command(flatten)]
rate_limit: RateLimit,
#[command(flatten)]
verbose: clap_verbosity_flag::Verbosity,
}
fn parse_size(s: &str) -> Result<u64, String> {
let size = size::Size::from_str(s).map_err(|e| e.to_string())?;
u64::try_from(size.bytes()).map_err(|_| "size must not be negative".to_string())
}
impl ProxyCli {
pub fn build_upstream(&self) -> Box<dyn UpstreamSelector> {
if self.strategy == UpstreamSelectorStrategy::Forward {
if !self.upstream.is_empty() {
panic!("forward proxy does not take upstreams");
} else {
println!("Running in forward proxy mode");
}
} else if self.upstream.is_empty() {
panic!("upstream required unless --strategy forward is provided");
} else if self.upstream.len() == 1 {
let upstream = self.upstream[0].clone().into_upstream();
println!("Proxying to {upstream}");
return upstream.boxed();
} else {
println!(
"Forwarding to {} with strategy {}",
self.upstream
.iter()
.map(|u| u.as_str())
.collect::<Vec<_>>()
.join(", "),
self.strategy.to_possible_value().unwrap().get_name()
);
}
match self.strategy {
UpstreamSelectorStrategy::RoundRobin => RoundRobin::new(self.upstream.clone()).boxed(),
UpstreamSelectorStrategy::ConnectionCounting => {
ConnectionCounting::new(self.upstream.clone()).boxed()
}
UpstreamSelectorStrategy::Random => RandomSelector::new(self.upstream.clone()).boxed(),
UpstreamSelectorStrategy::Forward => ForwardProxy.boxed(),
}
}
pub fn run(self) {
env_logger::Builder::new()
.filter_level(self.verbose.log_level_filter())
.init();
let cache = (!self.no_cache).then(|| {
let mut storage = InMemoryStorage::new().with_max_capacity_bytes(self.cache_capacity);
if let Some(time_to_idle) = self.cache_time_to_idle {
storage = storage.with_time_to_idle(time_to_idle);
}
if let Some(time_to_live) = self.cache_time_to_live {
storage = storage.with_time_to_live(time_to_live);
}
Cache::new(storage)
.with_max_cacheable_size(self.cache_max_body)
.shared()
});
let client = crate::tls::build_client(self.client_tls, self.insecure).with_handler((
ClientLogger::new().with_formatter(("-> ", client_dev_formatter)),
cache,
));
let server = (
Logger::new().with_formatter(("<- ", dev_formatter)),
self.rate_limit.limiter(),
(!self.no_compress).then(trillium_compression::compression),
trillium_caching_headers::caching_headers(),
if self.strategy == UpstreamSelectorStrategy::Forward {
Some((
ForwardProxyConnect::new(ClientConfig::default()),
|conn: Conn| async move {
if conn.status() == Some(Status::Ok) && conn.method() == Method::Connect {
conn.halt()
} else {
conn
}
},
))
} else {
None
},
Proxy::new(client, self.build_upstream())
.with_via_pseudonym("trillium-proxy")
.with_websocket_upgrades()
.proxy_not_found(),
);
let config = trillium_smol::config()
.with_port(self.port)
.with_host(&self.host);
self.server_tls.run_with_tls(config, server);
}
}