use axum::{
extract::{Query, State},
http::{header::AUTHORIZATION, HeaderMap, StatusCode},
Json,
};
use mlua_swarm::{CapToken, ContentRef, OutputEvent, TaskId, WorkerPayload};
use serde::Deserialize;
use serde_json::Value;
use crate::{ApiError, AppState};
#[derive(Debug, Deserialize)]
pub struct PromptQuery {
pub task_id: String,
}
pub async fn worker_prompt(
State(state): State<AppState>,
headers: HeaderMap,
Query(q): Query<PromptQuery>,
) -> Result<Json<WorkerPayload>, ApiError> {
let task_id = TaskId(q.task_id.clone());
let bearer = extract_bearer_raw(&headers)?;
let payload = if let Some(handle) = parse_worker_handle(&bearer) {
let resolved = state
.engine
.task_id_from_handle(handle)
.await
.map_err(|e| ApiError::engine(format!("task_id_from_handle: {e}")))?;
if resolved != task_id {
return Err(ApiError::bad_request(format!(
"handle {handle} is bound to task {}, not {}",
resolved.0, task_id.0
)));
}
state
.engine
.fetch_worker_payload_trusted(&task_id)
.await
.map_err(|e| ApiError::engine(format!("fetch_worker_payload_trusted: {e}")))?
} else {
let token = CapToken::decode(bearer.trim())
.map_err(|e| ApiError::bad_request(format!("invalid token: {e}")))?;
state
.engine
.fetch_worker_payload(&token, &task_id)
.await
.map_err(|e| ApiError::engine(format!("fetch_worker_payload: {e}")))?
};
Ok(Json(payload))
}
#[derive(Debug, Deserialize)]
pub struct WorkerResultReq {
pub task_id: String,
pub value: Value,
#[serde(default = "default_ok_true")]
pub ok: bool,
#[serde(default)]
pub attempt: Option<u32>,
}
fn default_ok_true() -> bool {
true
}
pub async fn worker_result(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<WorkerResultReq>,
) -> Result<StatusCode, ApiError> {
let token = decode_worker_bearer(&headers)?;
let task_id = TaskId(req.task_id);
let attempt = match req.attempt {
Some(n) => n,
None => state
.engine
.task_attempt(&task_id)
.await
.map_err(|e| ApiError::engine(format!("task_attempt: {e}")))?,
};
let event = OutputEvent::Final {
content: ContentRef::Inline {
value: req.value.clone(),
},
ok: req.ok,
};
state
.engine
.submit_output(&token, &task_id, attempt, event)
.await
.map_err(|e| ApiError::engine(format!("submit_output: {e}")))?;
state
.engine
.post_result(&token, &task_id, req.value)
.await
.map_err(|e| ApiError::engine(format!("post_result: {e}")))?;
Ok(StatusCode::NO_CONTENT)
}
#[derive(Debug, Deserialize, Default)]
pub struct SubmitQuery {
#[serde(default)]
pub ok: Option<bool>,
}
pub async fn worker_submit(
State(state): State<AppState>,
headers: HeaderMap,
Query(q): Query<SubmitQuery>,
body: axum::body::Bytes,
) -> Result<StatusCode, ApiError> {
let bearer = extract_bearer_raw(&headers)?;
let task_id = if let Some(handle) = parse_worker_handle(&bearer) {
state
.engine
.task_id_from_handle(handle)
.await
.map_err(|e| ApiError::engine(format!("task_id_from_handle: {e}")))?
} else {
let token = CapToken::decode(bearer.trim())
.map_err(|e| ApiError::bad_request(format!("invalid token: {e}")))?;
state
.engine
.task_id_from_token(&token)
.await
.map_err(|e| ApiError::engine(format!("task_id_from_token: {e}")))?
};
let attempt = state
.engine
.task_attempt(&task_id)
.await
.map_err(|e| ApiError::engine(format!("task_attempt: {e}")))?;
let body_str = String::from_utf8_lossy(&body).trim_end().to_string();
let value = Value::String(body_str);
let ok = q.ok.unwrap_or(true);
state
.engine
.submit_worker_result_trusted(&task_id, attempt, value, ok)
.await
.map_err(|e| ApiError::engine(format!("submit_worker_result_trusted: {e}")))?;
Ok(StatusCode::NO_CONTENT)
}
fn extract_bearer_raw(headers: &HeaderMap) -> Result<String, ApiError> {
let v = headers
.get(AUTHORIZATION)
.ok_or_else(|| ApiError::bad_request("missing Authorization header".into()))?
.to_str()
.map_err(|_| ApiError::bad_request("invalid Authorization header encoding".into()))?;
let s = v
.strip_prefix("Bearer ")
.ok_or_else(|| ApiError::bad_request("Authorization must be 'Bearer <token>'".into()))?
.trim();
if s.is_empty() {
return Err(ApiError::bad_request("Bearer is empty".into()));
}
Ok(s.to_string())
}
fn parse_worker_handle(s: &str) -> Option<&str> {
let s = s.trim();
if s.starts_with("wh-")
&& s.len() >= 5
&& s.len() <= 64
&& s[3..].chars().all(|c| c.is_ascii_alphanumeric())
{
Some(s)
} else {
None
}
}
fn decode_worker_bearer(headers: &HeaderMap) -> Result<CapToken, ApiError> {
let v = headers
.get(AUTHORIZATION)
.ok_or_else(|| ApiError::bad_request("missing Authorization header".into()))?
.to_str()
.map_err(|_| ApiError::bad_request("invalid Authorization header encoding".into()))?;
let encoded = v
.strip_prefix("Bearer ")
.ok_or_else(|| ApiError::bad_request("Authorization must be 'Bearer <token>'".into()))?
.trim();
if encoded.is_empty() {
return Err(ApiError::bad_request("Bearer token is empty".into()));
}
CapToken::decode(encoded).map_err(|e| ApiError::bad_request(format!("invalid token: {e}")))
}