use axum::{
extract::{Multipart, State},
http::{header, HeaderMap, StatusCode},
response::IntoResponse,
routing::post,
Json, Router,
};
use base64::{engine::general_purpose::STANDARD as B64, Engine};
use chrono::Utc;
use serde_json::json;
use sha2::{Digest, Sha256};
use dragoon_proto::{
models::{
LogChunk, SignedTask, WorkerFinish, WorkerInitRequest, WorkerInitResponse,
WorkerPollRequest, WorkerPollResponse, WorkerStatus,
},
pubkey::parse_pubkey_blob,
task_sig::sign_task,
};
use crate::{
app::AppState, messages_repo, server_keys, storage, tasks_repo, users_repo, workers_repo,
};
pub fn router(state: AppState) -> Router {
Router::new()
.route("/v1/worker/init", post(init))
.route("/v1/worker/poll", post(poll))
.route("/v1/worker/log", post(log_chunk))
.route("/v1/worker/finish", post(finish))
.route("/v1/worker/blob", post(blob))
.with_state(state)
}
fn worker_from_bearer(
state: &AppState,
headers: &HeaderMap,
) -> Result<workers_repo::Worker, StatusCode> {
let auth = headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.strip_prefix("Bearer "))
.ok_or(StatusCode::UNAUTHORIZED)?;
let conn = state.conn.lock().unwrap();
workers_repo::lookup_by_token(&conn, auth)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::UNAUTHORIZED)
}
async fn init(
State(state): State<AppState>,
Json(req): Json<WorkerInitRequest>,
) -> Result<Json<WorkerInitResponse>, StatusCode> {
let client_pub: Option<Vec<u8>> = if let Some(b64) = &req.client_pubkey_b64 {
let bytes = B64
.decode(b64.as_bytes())
.map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?;
parse_pubkey_blob(&bytes).map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?;
Some(bytes)
} else {
None
};
let conn = state.conn.lock().unwrap();
let (token, wid) =
workers_repo::finalize_register(&conn, &req.name, &req.register_code, client_pub.as_deref())
.map_err(|_| StatusCode::UNAUTHORIZED)?;
let pub_blob = server_keys::get_public_blob(&conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let fp = server_keys::fingerprint(&conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(WorkerInitResponse {
worker_token: token,
worker_id: wid,
server_pubkey_b64: B64.encode(&pub_blob),
server_key_fingerprint: fp,
}))
}
async fn poll(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<WorkerPollRequest>,
) -> Result<Json<WorkerPollResponse>, StatusCode> {
let w = worker_from_bearer(&state, &headers)?;
let conn = state.conn.lock().unwrap();
workers_repo::update_status(
&conn,
&w.name,
match req.status {
WorkerStatus::Idle => "IDLE",
WorkerStatus::Busy => "BUSY",
},
None,
req.current_task_id.as_deref(),
)
.ok();
if let Some(current) = &req.current_task_id {
match tasks_repo::get_task(&conn, current).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? {
None => {
return Ok(Json(WorkerPollResponse {
task: None,
cancel_task_id: None,
}));
}
Some(_) => {
let cancel = tasks_repo::consume_cancel_signal(&conn, current)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
return Ok(Json(WorkerPollResponse {
task: None,
cancel_task_id: if cancel { Some(current.clone()) } else { None },
}));
}
}
}
if !matches!(req.status, WorkerStatus::Idle) {
return Ok(Json(WorkerPollResponse {
task: None,
cancel_task_id: None,
}));
}
let nxt = tasks_repo::next_queued_for_worker(&conn, &w.name)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let Some(nxt) = nxt else {
return Ok(Json(WorkerPollResponse {
task: None,
cancel_task_id: None,
}));
};
let started = Utc::now();
let t = tasks_repo::transition(
&conn,
&nxt.task_id,
dragoon_proto::models::TaskState::Running,
tasks_repo::TransitionUpdate {
started_at: Some(started),
..Default::default()
},
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let priv_key = server_keys::get_private(&conn).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let (signature_b64, server_key_fingerprint) = sign_task(&priv_key, &t);
Ok(Json(WorkerPollResponse {
task: Some(SignedTask {
task: t,
signature_b64,
server_key_fingerprint,
}),
cancel_task_id: None,
}))
}
async fn log_chunk(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<LogChunk>,
) -> Result<Json<serde_json::Value>, StatusCode> {
let w = worker_from_bearer(&state, &headers)?;
{
let conn = state.conn.lock().unwrap();
let t = tasks_repo::get_task(&conn, &req.task_id)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
if t.worker_name != w.name {
return Err(StatusCode::FORBIDDEN);
}
}
let data = if req.data_b64.is_empty() {
Vec::new()
} else {
B64.decode(req.data_b64.as_bytes())
.map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?
};
if !data.is_empty() {
let _ = storage::append_log(&state.settings.blobs_dir(), &req.task_id, req.stream, &data);
}
Ok(Json(json!({"ok": true})))
}
async fn finish(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<WorkerFinish>,
) -> Result<Json<serde_json::Value>, StatusCode> {
use dragoon_proto::models::TaskState;
let w = worker_from_bearer(&state, &headers)?;
let conn = state.conn.lock().unwrap();
let t = tasks_repo::get_task(&conn, &req.task_id)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
.ok_or(StatusCode::NOT_FOUND)?;
if t.worker_name != w.name {
return Err(StatusCode::FORBIDDEN);
}
let new_state = if req.error.as_deref() == Some("timeout") {
TaskState::Timeout
} else if tasks_repo::consume_cancel_signal(&conn, &req.task_id)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
TaskState::Cancelled
} else if req.exit_code == Some(0) && req.error.is_none() {
TaskState::Completed
} else {
TaskState::Failed
};
if matches!(new_state, TaskState::Cancelled) && t.state == TaskState::Running {
let _ = tasks_repo::transition(
&conn,
&req.task_id,
TaskState::Cancelling,
tasks_repo::TransitionUpdate::default(),
);
}
let new_t = tasks_repo::transition(
&conn,
&req.task_id,
new_state,
tasks_repo::TransitionUpdate {
finished_at: Some(Utc::now()),
exit_code: req.exit_code,
final_pwd: req.final_pwd.clone(),
error: req.error.clone(),
..Default::default()
},
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
workers_repo::update_status(
&conn,
&w.name,
"IDLE",
req.final_pwd.as_deref(),
None,
)
.ok();
for art in &req.artifacts {
let exists: Option<i64> = conn
.query_row(
"SELECT id FROM artifacts WHERE task_id=? AND path=? AND sha256=?",
rusqlite::params![req.task_id, art.path, art.sha256],
|r| r.get(0),
)
.ok();
if exists.is_none() {
tasks_repo::add_artifact(&conn, &req.task_id, art, "")
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}
}
if let Some(submitter_id) = users_repo::lookup_user_id_by_name(&conn, &t.submitter)
.ok()
.flatten()
{
let kind = if req.error.is_none() {
"task_state"
} else {
"worker_error"
};
let payload = json!({
"state": match new_t.state {
TaskState::Completed => "COMPLETED",
TaskState::Failed => "FAILED",
TaskState::Timeout => "TIMEOUT",
TaskState::Cancelled => "CANCELLED",
TaskState::Cancelling => "CANCELLING",
TaskState::Running => "RUNNING",
TaskState::Queued => "QUEUED",
},
"exit_code": req.exit_code,
"final_pwd": req.final_pwd,
"task_payload": t.payload,
"worker_seq": t.worker_seq,
});
let _ = messages_repo::enqueue(
&conn,
submitter_id,
kind,
Some(&req.task_id),
Some(&w.name),
req.error.as_deref(),
&payload,
);
}
Ok(Json(json!({"ok": true, "state": match new_t.state {
TaskState::Completed => "COMPLETED",
TaskState::Failed => "FAILED",
TaskState::Timeout => "TIMEOUT",
TaskState::Cancelled => "CANCELLED",
TaskState::Cancelling => "CANCELLING",
TaskState::Running => "RUNNING",
TaskState::Queued => "QUEUED",
}})))
}
async fn blob(
State(state): State<AppState>,
headers: HeaderMap,
mut multipart: Multipart,
) -> impl IntoResponse {
let w = match worker_from_bearer(&state, &headers) {
Ok(w) => w,
Err(s) => return (s, "unauthenticated").into_response(),
};
let mut task_id: Option<String> = None;
let mut path: Option<String> = None;
let mut sha: Option<String> = None;
let mut file_bytes: Option<Vec<u8>> = None;
while let Ok(Some(field)) = multipart.next_field().await {
let name = field.name().unwrap_or("").to_string();
match name.as_str() {
"task_id" => task_id = field.text().await.ok(),
"path" => path = field.text().await.ok(),
"sha256" => sha = field.text().await.ok(),
"file" => {
let data = field.bytes().await.unwrap_or_default();
file_bytes = Some(data.to_vec());
}
_ => {}
}
}
let (Some(task_id), Some(path), Some(sha), Some(data)) = (task_id, path, sha, file_bytes)
else {
return (StatusCode::UNPROCESSABLE_ENTITY, "missing fields").into_response();
};
{
let conn = state.conn.lock().unwrap();
let Some(t) = tasks_repo::get_task(&conn, &task_id).ok().flatten() else {
return (StatusCode::NOT_FOUND, "unknown task").into_response();
};
if t.worker_name != w.name {
return (StatusCode::FORBIDDEN, "not owned").into_response();
}
}
let actual_sha = hex::encode(Sha256::digest(&data));
if actual_sha.to_lowercase() != sha.to_lowercase() {
return (StatusCode::UNPROCESSABLE_ENTITY, "sha256 mismatch").into_response();
}
if storage::store_artifact(&state.settings.blobs_dir(), &task_id, &path, &data).is_err() {
return (StatusCode::INTERNAL_SERVER_ERROR, "store failed").into_response();
}
let conn = state.conn.lock().unwrap();
let exists: Option<i64> = conn
.query_row(
"SELECT id FROM artifacts WHERE task_id=? AND path=? AND sha256=?",
rusqlite::params![task_id, path, actual_sha],
|r| r.get(0),
)
.ok();
if exists.is_none() {
let art = dragoon_proto::models::Artifact {
path: path.clone(),
size: data.len() as u64,
sha256: actual_sha.clone(),
};
let _ = tasks_repo::add_artifact(&conn, &task_id, &art, &path);
}
Json(json!({"ok": true, "size": data.len()})).into_response()
}