#![expect(
clippy::unwrap_used,
reason = "legacy middleware tests use static fixtures; cleanup is tracked in policy/clippy-debt.toml"
)]
use axum::{
body::Body,
extract::connect_info::ConnectInfo,
http::{Request, StatusCode},
};
use hl7v2_server::{
handlers::{parse_handler, validate_handler},
metrics::{init_metrics_recorder, metrics_handler},
server::AppState,
};
use std::sync::Arc;
use std::time::Instant;
use tower::ServiceExt; use tower::limit::ConcurrencyLimitLayer;
use tower_governor::governor::GovernorConfigBuilder;
use tower_http::{
compression::CompressionLayer,
cors::{Any, CorsLayer},
trace::TraceLayer,
};
use utoipa_swagger_ui::SwaggerUi;
fn build_test_router(
rate_per_second: u64,
burst_size: u32,
max_concurrency: usize,
) -> axum::Router {
let metrics_handle = init_metrics_recorder();
let state = Arc::new(AppState {
start_time: Instant::now(),
metrics_handle: Arc::new(metrics_handle),
api_key: None,
cors_allowed_origins: Default::default(),
});
let governor_conf = Arc::new(
GovernorConfigBuilder::default()
.per_second(rate_per_second)
.burst_size(burst_size)
.finish()
.unwrap(),
);
const OPENAPI_YAML: &str = include_str!("../../../api/openapi/hl7v2-api-v1.yaml");
let api_routes = axum::Router::new()
.route("/parse", axum::routing::post(parse_handler))
.route("/validate", axum::routing::post(validate_handler));
axum::Router::new()
.merge(
SwaggerUi::new("/api/docs")
.config(utoipa_swagger_ui::Config::from("/api/openapi.yaml")),
)
.route(
"/api/openapi.yaml",
axum::routing::get(|| async {
(
[(axum::http::header::CONTENT_TYPE, "text/yaml")],
OPENAPI_YAML,
)
}),
)
.route(
"/health",
axum::routing::get(|| async { (StatusCode::OK, "{\"status\":\"healthy\"}") }),
)
.route(
"/ready",
axum::routing::get(|| async { "{\"ready\":true}" }),
)
.route("/metrics", axum::routing::get(metrics_handler))
.nest("/hl7", api_routes) .with_state(state)
.layer(axum::middleware::from_fn(
hl7v2_server::metrics::middleware::metrics_middleware,
))
.layer(CompressionLayer::new())
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
)
.layer(TraceLayer::new_for_http())
.layer(tower_governor::GovernorLayer::new(governor_conf.clone())) .layer(ConcurrencyLimitLayer::new(max_concurrency)) }
#[tokio::test]
async fn test_rate_limiting_allows_requests_within_limit() {
let app = build_test_router(5, 10, 100);
for i in 0..3 {
let app_clone = app.clone();
let response = app_clone
.oneshot(
Request::builder()
.extension(ConnectInfo(std::net::SocketAddr::from((
[127, 0, 0, 1],
8080,
))))
.uri("/hl7/parse") .method("POST")
.header("Content-Type", "application/json")
.body(Body::from(create_parse_request_payload()))
.unwrap(),
)
.await
.unwrap();
assert_ne!(
response.status(),
StatusCode::TOO_MANY_REQUESTS,
"Request {} should not be rate limited",
i + 1
);
}
}
#[tokio::test]
async fn test_rate_limiting_blocks_requests_over_limit() {
let app = build_test_router(1, 1, 100);
let app_clone = app.clone();
let response = app_clone
.oneshot(
Request::builder()
.extension(ConnectInfo(std::net::SocketAddr::from((
[127, 0, 0, 1],
8080,
))))
.uri("/hl7/parse") .method("POST")
.header("Content-Type", "application/json")
.body(Body::from(create_parse_request_payload()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
response.status(),
StatusCode::OK,
"First request should succeed"
);
let app_clone = app.clone();
let response = app_clone
.oneshot(
Request::builder()
.extension(ConnectInfo(std::net::SocketAddr::from((
[127, 0, 0, 1],
8080,
))))
.uri("/hl7/parse") .method("POST")
.header("Content-Type", "application/json")
.body(Body::from(create_parse_request_payload()))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
response.status(),
StatusCode::TOO_MANY_REQUESTS,
"Second request should be rate limited"
);
}
#[tokio::test]
async fn test_concurrency_limiting_allows_requests_under_limit() {
let app = build_test_router(100, 100, 50);
let mut tasks = vec![];
for _ in 0..10 {
let app_clone = app.clone();
let task = tokio::spawn(async move {
app_clone
.oneshot(
Request::builder()
.extension(ConnectInfo(std::net::SocketAddr::from((
[127, 0, 0, 1],
8080,
))))
.uri("/hl7/parse") .method("POST")
.header("Content-Type", "application/json")
.body(Body::from(create_parse_request_payload()))
.unwrap(),
)
.await
.unwrap()
});
tasks.push(task);
}
for task in tasks {
let response = task.await.unwrap();
assert_ne!(
response.status(),
StatusCode::SERVICE_UNAVAILABLE,
"Request should not be rejected due to concurrency limit"
);
}
}
#[tokio::test]
async fn test_concurrency_limiting_blocks_requests_over_limit() {
let app = build_test_router(100, 100, 2);
let mut tasks = vec![];
for i in 0..4 {
let app_clone = app.clone();
let task = tokio::spawn(async move {
app_clone
.oneshot(
Request::builder()
.extension(ConnectInfo(std::net::SocketAddr::from((
[127, 0, 0, 1],
8080,
))))
.uri("/hl7/parse") .method("POST")
.header("Content-Type", "application/json")
.body(Body::from(create_parse_request_payload()))
.unwrap(),
)
.await
.unwrap()
});
tasks.push((i, task));
}
let mut responses = vec![];
for (i, task) in tasks {
let response = task.await.unwrap();
responses.push((i, response));
}
for (i, response) in responses {
assert_eq!(
response.status(),
StatusCode::OK,
"Request {} should eventually succeed",
i + 1
);
}
}
fn create_parse_request_payload() -> String {
let request_body = serde_json::json!({
"message": "MSH|^~\\&|SendingApp|SendingFac|ReceivingApp|ReceivingFac|20231119120000||ADT^A01|MSG001|P|2.5\rPID|1||MRN123^^^Facility^MR||Doe^John^A||19800101|M\r",
"mllp_framed": false,
"options": {
"include_json": true,
"validate_structure": true
}
});
serde_json::to_string(&request_body).unwrap()
}