use axum::extract::{ConnectInfo, FromRequestParts};
use axum::http::request::Parts;
use snafu::Snafu;
use std::net::{IpAddr, SocketAddr};
use tibba_error::Error as BaseError;
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("{message}"))]
Common { message: String, category: String },
#[snafu(display("too many requests, limit: {limit}, current: {current}"))]
TooManyRequests { limit: i64, current: i64 },
}
impl From<Error> for BaseError {
fn from(val: Error) -> Self {
let err = match val {
Error::Common { message, category } => {
BaseError::new(&message).with_sub_category(&category)
}
Error::TooManyRequests { .. } => BaseError::new(val.to_string())
.with_sub_category("too_many_requests")
.with_status(429),
};
err.with_category("middleware")
}
}
#[derive(Debug, Clone, Copy)]
pub struct ClientIp(pub IpAddr);
impl<S> FromRequestParts<S> for ClientIp
where
S: Sync,
{
type Rejection = tibba_error::Error;
async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> std::result::Result<Self, Self::Rejection> {
let client_ip = {
parts
.headers
.get("X-Forwarded-For")
.and_then(|header| header.to_str().ok()) .and_then(|s| s.split(',').next()) .map(|s| s.trim()) .and_then(|s| s.parse::<IpAddr>().ok()) .or_else(|| {
parts
.headers
.get("X-Real-Ip")
.and_then(|header| header.to_str().ok())
.map(|s| s.trim())
.and_then(|s| s.parse::<IpAddr>().ok())
})
.or_else(|| {
parts
.extensions
.get::<ConnectInfo<SocketAddr>>()
.map(|ConnectInfo(addr)| addr.ip())
})
};
client_ip
.map(ClientIp)
.ok_or_else(|| BaseError::new("Client IP address could not be determined"))
}
}
mod common;
mod entry;
mod limit;
mod session;
mod stats;
mod tracker;
pub use common::*;
pub use entry::*;
pub use limit::*;
pub use session::*;
pub use stats::*;
pub use tracker::*;