use std::sync::Arc;
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use criterion_hypothesis_core::protocol::{
BenchmarkListResponse, ClaimRequest, ClaimResponse, HealthResponse, ReleaseRequest,
ReleaseResponse, RunIterationRequest, RunIterationResponse, ShutdownResponse, CLAIM_HEADER,
};
use tokio::sync::{watch, Mutex};
use crate::BenchmarkRegistry;
use std::sync::atomic::{AtomicU64, Ordering};
const LOG_INTERVAL: u64 = 100;
struct AppState {
registry: Arc<BenchmarkRegistry>,
shutdown_tx: watch::Sender<bool>,
claim: Mutex<Option<String>>,
iteration_count: AtomicU64,
}
async fn health() -> Json<HealthResponse> {
Json(HealthResponse::healthy())
}
async fn list_benchmarks(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
) -> impl IntoResponse {
if let Err(response) = check_claim(&state, &headers).await {
return response;
}
let benchmarks = state.registry.list();
eprintln!("[harness] Listed {} benchmark(s)", benchmarks.len());
(StatusCode::OK, Json(BenchmarkListResponse::new(benchmarks))).into_response()
}
async fn run_iteration(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(request): Json<RunIterationRequest>,
) -> impl IntoResponse {
if let Err(response) = check_claim(&state, &headers).await {
return response;
}
if request.iterations == 0 {
return (
StatusCode::BAD_REQUEST,
Json(RunIterationResponse::failure("iterations must be >= 1")),
)
.into_response();
}
match state
.registry
.run(&request.benchmark_id, request.iterations)
{
Some(duration) => {
let count = state.iteration_count.fetch_add(1, Ordering::Relaxed) + 1;
if count % LOG_INTERVAL == 0 {
eprintln!("[harness] {} run calls completed", count);
}
(
StatusCode::OK,
Json(RunIterationResponse::success(request.iterations, duration)),
)
.into_response()
}
None => {
eprintln!("[harness] Benchmark '{}' not found", request.benchmark_id);
(
StatusCode::NOT_FOUND,
Json(RunIterationResponse::failure(format!(
"Benchmark '{}' not found",
request.benchmark_id
))),
)
.into_response()
}
}
}
async fn shutdown(State(state): State<Arc<AppState>>, headers: HeaderMap) -> impl IntoResponse {
if let Err(response) = check_claim(&state, &headers).await {
return response;
}
let _ = state.shutdown_tx.send(true);
(StatusCode::OK, Json(ShutdownResponse::acknowledged())).into_response()
}
async fn claim(
State(state): State<Arc<AppState>>,
Json(request): Json<ClaimRequest>,
) -> impl IntoResponse {
let mut claim = state.claim.lock().await;
match &*claim {
Some(existing) if existing != &request.nonce => {
eprintln!("[harness] Claim rejected - already claimed by another orchestrator");
(StatusCode::CONFLICT, Json(ClaimResponse::already_claimed()))
}
Some(_) => {
eprintln!("[harness] Claim refreshed (same nonce)");
(StatusCode::OK, Json(ClaimResponse::success()))
}
None => {
eprintln!(
"[harness] Claimed by orchestrator (nonce: {}...)",
&request.nonce[..8.min(request.nonce.len())]
);
*claim = Some(request.nonce);
(StatusCode::OK, Json(ClaimResponse::success()))
}
}
}
async fn release(
State(state): State<Arc<AppState>>,
Json(request): Json<ReleaseRequest>,
) -> impl IntoResponse {
let mut claim = state.claim.lock().await;
match &*claim {
Some(existing) if existing == &request.nonce => {
eprintln!("[harness] Released by orchestrator");
*claim = None;
(StatusCode::OK, Json(ReleaseResponse::success()))
}
_ => {
eprintln!("[harness] Release rejected - wrong nonce or not claimed");
(
StatusCode::BAD_REQUEST,
Json(ReleaseResponse { success: false }),
)
}
}
}
async fn check_claim(
state: &AppState,
headers: &HeaderMap,
) -> Result<(), axum::response::Response> {
let claim = state.claim.lock().await;
if let Some(expected_nonce) = &*claim {
match headers.get(CLAIM_HEADER) {
Some(value) => {
let provided = value.to_str().unwrap_or("");
if provided != expected_nonce {
return Err((
StatusCode::FORBIDDEN,
Json(serde_json::json!({
"error": "Invalid claim nonce"
})),
)
.into_response());
}
}
None => {
return Err((
StatusCode::FORBIDDEN,
Json(serde_json::json!({
"error": "Harness is claimed, X-Harness-Claim header required"
})),
)
.into_response());
}
}
}
Ok(())
}
fn build_router(state: Arc<AppState>) -> Router {
Router::new()
.route("/health", get(health))
.route("/benchmarks", get(list_benchmarks))
.route("/run", post(run_iteration))
.route("/shutdown", post(shutdown))
.route("/claim", post(claim))
.route("/release", post(release))
.with_state(state)
}
pub fn run_harness(registry: BenchmarkRegistry, port: u16) -> anyhow::Result<()> {
let runtime = tokio::runtime::Runtime::new()?;
runtime.block_on(async { run_harness_async(registry, port).await })
}
pub async fn run_harness_async(registry: BenchmarkRegistry, port: u16) -> anyhow::Result<()> {
let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
let state = Arc::new(AppState {
registry: Arc::new(registry),
shutdown_tx,
claim: Mutex::new(None),
iteration_count: AtomicU64::new(0),
});
let app = build_router(state);
let addr = format!("0.0.0.0:{}", port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
eprintln!("Benchmark harness listening on {}", addr);
axum::serve(listener, app)
.with_graceful_shutdown(async move {
while !*shutdown_rx.borrow() {
if shutdown_rx.changed().await.is_err() {
break;
}
}
eprintln!("Shutting down benchmark harness");
})
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::Request;
use std::time::Duration;
use tower::ServiceExt;
fn create_test_state() -> Arc<AppState> {
let mut registry = BenchmarkRegistry::new();
registry.register("test_bench", |_n| Duration::from_millis(42));
let (shutdown_tx, _) = watch::channel(false);
Arc::new(AppState {
registry: Arc::new(registry),
shutdown_tx,
claim: Mutex::new(None),
iteration_count: AtomicU64::new(0),
})
}
#[tokio::test]
async fn test_health_endpoint() {
let state = create_test_state();
let app = build_router(state);
let response = app
.oneshot(
Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let health: HealthResponse = serde_json::from_slice(&body).unwrap();
assert_eq!(health.status, "healthy");
}
#[tokio::test]
async fn test_list_benchmarks_endpoint() {
let state = create_test_state();
let app = build_router(state);
let response = app
.oneshot(
Request::builder()
.uri("/benchmarks")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let benchmarks: BenchmarkListResponse = serde_json::from_slice(&body).unwrap();
assert!(benchmarks.benchmarks.contains(&"test_bench".to_string()));
}
#[tokio::test]
async fn test_run_iteration_success() {
let state = create_test_state();
let app = build_router(state);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/run")
.header("content-type", "application/json")
.body(Body::from(
r#"{"benchmark_id": "test_bench", "iterations": 7}"#,
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let result: RunIterationResponse = serde_json::from_slice(&body).unwrap();
assert!(result.success);
assert_eq!(result.iterations, 7);
assert_eq!(result.duration_ns, 42_000_000); assert!(result.error.is_none());
}
#[tokio::test]
async fn test_run_iteration_zero_iterations_rejected() {
let state = create_test_state();
let app = build_router(state);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/run")
.header("content-type", "application/json")
.body(Body::from(
r#"{"benchmark_id": "test_bench", "iterations": 0}"#,
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn test_run_iteration_not_found() {
let state = create_test_state();
let app = build_router(state);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/run")
.header("content-type", "application/json")
.body(Body::from(
r#"{"benchmark_id": "nonexistent", "iterations": 1}"#,
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let result: RunIterationResponse = serde_json::from_slice(&body).unwrap();
assert!(!result.success);
assert_eq!(result.duration_ns, 0);
assert!(result.error.is_some());
assert!(result.error.unwrap().contains("nonexistent"));
}
#[tokio::test]
async fn test_shutdown_endpoint() {
let state = create_test_state();
let app = build_router(state);
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/shutdown")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let result: ShutdownResponse = serde_json::from_slice(&body).unwrap();
assert_eq!(result.status, "shutting_down");
}
}