use reinhardt_di::InjectionContext;
use reinhardt_http::Handler;
use reinhardt_http::{Request, Response};
use reinhardt_server::{
HttpServer, RateLimitConfig, RateLimitHandler, ShutdownCoordinator, TimeoutHandler,
};
use reinhardt_urls::routers::ServerRouter as Router;
use rstest::fixture;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
#[cfg(feature = "websockets")]
use reinhardt_server::WebSocketServer;
#[cfg(feature = "graphql")]
use reinhardt_server::GraphQLHandler;
pub struct TestServerGuard {
pub url: String,
pub coordinator: Arc<ShutdownCoordinator>,
server_task: Option<JoinHandle<()>>,
}
impl TestServerGuard {
async fn new(router: Router) -> Self {
let shutdown_timeout = Duration::from_secs(5);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let actual_addr = listener.local_addr().unwrap();
let url = format!("http://{}", actual_addr);
let coordinator = Arc::new(ShutdownCoordinator::new(shutdown_timeout));
let server_coordinator = (*coordinator).clone();
let handler: Arc<dyn Handler> = Arc::new(router);
let server = HttpServer::new(handler);
let mut shutdown_rx = server_coordinator.subscribe();
let server_task = tokio::spawn(async move {
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, socket_addr)) => {
let handler_clone = server.handler();
tokio::spawn(async move {
if let Err(e) =
HttpServer::handle_connection(stream, socket_addr, handler_clone, None)
.await
{
eprintln!("Error handling connection: {:?}", e);
}
});
}
Err(e) => {
eprintln!("Error accepting connection: {:?}", e);
break;
}
}
}
_ = shutdown_rx.recv() => {
break;
}
}
}
});
wait_for_server_ready(actual_addr)
.await
.expect("Test server failed to become ready");
Self {
url,
coordinator,
server_task: Some(server_task),
}
}
}
impl Drop for TestServerGuard {
fn drop(&mut self) {
self.coordinator.shutdown();
if let Some(task) = self.server_task.take() {
task.abort();
}
}
}
pub async fn test_server_guard(router: Router) -> TestServerGuard {
TestServerGuard::new(router).await
}
#[derive(Clone)]
pub struct BasicHandler;
#[async_trait::async_trait]
impl Handler for BasicHandler {
async fn handle(&self, _request: Request) -> reinhardt_core::exception::Result<Response> {
Ok(Response::ok().with_body("OK"))
}
}
#[fixture]
pub fn http_client() -> reqwest::Client {
reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.expect("Failed to create HTTP client")
}
#[fixture]
pub async fn http1_server() -> TestServer {
let handler = Arc::new(BasicHandler);
TestServer::builder()
.handler(handler)
.build()
.await
.expect("Failed to create HTTP/1.1 server")
}
#[fixture]
pub async fn http2_server() -> TestServer {
let handler = Arc::new(BasicHandler);
TestServer::builder()
.handler(handler)
.http2(true)
.build()
.await
.expect("Failed to create HTTP/2 server")
}
#[fixture]
pub async fn server_with_timeout(
#[default(Duration::from_secs(5))] timeout: Duration,
) -> TestServer {
let handler = Arc::new(BasicHandler);
let timeout_handler = Arc::new(TimeoutHandler::new(handler, timeout));
TestServer::builder()
.handler(timeout_handler)
.build()
.await
.expect("Failed to create server with timeout")
}
#[fixture]
pub async fn server_with_rate_limit(#[default(100)] limit: u32) -> TestServer {
let handler = Arc::new(BasicHandler);
let config = RateLimitConfig::per_minute(limit as usize);
let rate_limit_handler = Arc::new(RateLimitHandler::new(handler, config));
TestServer::builder()
.handler(rate_limit_handler)
.build()
.await
.expect("Failed to create server with rate limit")
}
#[fixture]
pub async fn server_with_middleware_chain() -> TestServer {
let handler = Arc::new(BasicHandler);
let timeout_handler = Arc::new(TimeoutHandler::new(handler, Duration::from_secs(5)));
let config = RateLimitConfig::per_minute(100);
let rate_limit_handler = Arc::new(RateLimitHandler::new(timeout_handler, config));
TestServer::builder()
.handler(rate_limit_handler)
.build()
.await
.expect("Failed to create server with middleware chain")
}
#[fixture]
pub async fn server_with_di() -> (TestServer, Arc<InjectionContext>) {
use reinhardt_di::SingletonScope;
let handler = Arc::new(BasicHandler);
let di_context = Arc::new(InjectionContext::builder(Arc::new(SingletonScope::new())).build());
let server = TestServer::builder()
.handler(handler)
.di_context(di_context.clone())
.build()
.await
.expect("Failed to create server with DI context");
(server, di_context)
}
#[cfg(feature = "websockets")]
#[fixture]
pub async fn websocket_server() -> TestServer {
use reinhardt_server::WebSocketHandler;
#[derive(Clone)]
struct EchoHandler;
#[async_trait::async_trait]
impl WebSocketHandler for EchoHandler {
async fn handle_message(&self, message: String) -> Result<String, String> {
Ok(message) }
async fn on_connect(&self) {}
async fn on_disconnect(&self) {}
}
let ws_handler = Arc::new(EchoHandler);
TestServer::builder()
.websocket_handler(ws_handler)
.build()
.await
.expect("Failed to create WebSocket server")
}
#[cfg(feature = "websockets")]
#[fixture]
pub async fn websocket_client(
#[from(websocket_server)]
#[future]
server: TestServer,
) -> tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>> {
let server = server.await;
let ws_url = server.url.replace("http://", "ws://");
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url)
.await
.expect("Failed to connect WebSocket");
ws_stream
}
#[cfg(feature = "graphql")]
#[cfg(feature = "graphql")]
#[fixture]
pub async fn graphql_server() -> TestServer {
use async_graphql::{EmptyMutation, EmptySubscription, Object, Schema};
struct Query;
#[Object]
impl Query {
async fn hello(&self) -> &'static str {
"Hello, GraphQL!"
}
}
let schema = Schema::build(Query, EmptyMutation, EmptySubscription).finish();
let graphql_handler = Arc::new(GraphQLHandler::new(schema));
TestServer::builder()
.handler(graphql_handler)
.build()
.await
.expect("Failed to create GraphQL server")
}
pub struct TestServer {
pub url: String,
pub addr: SocketAddr,
pub coordinator: Arc<ShutdownCoordinator>,
server_task: Option<JoinHandle<()>>,
}
impl TestServer {
pub fn builder() -> TestServerBuilder {
TestServerBuilder::new()
}
}
impl Drop for TestServer {
fn drop(&mut self) {
self.coordinator.shutdown();
if let Some(task) = self.server_task.take() {
task.abort();
}
}
}
pub struct TestServerBuilder {
handler: Option<Arc<dyn Handler>>,
#[cfg(feature = "websockets")]
websocket_handler: Option<Arc<dyn reinhardt_server::WebSocketHandler>>,
di_context: Option<Arc<InjectionContext>>,
http2: bool,
shutdown_timeout: Duration,
}
impl TestServerBuilder {
fn new() -> Self {
Self {
handler: None,
#[cfg(feature = "websockets")]
websocket_handler: None,
di_context: None,
http2: false,
shutdown_timeout: Duration::from_secs(5),
}
}
pub fn handler(mut self, handler: Arc<dyn Handler>) -> Self {
self.handler = Some(handler);
self
}
#[cfg(feature = "websockets")]
pub fn websocket_handler(
mut self,
handler: Arc<dyn reinhardt_server::WebSocketHandler>,
) -> Self {
self.websocket_handler = Some(handler);
self
}
pub fn di_context(mut self, context: Arc<InjectionContext>) -> Self {
self.di_context = Some(context);
self
}
pub fn http2(mut self, enabled: bool) -> Self {
self.http2 = enabled;
self
}
pub fn shutdown_timeout(mut self, timeout: Duration) -> Self {
self.shutdown_timeout = timeout;
self
}
pub async fn build(self) -> Result<TestServer, Box<dyn std::error::Error>> {
let listener = TcpListener::bind("127.0.0.1:0").await?;
let actual_addr = listener.local_addr()?;
let url = format!("http://{}", actual_addr);
let coordinator = Arc::new(ShutdownCoordinator::new(self.shutdown_timeout));
let server_coordinator = (*coordinator).clone();
#[cfg(feature = "websockets")]
let websocket_handler = self.websocket_handler;
let handler = self.handler;
let di_context = self.di_context;
let http2 = self.http2;
let server_task = tokio::spawn(async move {
#[cfg(feature = "websockets")]
if let Some(ws_handler) = websocket_handler {
drop(listener);
let server = WebSocketServer::from_arc(ws_handler);
let _ = server
.listen_with_shutdown(actual_addr, server_coordinator)
.await;
return;
}
if let Some(h) = handler {
if http2 {
drop(listener);
let server = reinhardt_server::Http2Server::new(h);
let _ = server
.listen_with_shutdown(actual_addr, server_coordinator)
.await;
} else {
let server = HttpServer::new(h);
let mut shutdown_rx = server_coordinator.subscribe();
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, socket_addr)) => {
let handler_clone = server.handler();
let di_ctx = di_context.clone();
tokio::spawn(async move {
if let Err(e) =
HttpServer::handle_connection(stream, socket_addr, handler_clone, di_ctx)
.await
{
eprintln!("Error handling connection: {:?}", e);
}
});
}
Err(e) => {
eprintln!("Error accepting connection: {:?}", e);
break;
}
}
}
_ = shutdown_rx.recv() => {
break;
}
}
}
}
}
});
wait_for_server_ready(actual_addr)
.await
.expect("Test server failed to become ready");
Ok(TestServer {
url,
addr: actual_addr,
coordinator,
server_task: Some(server_task),
})
}
}
const SERVER_READY_MAX_ATTEMPTS: u32 = 20;
const SERVER_READY_PROBE_INTERVAL_MS: u64 = 50;
async fn wait_for_server_ready(addr: SocketAddr) -> Result<(), std::io::Error> {
for attempt in 1..=SERVER_READY_MAX_ATTEMPTS {
match tokio::net::TcpStream::connect(addr).await {
Ok(_) => return Ok(()),
Err(_) if attempt < SERVER_READY_MAX_ATTEMPTS => {
tokio::time::sleep(Duration::from_millis(SERVER_READY_PROBE_INTERVAL_MS)).await;
}
Err(e) => {
return Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!(
"Server at {} not ready after {} attempts: {}",
addr, SERVER_READY_MAX_ATTEMPTS, e
),
));
}
}
}
Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!(
"Server at {} not ready after {} attempts",
addr, SERVER_READY_MAX_ATTEMPTS
),
))
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::*;
#[rstest]
#[tokio::test]
async fn test_basic_handler_returns_ok() {
let handler = BasicHandler;
let request = Request::builder()
.method(hyper::Method::GET)
.uri("/")
.build()
.expect("Failed to build request");
let response = handler.handle(request).await;
assert!(response.is_ok(), "Expected Ok response from BasicHandler");
let resp = response.unwrap();
assert_eq!(resp.status, hyper::StatusCode::OK);
}
#[rstest]
#[tokio::test]
async fn test_test_server_guard_starts() {
let router = Router::new();
let server = test_server_guard(router).await;
assert!(
server.url.starts_with("http://127.0.0.1:"),
"Expected URL to start with 'http://127.0.0.1:', got: {}",
server.url
);
}
#[rstest]
#[tokio::test]
async fn test_test_server_builder_default() {
let handler: Arc<dyn Handler> = Arc::new(BasicHandler);
let result = TestServer::builder().handler(handler).build().await;
assert!(
result.is_ok(),
"Expected TestServer::builder().handler().build() to succeed"
);
}
#[rstest]
#[tokio::test]
async fn test_test_server_url_format() {
let handler: Arc<dyn Handler> = Arc::new(BasicHandler);
let server = TestServer::builder()
.handler(handler)
.build()
.await
.expect("Failed to build TestServer");
assert!(
server.url.starts_with("http://127.0.0.1:"),
"Expected URL format 'http://127.0.0.1:<port>', got: {}",
server.url
);
assert!(
server.addr.port() > 0,
"Expected non-zero port, got: {}",
server.addr.port()
);
}
#[rstest]
#[tokio::test]
async fn test_test_server_responds_to_request() {
let handler: Arc<dyn Handler> = Arc::new(BasicHandler);
let server = TestServer::builder()
.handler(handler)
.build()
.await
.expect("Failed to build TestServer");
let client = reqwest::Client::new();
let response = client.get(&server.url).send().await;
assert!(response.is_ok(), "Expected GET request to succeed");
let resp = response.unwrap();
assert_eq!(resp.status(), reqwest::StatusCode::OK);
}
#[rstest]
fn test_http_client_fixture() {
let client = http_client();
let _: &reqwest::Client = &client;
}
#[rstest]
#[tokio::test]
async fn test_test_server_shutdown_timeout() {
let handler: Arc<dyn Handler> = Arc::new(BasicHandler);
let custom_timeout = Duration::from_secs(10);
let result = TestServer::builder()
.handler(handler)
.shutdown_timeout(custom_timeout)
.build()
.await;
assert!(
result.is_ok(),
"Expected TestServer with custom shutdown timeout to build successfully"
);
}
#[rstest]
#[tokio::test]
async fn test_wait_for_server_ready() {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("Failed to bind listener");
let addr = listener.local_addr().expect("Failed to get local addr");
let result = wait_for_server_ready(addr).await;
assert!(
result.is_ok(),
"Expected wait_for_server_ready to succeed for a bound address"
);
}
}