use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::Duration;
use axum::Router as AxumRouter;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use crate::error::{ServiceError, ServiceResult};
use crate::router::ServiceRouter;
use crate::server::axum::{bind_random_port, AxumServer};
use crate::server::ServerConfig;
#[derive(Debug, Clone)]
pub struct TestHarnessConfig {
pub service_name: String,
pub host: String,
pub enable_metrics: bool,
pub enable_tracing: bool,
}
impl Default for TestHarnessConfig {
fn default() -> Self {
Self {
service_name: "test-service".to_string(),
host: "127.0.0.1".to_string(),
enable_metrics: false,
enable_tracing: false,
}
}
}
struct Running {
addr: SocketAddr,
shutdown: Option<oneshot::Sender<()>>,
handle: Option<JoinHandle<()>>,
}
pub struct TestHarness {
config: TestHarnessConfig,
service_router: Option<ServiceRouter>,
extra_routes: Option<AxumRouter>,
running: Option<Running>,
client: reqwest::Client,
}
impl std::fmt::Debug for TestHarness {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestHarness")
.field("config", &self.config)
.field("running", &self.running.as_ref().map(|r| r.addr))
.finish()
}
}
impl TestHarness {
pub fn new() -> Self {
Self::with_config(TestHarnessConfig::default())
}
pub fn with_config(config: TestHarnessConfig) -> Self {
Self {
config,
service_router: None,
extra_routes: None,
running: None,
client: reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.build()
.expect("build reqwest client"),
}
}
pub fn with_router(mut self, router: ServiceRouter) -> Self {
self.service_router = Some(router);
self
}
pub fn with_routes(mut self, routes: AxumRouter) -> Self {
self.extra_routes = match self.extra_routes.take() {
Some(existing) => Some(existing.merge(routes)),
None => Some(routes),
};
self
}
pub fn config(&self) -> &TestHarnessConfig {
&self.config
}
pub fn addr(&self) -> Option<SocketAddr> {
self.running.as_ref().map(|r| r.addr)
}
pub fn base_url(&self) -> Option<String> {
self.addr().map(|a| format!("http://{a}"))
}
pub async fn start(&mut self) -> ServiceResult<()> {
if self.running.is_some() {
return Err(ServiceError::Configuration("harness already started".into()));
}
let router = self.service_router.take().unwrap_or_default();
let (listener, addr) = bind_random_port(&self.config.host).await?;
let mut server = AxumServer::with_config(
router,
ServerConfig {
addr,
name: self.config.service_name.clone(),
..ServerConfig::default()
},
);
if let Some(extra) = self.extra_routes.take() {
server = server.merge(extra);
}
let app = server.app();
let (tx, rx) = oneshot::channel::<()>();
let handle = tokio::spawn(async move {
let _ = axum::serve(listener, app)
.with_graceful_shutdown(async move {
let _ = rx.await;
})
.await;
});
self.running = Some(Running {
addr,
shutdown: Some(tx),
handle: Some(handle),
});
Ok(())
}
pub async fn stop(&mut self) -> ServiceResult<()> {
if let Some(mut running) = self.running.take() {
if let Some(tx) = running.shutdown.take() {
let _ = tx.send(());
}
if let Some(handle) = running.handle.take() {
let _ = handle.await;
}
}
Ok(())
}
pub async fn request(
&self,
method: &str,
path: &str,
body: Option<Vec<u8>>,
headers: Option<HashMap<String, String>>,
) -> ServiceResult<TestResponse> {
let base = self
.base_url()
.ok_or_else(|| ServiceError::Configuration("harness not started".into()))?;
let url = format!("{}{}", base, path);
let method = reqwest::Method::from_bytes(method.as_bytes())
.map_err(|e| ServiceError::InvalidArgument(format!("bad method {method}: {e}")))?;
let mut req = self.client.request(method, &url);
if let Some(headers) = headers {
for (k, v) in headers {
req = req.header(k, v);
}
}
if let Some(body) = body {
req = req.body(body);
}
let resp = req
.send()
.await
.map_err(|e| ServiceError::Internal(format!("send: {e}")))?;
let status = resp.status().as_u16();
let mut header_map = HashMap::new();
for (k, v) in resp.headers() {
if let Ok(s) = v.to_str() {
header_map.insert(k.as_str().to_string(), s.to_string());
}
}
let body = resp
.bytes()
.await
.map_err(|e| ServiceError::Internal(format!("read body: {e}")))?
.to_vec();
Ok(TestResponse {
status,
body,
headers: header_map,
})
}
pub async fn get(&self, path: &str) -> ServiceResult<TestResponse> {
self.request("GET", path, None, None).await
}
pub async fn post(&self, path: &str, body: impl Into<Vec<u8>>) -> ServiceResult<TestResponse> {
self.request("POST", path, Some(body.into()), None).await
}
pub async fn wait_for_ready(&self, path: &str, timeout: Duration) -> ServiceResult<()> {
let deadline = tokio::time::Instant::now() + timeout;
loop {
if let Ok(resp) = self.get(path).await {
if resp.is_success() {
return Ok(());
}
}
if tokio::time::Instant::now() >= deadline {
return Err(ServiceError::DeadlineExceeded(format!(
"{path} not ready within {timeout:?}"
)));
}
tokio::time::sleep(Duration::from_millis(25)).await;
}
}
}
impl Default for TestHarness {
fn default() -> Self {
Self::new()
}
}
impl Drop for TestHarness {
fn drop(&mut self) {
if let Some(mut running) = self.running.take() {
if let Some(tx) = running.shutdown.take() {
let _ = tx.send(());
}
}
}
}
#[derive(Debug, Clone)]
pub struct TestResponse {
pub status: u16,
pub body: Vec<u8>,
pub headers: HashMap<String, String>,
}
impl TestResponse {
pub fn body_text(&self) -> String {
String::from_utf8_lossy(&self.body).to_string()
}
pub fn is_success(&self) -> bool {
(200..300).contains(&self.status)
}
pub fn get_header(&self, key: &str) -> Option<&String> {
self.headers.get(key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{routing::get, Router as AxumRouter};
#[tokio::test]
async fn test_harness_new() {
let harness = TestHarness::new();
assert_eq!(harness.config.service_name, "test-service");
assert!(harness.addr().is_none());
}
#[tokio::test]
async fn test_harness_start_stop() {
let mut harness = TestHarness::new();
harness.start().await.unwrap();
let addr = harness.addr().expect("bound");
assert_ne!(addr.port(), 0);
harness.stop().await.unwrap();
assert!(harness.addr().is_none());
}
#[tokio::test]
async fn test_harness_real_request() {
let extra = AxumRouter::new().route("/ping", get(|| async { "pong" }));
let mut harness = TestHarness::new().with_routes(extra);
harness.start().await.unwrap();
let resp = harness.get("/ping").await.unwrap();
assert!(resp.is_success());
assert_eq!(resp.body_text(), "pong");
harness.stop().await.unwrap();
}
#[tokio::test]
async fn test_harness_404_on_missing_route() {
let mut harness = TestHarness::new();
harness.start().await.unwrap();
let resp = harness.get("/nope").await.unwrap();
assert_eq!(resp.status, 404);
harness.stop().await.unwrap();
}
#[tokio::test]
async fn test_harness_wait_for_ready() {
let extra = AxumRouter::new().route("/health", get(|| async { "ok" }));
let mut harness = TestHarness::new().with_routes(extra);
harness.start().await.unwrap();
harness
.wait_for_ready("/health", Duration::from_secs(2))
.await
.unwrap();
harness.stop().await.unwrap();
}
#[tokio::test]
async fn test_response_helpers() {
let response = TestResponse {
status: 200,
body: b"hello".to_vec(),
headers: HashMap::new(),
};
assert_eq!(response.body_text(), "hello");
assert!(response.is_success());
}
}