1#[cfg(all(feature = "native-tls", feature = "rustls"))]
2compile_error!("Features `native-tls` and `rustls` are mutually exclusive — enable only one.");
3
4pub mod cache;
5pub mod compression;
6pub mod config;
7pub mod control;
8pub mod path_matcher;
9pub mod proxy;
10
11use axum::{extract::Extension, Router};
12use cache::{CacheHandle, CacheStore};
13use proxy::ProxyState;
14use serde::{Deserialize, Serialize};
15use std::path::PathBuf;
16use std::sync::Arc;
17use tokio::sync::mpsc;
18
19#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(rename_all = "snake_case")]
22pub enum CacheStrategy {
23 #[default]
25 All,
26 None,
28 OnlyHtml,
30 NoImages,
32 OnlyImages,
34 OnlyAssets,
36}
37
38impl CacheStrategy {
39 pub fn allows_content_type(&self, content_type: Option<&str>) -> bool {
41 let content_type = content_type
42 .and_then(|value| value.split(';').next())
43 .map(|value| value.trim().to_ascii_lowercase());
44
45 match self {
46 Self::All => true,
47 Self::None => false,
48 Self::OnlyHtml => content_type
49 .as_deref()
50 .is_some_and(|value| value == "text/html" || value == "application/xhtml+xml"),
51 Self::NoImages => !content_type
52 .as_deref()
53 .is_some_and(|value| value.starts_with("image/")),
54 Self::OnlyImages => content_type
55 .as_deref()
56 .is_some_and(|value| value.starts_with("image/")),
57 Self::OnlyAssets => content_type.as_deref().is_some_and(|value| {
58 value.starts_with("image/")
59 || value.starts_with("font/")
60 || value == "text/css"
61 || value == "text/javascript"
62 || value == "application/javascript"
63 || value == "application/x-javascript"
64 || value == "application/json"
65 || value == "application/manifest+json"
66 || value == "application/wasm"
67 || value == "application/xml"
68 || value == "text/xml"
69 }),
70 }
71 }
72}
73
74impl std::fmt::Display for CacheStrategy {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 let value = match self {
77 Self::All => "all",
78 Self::None => "none",
79 Self::OnlyHtml => "only_html",
80 Self::NoImages => "no_images",
81 Self::OnlyImages => "only_images",
82 Self::OnlyAssets => "only_assets",
83 };
84
85 f.write_str(value)
86 }
87}
88
89#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
91#[serde(rename_all = "snake_case")]
92pub enum CompressStrategy {
93 None,
95 #[default]
97 Brotli,
98 Gzip,
100 Deflate,
102}
103
104impl std::fmt::Display for CompressStrategy {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 let value = match self {
107 Self::None => "none",
108 Self::Brotli => "brotli",
109 Self::Gzip => "gzip",
110 Self::Deflate => "deflate",
111 };
112
113 f.write_str(value)
114 }
115}
116
117#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
119#[serde(rename_all = "snake_case")]
120pub enum CacheStorageMode {
121 #[default]
123 Memory,
124 Filesystem,
126}
127
128impl std::fmt::Display for CacheStorageMode {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 let value = match self {
131 Self::Memory => "memory",
132 Self::Filesystem => "filesystem",
133 };
134
135 f.write_str(value)
136 }
137}
138
139#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
141#[serde(rename_all = "snake_case")]
142pub enum WebhookType {
143 Blocking,
147 #[default]
149 Notify,
150 CacheKey,
154}
155
156#[derive(Clone, Debug, Serialize, Deserialize)]
158pub struct WebhookConfig {
159 pub url: String,
161
162 #[serde(rename = "type", default)]
164 pub webhook_type: WebhookType,
165
166 #[serde(default)]
169 pub timeout_ms: Option<u64>,
170}
171
172#[derive(Clone, Debug, Default)]
174pub enum ProxyMode {
175 #[default]
178 Dynamic,
179 PreGenerate {
190 paths: Vec<String>,
192 fallthrough: bool,
195 },
196}
197
198#[derive(Clone, Debug)]
200pub struct RequestInfo<'a> {
201 pub method: &'a str,
203 pub path: &'a str,
205 pub query: &'a str,
207 pub headers: &'a axum::http::HeaderMap,
209}
210
211#[derive(Clone)]
213pub struct CreateProxyConfig {
214 pub proxy_url: String,
216
217 pub include_paths: Vec<String>,
220
221 pub exclude_paths: Vec<String>,
225
226 pub enable_websocket: bool,
230
231 pub forward_get_only: bool,
235
236 pub cache_key_fn: Arc<dyn Fn(&RequestInfo) -> String + Send + Sync>,
240 pub cache_404_capacity: usize,
242
243 pub use_404_meta: bool,
246
247 pub cache_strategy: CacheStrategy,
249
250 pub compress_strategy: CompressStrategy,
252
253 pub cache_storage_mode: CacheStorageMode,
255
256 pub cache_directory: Option<PathBuf>,
258
259 pub proxy_mode: ProxyMode,
261
262 pub webhooks: Vec<WebhookConfig>,
265}
266
267impl CreateProxyConfig {
268 pub fn new(proxy_url: String) -> Self {
270 Self {
271 proxy_url,
272 include_paths: vec![],
273 exclude_paths: vec![],
274 enable_websocket: true,
275 forward_get_only: false,
276 cache_key_fn: Arc::new(|req_info| {
277 if req_info.query.is_empty() {
278 format!("{}:{}", req_info.method, req_info.path)
279 } else {
280 format!("{}:{}?{}", req_info.method, req_info.path, req_info.query)
281 }
282 }),
283 cache_404_capacity: 100,
284 use_404_meta: false,
285 cache_strategy: CacheStrategy::All,
286 compress_strategy: CompressStrategy::Brotli,
287 cache_storage_mode: CacheStorageMode::Memory,
288 cache_directory: None,
289 proxy_mode: ProxyMode::Dynamic,
290 webhooks: vec![],
291 }
292 }
293
294 pub fn with_include_paths(mut self, paths: Vec<String>) -> Self {
296 self.include_paths = paths;
297 self
298 }
299
300 pub fn with_exclude_paths(mut self, paths: Vec<String>) -> Self {
302 self.exclude_paths = paths;
303 self
304 }
305
306 pub fn with_websocket_enabled(mut self, enabled: bool) -> Self {
308 self.enable_websocket = enabled;
309 self
310 }
311
312 pub fn with_forward_get_only(mut self, enabled: bool) -> Self {
314 self.forward_get_only = enabled;
315 self
316 }
317
318 pub fn with_cache_key_fn<F>(mut self, f: F) -> Self
320 where
321 F: Fn(&RequestInfo) -> String + Send + Sync + 'static,
322 {
323 self.cache_key_fn = Arc::new(f);
324 self
325 }
326
327 pub fn with_cache_404_capacity(mut self, capacity: usize) -> Self {
329 self.cache_404_capacity = capacity;
330 self
331 }
332
333 pub fn with_use_404_meta(mut self, enabled: bool) -> Self {
335 self.use_404_meta = enabled;
336 self
337 }
338
339 pub fn with_cache_strategy(mut self, strategy: CacheStrategy) -> Self {
341 self.cache_strategy = strategy;
342 self
343 }
344
345 pub fn caching_strategy(self, strategy: CacheStrategy) -> Self {
347 self.with_cache_strategy(strategy)
348 }
349
350 pub fn with_compress_strategy(mut self, strategy: CompressStrategy) -> Self {
352 self.compress_strategy = strategy;
353 self
354 }
355
356 pub fn compression_strategy(self, strategy: CompressStrategy) -> Self {
358 self.with_compress_strategy(strategy)
359 }
360
361 pub fn with_cache_storage_mode(mut self, mode: CacheStorageMode) -> Self {
363 self.cache_storage_mode = mode;
364 self
365 }
366
367 pub fn with_cache_directory(mut self, directory: impl Into<PathBuf>) -> Self {
369 self.cache_directory = Some(directory.into());
370 self
371 }
372
373 pub fn with_proxy_mode(mut self, mode: ProxyMode) -> Self {
376 self.proxy_mode = mode;
377 self
378 }
379
380 pub fn with_webhooks(mut self, webhooks: Vec<WebhookConfig>) -> Self {
383 self.webhooks = webhooks;
384 self
385 }
386}
387
388pub fn create_proxy(config: CreateProxyConfig) -> (Router, CacheHandle) {
391 let (handle, snapshot_rx) = if let ProxyMode::PreGenerate { .. } = &config.proxy_mode {
393 let (tx, rx) = mpsc::channel(32);
394 (CacheHandle::new_with_snapshots(tx), Some(rx))
395 } else {
396 (CacheHandle::new(), None)
397 };
398
399 let cache = CacheStore::with_storage(
400 handle.clone(),
401 config.cache_404_capacity,
402 config.cache_storage_mode.clone(),
403 config.cache_directory.clone(),
404 );
405
406 spawn_invalidation_listener(cache.clone());
408
409 if let (Some(rx), ProxyMode::PreGenerate { paths, .. }) =
411 (snapshot_rx, &config.proxy_mode)
412 {
413 let worker = SnapshotWorker {
414 rx,
415 cache: cache.clone(),
416 proxy_url: config.proxy_url.clone(),
417 compress_strategy: config.compress_strategy.clone(),
418 cache_key_fn: config.cache_key_fn.clone(),
419 snapshots: paths.clone(),
420 };
421 tokio::spawn(worker.run());
422 }
423
424 let proxy_state = Arc::new(ProxyState::new(cache, config));
425
426 let app = Router::new()
427 .fallback(proxy::proxy_handler)
428 .layer(Extension(proxy_state));
429
430 (app, handle)
431}
432
433pub fn create_proxy_with_handle(config: CreateProxyConfig, handle: CacheHandle) -> Router {
440 let cache = CacheStore::with_storage(
441 handle,
442 config.cache_404_capacity,
443 config.cache_storage_mode.clone(),
444 config.cache_directory.clone(),
445 );
446
447 spawn_invalidation_listener(cache.clone());
449
450 let proxy_state = Arc::new(ProxyState::new(cache, config));
451
452 Router::new()
453 .fallback(proxy::proxy_handler)
454 .layer(Extension(proxy_state))
455}
456
457fn spawn_invalidation_listener(cache: CacheStore) {
459 let mut receiver = cache.handle().subscribe();
460
461 tokio::spawn(async move {
462 loop {
463 match receiver.recv().await {
464 Ok(cache::InvalidationMessage::All) => {
465 tracing::debug!("Cache invalidation triggered: clearing all entries");
466 cache.clear().await;
467 }
468 Ok(cache::InvalidationMessage::Pattern(pattern)) => {
469 tracing::debug!(
470 "Cache invalidation triggered: clearing entries matching pattern '{}'",
471 pattern
472 );
473 cache.clear_by_pattern(&pattern).await;
474 }
475 Err(e) => {
476 tracing::error!("Invalidation channel error: {}", e);
477 break;
478 }
479 }
480 }
481 });
482}
483
484struct SnapshotWorker {
487 rx: mpsc::Receiver<cache::SnapshotRequest>,
488 cache: CacheStore,
489 proxy_url: String,
490 compress_strategy: CompressStrategy,
491 cache_key_fn: Arc<dyn Fn(&RequestInfo) -> String + Send + Sync>,
492 snapshots: Vec<String>,
494}
495
496impl SnapshotWorker {
497 async fn run(mut self) {
498 let initial = self.snapshots.clone();
500 for path in &initial {
501 if let Err(e) = self.fetch_and_store(path).await {
502 tracing::warn!("Failed to pre-generate snapshot '{}': {}", path, e);
503 }
504 }
505
506 while let Some(req) = self.rx.recv().await {
508 match req.op {
509 cache::SnapshotOp::Add(path) => {
510 match self.fetch_and_store(&path).await {
511 Ok(()) => self.snapshots.push(path),
512 Err(e) => tracing::warn!("add_snapshot '{}' failed: {}", path, e),
513 }
514 }
515 cache::SnapshotOp::Refresh(path) => {
516 if let Err(e) = self.fetch_and_store(&path).await {
517 tracing::warn!("refresh_snapshot '{}' failed: {}", path, e);
518 }
519 }
520 cache::SnapshotOp::Remove(path) => {
521 let empty_headers = axum::http::HeaderMap::new();
522 let req_info = RequestInfo {
523 method: "GET",
524 path: &path,
525 query: "",
526 headers: &empty_headers,
527 };
528 let key = (self.cache_key_fn)(&req_info);
529 self.cache.clear_by_pattern(&key).await;
530 self.snapshots.retain(|s| s != &path);
531 }
532 cache::SnapshotOp::RefreshAll => {
533 let paths: Vec<String> = self.snapshots.clone();
534 for path in &paths {
535 if let Err(e) = self.fetch_and_store(path).await {
536 tracing::warn!("refresh_all_snapshots '{}' failed: {}", path, e);
537 }
538 }
539 }
540 }
541 let _ = req.done.send(());
543 }
544 }
545
546 async fn fetch_and_store(&self, path: &str) -> anyhow::Result<()> {
547 proxy::fetch_and_cache_snapshot(
548 path,
549 &self.proxy_url,
550 &self.cache,
551 &self.compress_strategy,
552 &self.cache_key_fn,
553 )
554 .await
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561
562 #[test]
563 fn test_cache_strategy_content_types() {
564 assert!(CacheStrategy::All.allows_content_type(None));
565 assert!(!CacheStrategy::None.allows_content_type(Some("text/html")));
566 assert!(CacheStrategy::OnlyHtml.allows_content_type(Some("text/html; charset=utf-8")));
567 assert!(!CacheStrategy::OnlyHtml.allows_content_type(Some("image/png")));
568 assert!(CacheStrategy::NoImages.allows_content_type(Some("text/css")));
569 assert!(!CacheStrategy::NoImages.allows_content_type(Some("image/webp")));
570 assert!(CacheStrategy::OnlyImages.allows_content_type(Some("image/svg+xml")));
571 assert!(!CacheStrategy::OnlyImages.allows_content_type(Some("application/javascript")));
572 assert!(CacheStrategy::OnlyAssets.allows_content_type(Some("application/javascript")));
573 assert!(CacheStrategy::OnlyAssets.allows_content_type(Some("image/png")));
574 assert!(!CacheStrategy::OnlyAssets.allows_content_type(Some("text/html")));
575 assert!(!CacheStrategy::OnlyAssets.allows_content_type(None));
576 }
577
578 #[test]
579 fn test_compress_strategy_display() {
580 assert_eq!(CompressStrategy::default().to_string(), "brotli");
581 assert_eq!(CompressStrategy::None.to_string(), "none");
582 assert_eq!(CompressStrategy::Gzip.to_string(), "gzip");
583 assert_eq!(CompressStrategy::Deflate.to_string(), "deflate");
584 }
585
586 #[tokio::test]
587 async fn test_create_proxy() {
588 let config = CreateProxyConfig::new("http://localhost:8080".to_string());
589 assert_eq!(config.compress_strategy, CompressStrategy::Brotli);
590 let (_app, handle) = create_proxy(config);
591 handle.invalidate_all();
592 handle.invalidate("GET:/api/*");
593 }
595}