use axum::{
extract::{Path, Query, State},
http::HeaderMap,
response::sse::{Event, KeepAlive, Sse},
Json,
};
use futures_util::stream::{self, Stream};
use mockforge_bench::ssrf::{validate_target_url, Policy as SsrfPolicy};
use mockforge_registry_core::models::test_run::EnqueueTestRun;
use serde::Deserialize;
use std::convert::Infallible;
use uuid::Uuid;
use crate::{
error::{ApiError, ApiResult},
handlers::usage::effective_limits,
middleware::{resolve_org_context, AuthUser},
models::{TestRun, TestSuite},
AppState,
};
const DEFAULT_LIMIT: i64 = 50;
const MAX_LIMIT: i64 = 500;
#[derive(Debug, Deserialize)]
pub struct TriggerRunRequest {
#[serde(default)]
pub triggered_by: Option<String>,
#[serde(default)]
pub git_ref: Option<String>,
#[serde(default)]
pub git_sha: Option<String>,
}
pub async fn trigger_run(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
Path(suite_id): Path<Uuid>,
headers: HeaderMap,
Json(request): Json<TriggerRunRequest>,
) -> ApiResult<Json<TestRun>> {
let suite = TestSuite::find_by_id(state.db.pool(), suite_id)
.await
.map_err(ApiError::Database)?
.ok_or_else(|| ApiError::InvalidRequest("Test suite not found".into()))?;
let ctx = resolve_org_context(&state, user_id, &headers, None)
.await
.map_err(|_| ApiError::InvalidRequest("Organization not found".into()))?;
let workspace = mockforge_registry_core::models::CloudWorkspace::find_by_id(
state.db.pool(),
suite.workspace_id,
)
.await?
.ok_or_else(|| ApiError::InvalidRequest("Workspace not found".into()))?;
if workspace.org_id != ctx.org_id {
return Err(ApiError::InvalidRequest("Test suite not found".into()));
}
let limits = effective_limits(&state, &ctx.org).await?;
let max_concurrent = limits.get("max_concurrent_runs").and_then(|v| v.as_i64()).unwrap_or(0);
if max_concurrent == 0 {
return Err(ApiError::ResourceLimitExceeded(
"Test execution is not enabled on this plan — upgrade to Pro or Team to run tests"
.into(),
));
}
if max_concurrent > 0 {
let inflight = TestRun::count_inflight(state.db.pool(), ctx.org_id)
.await
.map_err(ApiError::Database)?;
if inflight.total() >= max_concurrent {
return Err(ApiError::ResourceLimitExceeded(format!(
"Concurrent run limit reached ({}/{}). Wait for a run to finish or upgrade your plan.",
inflight.total(),
max_concurrent,
)));
}
}
let triggered_by = request.triggered_by.as_deref().unwrap_or("manual");
if !is_valid_trigger_source(triggered_by) {
return Err(ApiError::InvalidRequest(
"triggered_by must be one of: manual, schedule, ci, webhook".into(),
));
}
if let Some(target_url) = extract_target_url(&suite.config) {
validate_target_url(&target_url, ssrf_policy())
.await
.map_err(|e| ApiError::InvalidRequest(format!("target_url rejected: {}", e)))?;
}
let run = TestRun::enqueue(
state.db.pool(),
EnqueueTestRun {
suite_id: suite.id,
org_id: ctx.org_id,
kind: &suite.kind,
triggered_by,
triggered_by_user: Some(user_id),
git_ref: request.git_ref.as_deref(),
git_sha: request.git_sha.as_deref(),
},
)
.await
.map_err(ApiError::Database)?;
if let Err(e) = crate::run_queue::enqueue(
state.redis.as_ref(),
crate::run_queue::EnqueuedJob {
run_id: run.id,
org_id: run.org_id,
source_id: suite.id,
kind: &suite.kind,
payload: suite.config.clone(),
},
)
.await
{
tracing::error!(run_id = %run.id, error = %e, "failed to enqueue test_run");
}
Ok(Json(run))
}
#[derive(Debug, Deserialize)]
pub struct ListRunsQuery {
#[serde(default)]
pub limit: Option<i64>,
}
pub async fn list_suite_runs(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
Path(suite_id): Path<Uuid>,
Query(query): Query<ListRunsQuery>,
headers: HeaderMap,
) -> ApiResult<Json<Vec<TestRun>>> {
let suite = TestSuite::find_by_id(state.db.pool(), suite_id)
.await
.map_err(ApiError::Database)?
.ok_or_else(|| ApiError::InvalidRequest("Test suite not found".into()))?;
let ctx = resolve_org_context(&state, user_id, &headers, None)
.await
.map_err(|_| ApiError::InvalidRequest("Organization not found".into()))?;
let workspace = mockforge_registry_core::models::CloudWorkspace::find_by_id(
state.db.pool(),
suite.workspace_id,
)
.await?
.ok_or_else(|| ApiError::InvalidRequest("Workspace not found".into()))?;
if workspace.org_id != ctx.org_id {
return Err(ApiError::InvalidRequest("Test suite not found".into()));
}
let limit = query.limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT);
let runs = TestRun::list_by_suite(state.db.pool(), suite.id, limit)
.await
.map_err(ApiError::Database)?;
Ok(Json(runs))
}
#[derive(Debug, Deserialize)]
pub struct ListOrgRunsQuery {
#[serde(default)]
pub status: Option<String>,
#[serde(default)]
pub limit: Option<i64>,
}
pub async fn list_org_runs(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
Path(org_id): Path<Uuid>,
Query(query): Query<ListOrgRunsQuery>,
headers: HeaderMap,
) -> ApiResult<Json<Vec<TestRun>>> {
let ctx = resolve_org_context(&state, user_id, &headers, None)
.await
.map_err(|_| ApiError::InvalidRequest("Organization not found".into()))?;
if ctx.org_id != org_id {
return Err(ApiError::InvalidRequest("Cannot list runs for a different org".into()));
}
let limit = query.limit.unwrap_or(DEFAULT_LIMIT).clamp(1, MAX_LIMIT);
let runs = TestRun::list_by_org(state.db.pool(), org_id, query.status.as_deref(), limit)
.await
.map_err(ApiError::Database)?;
Ok(Json(runs))
}
pub async fn get_run(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
Path(id): Path<Uuid>,
headers: HeaderMap,
) -> ApiResult<Json<TestRun>> {
let run = load_authorized_run(&state, user_id, &headers, id).await?;
Ok(Json(run))
}
pub async fn stream_run_events(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
Path(id): Path<Uuid>,
headers: HeaderMap,
) -> ApiResult<Sse<impl Stream<Item = Result<Event, Infallible>>>> {
load_authorized_run(&state, user_id, &headers, id).await?;
let pool = state.db.pool().clone();
let cursor = EventCursor {
run_id: id,
pool,
seq: 0,
buffered: Vec::new(),
terminal_emitted: false,
};
let stream = stream::unfold(cursor, advance_event_cursor);
Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
}
async fn advance_event_cursor(
mut cursor: EventCursor,
) -> Option<(Result<Event, Infallible>, EventCursor)> {
if cursor.terminal_emitted {
return None;
}
if let Some(row) = cursor.buffered.pop() {
let payload = serde_json::json!({
"seq": row.seq,
"type": row.event_type,
"payload": row.payload,
"occurred_at": row.occurred_at,
});
let evt = Event::default().event(&row.event_type).data(payload.to_string());
return Some((Ok(evt), cursor));
}
tokio::time::sleep(std::time::Duration::from_millis(1000)).await;
let events: Vec<TestRunEventRow> = match sqlx::query_as::<_, TestRunEventRow>(
"SELECT seq, event_type, payload, occurred_at \
FROM test_run_events \
WHERE run_id = $1 AND seq > $2 \
ORDER BY seq ASC LIMIT 200",
)
.bind(cursor.run_id)
.bind(cursor.seq)
.fetch_all(&cursor.pool)
.await
{
Ok(rows) => rows,
Err(e) => {
let evt = Event::default()
.event("stream_error")
.data(serde_json::json!({ "error": e.to_string() }).to_string());
cursor.terminal_emitted = true;
return Some((Ok(evt), cursor));
}
};
for row in &events {
cursor.seq = row.seq.max(cursor.seq);
}
cursor.buffered = events.into_iter().rev().collect();
if let Some(row) = cursor.buffered.pop() {
let payload = serde_json::json!({
"seq": row.seq,
"type": row.event_type,
"payload": row.payload,
"occurred_at": row.occurred_at,
});
let evt = Event::default().event(&row.event_type).data(payload.to_string());
return Some((Ok(evt), cursor));
}
let terminal = matches!(
sqlx::query_as::<_, (String,)>("SELECT status FROM test_runs WHERE id = $1")
.bind(cursor.run_id)
.fetch_optional(&cursor.pool)
.await,
Ok(Some((ref s,))) if matches!(
s.as_str(),
"passed" | "failed" | "cancelled" | "errored"
)
);
if !terminal {
let evt = Event::default().event("ping").data("{}");
return Some((Ok(evt), cursor));
}
let final_payload = match sqlx::query_as::<_, (String, Option<i32>, Option<serde_json::Value>)>(
"SELECT status, runner_seconds, summary FROM test_runs WHERE id = $1",
)
.bind(cursor.run_id)
.fetch_optional(&cursor.pool)
.await
{
Ok(Some((status, runner_seconds, summary))) => serde_json::json!({
"status": status,
"runner_seconds": runner_seconds,
"summary": summary,
}),
_ => serde_json::json!({ "status": "unknown" }),
};
cursor.terminal_emitted = true;
let evt = Event::default().event("done").data(final_payload.to_string());
Some((Ok(evt), cursor))
}
struct EventCursor {
run_id: Uuid,
pool: sqlx::PgPool,
seq: i32,
buffered: Vec<TestRunEventRow>,
terminal_emitted: bool,
}
#[derive(sqlx::FromRow)]
struct TestRunEventRow {
seq: i32,
event_type: String,
payload: serde_json::Value,
occurred_at: chrono::DateTime<chrono::Utc>,
}
pub async fn cancel_run(
State(state): State<AppState>,
AuthUser(user_id): AuthUser,
Path(id): Path<Uuid>,
headers: HeaderMap,
) -> ApiResult<Json<TestRun>> {
load_authorized_run(&state, user_id, &headers, id).await?;
let updated = TestRun::cancel(state.db.pool(), id)
.await
.map_err(ApiError::Database)?
.ok_or_else(|| {
ApiError::InvalidRequest(
"Run is not cancellable (already terminal or not found)".into(),
)
})?;
Ok(Json(updated))
}
async fn load_authorized_run(
state: &AppState,
user_id: Uuid,
headers: &HeaderMap,
id: Uuid,
) -> ApiResult<TestRun> {
let run = TestRun::find_by_id(state.db.pool(), id)
.await
.map_err(ApiError::Database)?
.ok_or_else(|| ApiError::InvalidRequest("Test run not found".into()))?;
let ctx = resolve_org_context(state, user_id, headers, None)
.await
.map_err(|_| ApiError::InvalidRequest("Organization not found".into()))?;
if ctx.org_id != run.org_id {
return Err(ApiError::InvalidRequest("Test run not found".into()));
}
Ok(run)
}
fn is_valid_trigger_source(s: &str) -> bool {
matches!(s, "manual" | "schedule" | "ci" | "webhook")
}
fn extract_target_url(config: &serde_json::Value) -> Option<String> {
let raw = config.get("target_url")?.as_str()?.trim();
if raw.is_empty() {
None
} else {
Some(raw.to_string())
}
}
fn ssrf_policy() -> SsrfPolicy {
match std::env::var("MOCKFORGE_SSRF_ALLOW_LOOPBACK").as_deref() {
Ok("1") | Ok("true") => SsrfPolicy::for_test(),
_ => SsrfPolicy::strict(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn trigger_source_accepts_canonical_values() {
assert!(is_valid_trigger_source("manual"));
assert!(is_valid_trigger_source("schedule"));
assert!(is_valid_trigger_source("ci"));
assert!(is_valid_trigger_source("webhook"));
}
#[test]
fn trigger_source_rejects_others() {
assert!(!is_valid_trigger_source("MANUAL"));
assert!(!is_valid_trigger_source(""));
assert!(!is_valid_trigger_source("api"));
}
#[test]
fn extract_target_url_pulls_string() {
assert_eq!(
extract_target_url(&json!({"target_url": "https://api.example.com"})),
Some("https://api.example.com".to_string())
);
assert_eq!(
extract_target_url(&json!({"target_url": " https://x.com "})),
Some("https://x.com".to_string())
);
}
#[test]
fn extract_target_url_none_when_missing_blank_or_wrong_type() {
assert!(extract_target_url(&json!({})).is_none());
assert!(extract_target_url(&json!({"target_url": ""})).is_none());
assert!(extract_target_url(&json!({"target_url": " "})).is_none());
assert!(extract_target_url(&json!({"target_url": 42})).is_none());
assert!(extract_target_url(&json!({"other": "https://x.com"})).is_none());
}
#[tokio::test]
async fn ssrf_policy_blocks_loopback_target_in_strict_mode() {
std::env::remove_var("MOCKFORGE_SSRF_ALLOW_LOOPBACK");
let policy = ssrf_policy();
let err = validate_target_url("http://127.0.0.1/", policy).await.unwrap_err();
assert!(err.to_string().contains("loopback"), "got: {err}");
}
#[tokio::test]
async fn ssrf_policy_blocks_metadata_ip() {
std::env::remove_var("MOCKFORGE_SSRF_ALLOW_LOOPBACK");
let policy = ssrf_policy();
let err = validate_target_url("http://169.254.169.254/latest/meta-data/", policy)
.await
.unwrap_err();
assert!(err.to_string().contains("link-local"), "got: {err}");
}
#[tokio::test]
async fn ssrf_policy_loose_allows_loopback() {
std::env::set_var("MOCKFORGE_SSRF_ALLOW_LOOPBACK", "1");
let policy = ssrf_policy();
validate_target_url("http://127.0.0.1:8080/", policy).await.unwrap();
std::env::remove_var("MOCKFORGE_SSRF_ALLOW_LOOPBACK");
}
}