use axum::{
extract::{Path, State},
http::StatusCode,
response::{
sse::{Event, Sse},
Json,
},
routing::{delete, get, post},
Router,
};
use dashmap::DashMap;
use futures::stream::{self, Stream};
use mockforge_bench::conformance::{
ConformanceConfig, ConformanceProgress, NativeConformanceExecutor,
};
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::sync::Arc;
use tokio::sync::broadcast;
use tracing::{error, info};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum RunStatus {
Pending,
Running,
Completed,
Failed,
}
#[derive(Debug, Clone, Serialize)]
pub struct ConformanceRun {
pub id: Uuid,
pub status: RunStatus,
pub config: ConformanceRunRequest,
#[serde(skip_serializing_if = "Option::is_none")]
pub report: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
pub checks_done: usize,
pub total_checks: usize,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ConformanceRunRequest {
pub target_url: String,
#[serde(default)]
pub spec: Option<String>,
#[serde(default)]
pub categories: Option<Vec<String>>,
#[serde(default)]
pub custom_headers: Option<Vec<(String, String)>>,
#[serde(default)]
pub api_key: Option<String>,
#[serde(default)]
pub basic_auth: Option<String>,
#[serde(default)]
pub skip_tls_verify: Option<bool>,
#[serde(default)]
pub base_path: Option<String>,
#[serde(default)]
pub all_operations: Option<bool>,
#[serde(default)]
pub request_delay_ms: Option<u64>,
#[serde(default)]
pub custom_checks_yaml: Option<String>,
}
#[derive(Clone)]
pub struct ConformanceState {
pub runs: Arc<DashMap<Uuid, ConformanceRun>>,
pub progress_channels: Arc<DashMap<Uuid, broadcast::Sender<ConformanceProgress>>>,
}
impl ConformanceState {
pub fn new() -> Self {
Self {
runs: Arc::new(DashMap::new()),
progress_channels: Arc::new(DashMap::new()),
}
}
}
impl Default for ConformanceState {
fn default() -> Self {
Self::new()
}
}
pub fn conformance_router(state: ConformanceState) -> Router {
Router::new()
.route("/run", post(start_run))
.route("/run/{id}", get(get_run))
.route("/run/{id}", delete(delete_run))
.route("/run/{id}/stream", get(stream_progress))
.route("/runs", get(list_runs))
.with_state(state)
}
async fn start_run(
State(state): State<ConformanceState>,
Json(req): Json<ConformanceRunRequest>,
) -> Result<Json<serde_json::Value>, StatusCode> {
if req.target_url.is_empty() {
return Err(StatusCode::BAD_REQUEST);
}
let id = Uuid::new_v4();
let (tx, _) = broadcast::channel(256);
state.progress_channels.insert(id, tx.clone());
let run = ConformanceRun {
id,
status: RunStatus::Pending,
config: req.clone(),
report: None,
error: None,
checks_done: 0,
total_checks: 0,
};
state.runs.insert(id, run);
let runs = state.runs.clone();
let channels = state.progress_channels.clone();
tokio::spawn(async move {
let config = ConformanceConfig {
target_url: req.target_url.clone(),
api_key: req.api_key.clone(),
basic_auth: req.basic_auth.clone(),
skip_tls_verify: req.skip_tls_verify.unwrap_or(false),
categories: req.categories.clone(),
base_path: req.base_path.clone(),
custom_headers: req.custom_headers.clone().unwrap_or_default(),
output_dir: None,
all_operations: req.all_operations.unwrap_or(false),
custom_checks_file: None,
request_delay_ms: req.request_delay_ms.unwrap_or(0),
custom_filter: None,
export_requests: false,
validate_requests: false,
};
let executor = match NativeConformanceExecutor::new(config) {
Ok(e) => e,
Err(e) => {
if let Some(mut run) = runs.get_mut(&id) {
run.status = RunStatus::Failed;
run.error = Some(format!("Failed to create executor: {}", e));
}
let _ = tx.send(ConformanceProgress::Error {
message: e.to_string(),
});
return;
}
};
let executor = executor.with_reference_checks();
if let Some(mut run) = runs.get_mut(&id) {
run.status = RunStatus::Running;
run.total_checks = executor.check_count();
}
let (progress_tx, mut progress_rx) = tokio::sync::mpsc::channel(256);
let broadcast_tx = tx.clone();
let runs_for_progress = runs.clone();
tokio::spawn(async move {
while let Some(progress) = progress_rx.recv().await {
if let ConformanceProgress::CheckCompleted { checks_done, .. } = &progress {
if let Some(mut run) = runs_for_progress.get_mut(&id) {
run.checks_done = *checks_done;
}
}
let _ = broadcast_tx.send(progress);
}
});
match executor.execute_with_progress(progress_tx).await {
Ok(report) => {
let report_json = report.to_json();
if let Some(mut run) = runs.get_mut(&id) {
run.status = RunStatus::Completed;
run.report = Some(report_json);
run.checks_done = run.total_checks;
}
info!("Conformance run {} completed", id);
}
Err(e) => {
if let Some(mut run) = runs.get_mut(&id) {
run.status = RunStatus::Failed;
run.error = Some(format!("{}", e));
}
error!("Conformance run {} failed: {}", id, e);
}
}
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
channels.remove(&id);
});
Ok(Json(serde_json::json!({ "id": id })))
}
async fn get_run(
State(state): State<ConformanceState>,
Path(id): Path<Uuid>,
) -> Result<Json<ConformanceRun>, StatusCode> {
state.runs.get(&id).map(|run| Json(run.clone())).ok_or(StatusCode::NOT_FOUND)
}
async fn list_runs(State(state): State<ConformanceState>) -> Json<Vec<serde_json::Value>> {
let runs: Vec<serde_json::Value> = state
.runs
.iter()
.map(|entry| {
let run = entry.value();
serde_json::json!({
"id": run.id,
"status": run.status,
"checks_done": run.checks_done,
"total_checks": run.total_checks,
"target_url": run.config.target_url,
})
})
.collect();
Json(runs)
}
async fn delete_run(State(state): State<ConformanceState>, Path(id): Path<Uuid>) -> StatusCode {
if let Some((_, run)) = state.runs.remove(&id) {
if run.status == RunStatus::Running {
state.runs.insert(id, run);
return StatusCode::CONFLICT;
}
state.progress_channels.remove(&id);
StatusCode::NO_CONTENT
} else {
StatusCode::NOT_FOUND
}
}
async fn stream_progress(
State(state): State<ConformanceState>,
Path(id): Path<Uuid>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
let rx = state
.progress_channels
.get(&id)
.map(|entry| entry.subscribe())
.ok_or(StatusCode::NOT_FOUND)?;
let stream = stream::unfold(rx, |mut rx| async move {
match rx.recv().await {
Ok(progress) => {
let data = serde_json::to_string(&progress).unwrap_or_default();
let event = Event::default().event("conformance_progress").data(data);
Some((Ok(event), rx))
}
Err(broadcast::error::RecvError::Closed) => None,
Err(broadcast::error::RecvError::Lagged(skipped)) => {
let event = Event::default().event("conformance_progress").data(format!(
r#"{{"type":"error","message":"lagged, skipped {} events"}}"#,
skipped
));
Some((Ok(event), rx))
}
}
});
Ok(Sse::new(stream).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(std::time::Duration::from_secs(15))
.text("keep-alive"),
))
}