Skip to main content

tibba_middleware/
common.rs

1// Copyright 2026 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::{Error, HeaderValueSnafu, LOG_TARGET};
16use axum::extract::Request;
17use axum::extract::State;
18use axum::middleware::Next;
19use axum::response::Response;
20use scopeguard::defer;
21use snafu::ResultExt;
22use std::time::Duration;
23use tibba_cache::RedisCache;
24use tibba_state::CTX;
25use tokio::time::sleep;
26use tracing::debug;
27
28type Result<T, E = Error> = std::result::Result<T, E>;
29
30/// Parameters for configuring the wait middleware
31/// Controls the waiting behavior after request processing
32#[derive(Debug, Clone, Default)]
33pub struct WaitParams {
34    // Duration to wait in milliseconds
35    wait: Duration,
36    // If true, only wait when an error response occurs (status >= 400)
37    only_error_occurred: bool,
38}
39
40impl WaitParams {
41    /// Creates a new WaitParams with the specified wait duration in milliseconds.
42    pub fn new(ms: u64) -> Self {
43        Self {
44            wait: Duration::from_millis(ms),
45            ..Default::default()
46        }
47    }
48
49    /// Only wait when the response status is >= 400.
50    #[must_use]
51    pub fn only_on_error(mut self) -> Self {
52        self.only_error_occurred = true;
53        self
54    }
55}
56
57/// Middleware that adds a configurable delay after request processing
58///
59/// This middleware can be useful for:
60/// - Rate limiting
61/// - Simulating network latency
62/// - Preventing timing attacks
63/// - Ensuring minimum response times
64///
65/// # Arguments
66/// * `State(params)` - Wait configuration parameters
67/// * `req` - The incoming request
68/// * `next` - The next middleware in the chain
69pub async fn wait(State(params): State<WaitParams>, req: Request, next: Next) -> Response {
70    // Log middleware entry
71    debug!(target: LOG_TARGET, "--> wait");
72    // Ensure exit logging happens even if processing panics
73    defer!(debug!(target: LOG_TARGET, "<-- wait"););
74
75    // Process the request through the middleware chain
76    let res = next.run(req).await;
77
78    // Check if we should wait based on error condition
79    if params.only_error_occurred && res.status().as_u16() < 400 {
80        return res;
81    }
82
83    // Calculate remaining time to wait
84    let elapsed = CTX.get().elapsed();
85    let remaining_wait = params.wait.saturating_sub(elapsed);
86
87    // Only wait if the remaining time is significant (>= 10ms)
88    if remaining_wait.as_millis() >= 10 {
89        sleep(remaining_wait).await
90    }
91
92    res
93}
94
95/// Middleware to validate captcha tokens in incoming requests
96///
97/// # Arguments
98/// * `magic_code` - Special code that can bypass normal captcha validation (for testing)
99/// * `cache` - Redis cache instance for storing/retrieving captcha codes
100/// * `req` - The incoming HTTP request
101/// * `next` - The next middleware handler
102///
103/// # Format
104/// The X-Captcha header should contain a colon-separated string with 3 parts:
105/// `key:code` where:
106/// - key: unique key to look up the stored captcha code
107/// - code: the actual captcha code to validate
108pub async fn validate_captcha(
109    State((magic_code, cache)): State<(String, &'static RedisCache)>,
110    req: Request,
111    next: Next,
112) -> Result<Response, tibba_error::Error> {
113    // Category name for error handling
114    let category = "captcha";
115
116    // Extract and parse the X-Captcha header
117    let value = req
118        .headers()
119        .get("X-Captcha")
120        .ok_or(Error::Common {
121            message: "captcha is required".to_string(),
122            category: category.to_string(),
123        })?
124        .to_str()
125        .context(HeaderValueSnafu)?;
126
127    let (key, user_code) = value.split_once(':').ok_or_else(|| Error::Common {
128        message: "captcha parameter is invalid, expect 'key:code'".to_string(),
129        category: category.to_string(),
130    })?;
131
132    // Check if this is a mock request using the magic code
133    if !magic_code.is_empty() && user_code == magic_code {
134        return Ok(next.run(req).await);
135    }
136
137    // Retrieve and delete the stored code from cache using the key (arr[1])
138    let code: Option<String> = cache.get_del(key).await?;
139    let Some(code) = code else {
140        return Err(Error::Common {
141            message: "captcha is expired".to_string(),
142            category: category.to_string(),
143        }
144        .into());
145    };
146
147    // Compare the provided code against the stored code
148    if code != user_code {
149        return Err(Error::Common {
150            message: "captcha is invalid".to_string(),
151            category: category.to_string(),
152        }
153        .into());
154    }
155
156    Ok(next.run(req).await)
157}