use http::{Request, Response, StatusCode};
use ipnet::IpNet;
use serde::{Deserialize, Serialize};
use tower_http::validate_request::{ValidateRequest, ValidateRequestHeaderLayer};
use super::ApplyLayer;
use crate::automatic_body::{add_automatic_body, AutomaticBody};
use crate::client_ip::GetClientIp;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum IpFilterMode {
Allow,
Deny,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub struct IpFilterConfig {
filter_mode: IpFilterMode,
prefixes: Vec<IpNet>,
}
impl<B> ValidateRequest<B> for IpFilterConfig {
type ResponseBody = AutomaticBody;
fn validate(&mut self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
let client_ip = req.client_ip();
let matched = self.prefixes.iter().any(|x| x.contains(&client_ip));
let allowed = match self.filter_mode {
IpFilterMode::Deny => !matched,
IpFilterMode::Allow => matched,
};
if allowed {
Ok(())
} else {
let mut res = Response::new(Self::ResponseBody::default());
*res.status_mut() = StatusCode::FORBIDDEN;
Err(res)
}
}
}
pub fn from_config(config: IpFilterConfig) -> anyhow::Result<impl ApplyLayer> {
Ok(add_automatic_body(ValidateRequestHeaderLayer::custom(
config,
)))
}