1use std::sync::Arc;
2
3use axum::extract::{Path, State};
4use axum::routing::post;
5use axum::{Json, Router};
6use serde::{Deserialize, Serialize};
7use utoipa::ToSchema;
8
9use crate::api::workflows::AppError;
10use crate::api::AppState;
11use crate::store::WorkflowStore;
12use crate::types::WorkflowWorker;
13
14pub fn router<S: WorkflowStore + 'static>() -> Router<Arc<AppState<S>>> {
15 Router::new()
16 .route("/workers/register", post(register_worker))
17 .route("/workers/heartbeat", post(worker_heartbeat))
18 .route("/tasks/poll", post(poll_task))
19 .route("/tasks/{id}/complete", post(complete_task))
20 .route("/tasks/{id}/fail", post(fail_task))
21 .route("/tasks/{id}/heartbeat", post(heartbeat_task))
22}
23
24#[derive(Deserialize, ToSchema)]
25pub struct RegisterWorkerRequest {
26 #[serde(default = "default_namespace")]
28 pub namespace: String,
29 pub identity: String,
31 pub queue: String,
33 pub workflows: Option<Vec<String>>,
35 pub activities: Option<Vec<String>>,
37 #[serde(default = "default_concurrent")]
38 pub max_concurrent_workflows: i32,
39 #[serde(default = "default_concurrent")]
40 pub max_concurrent_activities: i32,
41}
42
43fn default_namespace() -> String {
44 "main".to_string()
45}
46
47fn default_concurrent() -> i32 {
48 10
49}
50
51#[derive(Serialize, ToSchema)]
52pub struct RegisterWorkerResponse {
53 pub worker_id: String,
55}
56
57#[utoipa::path(
58 post, path = "/api/v1/workers/register",
59 tag = "tasks",
60 request_body = RegisterWorkerRequest,
61 responses(
62 (status = 200, description = "Worker registered", body = RegisterWorkerResponse),
63 ),
64)]
65pub async fn register_worker<S: WorkflowStore>(
66 State(state): State<Arc<AppState<S>>>,
67 Json(req): Json<RegisterWorkerRequest>,
68) -> Result<Json<RegisterWorkerResponse>, AppError> {
69 let now = timestamp_now();
70 let worker_id = format!("w-{}", &uuid_short());
71
72 let worker = WorkflowWorker {
73 id: worker_id.clone(),
74 namespace: req.namespace,
75 identity: req.identity,
76 task_queue: req.queue,
77 workflows: req.workflows.map(|v| serde_json::to_string(&v).unwrap()),
78 activities: req.activities.map(|v| serde_json::to_string(&v).unwrap()),
79 max_concurrent_workflows: req.max_concurrent_workflows,
80 max_concurrent_activities: req.max_concurrent_activities,
81 active_tasks: 0,
82 last_heartbeat: now,
83 registered_at: now,
84 };
85
86 state.engine.register_worker(&worker).await?;
87 Ok(Json(RegisterWorkerResponse { worker_id }))
88}
89
90#[derive(Deserialize, ToSchema)]
91pub struct HeartbeatRequest {
92 pub worker_id: String,
93}
94
95#[utoipa::path(
96 post, path = "/api/v1/workers/heartbeat",
97 tag = "tasks",
98 responses((status = 200, description = "Heartbeat recorded")),
99)]
100pub async fn worker_heartbeat<S: WorkflowStore>(
101 State(state): State<Arc<AppState<S>>>,
102 Json(req): Json<HeartbeatRequest>,
103) -> Result<axum::http::StatusCode, AppError> {
104 state.engine.heartbeat_worker(&req.worker_id).await?;
105 Ok(axum::http::StatusCode::OK)
106}
107
108#[derive(Deserialize, ToSchema)]
109pub struct PollRequest {
110 pub queue: String,
112 pub worker_id: String,
114}
115
116#[utoipa::path(
117 post, path = "/api/v1/tasks/poll",
118 tag = "tasks",
119 request_body = PollRequest,
120 responses(
121 (status = 200, description = "Activity task (or null if none available)", body = WorkflowActivity),
122 ),
123)]
124pub async fn poll_task<S: WorkflowStore>(
125 State(state): State<Arc<AppState<S>>>,
126 Json(req): Json<PollRequest>,
127) -> Result<Json<serde_json::Value>, AppError> {
128 let activity = state
129 .engine
130 .claim_activity(&req.queue, &req.worker_id)
131 .await?;
132
133 match activity {
134 Some(act) => Ok(Json(serde_json::to_value(act)?)),
135 None => Ok(Json(serde_json::json!({ "task": null }))),
136 }
137}
138
139#[derive(Deserialize, ToSchema)]
140pub struct CompleteTaskBody {
141 pub result: Option<serde_json::Value>,
143}
144
145#[utoipa::path(
146 post, path = "/api/v1/tasks/{id}/complete",
147 tag = "tasks",
148 params(("id" = i64, Path, description = "Activity task ID")),
149 request_body = CompleteTaskBody,
150 responses((status = 200, description = "Task completed")),
151)]
152pub async fn complete_task<S: WorkflowStore>(
153 State(state): State<Arc<AppState<S>>>,
154 Path(id): Path<i64>,
155 Json(body): Json<CompleteTaskBody>,
156) -> Result<axum::http::StatusCode, AppError> {
157 let result = body.result.map(|v| v.to_string());
158 state
159 .engine
160 .complete_activity(id, result.as_deref(), None, false)
161 .await?;
162 Ok(axum::http::StatusCode::OK)
163}
164
165#[derive(Deserialize, ToSchema)]
166pub struct FailTaskBody {
167 pub error: String,
169}
170
171#[utoipa::path(
172 post, path = "/api/v1/tasks/{id}/fail",
173 tag = "tasks",
174 params(("id" = i64, Path, description = "Activity task ID")),
175 request_body = FailTaskBody,
176 responses((status = 200, description = "Task marked as failed")),
177)]
178pub async fn fail_task<S: WorkflowStore>(
179 State(state): State<Arc<AppState<S>>>,
180 Path(id): Path<i64>,
181 Json(body): Json<FailTaskBody>,
182) -> Result<axum::http::StatusCode, AppError> {
183 state.engine.fail_activity(id, &body.error).await?;
187 Ok(axum::http::StatusCode::OK)
188}
189
190#[derive(Deserialize, ToSchema)]
191pub struct HeartbeatTaskBody {
192 pub details: Option<String>,
193}
194
195#[utoipa::path(
196 post, path = "/api/v1/tasks/{id}/heartbeat",
197 tag = "tasks",
198 params(("id" = i64, Path, description = "Activity task ID")),
199 responses((status = 200, description = "Heartbeat recorded")),
200)]
201pub async fn heartbeat_task<S: WorkflowStore>(
202 State(state): State<Arc<AppState<S>>>,
203 Path(id): Path<i64>,
204 Json(body): Json<HeartbeatTaskBody>,
205) -> Result<axum::http::StatusCode, AppError> {
206 state
207 .engine
208 .heartbeat_activity(id, body.details.as_deref())
209 .await?;
210 Ok(axum::http::StatusCode::OK)
211}
212
213fn timestamp_now() -> f64 {
214 std::time::SystemTime::now()
215 .duration_since(std::time::UNIX_EPOCH)
216 .unwrap()
217 .as_secs_f64()
218}
219
220fn uuid_short() -> String {
221 use std::collections::hash_map::DefaultHasher;
222 use std::hash::{Hash, Hasher};
223 let mut h = DefaultHasher::new();
224 std::time::SystemTime::now().hash(&mut h);
225 std::thread::current().id().hash(&mut h);
226 format!("{:016x}", h.finish())
227}
228
229use crate::types::WorkflowActivity;