Skip to main content

tibba_middleware/
lib.rs

1// Copyright 2026 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use axum::extract::{ConnectInfo, FromRequestParts};
16use axum::http::request::Parts;
17use snafu::Snafu;
18use std::net::{IpAddr, SocketAddr};
19use tibba_error::Error as BaseError;
20
21/// Tracing target for all log events in this crate.
22/// Use `RUST_LOG=tibba-middleware=info` (or `debug`) to filter these logs.
23pub(crate) const LOG_TARGET: &str = "tibba:middleware";
24
25#[derive(Debug, Snafu)]
26pub enum Error {
27    #[snafu(display("{message}"))]
28    Common { message: String, category: String },
29    #[snafu(display("too many requests, limit: {limit}, current: {current}"))]
30    TooManyRequests { limit: i64, current: i64 },
31    #[snafu(display("{source}"))]
32    HeaderValue {
33        source: axum::http::header::ToStrError,
34    },
35}
36
37impl From<Error> for BaseError {
38    fn from(val: Error) -> Self {
39        let err = match val {
40            Error::Common { message, category } => {
41                BaseError::new(&message).with_sub_category(&category)
42            }
43            Error::TooManyRequests { limit, current } => BaseError::new(format!(
44                "too many requests, limit: {limit}, current: {current}"
45            ))
46            .with_sub_category("too_many_requests")
47            .with_status(429),
48            Error::HeaderValue { source } => {
49                BaseError::new(source).with_sub_category("header_value")
50            }
51        };
52        err.with_category("middleware")
53    }
54}
55
56#[derive(Debug, Clone, Copy)]
57pub struct ClientIp(pub IpAddr);
58
59impl<S> FromRequestParts<S> for ClientIp
60where
61    S: Sync,
62{
63    type Rejection = tibba_error::Error;
64
65    async fn from_request_parts(
66        parts: &mut Parts,
67        _state: &S,
68    ) -> std::result::Result<Self, Self::Rejection> {
69        let client_ip = parts
70            .headers
71            .get("X-Forwarded-For")
72            .and_then(|header| header.to_str().ok())
73            .and_then(|s| s.split(',').next())
74            .map(|s| s.trim())
75            .and_then(|s| s.parse::<IpAddr>().ok())
76            .or_else(|| {
77                parts
78                    .headers
79                    .get("X-Real-Ip")
80                    .and_then(|header| header.to_str().ok())
81                    .map(|s| s.trim())
82                    .and_then(|s| s.parse::<IpAddr>().ok())
83            })
84            .or_else(|| {
85                parts
86                    .extensions
87                    .get::<ConnectInfo<SocketAddr>>()
88                    .map(|ConnectInfo(addr)| addr.ip())
89            });
90
91        // if all attempts fail (result is None), return error
92        client_ip
93            .map(ClientIp)
94            .ok_or_else(|| BaseError::new("Client IP address could not be determined"))
95    }
96}
97
98mod common;
99mod entry;
100mod limit;
101mod stats;
102mod tracker;
103
104pub use common::*;
105pub use entry::*;
106pub use limit::*;
107pub use stats::*;
108pub use tracker::*;