tibba_middleware/common.rs
1// Copyright 2026 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::{Error, HeaderValueSnafu, LOG_TARGET};
16use axum::extract::Request;
17use axum::extract::State;
18use axum::middleware::Next;
19use axum::response::Response;
20use scopeguard::defer;
21use snafu::ResultExt;
22use std::time::Duration;
23use tibba_cache::RedisCache;
24use tibba_state::CTX;
25use tokio::time::sleep;
26use tracing::debug;
27
28type Result<T, E = Error> = std::result::Result<T, E>;
29
30/// Parameters for configuring the wait middleware
31/// Controls the waiting behavior after request processing
32#[derive(Debug, Clone, Default)]
33pub struct WaitParams {
34 // Duration to wait in milliseconds
35 wait: Duration,
36 // If true, only wait when an error response occurs (status >= 400)
37 only_error_occurred: bool,
38}
39
40impl WaitParams {
41 /// Creates a new WaitParams with the specified wait duration in milliseconds.
42 pub fn new(ms: u64) -> Self {
43 Self {
44 wait: Duration::from_millis(ms),
45 ..Default::default()
46 }
47 }
48
49 /// Only wait when the response status is >= 400.
50 #[must_use]
51 pub fn only_on_error(mut self) -> Self {
52 self.only_error_occurred = true;
53 self
54 }
55}
56
57/// Middleware that adds a configurable delay after request processing
58///
59/// This middleware can be useful for:
60/// - Rate limiting
61/// - Simulating network latency
62/// - Preventing timing attacks
63/// - Ensuring minimum response times
64///
65/// # Arguments
66/// * `State(params)` - Wait configuration parameters
67/// * `req` - The incoming request
68/// * `next` - The next middleware in the chain
69pub async fn wait(State(params): State<WaitParams>, req: Request, next: Next) -> Response {
70 // Log middleware entry
71 debug!(target: LOG_TARGET, "--> wait");
72 // Ensure exit logging happens even if processing panics
73 defer!(debug!(target: LOG_TARGET, "<-- wait"););
74
75 // Process the request through the middleware chain
76 let res = next.run(req).await;
77
78 // Check if we should wait based on error condition
79 if params.only_error_occurred && res.status().as_u16() < 400 {
80 return res;
81 }
82
83 // Calculate remaining time to wait
84 let elapsed = CTX.get().elapsed();
85 let remaining_wait = params.wait.saturating_sub(elapsed);
86
87 // Only wait if the remaining time is significant (>= 10ms)
88 if remaining_wait.as_millis() >= 10 {
89 sleep(remaining_wait).await
90 }
91
92 res
93}
94
95/// Middleware to validate captcha tokens in incoming requests
96///
97/// # Arguments
98/// * `magic_code` - Special code that can bypass normal captcha validation (for testing)
99/// * `cache` - Redis cache instance for storing/retrieving captcha codes
100/// * `req` - The incoming HTTP request
101/// * `next` - The next middleware handler
102///
103/// # Format
104/// The X-Captcha header should contain a colon-separated string with 3 parts:
105/// `key:code` where:
106/// - key: unique key to look up the stored captcha code
107/// - code: the actual captcha code to validate
108pub async fn validate_captcha(
109 State((magic_code, cache)): State<(String, &'static RedisCache)>,
110 req: Request,
111 next: Next,
112) -> Result<Response, tibba_error::Error> {
113 // Category name for error handling
114 let category = "captcha";
115
116 // Extract and parse the X-Captcha header
117 let value = req
118 .headers()
119 .get("X-Captcha")
120 .ok_or(Error::Common {
121 message: "captcha is required".to_string(),
122 category: category.to_string(),
123 })?
124 .to_str()
125 .context(HeaderValueSnafu)?;
126
127 let (key, user_code) = value.split_once(':').ok_or_else(|| Error::Common {
128 message: "captcha parameter is invalid, expect 'key:code'".to_string(),
129 category: category.to_string(),
130 })?;
131
132 // Check if this is a mock request using the magic code
133 if !magic_code.is_empty() && user_code == magic_code {
134 return Ok(next.run(req).await);
135 }
136
137 // Retrieve and delete the stored code from cache using the key (arr[1])
138 let code: Option<String> = cache.get_del(key).await?;
139 let Some(code) = code else {
140 return Err(Error::Common {
141 message: "captcha is expired".to_string(),
142 category: category.to_string(),
143 }
144 .into());
145 };
146
147 // Compare the provided code against the stored code
148 if code != user_code {
149 return Err(Error::Common {
150 message: "captcha is invalid".to_string(),
151 category: category.to_string(),
152 }
153 .into());
154 }
155
156 Ok(next.run(req).await)
157}