use axum::Router as AXRouter;
use serde::{Deserialize, Deserializer, Serialize};
use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result};
#[derive(Debug, Clone, Copy, Deserialize, Serialize)]
pub enum DefaultBodyLimitKind {
Disable,
Limit(usize),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LimitPayload {
#[serde(
default = "default_body_limit",
deserialize_with = "deserialize_body_limit"
)]
pub body_limit: DefaultBodyLimitKind,
}
impl Default for LimitPayload {
fn default() -> Self {
Self {
body_limit: default_body_limit(),
}
}
}
fn default_body_limit() -> DefaultBodyLimitKind {
DefaultBodyLimitKind::Limit(2_000_000)
}
fn deserialize_body_limit<'de, D>(deserializer: D) -> Result<DefaultBodyLimitKind, D::Error>
where
D: Deserializer<'de>,
{
let s: String = String::deserialize(deserializer)?;
match s.as_str() {
"disable" => Ok(DefaultBodyLimitKind::Disable),
limit => {
let bytes = byte_unit::Byte::from_str(limit)
.map_err(|err| serde::de::Error::custom(err.to_string()))?
.get_bytes();
Ok(DefaultBodyLimitKind::Limit(bytes as usize))
}
}
}
impl MiddlewareLayer for LimitPayload {
fn name(&self) -> &'static str {
"limit_payload"
}
fn is_enabled(&self) -> bool {
true
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
let body_limit_layer = match self.body_limit {
DefaultBodyLimitKind::Disable => axum::extract::DefaultBodyLimit::disable(),
DefaultBodyLimitKind::Limit(limit) => axum::extract::DefaultBodyLimit::max(limit),
};
Ok(app.layer(body_limit_layer))
}
}