use axum::Router as AXRouter;
use serde::{Deserialize, Deserializer, Serialize};
use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Result};
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LimitPayload {
#[serde(default)]
pub enable: bool,
#[serde(deserialize_with = "deserialize_body_limit")]
#[serde(default = "default_body_limit")]
pub body_limit: usize,
}
impl Default for LimitPayload {
fn default() -> Self {
Self {
enable: false,
body_limit: default_body_limit(),
}
}
}
fn default_body_limit() -> usize {
2_000_000
}
fn deserialize_body_limit<'de, D>(deserializer: D) -> Result<usize, D::Error>
where
D: Deserializer<'de>,
{
Ok(
byte_unit::Byte::from_str(String::deserialize(deserializer)?)
.map_err(|err| serde::de::Error::custom(err.to_string()))?
.get_bytes() as usize,
)
}
impl MiddlewareLayer for LimitPayload {
fn name(&self) -> &'static str {
"limit_payload"
}
fn is_enabled(&self) -> bool {
self.enable
}
fn config(&self) -> serde_json::Result<serde_json::Value> {
serde_json::to_value(self)
}
fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
Ok(app.layer(axum::extract::DefaultBodyLimit::max(self.body_limit)))
}
}