1use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::spawn;
7use tokio::sync::Semaphore;
8use tokio::time::sleep;
9use tracing::{error, info, warn};
10
11#[cfg(feature = "prometheus")]
12use ironflow_core::metric_names::{WORKER_ACTIVE, WORKER_POLLS_TOTAL};
13use ironflow_core::provider::AgentProvider;
14use ironflow_engine::engine::Engine;
15use ironflow_engine::handler::WorkflowHandler;
16use ironflow_store::store::RunStore;
17#[cfg(feature = "prometheus")]
18use metrics::{counter, gauge};
19#[cfg(feature = "heartbeat")]
20use reqwest::Client;
21
22use crate::api_store::ApiRunStore;
23use crate::error::WorkerError;
24
25const DEFAULT_CONCURRENCY: usize = 2;
26const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(2);
27#[cfg(feature = "heartbeat")]
28const DEFAULT_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
29
30pub struct WorkerBuilder {
52 api_url: String,
53 worker_token: String,
54 provider: Option<Arc<dyn AgentProvider>>,
55 handlers: Vec<Box<dyn WorkflowHandler>>,
56 concurrency: usize,
57 poll_interval: Duration,
58 #[cfg(feature = "heartbeat")]
59 heartbeat_url: Option<String>,
60 #[cfg(feature = "heartbeat")]
61 heartbeat_interval: Duration,
62}
63
64impl WorkerBuilder {
65 pub fn new(api_url: &str, worker_token: &str) -> Self {
67 Self {
68 api_url: api_url.to_string(),
69 worker_token: worker_token.to_string(),
70 provider: None,
71 handlers: Vec::new(),
72 concurrency: DEFAULT_CONCURRENCY,
73 poll_interval: DEFAULT_POLL_INTERVAL,
74 #[cfg(feature = "heartbeat")]
75 heartbeat_url: None,
76 #[cfg(feature = "heartbeat")]
77 heartbeat_interval: DEFAULT_HEARTBEAT_INTERVAL,
78 }
79 }
80
81 pub fn provider(mut self, provider: Arc<dyn AgentProvider>) -> Self {
83 self.provider = Some(provider);
84 self
85 }
86
87 pub fn register(mut self, handler: impl WorkflowHandler + 'static) -> Self {
89 self.handlers.push(Box::new(handler));
90 self
91 }
92
93 pub fn concurrency(mut self, n: usize) -> Self {
95 self.concurrency = n;
96 self
97 }
98
99 pub fn poll_interval(mut self, interval: Duration) -> Self {
101 self.poll_interval = interval;
102 self
103 }
104
105 #[cfg(feature = "heartbeat")]
124 pub fn heartbeat_url(mut self, url: &str) -> Self {
125 self.heartbeat_url = Some(url.to_string());
126 self
127 }
128
129 #[cfg(feature = "heartbeat")]
147 pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
148 self.heartbeat_interval = interval;
149 self
150 }
151
152 pub fn build(self) -> Result<Worker, WorkerError> {
159 let provider = self
160 .provider
161 .ok_or_else(|| WorkerError::Internal("WorkerBuilder: provider is required".into()))?;
162
163 let store: Arc<dyn RunStore> =
164 Arc::new(ApiRunStore::new(&self.api_url, &self.worker_token));
165
166 let mut engine = Engine::new(store, provider);
167 for handler in self.handlers {
168 engine
169 .register_boxed(handler)
170 .map_err(WorkerError::Engine)?;
171 }
172
173 #[cfg(feature = "heartbeat")]
174 let heartbeat_client = Client::builder()
175 .timeout(Duration::from_secs(5))
176 .build()
177 .expect("failed to build heartbeat HTTP client");
178
179 Ok(Worker {
180 engine: Arc::new(engine),
181 concurrency: self.concurrency,
182 poll_interval: self.poll_interval,
183 #[cfg(feature = "heartbeat")]
184 heartbeat_url: self.heartbeat_url,
185 #[cfg(feature = "heartbeat")]
186 heartbeat_interval: self.heartbeat_interval,
187 #[cfg(feature = "heartbeat")]
188 heartbeat_client,
189 })
190 }
191}
192
193pub struct Worker {
195 engine: Arc<Engine>,
196 concurrency: usize,
197 poll_interval: Duration,
198 #[cfg(feature = "heartbeat")]
199 heartbeat_url: Option<String>,
200 #[cfg(feature = "heartbeat")]
201 heartbeat_interval: Duration,
202 #[cfg(feature = "heartbeat")]
203 heartbeat_client: Client,
204}
205
206impl Worker {
207 pub async fn run(&self) -> Result<(), WorkerError> {
213 let semaphore = Arc::new(Semaphore::new(self.concurrency));
214 let mut idle_streak = 0u32;
215
216 info!(
217 concurrency = self.concurrency,
218 poll_interval_ms = self.poll_interval.as_millis() as u64,
219 "worker started"
220 );
221
222 #[cfg(feature = "heartbeat")]
223 if let Some(ref url) = self.heartbeat_url {
224 let interval = self.heartbeat_interval;
225 let url = url.clone();
226 let client = self.heartbeat_client.clone();
227
228 spawn(async move {
229 let mut ticker = tokio::time::interval(interval);
230 ticker.tick().await;
232 loop {
233 ticker.tick().await;
234 match client.head(&url).send().await {
235 Ok(resp) if resp.status().is_success() => {
236 info!(url = %url, "heartbeat sent");
237 }
238 Ok(resp) => {
239 warn!(
240 url = %url,
241 status = %resp.status(),
242 "heartbeat ping returned non-success status"
243 );
244 }
245 Err(err) => {
246 warn!(
247 url = %url,
248 error = %err,
249 "heartbeat ping failed"
250 );
251 }
252 }
253 }
254 });
255 }
256
257 loop {
258 let run = self.engine.store().pick_next_pending().await;
259
260 match run {
261 Ok(Some(run)) => {
262 #[cfg(feature = "prometheus")]
263 counter!(WORKER_POLLS_TOTAL, "result" => "hit").increment(1);
264
265 let permit = semaphore
266 .clone()
267 .acquire_owned()
268 .await
269 .map_err(|_| WorkerError::Internal("semaphore closed".to_string()))?;
270
271 idle_streak = 0;
272 let engine = self.engine.clone();
273 let run_id = run.id;
274 let workflow = run.workflow_name.clone();
275
276 info!(run_id = %run_id, workflow = %workflow, "executing run");
277
278 #[cfg(feature = "prometheus")]
279 gauge!(WORKER_ACTIVE).increment(1.0);
280
281 let handle = spawn(async move {
282 let _permit = permit;
283 match engine.execute_handler_run(run_id).await {
284 Ok(_) => {
285 info!(run_id = %run_id, workflow = %workflow, "run completed");
286 }
287 Err(e) => {
288 error!(run_id = %run_id, workflow = %workflow, error = %e, "run failed");
289 }
290 }
291 #[cfg(feature = "prometheus")]
292 gauge!(WORKER_ACTIVE).decrement(1.0);
293 });
294
295 let store = self.engine.store().clone();
297 spawn(async move {
298 if let Err(e) = handle.await {
299 error!(run_id = %run_id, "spawned task panicked: {e}");
300 if let Err(store_err) = store
301 .update_run_status(
302 run_id,
303 ironflow_store::entities::RunStatus::Failed,
304 )
305 .await
306 {
307 error!(run_id = %run_id, error = %store_err, "failed to mark panicked run as failed");
308 }
309 }
310 });
311 }
312 Ok(None) => {
313 #[cfg(feature = "prometheus")]
314 counter!(WORKER_POLLS_TOTAL, "result" => "miss").increment(1);
315
316 idle_streak += 1;
317 let backoff = if idle_streak > 10 {
318 self.poll_interval * 3
319 } else if idle_streak > 5 {
320 self.poll_interval * 2
321 } else {
322 self.poll_interval
323 };
324 sleep(backoff).await;
325 }
326 Err(e) => {
327 warn!(error = %e, "poll error");
328 sleep(self.poll_interval).await;
329 }
330 }
331 }
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338 use ironflow_core::providers::claude::ClaudeCodeProvider;
339
340 #[test]
341 fn builder_new_creates_default_config() {
342 let builder = WorkerBuilder::new("http://localhost:3000", "my-token");
343 assert_eq!(builder.api_url, "http://localhost:3000");
344 assert_eq!(builder.worker_token, "my-token");
345 assert_eq!(builder.concurrency, DEFAULT_CONCURRENCY);
346 assert_eq!(builder.poll_interval, DEFAULT_POLL_INTERVAL);
347 assert!(builder.provider.is_none());
348 }
349
350 #[test]
351 fn builder_with_trailing_slash_normalized() {
352 let builder = WorkerBuilder::new("http://localhost:3000/", "token");
353 assert_eq!(builder.api_url, "http://localhost:3000/");
354 }
355
356 #[test]
357 fn builder_provider_sets_provider() {
358 let provider = Arc::new(ClaudeCodeProvider::new());
359 let builder =
360 WorkerBuilder::new("http://localhost:3000", "token").provider(provider.clone());
361 assert!(builder.provider.is_some());
362 }
363
364 #[test]
365 fn builder_concurrency_sets_concurrency() {
366 let builder = WorkerBuilder::new("http://localhost:3000", "token").concurrency(8);
367 assert_eq!(builder.concurrency, 8);
368 }
369
370 #[test]
371 fn builder_concurrency_zero_accepted() {
372 let provider = Arc::new(ClaudeCodeProvider::new());
373 let builder = WorkerBuilder::new("http://localhost:3000", "token")
374 .provider(provider)
375 .concurrency(0);
376 assert_eq!(builder.concurrency, 0);
377 }
378
379 #[test]
380 fn builder_poll_interval_sets_interval() {
381 let interval = Duration::from_secs(5);
382 let builder = WorkerBuilder::new("http://localhost:3000", "token").poll_interval(interval);
383 assert_eq!(builder.poll_interval, interval);
384 }
385
386 #[test]
387 fn builder_build_without_provider_fails() {
388 let builder = WorkerBuilder::new("http://localhost:3000", "token");
389 let result = builder.build();
390 assert!(result.is_err());
391 match result {
392 Err(WorkerError::Internal(msg)) => {
393 assert!(msg.contains("provider is required"));
394 }
395 _ => panic!("expected Internal error about missing provider"),
396 }
397 }
398
399 #[test]
400 fn builder_build_with_provider_succeeds() {
401 let provider = Arc::new(ClaudeCodeProvider::new());
402 let builder = WorkerBuilder::new("http://localhost:3000", "token").provider(provider);
403 let result = builder.build();
404 assert!(result.is_ok());
405 }
406
407 #[test]
408 fn builder_build_creates_worker_with_correct_concurrency() {
409 let provider = Arc::new(ClaudeCodeProvider::new());
410 let builder = WorkerBuilder::new("http://localhost:3000", "token")
411 .provider(provider)
412 .concurrency(16);
413 let worker = builder.build().unwrap();
414 assert_eq!(worker.concurrency, 16);
415 }
416
417 #[test]
418 fn builder_build_creates_worker_with_correct_interval() {
419 let provider = Arc::new(ClaudeCodeProvider::new());
420 let interval = Duration::from_secs(10);
421 let builder = WorkerBuilder::new("http://localhost:3000", "token")
422 .provider(provider)
423 .poll_interval(interval);
424 let worker = builder.build().unwrap();
425 assert_eq!(worker.poll_interval, interval);
426 }
427
428 #[test]
429 fn builder_chaining_works() {
430 let provider = Arc::new(ClaudeCodeProvider::new());
431 let result = WorkerBuilder::new("http://localhost:3000", "token")
432 .provider(provider)
433 .concurrency(4)
434 .poll_interval(Duration::from_secs(3))
435 .build();
436 assert!(result.is_ok());
437 let worker = result.unwrap();
438 assert_eq!(worker.concurrency, 4);
439 assert_eq!(worker.poll_interval, Duration::from_secs(3));
440 }
441
442 #[test]
443 fn builder_empty_api_url_accepted() {
444 let provider = Arc::new(ClaudeCodeProvider::new());
445 let builder = WorkerBuilder::new("", "token").provider(provider);
446 let result = builder.build();
447 assert!(result.is_ok());
448 }
449
450 #[test]
451 fn builder_empty_token_accepted() {
452 let provider = Arc::new(ClaudeCodeProvider::new());
453 let builder = WorkerBuilder::new("http://localhost:3000", "").provider(provider);
454 let result = builder.build();
455 assert!(result.is_ok());
456 }
457
458 #[cfg(feature = "heartbeat")]
459 #[test]
460 fn builder_heartbeat_defaults() {
461 let builder = WorkerBuilder::new("http://localhost:3000", "token");
462 assert!(builder.heartbeat_url.is_none());
463 assert_eq!(builder.heartbeat_interval, DEFAULT_HEARTBEAT_INTERVAL);
464 }
465
466 #[cfg(feature = "heartbeat")]
467 #[test]
468 fn builder_heartbeat_url_sets_url() {
469 let builder = WorkerBuilder::new("http://localhost:3000", "token")
470 .heartbeat_url("https://uptime.betterstack.com/api/v1/heartbeat/abc");
471 assert_eq!(
472 builder.heartbeat_url.as_deref(),
473 Some("https://uptime.betterstack.com/api/v1/heartbeat/abc")
474 );
475 }
476
477 #[cfg(feature = "heartbeat")]
478 #[test]
479 fn builder_heartbeat_custom_interval() {
480 let interval = Duration::from_secs(10);
481 let builder =
482 WorkerBuilder::new("http://localhost:3000", "token").heartbeat_interval(interval);
483 assert_eq!(builder.heartbeat_interval, interval);
484 }
485
486 #[cfg(feature = "heartbeat")]
487 #[test]
488 fn builder_build_preserves_heartbeat_config() {
489 let provider = Arc::new(ClaudeCodeProvider::new());
490 let interval = Duration::from_secs(15);
491 let worker = WorkerBuilder::new("http://localhost:3000", "token")
492 .provider(provider)
493 .heartbeat_url("https://example.com/heartbeat")
494 .heartbeat_interval(interval)
495 .build()
496 .unwrap();
497 assert_eq!(
498 worker.heartbeat_url.as_deref(),
499 Some("https://example.com/heartbeat")
500 );
501 assert_eq!(worker.heartbeat_interval, interval);
502 }
503
504 #[cfg(feature = "heartbeat")]
505 #[test]
506 fn builder_build_without_heartbeat_url_has_none() {
507 let provider = Arc::new(ClaudeCodeProvider::new());
508 let worker = WorkerBuilder::new("http://localhost:3000", "token")
509 .provider(provider)
510 .build()
511 .unwrap();
512 assert!(worker.heartbeat_url.is_none());
513 }
514}