1#[cfg(all(feature = "native-tls", feature = "rustls"))]
2compile_error!("Features `native-tls` and `rustls` are mutually exclusive — enable only one.");
3
4pub mod cache;
5pub mod compression;
6pub mod config;
7pub mod control;
8pub mod path_matcher;
9pub mod proxy;
10
11use axum::{extract::Extension, Router};
12use cache::{CacheHandle, CacheStore};
13use proxy::ProxyState;
14use serde::{Deserialize, Serialize};
15use std::path::PathBuf;
16use std::sync::Arc;
17use tokio::sync::mpsc;
18
19#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(rename_all = "snake_case")]
22pub enum CacheStrategy {
23 #[default]
25 All,
26 None,
28 OnlyHtml,
30 NoImages,
32 OnlyImages,
34 OnlyAssets,
36}
37
38impl CacheStrategy {
39 pub fn allows_content_type(&self, content_type: Option<&str>) -> bool {
41 let content_type = content_type
42 .and_then(|value| value.split(';').next())
43 .map(|value| value.trim().to_ascii_lowercase());
44
45 match self {
46 Self::All => true,
47 Self::None => false,
48 Self::OnlyHtml => content_type
49 .as_deref()
50 .is_some_and(|value| value == "text/html" || value == "application/xhtml+xml"),
51 Self::NoImages => !content_type
52 .as_deref()
53 .is_some_and(|value| value.starts_with("image/")),
54 Self::OnlyImages => content_type
55 .as_deref()
56 .is_some_and(|value| value.starts_with("image/")),
57 Self::OnlyAssets => content_type.as_deref().is_some_and(|value| {
58 value.starts_with("image/")
59 || value.starts_with("font/")
60 || value == "text/css"
61 || value == "text/javascript"
62 || value == "application/javascript"
63 || value == "application/x-javascript"
64 || value == "application/json"
65 || value == "application/manifest+json"
66 || value == "application/wasm"
67 || value == "application/xml"
68 || value == "text/xml"
69 }),
70 }
71 }
72}
73
74impl std::fmt::Display for CacheStrategy {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 let value = match self {
77 Self::All => "all",
78 Self::None => "none",
79 Self::OnlyHtml => "only_html",
80 Self::NoImages => "no_images",
81 Self::OnlyImages => "only_images",
82 Self::OnlyAssets => "only_assets",
83 };
84
85 f.write_str(value)
86 }
87}
88
89#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
91#[serde(rename_all = "snake_case")]
92pub enum CompressStrategy {
93 None,
95 #[default]
97 Brotli,
98 Gzip,
100 Deflate,
102}
103
104impl std::fmt::Display for CompressStrategy {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 let value = match self {
107 Self::None => "none",
108 Self::Brotli => "brotli",
109 Self::Gzip => "gzip",
110 Self::Deflate => "deflate",
111 };
112
113 f.write_str(value)
114 }
115}
116
117#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
119#[serde(rename_all = "snake_case")]
120pub enum CacheStorageMode {
121 #[default]
123 Memory,
124 Filesystem,
126}
127
128impl std::fmt::Display for CacheStorageMode {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 let value = match self {
131 Self::Memory => "memory",
132 Self::Filesystem => "filesystem",
133 };
134
135 f.write_str(value)
136 }
137}
138
139#[derive(Clone, Debug, Default)]
141pub enum ProxyMode {
142 #[default]
145 Dynamic,
146 PreGenerate {
157 paths: Vec<String>,
159 fallthrough: bool,
162 },
163}
164
165#[derive(Clone, Debug)]
167pub struct RequestInfo<'a> {
168 pub method: &'a str,
170 pub path: &'a str,
172 pub query: &'a str,
174 pub headers: &'a axum::http::HeaderMap,
176}
177
178#[derive(Clone)]
180pub struct CreateProxyConfig {
181 pub proxy_url: String,
183
184 pub include_paths: Vec<String>,
187
188 pub exclude_paths: Vec<String>,
192
193 pub enable_websocket: bool,
197
198 pub forward_get_only: bool,
202
203 pub cache_key_fn: Arc<dyn Fn(&RequestInfo) -> String + Send + Sync>,
207 pub cache_404_capacity: usize,
209
210 pub use_404_meta: bool,
213
214 pub cache_strategy: CacheStrategy,
216
217 pub compress_strategy: CompressStrategy,
219
220 pub cache_storage_mode: CacheStorageMode,
222
223 pub cache_directory: Option<PathBuf>,
225
226 pub proxy_mode: ProxyMode,
228}
229
230impl CreateProxyConfig {
231 pub fn new(proxy_url: String) -> Self {
233 Self {
234 proxy_url,
235 include_paths: vec![],
236 exclude_paths: vec![],
237 enable_websocket: true,
238 forward_get_only: false,
239 cache_key_fn: Arc::new(|req_info| {
240 if req_info.query.is_empty() {
241 format!("{}:{}", req_info.method, req_info.path)
242 } else {
243 format!("{}:{}?{}", req_info.method, req_info.path, req_info.query)
244 }
245 }),
246 cache_404_capacity: 100,
247 use_404_meta: false,
248 cache_strategy: CacheStrategy::All,
249 compress_strategy: CompressStrategy::Brotli,
250 cache_storage_mode: CacheStorageMode::Memory,
251 cache_directory: None,
252 proxy_mode: ProxyMode::Dynamic,
253 }
254 }
255
256 pub fn with_include_paths(mut self, paths: Vec<String>) -> Self {
258 self.include_paths = paths;
259 self
260 }
261
262 pub fn with_exclude_paths(mut self, paths: Vec<String>) -> Self {
264 self.exclude_paths = paths;
265 self
266 }
267
268 pub fn with_websocket_enabled(mut self, enabled: bool) -> Self {
270 self.enable_websocket = enabled;
271 self
272 }
273
274 pub fn with_forward_get_only(mut self, enabled: bool) -> Self {
276 self.forward_get_only = enabled;
277 self
278 }
279
280 pub fn with_cache_key_fn<F>(mut self, f: F) -> Self
282 where
283 F: Fn(&RequestInfo) -> String + Send + Sync + 'static,
284 {
285 self.cache_key_fn = Arc::new(f);
286 self
287 }
288
289 pub fn with_cache_404_capacity(mut self, capacity: usize) -> Self {
291 self.cache_404_capacity = capacity;
292 self
293 }
294
295 pub fn with_use_404_meta(mut self, enabled: bool) -> Self {
297 self.use_404_meta = enabled;
298 self
299 }
300
301 pub fn with_cache_strategy(mut self, strategy: CacheStrategy) -> Self {
303 self.cache_strategy = strategy;
304 self
305 }
306
307 pub fn caching_strategy(self, strategy: CacheStrategy) -> Self {
309 self.with_cache_strategy(strategy)
310 }
311
312 pub fn with_compress_strategy(mut self, strategy: CompressStrategy) -> Self {
314 self.compress_strategy = strategy;
315 self
316 }
317
318 pub fn compression_strategy(self, strategy: CompressStrategy) -> Self {
320 self.with_compress_strategy(strategy)
321 }
322
323 pub fn with_cache_storage_mode(mut self, mode: CacheStorageMode) -> Self {
325 self.cache_storage_mode = mode;
326 self
327 }
328
329 pub fn with_cache_directory(mut self, directory: impl Into<PathBuf>) -> Self {
331 self.cache_directory = Some(directory.into());
332 self
333 }
334
335 pub fn with_proxy_mode(mut self, mode: ProxyMode) -> Self {
338 self.proxy_mode = mode;
339 self
340 }
341}
342
343pub fn create_proxy(config: CreateProxyConfig) -> (Router, CacheHandle) {
346 let (handle, snapshot_rx) = if let ProxyMode::PreGenerate { .. } = &config.proxy_mode {
348 let (tx, rx) = mpsc::channel(32);
349 (CacheHandle::new_with_snapshots(tx), Some(rx))
350 } else {
351 (CacheHandle::new(), None)
352 };
353
354 let cache = CacheStore::with_storage(
355 handle.clone(),
356 config.cache_404_capacity,
357 config.cache_storage_mode.clone(),
358 config.cache_directory.clone(),
359 );
360
361 spawn_invalidation_listener(cache.clone());
363
364 if let (Some(rx), ProxyMode::PreGenerate { paths, .. }) =
366 (snapshot_rx, &config.proxy_mode)
367 {
368 let worker = SnapshotWorker {
369 rx,
370 cache: cache.clone(),
371 proxy_url: config.proxy_url.clone(),
372 compress_strategy: config.compress_strategy.clone(),
373 cache_key_fn: config.cache_key_fn.clone(),
374 snapshots: paths.clone(),
375 };
376 tokio::spawn(worker.run());
377 }
378
379 let proxy_state = Arc::new(ProxyState::new(cache, config));
380
381 let app = Router::new()
382 .fallback(proxy::proxy_handler)
383 .layer(Extension(proxy_state));
384
385 (app, handle)
386}
387
388pub fn create_proxy_with_handle(config: CreateProxyConfig, handle: CacheHandle) -> Router {
395 let cache = CacheStore::with_storage(
396 handle,
397 config.cache_404_capacity,
398 config.cache_storage_mode.clone(),
399 config.cache_directory.clone(),
400 );
401
402 spawn_invalidation_listener(cache.clone());
404
405 let proxy_state = Arc::new(ProxyState::new(cache, config));
406
407 Router::new()
408 .fallback(proxy::proxy_handler)
409 .layer(Extension(proxy_state))
410}
411
412fn spawn_invalidation_listener(cache: CacheStore) {
414 let mut receiver = cache.handle().subscribe();
415
416 tokio::spawn(async move {
417 loop {
418 match receiver.recv().await {
419 Ok(cache::InvalidationMessage::All) => {
420 tracing::debug!("Cache invalidation triggered: clearing all entries");
421 cache.clear().await;
422 }
423 Ok(cache::InvalidationMessage::Pattern(pattern)) => {
424 tracing::debug!(
425 "Cache invalidation triggered: clearing entries matching pattern '{}'",
426 pattern
427 );
428 cache.clear_by_pattern(&pattern).await;
429 }
430 Err(e) => {
431 tracing::error!("Invalidation channel error: {}", e);
432 break;
433 }
434 }
435 }
436 });
437}
438
439struct SnapshotWorker {
442 rx: mpsc::Receiver<cache::SnapshotRequest>,
443 cache: CacheStore,
444 proxy_url: String,
445 compress_strategy: CompressStrategy,
446 cache_key_fn: Arc<dyn Fn(&RequestInfo) -> String + Send + Sync>,
447 snapshots: Vec<String>,
449}
450
451impl SnapshotWorker {
452 async fn run(mut self) {
453 let initial = self.snapshots.clone();
455 for path in &initial {
456 if let Err(e) = self.fetch_and_store(path).await {
457 tracing::warn!("Failed to pre-generate snapshot '{}': {}", path, e);
458 }
459 }
460
461 while let Some(req) = self.rx.recv().await {
463 match req.op {
464 cache::SnapshotOp::Add(path) => {
465 match self.fetch_and_store(&path).await {
466 Ok(()) => self.snapshots.push(path),
467 Err(e) => tracing::warn!("add_snapshot '{}' failed: {}", path, e),
468 }
469 }
470 cache::SnapshotOp::Refresh(path) => {
471 if let Err(e) = self.fetch_and_store(&path).await {
472 tracing::warn!("refresh_snapshot '{}' failed: {}", path, e);
473 }
474 }
475 cache::SnapshotOp::Remove(path) => {
476 let empty_headers = axum::http::HeaderMap::new();
477 let req_info = RequestInfo {
478 method: "GET",
479 path: &path,
480 query: "",
481 headers: &empty_headers,
482 };
483 let key = (self.cache_key_fn)(&req_info);
484 self.cache.clear_by_pattern(&key).await;
485 self.snapshots.retain(|s| s != &path);
486 }
487 cache::SnapshotOp::RefreshAll => {
488 let paths: Vec<String> = self.snapshots.clone();
489 for path in &paths {
490 if let Err(e) = self.fetch_and_store(path).await {
491 tracing::warn!("refresh_all_snapshots '{}' failed: {}", path, e);
492 }
493 }
494 }
495 }
496 let _ = req.done.send(());
498 }
499 }
500
501 async fn fetch_and_store(&self, path: &str) -> anyhow::Result<()> {
502 proxy::fetch_and_cache_snapshot(
503 path,
504 &self.proxy_url,
505 &self.cache,
506 &self.compress_strategy,
507 &self.cache_key_fn,
508 )
509 .await
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_cache_strategy_content_types() {
519 assert!(CacheStrategy::All.allows_content_type(None));
520 assert!(!CacheStrategy::None.allows_content_type(Some("text/html")));
521 assert!(CacheStrategy::OnlyHtml.allows_content_type(Some("text/html; charset=utf-8")));
522 assert!(!CacheStrategy::OnlyHtml.allows_content_type(Some("image/png")));
523 assert!(CacheStrategy::NoImages.allows_content_type(Some("text/css")));
524 assert!(!CacheStrategy::NoImages.allows_content_type(Some("image/webp")));
525 assert!(CacheStrategy::OnlyImages.allows_content_type(Some("image/svg+xml")));
526 assert!(!CacheStrategy::OnlyImages.allows_content_type(Some("application/javascript")));
527 assert!(CacheStrategy::OnlyAssets.allows_content_type(Some("application/javascript")));
528 assert!(CacheStrategy::OnlyAssets.allows_content_type(Some("image/png")));
529 assert!(!CacheStrategy::OnlyAssets.allows_content_type(Some("text/html")));
530 assert!(!CacheStrategy::OnlyAssets.allows_content_type(None));
531 }
532
533 #[test]
534 fn test_compress_strategy_display() {
535 assert_eq!(CompressStrategy::default().to_string(), "brotli");
536 assert_eq!(CompressStrategy::None.to_string(), "none");
537 assert_eq!(CompressStrategy::Gzip.to_string(), "gzip");
538 assert_eq!(CompressStrategy::Deflate.to_string(), "deflate");
539 }
540
541 #[tokio::test]
542 async fn test_create_proxy() {
543 let config = CreateProxyConfig::new("http://localhost:8080".to_string());
544 assert_eq!(config.compress_strategy, CompressStrategy::Brotli);
545 let (_app, handle) = create_proxy(config);
546 handle.invalidate_all();
547 handle.invalidate("GET:/api/*");
548 }
550}