axum_conf/fluent/
control.rs

1//! Traffic control middleware: rate limiting and panic catching.
2
3use super::router::FluentRouter;
4use crate::HttpMiddleware;
5
6use {
7    http::{Response, StatusCode},
8    tower_http::catch_panic::CatchPanicLayer,
9};
10
11#[cfg(feature = "rate-limiting")]
12use {
13    std::time::Duration,
14    tokio_util::task::AbortOnDropHandle,
15    tower_governor::{GovernorLayer, governor::GovernorConfigBuilder},
16};
17
18impl<State> FluentRouter<State>
19where
20    State: Clone + Send + Sync + 'static,
21{
22    /// Sets up per-IP rate limiting middleware.
23    ///
24    /// Limits the number of requests per second from each IP address. When the
25    /// limit is exceeded, requests receive a `429 Too Many Requests` response.
26    ///
27    /// # Configuration
28    ///
29    /// ```toml
30    /// [http]
31    /// max_requests_per_sec = 100  # Default
32    /// ```
33    ///
34    /// # Implementation
35    ///
36    /// Uses the token bucket algorithm:
37    /// - Each IP gets a bucket with tokens
38    /// - Each request consumes one token
39    /// - Tokens refill at the configured rate
40    ///
41    /// # Notes
42    ///
43    /// Rate limiting is per IP address. Behind a reverse proxy, ensure the
44    /// client's real IP is forwarded correctly.
45    ///
46    /// This middleware is automatically included in `setup_middleware()` as one of the
47    /// outermost layers to reject excessive traffic early.
48    #[cfg(feature = "rate-limiting")]
49    #[must_use]
50    pub fn setup_rate_limiting(mut self) -> Self {
51        // Skip rate limiting if max_requests_per_sec is 0
52        // This is useful for tests using oneshot() which don't have ConnectInfo<SocketAddr>
53        if self.config.http.max_requests_per_sec > 0
54            && self.is_middleware_enabled(HttpMiddleware::RateLimiting)
55        {
56            tracing::trace!(
57                max_requests_per_sec = self.config.http.max_requests_per_sec,
58                "RateLimiting middleware enabled"
59            );
60            // Used for rate limiting below
61            let governor_conf = Box::new(
62                GovernorConfigBuilder::default()
63                    .per_nanosecond((1_000_000_000 / self.config.http.max_requests_per_sec) as u64)
64                    .burst_size(self.config.http.max_requests_per_sec)
65                    .finish()
66                    .expect("Failed to build governor config for rate limiting"),
67            );
68
69            // Spawn a background thread to periodically clean up old entries
70            let governor_limiter = governor_conf.limiter().clone();
71            let interval = Duration::from_secs(60);
72
73            // Spawn a background task to clean up old entries
74            let handle = tokio::spawn(async move {
75                loop {
76                    tokio::time::sleep(interval).await;
77                    governor_limiter.retain_recent();
78                    if !governor_limiter.is_empty() {
79                        tracing::debug!("remaining rate storage size: {}", governor_limiter.len());
80                    }
81                }
82            });
83
84            // Wrap the handle so that it gets cancelled when the router is dropped
85            self.governor_handle = Some(AbortOnDropHandle::new(handle));
86
87            // Add the GovernorLayer for rate limiting
88            self.inner = self.inner.layer(GovernorLayer::new(governor_conf));
89        }
90        self
91    }
92
93    /// No-op when `rate-limiting` feature is disabled.
94    #[cfg(not(feature = "rate-limiting"))]
95    #[must_use]
96    pub fn setup_rate_limiting(self) -> Self {
97        if self.config.http.max_requests_per_sec > 0 {
98            tracing::warn!(
99                "Rate limiting is configured but the 'rate-limiting' feature is not enabled. \
100                 Add `rate-limiting` to your Cargo.toml features to enable rate limiting support."
101            );
102        }
103        self
104    }
105
106    /// Sets up panic catching middleware.
107    ///
108    /// Catches panics in request handlers and returns a `500 Internal Server Error`
109    /// response instead of crashing the server. Optionally sends panic details to
110    /// a notification channel if configured with `with_panic_notification_channel()`.
111    ///
112    /// # Panic Handling
113    ///
114    /// When a handler panics:
115    /// 1. Panic is caught before it crashes the server
116    /// 2. Client receives 500 response
117    /// 3. Panic message is sent to notification channel (if configured)
118    /// 4. Server continues running
119    ///
120    /// # Examples
121    ///
122    /// ```rust,no_run
123    /// # use axum_conf::{Config, FluentRouter};
124    /// # async fn example() -> axum_conf::Result<()> {
125    /// let (tx, rx) = tokio::sync::mpsc::channel(100);
126    ///
127    /// FluentRouter::without_state(Config::default())?
128    ///     .with_panic_notification_channel(tx)
129    ///     .setup_catch_panic();
130    /// # Ok(())
131    /// # }
132    /// ```
133    ///
134    /// # Production Use
135    ///
136    /// Essential for production to prevent panics from taking down the server.
137    /// That's why this middleware cannot be disabled.
138    ///
139    /// This middleware is automatically included in `setup_middleware()` as the
140    /// outermost layer to ensure ALL panics are caught.
141    #[must_use]
142    pub fn setup_catch_panic(mut self) -> Self {
143        // Note: Panic catching is critical and should generally not be disabled
144        // But we still respect the configuration for testing purposes
145        if !self.is_middleware_enabled(HttpMiddleware::CatchPanic) {
146            tracing::trace!("CatchPanic middleware skipped (disabled in config)");
147            return self;
148        }
149
150        tracing::trace!("CatchPanic middleware enabled");
151        let panic_channel = self.panic_channel.clone();
152        self.inner = self.inner.layer(CatchPanicLayer::custom(
153            move |err: Box<dyn std::any::Any + Send + 'static>| {
154                // NOTE: taken verbatime from the source of DefaultResponseForPanic
155                let msg = if let Some(s) = err.downcast_ref::<String>() {
156                    format!("Service panicked: {}", s)
157                } else if let Some(s) = err.downcast_ref::<&str>() {
158                    format!("Service panicked: {}", s)
159                } else {
160                    "`CatchPanic` was unable to downcast the panic info".to_string()
161                };
162
163                tracing::error!("Service panicked: {}", msg);
164                if let Some(ch) = &panic_channel {
165                    ch.try_send(msg).ok();
166                }
167
168                // Build the final response - use unwrap_or_else to avoid panicking
169                // inside the panic handler
170                Response::builder()
171                    .status(StatusCode::INTERNAL_SERVER_ERROR)
172                    .header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
173                    .body("Internal Server Error".to_string())
174                    .unwrap_or_else(|_| Response::new("Internal Server Error".to_string()))
175            },
176        ));
177        self
178    }
179}