Skip to main content

ironflow_worker/
worker.rs

1//! Worker — polls the API for pending runs and executes them.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::spawn;
7use tokio::sync::Semaphore;
8use tokio::time::sleep;
9use tracing::{error, info, warn};
10
11#[cfg(feature = "prometheus")]
12use ironflow_core::metric_names::{WORKER_ACTIVE, WORKER_POLLS_TOTAL};
13use ironflow_core::provider::AgentProvider;
14use ironflow_engine::engine::Engine;
15use ironflow_engine::handler::WorkflowHandler;
16use ironflow_store::store::RunStore;
17#[cfg(feature = "prometheus")]
18use metrics::{counter, gauge};
19#[cfg(feature = "heartbeat")]
20use reqwest::Client;
21
22use crate::api_store::ApiRunStore;
23use crate::error::WorkerError;
24
25const DEFAULT_CONCURRENCY: usize = 2;
26const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(2);
27#[cfg(feature = "heartbeat")]
28const DEFAULT_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(30);
29
30/// Builder for configuring and creating a [`Worker`].
31///
32/// # Examples
33///
34/// ```no_run
35/// use std::sync::Arc;
36/// use std::time::Duration;
37/// use ironflow_worker::WorkerBuilder;
38/// use ironflow_core::providers::claude::ClaudeCodeProvider;
39///
40/// # async fn example() -> Result<(), ironflow_worker::WorkerError> {
41/// let worker = WorkerBuilder::new("http://localhost:3000", "my-token")
42///     .provider(Arc::new(ClaudeCodeProvider::new()))
43///     .concurrency(4)
44///     .poll_interval(Duration::from_secs(2))
45///     .build()?;
46///
47/// worker.run().await?;
48/// # Ok(())
49/// # }
50/// ```
51pub struct WorkerBuilder {
52    api_url: String,
53    worker_token: String,
54    provider: Option<Arc<dyn AgentProvider>>,
55    handlers: Vec<Box<dyn WorkflowHandler>>,
56    concurrency: usize,
57    poll_interval: Duration,
58    #[cfg(feature = "heartbeat")]
59    heartbeat_url: Option<String>,
60    #[cfg(feature = "heartbeat")]
61    heartbeat_interval: Duration,
62}
63
64impl WorkerBuilder {
65    /// Create a new builder targeting the given API server.
66    pub fn new(api_url: &str, worker_token: &str) -> Self {
67        Self {
68            api_url: api_url.to_string(),
69            worker_token: worker_token.to_string(),
70            provider: None,
71            handlers: Vec::new(),
72            concurrency: DEFAULT_CONCURRENCY,
73            poll_interval: DEFAULT_POLL_INTERVAL,
74            #[cfg(feature = "heartbeat")]
75            heartbeat_url: None,
76            #[cfg(feature = "heartbeat")]
77            heartbeat_interval: DEFAULT_HEARTBEAT_INTERVAL,
78        }
79    }
80
81    /// Set the agent provider for AI operations.
82    pub fn provider(mut self, provider: Arc<dyn AgentProvider>) -> Self {
83        self.provider = Some(provider);
84        self
85    }
86
87    /// Register a workflow handler.
88    pub fn register(mut self, handler: impl WorkflowHandler + 'static) -> Self {
89        self.handlers.push(Box::new(handler));
90        self
91    }
92
93    /// Set the maximum number of concurrent workflow executions.
94    pub fn concurrency(mut self, n: usize) -> Self {
95        self.concurrency = n;
96        self
97    }
98
99    /// Set the interval between polls for new runs.
100    pub fn poll_interval(mut self, interval: Duration) -> Self {
101        self.poll_interval = interval;
102        self
103    }
104
105    /// Set the heartbeat URL (dead man's switch).
106    ///
107    /// The worker pings this URL at every heartbeat interval with an HTTP
108    /// HEAD request. Compatible with BetterStack Heartbeats, Cronitor,
109    /// Healthchecks.io, or any dead man's switch service.
110    ///
111    /// If not set, no heartbeat is emitted even when the feature is enabled.
112    ///
113    /// # Examples
114    ///
115    /// ```no_run
116    /// use ironflow_worker::WorkerBuilder;
117    ///
118    /// # fn example() {
119    /// let builder = WorkerBuilder::new("http://localhost:3000", "token")
120    ///     .heartbeat_url("https://uptime.betterstack.com/api/v1/heartbeat/abc123");
121    /// # }
122    /// ```
123    #[cfg(feature = "heartbeat")]
124    pub fn heartbeat_url(mut self, url: &str) -> Self {
125        self.heartbeat_url = Some(url.to_string());
126        self
127    }
128
129    /// Set the heartbeat interval.
130    ///
131    /// Controls how often the worker pings the [`heartbeat_url`](Self::heartbeat_url).
132    /// Defaults to 30 seconds.
133    ///
134    /// # Examples
135    ///
136    /// ```no_run
137    /// use std::time::Duration;
138    /// use ironflow_worker::WorkerBuilder;
139    ///
140    /// # fn example() {
141    /// let builder = WorkerBuilder::new("http://localhost:3000", "token")
142    ///     .heartbeat_url("https://uptime.betterstack.com/api/v1/heartbeat/abc123")
143    ///     .heartbeat_interval(Duration::from_secs(60));
144    /// # }
145    /// ```
146    #[cfg(feature = "heartbeat")]
147    pub fn heartbeat_interval(mut self, interval: Duration) -> Self {
148        self.heartbeat_interval = interval;
149        self
150    }
151
152    /// Build the worker.
153    ///
154    /// # Errors
155    ///
156    /// Returns [`WorkerError::Internal`] if no provider has been set.
157    /// Returns [`WorkerError::Engine`] if a handler registration fails.
158    pub fn build(self) -> Result<Worker, WorkerError> {
159        let provider = self
160            .provider
161            .ok_or_else(|| WorkerError::Internal("WorkerBuilder: provider is required".into()))?;
162
163        let store: Arc<dyn RunStore> =
164            Arc::new(ApiRunStore::new(&self.api_url, &self.worker_token));
165
166        let mut engine = Engine::new(store, provider);
167        for handler in self.handlers {
168            engine
169                .register_boxed(handler)
170                .map_err(WorkerError::Engine)?;
171        }
172
173        #[cfg(feature = "heartbeat")]
174        let heartbeat_client = Client::builder()
175            .timeout(Duration::from_secs(5))
176            .build()
177            .expect("failed to build heartbeat HTTP client");
178
179        Ok(Worker {
180            engine: Arc::new(engine),
181            concurrency: self.concurrency,
182            poll_interval: self.poll_interval,
183            #[cfg(feature = "heartbeat")]
184            heartbeat_url: self.heartbeat_url,
185            #[cfg(feature = "heartbeat")]
186            heartbeat_interval: self.heartbeat_interval,
187            #[cfg(feature = "heartbeat")]
188            heartbeat_client,
189        })
190    }
191}
192
193/// Background worker that polls the API and executes workflows.
194pub struct Worker {
195    engine: Arc<Engine>,
196    concurrency: usize,
197    poll_interval: Duration,
198    #[cfg(feature = "heartbeat")]
199    heartbeat_url: Option<String>,
200    #[cfg(feature = "heartbeat")]
201    heartbeat_interval: Duration,
202    #[cfg(feature = "heartbeat")]
203    heartbeat_client: Client,
204}
205
206impl Worker {
207    /// Run the worker loop. Blocks until an error occurs or the process exits.
208    ///
209    /// # Errors
210    ///
211    /// Returns [`WorkerError`] if the polling loop encounters an unrecoverable error.
212    pub async fn run(&self) -> Result<(), WorkerError> {
213        let semaphore = Arc::new(Semaphore::new(self.concurrency));
214        let mut idle_streak = 0u32;
215
216        info!(
217            concurrency = self.concurrency,
218            poll_interval_ms = self.poll_interval.as_millis() as u64,
219            "worker started"
220        );
221
222        #[cfg(feature = "heartbeat")]
223        if let Some(ref url) = self.heartbeat_url {
224            let interval = self.heartbeat_interval;
225            let url = url.clone();
226            let client = self.heartbeat_client.clone();
227
228            spawn(async move {
229                let mut ticker = tokio::time::interval(interval);
230                // skip the first immediate tick
231                ticker.tick().await;
232                loop {
233                    ticker.tick().await;
234                    match client.head(&url).send().await {
235                        Ok(resp) if resp.status().is_success() => {
236                            info!(url = %url, "heartbeat sent");
237                        }
238                        Ok(resp) => {
239                            warn!(
240                                url = %url,
241                                status = %resp.status(),
242                                "heartbeat ping returned non-success status"
243                            );
244                        }
245                        Err(err) => {
246                            warn!(
247                                url = %url,
248                                error = %err,
249                                "heartbeat ping failed"
250                            );
251                        }
252                    }
253                }
254            });
255        }
256
257        loop {
258            let run = self.engine.store().pick_next_pending().await;
259
260            match run {
261                Ok(Some(run)) => {
262                    #[cfg(feature = "prometheus")]
263                    counter!(WORKER_POLLS_TOTAL, "result" => "hit").increment(1);
264
265                    let permit = semaphore
266                        .clone()
267                        .acquire_owned()
268                        .await
269                        .map_err(|_| WorkerError::Internal("semaphore closed".to_string()))?;
270
271                    idle_streak = 0;
272                    let engine = self.engine.clone();
273                    let run_id = run.id;
274                    let workflow = run.workflow_name.clone();
275
276                    info!(run_id = %run_id, workflow = %workflow, "executing run");
277
278                    #[cfg(feature = "prometheus")]
279                    gauge!(WORKER_ACTIVE).increment(1.0);
280
281                    let handle = spawn(async move {
282                        let _permit = permit;
283                        match engine.execute_handler_run(run_id).await {
284                            Ok(_) => {
285                                info!(run_id = %run_id, workflow = %workflow, "run completed");
286                            }
287                            Err(e) => {
288                                error!(run_id = %run_id, workflow = %workflow, error = %e, "run failed");
289                            }
290                        }
291                        #[cfg(feature = "prometheus")]
292                        gauge!(WORKER_ACTIVE).decrement(1.0);
293                    });
294
295                    // Spawn a watcher to catch panics and mark the run as failed
296                    let store = self.engine.store().clone();
297                    spawn(async move {
298                        if let Err(e) = handle.await {
299                            error!(run_id = %run_id, "spawned task panicked: {e}");
300                            if let Err(store_err) = store
301                                .update_run_status(
302                                    run_id,
303                                    ironflow_store::entities::RunStatus::Failed,
304                                )
305                                .await
306                            {
307                                error!(run_id = %run_id, error = %store_err, "failed to mark panicked run as failed");
308                            }
309                        }
310                    });
311                }
312                Ok(None) => {
313                    #[cfg(feature = "prometheus")]
314                    counter!(WORKER_POLLS_TOTAL, "result" => "miss").increment(1);
315
316                    idle_streak += 1;
317                    let backoff = if idle_streak > 10 {
318                        self.poll_interval * 3
319                    } else if idle_streak > 5 {
320                        self.poll_interval * 2
321                    } else {
322                        self.poll_interval
323                    };
324                    sleep(backoff).await;
325                }
326                Err(e) => {
327                    warn!(error = %e, "poll error");
328                    sleep(self.poll_interval).await;
329                }
330            }
331        }
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use ironflow_core::providers::claude::ClaudeCodeProvider;
339
340    #[test]
341    fn builder_new_creates_default_config() {
342        let builder = WorkerBuilder::new("http://localhost:3000", "my-token");
343        assert_eq!(builder.api_url, "http://localhost:3000");
344        assert_eq!(builder.worker_token, "my-token");
345        assert_eq!(builder.concurrency, DEFAULT_CONCURRENCY);
346        assert_eq!(builder.poll_interval, DEFAULT_POLL_INTERVAL);
347        assert!(builder.provider.is_none());
348    }
349
350    #[test]
351    fn builder_with_trailing_slash_normalized() {
352        let builder = WorkerBuilder::new("http://localhost:3000/", "token");
353        assert_eq!(builder.api_url, "http://localhost:3000/");
354    }
355
356    #[test]
357    fn builder_provider_sets_provider() {
358        let provider = Arc::new(ClaudeCodeProvider::new());
359        let builder =
360            WorkerBuilder::new("http://localhost:3000", "token").provider(provider.clone());
361        assert!(builder.provider.is_some());
362    }
363
364    #[test]
365    fn builder_concurrency_sets_concurrency() {
366        let builder = WorkerBuilder::new("http://localhost:3000", "token").concurrency(8);
367        assert_eq!(builder.concurrency, 8);
368    }
369
370    #[test]
371    fn builder_concurrency_zero_accepted() {
372        let provider = Arc::new(ClaudeCodeProvider::new());
373        let builder = WorkerBuilder::new("http://localhost:3000", "token")
374            .provider(provider)
375            .concurrency(0);
376        assert_eq!(builder.concurrency, 0);
377    }
378
379    #[test]
380    fn builder_poll_interval_sets_interval() {
381        let interval = Duration::from_secs(5);
382        let builder = WorkerBuilder::new("http://localhost:3000", "token").poll_interval(interval);
383        assert_eq!(builder.poll_interval, interval);
384    }
385
386    #[test]
387    fn builder_build_without_provider_fails() {
388        let builder = WorkerBuilder::new("http://localhost:3000", "token");
389        let result = builder.build();
390        assert!(result.is_err());
391        match result {
392            Err(WorkerError::Internal(msg)) => {
393                assert!(msg.contains("provider is required"));
394            }
395            _ => panic!("expected Internal error about missing provider"),
396        }
397    }
398
399    #[test]
400    fn builder_build_with_provider_succeeds() {
401        let provider = Arc::new(ClaudeCodeProvider::new());
402        let builder = WorkerBuilder::new("http://localhost:3000", "token").provider(provider);
403        let result = builder.build();
404        assert!(result.is_ok());
405    }
406
407    #[test]
408    fn builder_build_creates_worker_with_correct_concurrency() {
409        let provider = Arc::new(ClaudeCodeProvider::new());
410        let builder = WorkerBuilder::new("http://localhost:3000", "token")
411            .provider(provider)
412            .concurrency(16);
413        let worker = builder.build().unwrap();
414        assert_eq!(worker.concurrency, 16);
415    }
416
417    #[test]
418    fn builder_build_creates_worker_with_correct_interval() {
419        let provider = Arc::new(ClaudeCodeProvider::new());
420        let interval = Duration::from_secs(10);
421        let builder = WorkerBuilder::new("http://localhost:3000", "token")
422            .provider(provider)
423            .poll_interval(interval);
424        let worker = builder.build().unwrap();
425        assert_eq!(worker.poll_interval, interval);
426    }
427
428    #[test]
429    fn builder_chaining_works() {
430        let provider = Arc::new(ClaudeCodeProvider::new());
431        let result = WorkerBuilder::new("http://localhost:3000", "token")
432            .provider(provider)
433            .concurrency(4)
434            .poll_interval(Duration::from_secs(3))
435            .build();
436        assert!(result.is_ok());
437        let worker = result.unwrap();
438        assert_eq!(worker.concurrency, 4);
439        assert_eq!(worker.poll_interval, Duration::from_secs(3));
440    }
441
442    #[test]
443    fn builder_empty_api_url_accepted() {
444        let provider = Arc::new(ClaudeCodeProvider::new());
445        let builder = WorkerBuilder::new("", "token").provider(provider);
446        let result = builder.build();
447        assert!(result.is_ok());
448    }
449
450    #[test]
451    fn builder_empty_token_accepted() {
452        let provider = Arc::new(ClaudeCodeProvider::new());
453        let builder = WorkerBuilder::new("http://localhost:3000", "").provider(provider);
454        let result = builder.build();
455        assert!(result.is_ok());
456    }
457
458    #[cfg(feature = "heartbeat")]
459    #[test]
460    fn builder_heartbeat_defaults() {
461        let builder = WorkerBuilder::new("http://localhost:3000", "token");
462        assert!(builder.heartbeat_url.is_none());
463        assert_eq!(builder.heartbeat_interval, DEFAULT_HEARTBEAT_INTERVAL);
464    }
465
466    #[cfg(feature = "heartbeat")]
467    #[test]
468    fn builder_heartbeat_url_sets_url() {
469        let builder = WorkerBuilder::new("http://localhost:3000", "token")
470            .heartbeat_url("https://uptime.betterstack.com/api/v1/heartbeat/abc");
471        assert_eq!(
472            builder.heartbeat_url.as_deref(),
473            Some("https://uptime.betterstack.com/api/v1/heartbeat/abc")
474        );
475    }
476
477    #[cfg(feature = "heartbeat")]
478    #[test]
479    fn builder_heartbeat_custom_interval() {
480        let interval = Duration::from_secs(10);
481        let builder =
482            WorkerBuilder::new("http://localhost:3000", "token").heartbeat_interval(interval);
483        assert_eq!(builder.heartbeat_interval, interval);
484    }
485
486    #[cfg(feature = "heartbeat")]
487    #[test]
488    fn builder_build_preserves_heartbeat_config() {
489        let provider = Arc::new(ClaudeCodeProvider::new());
490        let interval = Duration::from_secs(15);
491        let worker = WorkerBuilder::new("http://localhost:3000", "token")
492            .provider(provider)
493            .heartbeat_url("https://example.com/heartbeat")
494            .heartbeat_interval(interval)
495            .build()
496            .unwrap();
497        assert_eq!(
498            worker.heartbeat_url.as_deref(),
499            Some("https://example.com/heartbeat")
500        );
501        assert_eq!(worker.heartbeat_interval, interval);
502    }
503
504    #[cfg(feature = "heartbeat")]
505    #[test]
506    fn builder_build_without_heartbeat_url_has_none() {
507        let provider = Arc::new(ClaudeCodeProvider::new());
508        let worker = WorkerBuilder::new("http://localhost:3000", "token")
509            .provider(provider)
510            .build()
511            .unwrap();
512        assert!(worker.heartbeat_url.is_none());
513    }
514}