use axum::{
body::Body,
extract::{Request, State},
http::{header::HeaderValue, HeaderName, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use super::service::LoginLockout;
use crate::error::{Error, ErrorResponse};
#[derive(Clone)]
pub struct LockoutMiddleware {
lockout: LoginLockout,
identity_field: String,
}
impl LockoutMiddleware {
pub fn new(lockout: LoginLockout, identity_field: &str) -> Self {
Self {
lockout,
identity_field: identity_field.to_string(),
}
}
pub async fn middleware(
State(mw): State<Self>,
request: Request<Body>,
next: Next,
) -> Result<Response, Error> {
let is_json = request
.headers()
.get(axum::http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|ct| ct.contains("application/json"))
.unwrap_or(false);
if !is_json {
return Ok(next.run(request).await);
}
let (parts, body) = request.into_parts();
let bytes = axum::body::to_bytes(body, 1024 * 1024) .await
.map_err(|e| Error::BadRequest(format!("Failed to read request body: {}", e)))?;
let identity = serde_json::from_slice::<serde_json::Value>(&bytes)
.ok()
.and_then(|v| v.get(&mw.identity_field).cloned())
.and_then(|v| v.as_str().map(|s| s.to_string()));
let identity = match identity {
Some(id) => id,
None => {
let request = Request::from_parts(parts, Body::from(bytes));
return Ok(next.run(request).await);
}
};
let status = mw.lockout.check(&identity).await?;
if status.locked {
let retry_after = status.lockout_remaining_secs;
let error_response = ErrorResponse::with_code(
StatusCode::LOCKED,
"ACCOUNT_LOCKED",
format!("Account locked. Try again in {} seconds", retry_after),
);
let mut response = (StatusCode::LOCKED, axum::Json(error_response)).into_response();
if let Ok(value) = HeaderValue::from_str(&retry_after.to_string()) {
response
.headers_mut()
.insert(HeaderName::from_static("retry-after"), value);
}
return Ok(response);
}
let request = Request::from_parts(parts, Body::from(bytes));
let response = next.run(request).await;
let response_status = response.status();
if response_status == StatusCode::UNAUTHORIZED {
let failure_status = mw.lockout.record_failure(&identity).await?;
if failure_status.delay_ms > 0 {
tokio::time::sleep(std::time::Duration::from_millis(failure_status.delay_ms)).await;
}
} else if response_status.is_success() {
mw.lockout.record_success(&identity).await?;
}
Ok(response)
}
}