use crate::models::AppState;
use axum::{
extract::{Request, State},
http::{HeaderMap, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
pub async fn connection_limit_middleware(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let path = request.uri().path();
if !path.starts_with("/v1/query") && !path.starts_with("/v1/batch") && !path.starts_with("/v1/transaction") {
return Ok(next.run(request).await);
}
let api_key = match extract_api_key(request.headers()) {
Ok(key) => key,
Err(_) => {
return Ok(next.run(request).await);
}
};
let accounts = state.accounts.read().await;
if let Some(account) = accounts.get(&api_key) {
let account_id = account.id.clone();
let max_connections = account.max_connections;
drop(accounts);
if max_connections == 0 {
return Ok(next.run(request).await);
}
let current_count = state.active_connections
.get(&account_id)
.map(|entry| *entry.value())
.unwrap_or(0);
if current_count >= max_connections {
let error_response = serde_json::json!({
"success": false,
"error": {
"code": "CONNECTION_LIMIT_EXCEEDED",
"message": format!("Maximum connections ({}) exceeded. Current: {}", max_connections, current_count)
}
});
return Ok((
StatusCode::SERVICE_UNAVAILABLE,
[("Content-Type", "application/json")],
axum::Json(error_response)
).into_response());
}
state.active_connections
.entry(account_id.clone())
.and_modify(|count| *count += 1)
.or_insert(1);
let response = next.run(request).await;
state.active_connections
.entry(account_id)
.and_modify(|count| *count = count.saturating_sub(1));
Ok(response)
} else {
Ok(next.run(request).await)
}
}
fn extract_api_key(headers: &HeaderMap) -> Result<String, String> {
headers
.get("x-api-key")
.or_else(|| headers.get("authorization"))
.and_then(|h| h.to_str().ok())
.map(|h| {
if h.starts_with("Bearer ") {
h.strip_prefix("Bearer ").unwrap().to_string()
} else {
h.to_string()
}
})
.ok_or_else(|| "Missing API key".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
#[test]
fn test_extract_api_key() {
let mut headers = HeaderMap::new();
headers.insert("x-api-key", HeaderValue::from_static("test-key"));
assert_eq!(extract_api_key(&headers).unwrap(), "test-key");
headers.clear();
headers.insert("authorization", HeaderValue::from_static("Bearer test-key-2"));
assert_eq!(extract_api_key(&headers).unwrap(), "test-key-2");
headers.clear();
assert!(extract_api_key(&headers).is_err());
}
}