#[cfg(all(feature = "native-tls", feature = "rustls"))]
compile_error!("Features `native-tls` and `rustls` are mutually exclusive — enable only one.");
pub mod cache;
pub mod compression;
pub mod config;
pub mod control;
pub mod path_matcher;
pub mod proxy;
use axum::{extract::Extension, Router};
use cache::{CacheHandle, CacheStore};
use proxy::ProxyState;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::mpsc;
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CacheStrategy {
#[default]
All,
None,
OnlyHtml,
NoImages,
OnlyImages,
OnlyAssets,
}
impl CacheStrategy {
pub fn allows_content_type(&self, content_type: Option<&str>) -> bool {
let content_type = content_type
.and_then(|value| value.split(';').next())
.map(|value| value.trim().to_ascii_lowercase());
match self {
Self::All => true,
Self::None => false,
Self::OnlyHtml => content_type
.as_deref()
.is_some_and(|value| value == "text/html" || value == "application/xhtml+xml"),
Self::NoImages => !content_type
.as_deref()
.is_some_and(|value| value.starts_with("image/")),
Self::OnlyImages => content_type
.as_deref()
.is_some_and(|value| value.starts_with("image/")),
Self::OnlyAssets => content_type.as_deref().is_some_and(|value| {
value.starts_with("image/")
|| value.starts_with("font/")
|| value == "text/css"
|| value == "text/javascript"
|| value == "application/javascript"
|| value == "application/x-javascript"
|| value == "application/json"
|| value == "application/manifest+json"
|| value == "application/wasm"
|| value == "application/xml"
|| value == "text/xml"
}),
}
}
}
impl std::fmt::Display for CacheStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let value = match self {
Self::All => "all",
Self::None => "none",
Self::OnlyHtml => "only_html",
Self::NoImages => "no_images",
Self::OnlyImages => "only_images",
Self::OnlyAssets => "only_assets",
};
f.write_str(value)
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CompressStrategy {
None,
#[default]
Brotli,
Gzip,
Deflate,
}
impl std::fmt::Display for CompressStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let value = match self {
Self::None => "none",
Self::Brotli => "brotli",
Self::Gzip => "gzip",
Self::Deflate => "deflate",
};
f.write_str(value)
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CacheStorageMode {
#[default]
Memory,
Filesystem,
}
impl std::fmt::Display for CacheStorageMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let value = match self {
Self::Memory => "memory",
Self::Filesystem => "filesystem",
};
f.write_str(value)
}
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum WebhookType {
Blocking,
#[default]
Notify,
CacheKey,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct WebhookConfig {
pub url: String,
#[serde(rename = "type", default)]
pub webhook_type: WebhookType,
#[serde(default)]
pub timeout_ms: Option<u64>,
}
#[derive(Clone, Debug, Default)]
pub enum ProxyMode {
#[default]
Dynamic,
PreGenerate {
paths: Vec<String>,
fallthrough: bool,
},
}
#[derive(Clone, Debug)]
pub struct RequestInfo<'a> {
pub method: &'a str,
pub path: &'a str,
pub query: &'a str,
pub headers: &'a axum::http::HeaderMap,
}
#[derive(Clone)]
pub struct CreateProxyConfig {
pub proxy_url: String,
pub include_paths: Vec<String>,
pub exclude_paths: Vec<String>,
pub enable_websocket: bool,
pub forward_get_only: bool,
pub cache_key_fn: Arc<dyn Fn(&RequestInfo) -> String + Send + Sync>,
pub cache_404_capacity: usize,
pub use_404_meta: bool,
pub cache_strategy: CacheStrategy,
pub compress_strategy: CompressStrategy,
pub cache_storage_mode: CacheStorageMode,
pub cache_directory: Option<PathBuf>,
pub proxy_mode: ProxyMode,
pub webhooks: Vec<WebhookConfig>,
}
impl CreateProxyConfig {
pub fn new(proxy_url: String) -> Self {
Self {
proxy_url,
include_paths: vec![],
exclude_paths: vec![],
enable_websocket: true,
forward_get_only: false,
cache_key_fn: Arc::new(|req_info| {
if req_info.query.is_empty() {
format!("{}:{}", req_info.method, req_info.path)
} else {
format!("{}:{}?{}", req_info.method, req_info.path, req_info.query)
}
}),
cache_404_capacity: 100,
use_404_meta: false,
cache_strategy: CacheStrategy::All,
compress_strategy: CompressStrategy::Brotli,
cache_storage_mode: CacheStorageMode::Memory,
cache_directory: None,
proxy_mode: ProxyMode::Dynamic,
webhooks: vec![],
}
}
pub fn with_include_paths(mut self, paths: Vec<String>) -> Self {
self.include_paths = paths;
self
}
pub fn with_exclude_paths(mut self, paths: Vec<String>) -> Self {
self.exclude_paths = paths;
self
}
pub fn with_websocket_enabled(mut self, enabled: bool) -> Self {
self.enable_websocket = enabled;
self
}
pub fn with_forward_get_only(mut self, enabled: bool) -> Self {
self.forward_get_only = enabled;
self
}
pub fn with_cache_key_fn<F>(mut self, f: F) -> Self
where
F: Fn(&RequestInfo) -> String + Send + Sync + 'static,
{
self.cache_key_fn = Arc::new(f);
self
}
pub fn with_cache_404_capacity(mut self, capacity: usize) -> Self {
self.cache_404_capacity = capacity;
self
}
pub fn with_use_404_meta(mut self, enabled: bool) -> Self {
self.use_404_meta = enabled;
self
}
pub fn with_cache_strategy(mut self, strategy: CacheStrategy) -> Self {
self.cache_strategy = strategy;
self
}
pub fn caching_strategy(self, strategy: CacheStrategy) -> Self {
self.with_cache_strategy(strategy)
}
pub fn with_compress_strategy(mut self, strategy: CompressStrategy) -> Self {
self.compress_strategy = strategy;
self
}
pub fn compression_strategy(self, strategy: CompressStrategy) -> Self {
self.with_compress_strategy(strategy)
}
pub fn with_cache_storage_mode(mut self, mode: CacheStorageMode) -> Self {
self.cache_storage_mode = mode;
self
}
pub fn with_cache_directory(mut self, directory: impl Into<PathBuf>) -> Self {
self.cache_directory = Some(directory.into());
self
}
pub fn with_proxy_mode(mut self, mode: ProxyMode) -> Self {
self.proxy_mode = mode;
self
}
pub fn with_webhooks(mut self, webhooks: Vec<WebhookConfig>) -> Self {
self.webhooks = webhooks;
self
}
}
pub fn create_proxy(config: CreateProxyConfig) -> (Router, CacheHandle) {
let upstream_client =
proxy::build_upstream_client().expect("failed to build shared upstream HTTP client");
let webhook_client =
proxy::build_webhook_client().expect("failed to build shared webhook HTTP client");
let (handle, snapshot_rx) = if let ProxyMode::PreGenerate { .. } = &config.proxy_mode {
let (tx, rx) = mpsc::channel(32);
(CacheHandle::new_with_snapshots(tx), Some(rx))
} else {
(CacheHandle::new(), None)
};
let cache = CacheStore::with_storage(
handle.clone(),
config.cache_404_capacity,
config.cache_storage_mode.clone(),
config.cache_directory.clone(),
);
spawn_invalidation_listener(cache.clone());
if let (Some(rx), ProxyMode::PreGenerate { paths, .. }) = (snapshot_rx, &config.proxy_mode) {
let worker = SnapshotWorker {
rx,
cache: cache.clone(),
upstream_client: upstream_client.clone(),
proxy_url: config.proxy_url.clone(),
compress_strategy: config.compress_strategy.clone(),
cache_key_fn: config.cache_key_fn.clone(),
snapshots: paths.clone(),
};
tokio::spawn(worker.run());
}
let proxy_state = Arc::new(ProxyState::new(
cache,
config,
upstream_client,
webhook_client,
));
let app = Router::new()
.fallback(proxy::proxy_handler)
.layer(Extension(proxy_state));
(app, handle)
}
pub fn create_proxy_with_handle(config: CreateProxyConfig, handle: CacheHandle) -> Router {
let upstream_client =
proxy::build_upstream_client().expect("failed to build shared upstream HTTP client");
let webhook_client =
proxy::build_webhook_client().expect("failed to build shared webhook HTTP client");
let cache = CacheStore::with_storage(
handle,
config.cache_404_capacity,
config.cache_storage_mode.clone(),
config.cache_directory.clone(),
);
spawn_invalidation_listener(cache.clone());
let proxy_state = Arc::new(ProxyState::new(
cache,
config,
upstream_client,
webhook_client,
));
Router::new()
.fallback(proxy::proxy_handler)
.layer(Extension(proxy_state))
}
fn spawn_invalidation_listener(cache: CacheStore) {
let mut receiver = cache.handle().subscribe();
tokio::spawn(async move {
loop {
match receiver.recv().await {
Ok(cache::InvalidationMessage::All) => {
tracing::debug!("Cache invalidation triggered: clearing all entries");
cache.clear().await;
}
Ok(cache::InvalidationMessage::Pattern(pattern)) => {
tracing::debug!(
"Cache invalidation triggered: clearing entries matching pattern '{}'",
pattern
);
cache.clear_by_pattern(&pattern).await;
}
Err(e) => {
tracing::error!("Invalidation channel error: {}", e);
break;
}
}
}
});
}
struct SnapshotWorker {
rx: mpsc::Receiver<cache::SnapshotRequest>,
cache: CacheStore,
upstream_client: reqwest::Client,
proxy_url: String,
compress_strategy: CompressStrategy,
cache_key_fn: Arc<dyn Fn(&RequestInfo) -> String + Send + Sync>,
snapshots: Vec<String>,
}
impl SnapshotWorker {
async fn run(mut self) {
let initial = self.snapshots.clone();
for path in &initial {
if let Err(e) = self.fetch_and_store(path).await {
tracing::warn!("Failed to pre-generate snapshot '{}': {}", path, e);
}
}
while let Some(req) = self.rx.recv().await {
match req.op {
cache::SnapshotOp::Add(path) => match self.fetch_and_store(&path).await {
Ok(()) => self.snapshots.push(path),
Err(e) => tracing::warn!("add_snapshot '{}' failed: {}", path, e),
},
cache::SnapshotOp::Refresh(path) => {
if let Err(e) = self.fetch_and_store(&path).await {
tracing::warn!("refresh_snapshot '{}' failed: {}", path, e);
}
}
cache::SnapshotOp::Remove(path) => {
let empty_headers = axum::http::HeaderMap::new();
let req_info = RequestInfo {
method: "GET",
path: &path,
query: "",
headers: &empty_headers,
};
let key = (self.cache_key_fn)(&req_info);
self.cache.clear_by_pattern(&key).await;
self.snapshots.retain(|s| s != &path);
}
cache::SnapshotOp::RefreshAll => {
let paths: Vec<String> = self.snapshots.clone();
for path in &paths {
if let Err(e) = self.fetch_and_store(path).await {
tracing::warn!("refresh_all_snapshots '{}' failed: {}", path, e);
}
}
}
}
let _ = req.done.send(());
}
}
async fn fetch_and_store(&self, path: &str) -> anyhow::Result<()> {
proxy::fetch_and_cache_snapshot(
path,
&self.upstream_client,
&self.proxy_url,
&self.cache,
&self.compress_strategy,
&self.cache_key_fn,
)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_strategy_content_types() {
assert!(CacheStrategy::All.allows_content_type(None));
assert!(!CacheStrategy::None.allows_content_type(Some("text/html")));
assert!(CacheStrategy::OnlyHtml.allows_content_type(Some("text/html; charset=utf-8")));
assert!(!CacheStrategy::OnlyHtml.allows_content_type(Some("image/png")));
assert!(CacheStrategy::NoImages.allows_content_type(Some("text/css")));
assert!(!CacheStrategy::NoImages.allows_content_type(Some("image/webp")));
assert!(CacheStrategy::OnlyImages.allows_content_type(Some("image/svg+xml")));
assert!(!CacheStrategy::OnlyImages.allows_content_type(Some("application/javascript")));
assert!(CacheStrategy::OnlyAssets.allows_content_type(Some("application/javascript")));
assert!(CacheStrategy::OnlyAssets.allows_content_type(Some("image/png")));
assert!(!CacheStrategy::OnlyAssets.allows_content_type(Some("text/html")));
assert!(!CacheStrategy::OnlyAssets.allows_content_type(None));
}
#[test]
fn test_compress_strategy_display() {
assert_eq!(CompressStrategy::default().to_string(), "brotli");
assert_eq!(CompressStrategy::None.to_string(), "none");
assert_eq!(CompressStrategy::Gzip.to_string(), "gzip");
assert_eq!(CompressStrategy::Deflate.to_string(), "deflate");
}
#[tokio::test]
async fn test_create_proxy() {
let config = CreateProxyConfig::new("http://localhost:8080".to_string());
assert_eq!(config.compress_strategy, CompressStrategy::Brotli);
let (_app, handle) = create_proxy(config);
handle.invalidate_all();
handle.invalidate("GET:/api/*");
}
}