axum_conf/fluent/control.rs
1//! Traffic control middleware: rate limiting and panic catching.
2
3use super::router::FluentRouter;
4use crate::HttpMiddleware;
5
6use {
7 http::{Response, StatusCode},
8 tower_http::catch_panic::CatchPanicLayer,
9};
10
11#[cfg(feature = "rate-limiting")]
12use {
13 std::time::Duration,
14 tokio_util::task::AbortOnDropHandle,
15 tower_governor::{GovernorLayer, governor::GovernorConfigBuilder},
16};
17
18impl<State> FluentRouter<State>
19where
20 State: Clone + Send + Sync + 'static,
21{
22 /// Sets up per-IP rate limiting middleware.
23 ///
24 /// Limits the number of requests per second from each IP address. When the
25 /// limit is exceeded, requests receive a `429 Too Many Requests` response.
26 ///
27 /// # Configuration
28 ///
29 /// ```toml
30 /// [http]
31 /// max_requests_per_sec = 100 # Default
32 /// ```
33 ///
34 /// # Implementation
35 ///
36 /// Uses the token bucket algorithm:
37 /// - Each IP gets a bucket with tokens
38 /// - Each request consumes one token
39 /// - Tokens refill at the configured rate
40 ///
41 /// # Notes
42 ///
43 /// Rate limiting is per IP address. Behind a reverse proxy, ensure the
44 /// client's real IP is forwarded correctly.
45 ///
46 /// This middleware is automatically included in `setup_middleware()` as one of the
47 /// outermost layers to reject excessive traffic early.
48 #[cfg(feature = "rate-limiting")]
49 #[must_use]
50 pub fn setup_rate_limiting(mut self) -> Self {
51 // Skip rate limiting if max_requests_per_sec is 0
52 // This is useful for tests using oneshot() which don't have ConnectInfo<SocketAddr>
53 if self.config.http.max_requests_per_sec > 0
54 && self.is_middleware_enabled(HttpMiddleware::RateLimiting)
55 {
56 tracing::trace!(
57 max_requests_per_sec = self.config.http.max_requests_per_sec,
58 "RateLimiting middleware enabled"
59 );
60 // Used for rate limiting below
61 let governor_conf = Box::new(
62 GovernorConfigBuilder::default()
63 .per_nanosecond((1_000_000_000 / self.config.http.max_requests_per_sec) as u64)
64 .burst_size(self.config.http.max_requests_per_sec)
65 .finish()
66 .expect("Failed to build governor config for rate limiting"),
67 );
68
69 // Spawn a background thread to periodically clean up old entries
70 let governor_limiter = governor_conf.limiter().clone();
71 let interval = Duration::from_secs(60);
72
73 // Spawn a background task to clean up old entries
74 let handle = tokio::spawn(async move {
75 loop {
76 tokio::time::sleep(interval).await;
77 governor_limiter.retain_recent();
78 if !governor_limiter.is_empty() {
79 tracing::debug!("remaining rate storage size: {}", governor_limiter.len());
80 }
81 }
82 });
83
84 // Wrap the handle so that it gets cancelled when the router is dropped
85 self.governor_handle = Some(AbortOnDropHandle::new(handle));
86
87 // Add the GovernorLayer for rate limiting
88 self.inner = self.inner.layer(GovernorLayer::new(governor_conf));
89 }
90 self
91 }
92
93 /// No-op when `rate-limiting` feature is disabled.
94 #[cfg(not(feature = "rate-limiting"))]
95 #[must_use]
96 pub fn setup_rate_limiting(self) -> Self {
97 if self.config.http.max_requests_per_sec > 0 {
98 tracing::warn!(
99 "Rate limiting is configured but the 'rate-limiting' feature is not enabled. \
100 Add `rate-limiting` to your Cargo.toml features to enable rate limiting support."
101 );
102 }
103 self
104 }
105
106 /// Sets up panic catching middleware.
107 ///
108 /// Catches panics in request handlers and returns a `500 Internal Server Error`
109 /// response instead of crashing the server. Optionally sends panic details to
110 /// a notification channel if configured with `with_panic_notification_channel()`.
111 ///
112 /// # Panic Handling
113 ///
114 /// When a handler panics:
115 /// 1. Panic is caught before it crashes the server
116 /// 2. Client receives 500 response
117 /// 3. Panic message is sent to notification channel (if configured)
118 /// 4. Server continues running
119 ///
120 /// # Examples
121 ///
122 /// ```rust,no_run
123 /// # use axum_conf::{Config, FluentRouter};
124 /// # async fn example() -> axum_conf::Result<()> {
125 /// let (tx, rx) = tokio::sync::mpsc::channel(100);
126 ///
127 /// FluentRouter::without_state(Config::default())?
128 /// .with_panic_notification_channel(tx)
129 /// .setup_catch_panic();
130 /// # Ok(())
131 /// # }
132 /// ```
133 ///
134 /// # Production Use
135 ///
136 /// Essential for production to prevent panics from taking down the server.
137 /// That's why this middleware cannot be disabled.
138 ///
139 /// This middleware is automatically included in `setup_middleware()` as the
140 /// outermost layer to ensure ALL panics are caught.
141 #[must_use]
142 pub fn setup_catch_panic(mut self) -> Self {
143 // Note: Panic catching is critical and should generally not be disabled
144 // But we still respect the configuration for testing purposes
145 if !self.is_middleware_enabled(HttpMiddleware::CatchPanic) {
146 tracing::trace!("CatchPanic middleware skipped (disabled in config)");
147 return self;
148 }
149
150 tracing::trace!("CatchPanic middleware enabled");
151 let panic_channel = self.panic_channel.clone();
152 self.inner = self.inner.layer(CatchPanicLayer::custom(
153 move |err: Box<dyn std::any::Any + Send + 'static>| {
154 // NOTE: taken verbatime from the source of DefaultResponseForPanic
155 let msg = if let Some(s) = err.downcast_ref::<String>() {
156 format!("Service panicked: {}", s)
157 } else if let Some(s) = err.downcast_ref::<&str>() {
158 format!("Service panicked: {}", s)
159 } else {
160 "`CatchPanic` was unable to downcast the panic info".to_string()
161 };
162
163 tracing::error!("Service panicked: {}", msg);
164 if let Some(ch) = &panic_channel {
165 ch.try_send(msg).ok();
166 }
167
168 // Build the final response - use unwrap_or_else to avoid panicking
169 // inside the panic handler
170 Response::builder()
171 .status(StatusCode::INTERNAL_SERVER_ERROR)
172 .header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
173 .body("Internal Server Error".to_string())
174 .unwrap_or_else(|_| Response::new("Internal Server Error".to_string()))
175 },
176 ));
177 self
178 }
179}