1#[cfg(all(feature = "native-tls", feature = "rustls"))]
2compile_error!("Features `native-tls` and `rustls` are mutually exclusive — enable only one.");
3
4pub mod cache;
5pub mod compression;
6pub mod config;
7pub mod control;
8pub mod path_matcher;
9pub mod proxy;
10
11use axum::{extract::Extension, Router};
12use cache::{CacheHandle, CacheStore};
13use proxy::ProxyState;
14use serde::{Deserialize, Serialize};
15use std::path::PathBuf;
16use std::sync::Arc;
17use tokio::sync::mpsc;
18
19#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(rename_all = "snake_case")]
22pub enum CacheStrategy {
23 #[default]
25 All,
26 None,
28 OnlyHtml,
30 NoImages,
32 OnlyImages,
34 OnlyAssets,
36}
37
38impl CacheStrategy {
39 pub fn allows_content_type(&self, content_type: Option<&str>) -> bool {
41 let content_type = content_type
42 .and_then(|value| value.split(';').next())
43 .map(|value| value.trim().to_ascii_lowercase());
44
45 match self {
46 Self::All => true,
47 Self::None => false,
48 Self::OnlyHtml => content_type
49 .as_deref()
50 .is_some_and(|value| value == "text/html" || value == "application/xhtml+xml"),
51 Self::NoImages => !content_type
52 .as_deref()
53 .is_some_and(|value| value.starts_with("image/")),
54 Self::OnlyImages => content_type
55 .as_deref()
56 .is_some_and(|value| value.starts_with("image/")),
57 Self::OnlyAssets => content_type.as_deref().is_some_and(|value| {
58 value.starts_with("image/")
59 || value.starts_with("font/")
60 || value == "text/css"
61 || value == "text/javascript"
62 || value == "application/javascript"
63 || value == "application/x-javascript"
64 || value == "application/json"
65 || value == "application/manifest+json"
66 || value == "application/wasm"
67 || value == "application/xml"
68 || value == "text/xml"
69 }),
70 }
71 }
72}
73
74impl std::fmt::Display for CacheStrategy {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 let value = match self {
77 Self::All => "all",
78 Self::None => "none",
79 Self::OnlyHtml => "only_html",
80 Self::NoImages => "no_images",
81 Self::OnlyImages => "only_images",
82 Self::OnlyAssets => "only_assets",
83 };
84
85 f.write_str(value)
86 }
87}
88
89#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
91#[serde(rename_all = "snake_case")]
92pub enum CompressStrategy {
93 None,
95 #[default]
97 Brotli,
98 Gzip,
100 Deflate,
102}
103
104impl std::fmt::Display for CompressStrategy {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 let value = match self {
107 Self::None => "none",
108 Self::Brotli => "brotli",
109 Self::Gzip => "gzip",
110 Self::Deflate => "deflate",
111 };
112
113 f.write_str(value)
114 }
115}
116
117#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
119#[serde(rename_all = "snake_case")]
120pub enum CacheStorageMode {
121 #[default]
123 Memory,
124 Filesystem,
126}
127
128impl std::fmt::Display for CacheStorageMode {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 let value = match self {
131 Self::Memory => "memory",
132 Self::Filesystem => "filesystem",
133 };
134
135 f.write_str(value)
136 }
137}
138
139#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
141#[serde(rename_all = "snake_case")]
142pub enum WebhookType {
143 Blocking,
147 #[default]
149 Notify,
150 CacheKey,
154}
155
156#[derive(Clone, Debug, Serialize, Deserialize)]
158pub struct WebhookConfig {
159 pub url: String,
161
162 #[serde(rename = "type", default)]
164 pub webhook_type: WebhookType,
165
166 #[serde(default)]
169 pub timeout_ms: Option<u64>,
170}
171
172#[derive(Clone, Debug, Default)]
174pub enum ProxyMode {
175 #[default]
178 Dynamic,
179 PreGenerate {
190 paths: Vec<String>,
192 fallthrough: bool,
195 },
196}
197
198#[derive(Clone, Debug)]
200pub struct RequestInfo<'a> {
201 pub method: &'a str,
203 pub path: &'a str,
205 pub query: &'a str,
207 pub headers: &'a axum::http::HeaderMap,
209}
210
211#[derive(Clone)]
213pub struct CreateProxyConfig {
214 pub proxy_url: String,
216
217 pub include_paths: Vec<String>,
220
221 pub exclude_paths: Vec<String>,
225
226 pub enable_websocket: bool,
230
231 pub forward_get_only: bool,
235
236 pub cache_key_fn: Arc<dyn Fn(&RequestInfo) -> String + Send + Sync>,
240 pub cache_404_capacity: usize,
242
243 pub use_404_meta: bool,
246
247 pub cache_strategy: CacheStrategy,
249
250 pub compress_strategy: CompressStrategy,
252
253 pub cache_storage_mode: CacheStorageMode,
255
256 pub cache_directory: Option<PathBuf>,
258
259 pub proxy_mode: ProxyMode,
261
262 pub webhooks: Vec<WebhookConfig>,
265}
266
267impl CreateProxyConfig {
268 pub fn new(proxy_url: String) -> Self {
270 Self {
271 proxy_url,
272 include_paths: vec![],
273 exclude_paths: vec![],
274 enable_websocket: true,
275 forward_get_only: false,
276 cache_key_fn: Arc::new(|req_info| {
277 if req_info.query.is_empty() {
278 format!("{}:{}", req_info.method, req_info.path)
279 } else {
280 format!("{}:{}?{}", req_info.method, req_info.path, req_info.query)
281 }
282 }),
283 cache_404_capacity: 100,
284 use_404_meta: false,
285 cache_strategy: CacheStrategy::All,
286 compress_strategy: CompressStrategy::Brotli,
287 cache_storage_mode: CacheStorageMode::Memory,
288 cache_directory: None,
289 proxy_mode: ProxyMode::Dynamic,
290 webhooks: vec![],
291 }
292 }
293
294 pub fn with_include_paths(mut self, paths: Vec<String>) -> Self {
296 self.include_paths = paths;
297 self
298 }
299
300 pub fn with_exclude_paths(mut self, paths: Vec<String>) -> Self {
302 self.exclude_paths = paths;
303 self
304 }
305
306 pub fn with_websocket_enabled(mut self, enabled: bool) -> Self {
308 self.enable_websocket = enabled;
309 self
310 }
311
312 pub fn with_forward_get_only(mut self, enabled: bool) -> Self {
314 self.forward_get_only = enabled;
315 self
316 }
317
318 pub fn with_cache_key_fn<F>(mut self, f: F) -> Self
320 where
321 F: Fn(&RequestInfo) -> String + Send + Sync + 'static,
322 {
323 self.cache_key_fn = Arc::new(f);
324 self
325 }
326
327 pub fn with_cache_404_capacity(mut self, capacity: usize) -> Self {
329 self.cache_404_capacity = capacity;
330 self
331 }
332
333 pub fn with_use_404_meta(mut self, enabled: bool) -> Self {
335 self.use_404_meta = enabled;
336 self
337 }
338
339 pub fn with_cache_strategy(mut self, strategy: CacheStrategy) -> Self {
341 self.cache_strategy = strategy;
342 self
343 }
344
345 pub fn caching_strategy(self, strategy: CacheStrategy) -> Self {
347 self.with_cache_strategy(strategy)
348 }
349
350 pub fn with_compress_strategy(mut self, strategy: CompressStrategy) -> Self {
352 self.compress_strategy = strategy;
353 self
354 }
355
356 pub fn compression_strategy(self, strategy: CompressStrategy) -> Self {
358 self.with_compress_strategy(strategy)
359 }
360
361 pub fn with_cache_storage_mode(mut self, mode: CacheStorageMode) -> Self {
363 self.cache_storage_mode = mode;
364 self
365 }
366
367 pub fn with_cache_directory(mut self, directory: impl Into<PathBuf>) -> Self {
369 self.cache_directory = Some(directory.into());
370 self
371 }
372
373 pub fn with_proxy_mode(mut self, mode: ProxyMode) -> Self {
376 self.proxy_mode = mode;
377 self
378 }
379
380 pub fn with_webhooks(mut self, webhooks: Vec<WebhookConfig>) -> Self {
383 self.webhooks = webhooks;
384 self
385 }
386}
387
388pub fn create_proxy(config: CreateProxyConfig) -> (Router, CacheHandle) {
391 let upstream_client =
392 proxy::build_upstream_client().expect("failed to build shared upstream HTTP client");
393 let webhook_client =
394 proxy::build_webhook_client().expect("failed to build shared webhook HTTP client");
395
396 let (handle, snapshot_rx) = if let ProxyMode::PreGenerate { .. } = &config.proxy_mode {
398 let (tx, rx) = mpsc::channel(32);
399 (CacheHandle::new_with_snapshots(tx), Some(rx))
400 } else {
401 (CacheHandle::new(), None)
402 };
403
404 let cache = CacheStore::with_storage(
405 handle.clone(),
406 config.cache_404_capacity,
407 config.cache_storage_mode.clone(),
408 config.cache_directory.clone(),
409 );
410
411 spawn_invalidation_listener(cache.clone());
413
414 if let (Some(rx), ProxyMode::PreGenerate { paths, .. }) = (snapshot_rx, &config.proxy_mode) {
416 let worker = SnapshotWorker {
417 rx,
418 cache: cache.clone(),
419 upstream_client: upstream_client.clone(),
420 proxy_url: config.proxy_url.clone(),
421 compress_strategy: config.compress_strategy.clone(),
422 cache_key_fn: config.cache_key_fn.clone(),
423 snapshots: paths.clone(),
424 };
425 tokio::spawn(worker.run());
426 }
427
428 let proxy_state = Arc::new(ProxyState::new(
429 cache,
430 config,
431 upstream_client,
432 webhook_client,
433 ));
434
435 let app = Router::new()
436 .fallback(proxy::proxy_handler)
437 .layer(Extension(proxy_state));
438
439 (app, handle)
440}
441
442pub fn create_proxy_with_handle(config: CreateProxyConfig, handle: CacheHandle) -> Router {
449 let upstream_client =
450 proxy::build_upstream_client().expect("failed to build shared upstream HTTP client");
451 let webhook_client =
452 proxy::build_webhook_client().expect("failed to build shared webhook HTTP client");
453
454 let cache = CacheStore::with_storage(
455 handle,
456 config.cache_404_capacity,
457 config.cache_storage_mode.clone(),
458 config.cache_directory.clone(),
459 );
460
461 spawn_invalidation_listener(cache.clone());
463
464 let proxy_state = Arc::new(ProxyState::new(
465 cache,
466 config,
467 upstream_client,
468 webhook_client,
469 ));
470
471 Router::new()
472 .fallback(proxy::proxy_handler)
473 .layer(Extension(proxy_state))
474}
475
476fn spawn_invalidation_listener(cache: CacheStore) {
478 let mut receiver = cache.handle().subscribe();
479
480 tokio::spawn(async move {
481 loop {
482 match receiver.recv().await {
483 Ok(cache::InvalidationMessage::All) => {
484 tracing::debug!("Cache invalidation triggered: clearing all entries");
485 cache.clear().await;
486 }
487 Ok(cache::InvalidationMessage::Pattern(pattern)) => {
488 tracing::debug!(
489 "Cache invalidation triggered: clearing entries matching pattern '{}'",
490 pattern
491 );
492 cache.clear_by_pattern(&pattern).await;
493 }
494 Err(e) => {
495 tracing::error!("Invalidation channel error: {}", e);
496 break;
497 }
498 }
499 }
500 });
501}
502
503struct SnapshotWorker {
506 rx: mpsc::Receiver<cache::SnapshotRequest>,
507 cache: CacheStore,
508 upstream_client: reqwest::Client,
509 proxy_url: String,
510 compress_strategy: CompressStrategy,
511 cache_key_fn: Arc<dyn Fn(&RequestInfo) -> String + Send + Sync>,
512 snapshots: Vec<String>,
514}
515
516impl SnapshotWorker {
517 async fn run(mut self) {
518 let initial = self.snapshots.clone();
520 for path in &initial {
521 if let Err(e) = self.fetch_and_store(path).await {
522 tracing::warn!("Failed to pre-generate snapshot '{}': {}", path, e);
523 }
524 }
525
526 while let Some(req) = self.rx.recv().await {
528 match req.op {
529 cache::SnapshotOp::Add(path) => match self.fetch_and_store(&path).await {
530 Ok(()) => self.snapshots.push(path),
531 Err(e) => tracing::warn!("add_snapshot '{}' failed: {}", path, e),
532 },
533 cache::SnapshotOp::Refresh(path) => {
534 if let Err(e) = self.fetch_and_store(&path).await {
535 tracing::warn!("refresh_snapshot '{}' failed: {}", path, e);
536 }
537 }
538 cache::SnapshotOp::Remove(path) => {
539 let empty_headers = axum::http::HeaderMap::new();
540 let req_info = RequestInfo {
541 method: "GET",
542 path: &path,
543 query: "",
544 headers: &empty_headers,
545 };
546 let key = (self.cache_key_fn)(&req_info);
547 self.cache.clear_by_pattern(&key).await;
548 self.snapshots.retain(|s| s != &path);
549 }
550 cache::SnapshotOp::RefreshAll => {
551 let paths: Vec<String> = self.snapshots.clone();
552 for path in &paths {
553 if let Err(e) = self.fetch_and_store(path).await {
554 tracing::warn!("refresh_all_snapshots '{}' failed: {}", path, e);
555 }
556 }
557 }
558 }
559 let _ = req.done.send(());
561 }
562 }
563
564 async fn fetch_and_store(&self, path: &str) -> anyhow::Result<()> {
565 proxy::fetch_and_cache_snapshot(
566 path,
567 &self.upstream_client,
568 &self.proxy_url,
569 &self.cache,
570 &self.compress_strategy,
571 &self.cache_key_fn,
572 )
573 .await
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[test]
582 fn test_cache_strategy_content_types() {
583 assert!(CacheStrategy::All.allows_content_type(None));
584 assert!(!CacheStrategy::None.allows_content_type(Some("text/html")));
585 assert!(CacheStrategy::OnlyHtml.allows_content_type(Some("text/html; charset=utf-8")));
586 assert!(!CacheStrategy::OnlyHtml.allows_content_type(Some("image/png")));
587 assert!(CacheStrategy::NoImages.allows_content_type(Some("text/css")));
588 assert!(!CacheStrategy::NoImages.allows_content_type(Some("image/webp")));
589 assert!(CacheStrategy::OnlyImages.allows_content_type(Some("image/svg+xml")));
590 assert!(!CacheStrategy::OnlyImages.allows_content_type(Some("application/javascript")));
591 assert!(CacheStrategy::OnlyAssets.allows_content_type(Some("application/javascript")));
592 assert!(CacheStrategy::OnlyAssets.allows_content_type(Some("image/png")));
593 assert!(!CacheStrategy::OnlyAssets.allows_content_type(Some("text/html")));
594 assert!(!CacheStrategy::OnlyAssets.allows_content_type(None));
595 }
596
597 #[test]
598 fn test_compress_strategy_display() {
599 assert_eq!(CompressStrategy::default().to_string(), "brotli");
600 assert_eq!(CompressStrategy::None.to_string(), "none");
601 assert_eq!(CompressStrategy::Gzip.to_string(), "gzip");
602 assert_eq!(CompressStrategy::Deflate.to_string(), "deflate");
603 }
604
605 #[tokio::test]
606 async fn test_create_proxy() {
607 let config = CreateProxyConfig::new("http://localhost:8080".to_string());
608 assert_eq!(config.compress_strategy, CompressStrategy::Brotli);
609 let (_app, handle) = create_proxy(config);
610 handle.invalidate_all();
611 handle.invalidate("GET:/api/*");
612 }
614}