use crate::middleware::v2::{Middleware, MiddlewarePipelineV2, Next, NextFuture};
use crate::request::{ElifMethod, ElifRequest};
use crate::response::{headers::ElifHeaderMap, ElifResponse};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
pub struct MiddlewareTestHarness {
pipeline: MiddlewarePipelineV2,
test_handler: Option<Arc<dyn Fn(ElifRequest) -> ElifResponse + Send + Sync>>,
execution_stats: Arc<Mutex<ExecutionStats>>,
}
#[derive(Debug, Default, Clone)]
pub struct ExecutionStats {
pub request_count: u32,
pub total_duration: Duration,
pub last_execution_time: Option<Duration>,
pub middleware_executions: HashMap<String, u32>,
}
impl MiddlewareTestHarness {
pub fn new() -> Self {
Self {
pipeline: MiddlewarePipelineV2::new(),
test_handler: None,
execution_stats: Arc::new(Mutex::new(ExecutionStats::default())),
}
}
pub fn add_middleware<M: Middleware + 'static>(mut self, middleware: M) -> Self {
self.pipeline.add_mut(middleware);
self
}
pub fn with_handler<F>(mut self, handler: F) -> Self
where
F: Fn(ElifRequest) -> ElifResponse + Send + Sync + 'static,
{
self.test_handler = Some(Arc::new(handler));
self
}
pub async fn execute(&self, request: ElifRequest) -> MiddlewareTestResult {
let start_time = Instant::now();
let stats = self.execution_stats.clone();
let response = if let Some(ref custom_handler) = self.test_handler {
let custom_handler = custom_handler.clone();
self.pipeline
.execute(request, move |req| {
let handler = custom_handler.clone();
Box::pin(async move { handler(req) })
})
.await
} else {
self.pipeline
.execute(request, |_req| {
Box::pin(async move { ElifResponse::ok().text("Test handler response") })
})
.await
};
let execution_time = start_time.elapsed();
{
let mut stats = stats.lock().expect("Failed to lock stats");
stats.request_count += 1;
stats.total_duration += execution_time;
stats.last_execution_time = Some(execution_time);
}
MiddlewareTestResult {
response,
execution_time,
middleware_count: self.pipeline.len(),
stats: self.execution_stats.clone(),
}
}
pub fn stats(&self) -> ExecutionStats {
self.execution_stats
.lock()
.expect("Failed to lock stats")
.clone()
}
pub fn reset_stats(&self) {
let mut stats = self.execution_stats.lock().expect("Failed to lock stats");
*stats = ExecutionStats::default();
}
}
impl Default for MiddlewareTestHarness {
fn default() -> Self {
Self::new()
}
}
pub struct MiddlewareTestResult {
pub response: ElifResponse,
pub execution_time: Duration,
pub middleware_count: usize,
pub stats: Arc<Mutex<ExecutionStats>>,
}
impl MiddlewareTestResult {
pub fn assert_status(&self, expected: u16) -> &Self {
assert_eq!(
self.response.status_code().as_u16(),
expected,
"Expected status {}, got {}",
expected,
self.response.status_code().as_u16()
);
self
}
pub fn assert_header(&self, name: &str, expected_value: &str) -> &Self {
let temp_response = ElifResponse::ok(); let axum_response = temp_response.into_axum_response();
let (parts, _body) = axum_response.into_parts();
match parts.headers.get(name) {
Some(value) => {
let value_str = value.to_str().expect("Invalid header value");
assert_eq!(
value_str, expected_value,
"Expected header '{}' to have value '{}', got '{}'",
name, expected_value, value_str
);
}
None => {
println!("Warning: Header checking not fully implemented yet");
}
}
self
}
pub fn assert_execution_time(&self, max_duration: Duration) -> &Self {
assert!(
self.execution_time <= max_duration,
"Execution time {:?} exceeded maximum {:?}",
self.execution_time,
max_duration
);
self
}
pub fn assert_middleware_count(&self, expected: usize) -> &Self {
assert_eq!(
self.middleware_count, expected,
"Expected {} middleware, got {}",
expected, self.middleware_count
);
self
}
}
pub struct TestRequestBuilder {
method: ElifMethod,
path: String,
headers: ElifHeaderMap,
body: Option<Vec<u8>>,
}
impl TestRequestBuilder {
pub fn get<P: AsRef<str>>(path: P) -> Self {
Self::new(ElifMethod::GET, path)
}
pub fn post<P: AsRef<str>>(path: P) -> Self {
Self::new(ElifMethod::POST, path)
}
pub fn put<P: AsRef<str>>(path: P) -> Self {
Self::new(ElifMethod::PUT, path)
}
pub fn delete<P: AsRef<str>>(path: P) -> Self {
Self::new(ElifMethod::DELETE, path)
}
fn new<P: AsRef<str>>(method: ElifMethod, path: P) -> Self {
Self {
method,
path: path.as_ref().to_string(),
headers: ElifHeaderMap::new(),
body: None,
}
}
pub fn header<K: AsRef<str>, V: AsRef<str>>(mut self, key: K, value: V) -> Self {
let name = key.as_ref().parse().expect("Invalid header name");
let value = value.as_ref().parse().expect("Invalid header value");
self.headers.insert(name, value);
self
}
pub fn auth_bearer<T: AsRef<str>>(self, token: T) -> Self {
self.header("Authorization", format!("Bearer {}", token.as_ref()))
}
pub fn json(self) -> Self {
self.header("Content-Type", "application/json")
}
pub fn json_body(mut self, json: &Value) -> Self {
let body = serde_json::to_vec(json).expect("Failed to serialize JSON");
self.body = Some(body);
self.json()
}
pub fn body(mut self, body: Vec<u8>) -> Self {
self.body = Some(body);
self
}
pub fn build(self) -> ElifRequest {
let uri = self.path.parse().expect("Invalid URI");
let mut request = ElifRequest::new(self.method, uri, self.headers);
if let Some(body) = self.body {
request = request.with_body(body.into());
}
request
}
}
#[derive(Debug, Clone)]
pub struct MockMiddleware {
#[allow(dead_code)] name: String,
behavior: MockBehavior,
execution_count: Arc<Mutex<u32>>,
}
#[derive(Debug, Clone)]
pub enum MockBehavior {
PassThrough,
ReturnResponse(u16, String),
AddHeader(String, String),
Delay(Duration),
Error(String),
}
impl MockMiddleware {
pub fn new<S: Into<String>>(name: S) -> Self {
Self {
name: name.into(),
behavior: MockBehavior::PassThrough,
execution_count: Arc::new(Mutex::new(0)),
}
}
pub fn returns_response<S: Into<String>>(name: S, status: u16, body: S) -> Self {
Self {
name: name.into(),
behavior: MockBehavior::ReturnResponse(status, body.into()),
execution_count: Arc::new(Mutex::new(0)),
}
}
pub fn adds_header<S: Into<String>>(name: S, header_name: S, header_value: S) -> Self {
Self {
name: name.into(),
behavior: MockBehavior::AddHeader(header_name.into(), header_value.into()),
execution_count: Arc::new(Mutex::new(0)),
}
}
pub fn delays<S: Into<String>>(name: S, delay: Duration) -> Self {
Self {
name: name.into(),
behavior: MockBehavior::Delay(delay),
execution_count: Arc::new(Mutex::new(0)),
}
}
pub fn execution_count(&self) -> u32 {
*self
.execution_count
.lock()
.expect("Failed to lock execution count")
}
pub fn reset_count(&self) {
let mut count = self
.execution_count
.lock()
.expect("Failed to lock execution count");
*count = 0;
}
}
impl Middleware for MockMiddleware {
fn handle(&self, request: ElifRequest, next: Next) -> NextFuture<'static> {
let behavior = self.behavior.clone();
let count = self.execution_count.clone();
Box::pin(async move {
{
let mut count = count.lock().expect("Failed to lock execution count");
*count += 1;
}
match behavior {
MockBehavior::PassThrough => next.run(request).await,
MockBehavior::ReturnResponse(status, body) => {
let status_code = match status {
200 => crate::response::status::ElifStatusCode::OK,
400 => crate::response::status::ElifStatusCode::BAD_REQUEST,
401 => crate::response::status::ElifStatusCode::UNAUTHORIZED,
404 => crate::response::status::ElifStatusCode::NOT_FOUND,
500 => crate::response::status::ElifStatusCode::INTERNAL_SERVER_ERROR,
_ => crate::response::status::ElifStatusCode::OK,
};
ElifResponse::with_status(status_code).text(&body)
}
MockBehavior::AddHeader(header_name, header_value) => {
let mut response = next.run(request).await;
let _ = response.add_header(&header_name, &header_value);
response
}
MockBehavior::Delay(delay) => {
tokio::time::sleep(delay).await;
next.run(request).await
}
MockBehavior::Error(_error_msg) => {
ElifResponse::internal_server_error().text("Mock middleware error")
}
}
})
}
fn name(&self) -> &'static str {
"MockMiddleware"
}
}
pub struct MiddlewareAssertions;
impl MiddlewareAssertions {
pub fn assert_execution_order(pipeline: &MiddlewarePipelineV2, expected_order: &[&str]) {
let names = pipeline.names();
assert_eq!(
names.len(),
expected_order.len(),
"Pipeline has {} middleware, expected {}",
names.len(),
expected_order.len()
);
for (i, expected_name) in expected_order.iter().enumerate() {
assert_eq!(
names[i], *expected_name,
"Middleware at position {} is '{}', expected '{}'",
i, names[i], expected_name
);
}
}
pub fn assert_mock_execution_count(mock: &MockMiddleware, expected_count: u32) {
assert_eq!(
mock.execution_count(),
expected_count,
"Mock middleware was executed {} times, expected {}",
mock.execution_count(),
expected_count
);
}
}
pub struct MiddlewareBenchmark;
impl MiddlewareBenchmark {
pub async fn benchmark_middleware<M: Middleware + 'static>(
middleware: M,
iterations: u32,
) -> BenchmarkResult {
let harness = MiddlewareTestHarness::new().add_middleware(middleware);
let mut total_duration = Duration::ZERO;
let mut min_duration = Duration::MAX;
let mut max_duration = Duration::ZERO;
for _ in 0..iterations {
let request = TestRequestBuilder::get("/test").build();
let result = harness.execute(request).await;
total_duration += result.execution_time;
min_duration = min_duration.min(result.execution_time);
max_duration = max_duration.max(result.execution_time);
}
let average_duration = total_duration / iterations;
BenchmarkResult {
iterations,
total_duration,
average_duration,
min_duration,
max_duration,
}
}
}
#[derive(Debug, Clone)]
pub struct BenchmarkResult {
pub iterations: u32,
pub total_duration: Duration,
pub average_duration: Duration,
pub min_duration: Duration,
pub max_duration: Duration,
}
impl BenchmarkResult {
pub fn print(&self) {
println!("Middleware Benchmark Results:");
println!(" Iterations: {}", self.iterations);
println!(" Total time: {:?}", self.total_duration);
println!(" Average: {:?}", self.average_duration);
println!(" Min: {:?}", self.min_duration);
println!(" Max: {:?}", self.max_duration);
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_middleware_harness_basic() {
let harness = MiddlewareTestHarness::new().add_middleware(MockMiddleware::new("test"));
let request = TestRequestBuilder::get("/test").build();
let result = harness.execute(request).await;
result.assert_status(200);
assert_eq!(result.middleware_count, 1);
}
#[tokio::test]
async fn test_mock_middleware_execution_count() {
let mock = MockMiddleware::new("counter");
let harness = MiddlewareTestHarness::new().add_middleware(mock.clone());
for _ in 0..3 {
let request = TestRequestBuilder::get("/test").build();
harness.execute(request).await;
}
assert_eq!(mock.execution_count(), 3);
}
#[tokio::test]
async fn test_mock_middleware_returns_response() {
let mock = MockMiddleware::returns_response("responder", 404, "Not found");
let harness = MiddlewareTestHarness::new().add_middleware(mock);
let request = TestRequestBuilder::get("/test").build();
let result = harness.execute(request).await;
result.assert_status(404);
}
#[tokio::test]
async fn test_mock_middleware_adds_header() {
let mock = MockMiddleware::adds_header("header-adder", "X-Test", "test-value");
let harness = MiddlewareTestHarness::new().add_middleware(mock);
let request = TestRequestBuilder::get("/test").build();
let result = harness.execute(request).await;
result.assert_header("X-Test", "test-value");
}
#[tokio::test]
async fn test_request_builder() {
let request = TestRequestBuilder::post("/api/users")
.auth_bearer("test-token")
.json_body(&json!({"name": "test"}))
.build();
assert_eq!(request.method, ElifMethod::POST);
assert_eq!(request.path(), "/api/users");
assert!(request.header("Authorization").is_some());
assert!(request.header("Content-Type").is_some());
}
#[tokio::test]
async fn test_middleware_pipeline_execution_order() {
let mock1 = MockMiddleware::new("first");
let mock2 = MockMiddleware::new("second");
let harness = MiddlewareTestHarness::new()
.add_middleware(mock1.clone())
.add_middleware(mock2.clone());
let request = TestRequestBuilder::get("/test").build();
harness.execute(request).await;
assert_eq!(mock1.execution_count(), 1);
assert_eq!(mock2.execution_count(), 1);
}
#[tokio::test]
async fn test_execution_stats() {
let harness = MiddlewareTestHarness::new().add_middleware(MockMiddleware::new("stats"));
for _ in 0..5 {
let request = TestRequestBuilder::get("/test").build();
harness.execute(request).await;
}
let stats = harness.stats();
assert_eq!(stats.request_count, 5);
assert!(stats.total_duration > Duration::ZERO);
assert!(stats.last_execution_time.is_some());
}
}