Skip to main content

ironflow_worker/
worker.rs

1//! Worker — polls the API for pending runs and executes them.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::Semaphore;
7use tracing::{error, info, warn};
8
9use ironflow_core::provider::AgentProvider;
10use ironflow_engine::engine::Engine;
11use ironflow_engine::handler::WorkflowHandler;
12use ironflow_store::store::RunStore;
13
14use crate::api_store::ApiRunStore;
15use crate::error::WorkerError;
16
17const DEFAULT_CONCURRENCY: usize = 2;
18const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(2);
19
20/// Builder for configuring and creating a [`Worker`].
21///
22/// # Examples
23///
24/// ```no_run
25/// use std::sync::Arc;
26/// use std::time::Duration;
27/// use ironflow_worker::WorkerBuilder;
28/// use ironflow_core::providers::claude::ClaudeCodeProvider;
29///
30/// # async fn example() -> Result<(), ironflow_worker::WorkerError> {
31/// let worker = WorkerBuilder::new("http://localhost:3000", "my-token")
32///     .provider(Arc::new(ClaudeCodeProvider::new()))
33///     .concurrency(4)
34///     .poll_interval(Duration::from_secs(2))
35///     .build()?;
36///
37/// worker.run().await?;
38/// # Ok(())
39/// # }
40/// ```
41pub struct WorkerBuilder {
42    api_url: String,
43    worker_token: String,
44    provider: Option<Arc<dyn AgentProvider>>,
45    handlers: Vec<Box<dyn WorkflowHandler>>,
46    concurrency: usize,
47    poll_interval: Duration,
48}
49
50impl WorkerBuilder {
51    /// Create a new builder targeting the given API server.
52    pub fn new(api_url: &str, worker_token: &str) -> Self {
53        Self {
54            api_url: api_url.to_string(),
55            worker_token: worker_token.to_string(),
56            provider: None,
57            handlers: Vec::new(),
58            concurrency: DEFAULT_CONCURRENCY,
59            poll_interval: DEFAULT_POLL_INTERVAL,
60        }
61    }
62
63    /// Set the agent provider for AI operations.
64    pub fn provider(mut self, provider: Arc<dyn AgentProvider>) -> Self {
65        self.provider = Some(provider);
66        self
67    }
68
69    /// Register a workflow handler.
70    pub fn register(mut self, handler: impl WorkflowHandler + 'static) -> Self {
71        self.handlers.push(Box::new(handler));
72        self
73    }
74
75    /// Set the maximum number of concurrent workflow executions.
76    pub fn concurrency(mut self, n: usize) -> Self {
77        self.concurrency = n;
78        self
79    }
80
81    /// Set the interval between polls for new runs.
82    pub fn poll_interval(mut self, interval: Duration) -> Self {
83        self.poll_interval = interval;
84        self
85    }
86
87    /// Build the worker.
88    ///
89    /// # Errors
90    ///
91    /// Returns [`WorkerError::Internal`] if no provider has been set.
92    /// Returns [`WorkerError::Engine`] if a handler registration fails.
93    pub fn build(self) -> Result<Worker, WorkerError> {
94        let provider = self
95            .provider
96            .ok_or_else(|| WorkerError::Internal("WorkerBuilder: provider is required".into()))?;
97
98        let store: Arc<dyn RunStore> =
99            Arc::new(ApiRunStore::new(&self.api_url, &self.worker_token));
100
101        let mut engine = Engine::new(store, provider);
102        for handler in self.handlers {
103            engine
104                .register_boxed(handler)
105                .map_err(WorkerError::Engine)?;
106        }
107
108        Ok(Worker {
109            engine: Arc::new(engine),
110            concurrency: self.concurrency,
111            poll_interval: self.poll_interval,
112        })
113    }
114}
115
116/// Background worker that polls the API and executes workflows.
117pub struct Worker {
118    engine: Arc<Engine>,
119    concurrency: usize,
120    poll_interval: Duration,
121}
122
123impl Worker {
124    /// Run the worker loop. Blocks until an error occurs or the process exits.
125    ///
126    /// # Errors
127    ///
128    /// Returns [`WorkerError`] if the polling loop encounters an unrecoverable error.
129    pub async fn run(&self) -> Result<(), WorkerError> {
130        let semaphore = Arc::new(Semaphore::new(self.concurrency));
131        let mut idle_streak = 0u32;
132
133        info!(
134            concurrency = self.concurrency,
135            poll_interval_ms = self.poll_interval.as_millis() as u64,
136            "worker started"
137        );
138
139        loop {
140            let run = self.engine.store().pick_next_pending().await;
141
142            match run {
143                Ok(Some(run)) => {
144                    let permit = semaphore
145                        .clone()
146                        .acquire_owned()
147                        .await
148                        .map_err(|_| WorkerError::Internal("semaphore closed".to_string()))?;
149
150                    idle_streak = 0;
151                    let engine = self.engine.clone();
152                    let run_id = run.id;
153                    let workflow = run.workflow_name.clone();
154
155                    info!(run_id = %run_id, workflow = %workflow, "executing run");
156
157                    let handle = tokio::spawn(async move {
158                        let _permit = permit;
159                        match engine.execute_handler_run(run_id).await {
160                            Ok(_) => {
161                                info!(run_id = %run_id, workflow = %workflow, "run completed");
162                            }
163                            Err(e) => {
164                                error!(run_id = %run_id, workflow = %workflow, error = %e, "run failed");
165                            }
166                        }
167                    });
168
169                    // Spawn a watcher to catch panics and mark the run as failed
170                    let store = self.engine.store().clone();
171                    tokio::spawn(async move {
172                        if let Err(e) = handle.await {
173                            error!(run_id = %run_id, "spawned task panicked: {e}");
174                            if let Err(store_err) = store
175                                .update_run_status(
176                                    run_id,
177                                    ironflow_store::entities::RunStatus::Failed,
178                                )
179                                .await
180                            {
181                                error!(run_id = %run_id, error = %store_err, "failed to mark panicked run as failed");
182                            }
183                        }
184                    });
185                }
186                Ok(None) => {
187                    idle_streak += 1;
188                    let backoff = if idle_streak > 10 {
189                        self.poll_interval * 3
190                    } else if idle_streak > 5 {
191                        self.poll_interval * 2
192                    } else {
193                        self.poll_interval
194                    };
195                    tokio::time::sleep(backoff).await;
196                }
197                Err(e) => {
198                    warn!(error = %e, "poll error");
199                    tokio::time::sleep(self.poll_interval).await;
200                }
201            }
202        }
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use ironflow_core::providers::claude::ClaudeCodeProvider;
210
211    #[test]
212    fn builder_new_creates_default_config() {
213        let builder = WorkerBuilder::new("http://localhost:3000", "my-token");
214        assert_eq!(builder.api_url, "http://localhost:3000");
215        assert_eq!(builder.worker_token, "my-token");
216        assert_eq!(builder.concurrency, DEFAULT_CONCURRENCY);
217        assert_eq!(builder.poll_interval, DEFAULT_POLL_INTERVAL);
218        assert!(builder.provider.is_none());
219    }
220
221    #[test]
222    fn builder_with_trailing_slash_normalized() {
223        let builder = WorkerBuilder::new("http://localhost:3000/", "token");
224        assert_eq!(builder.api_url, "http://localhost:3000/");
225    }
226
227    #[test]
228    fn builder_provider_sets_provider() {
229        let provider = Arc::new(ClaudeCodeProvider::new());
230        let builder =
231            WorkerBuilder::new("http://localhost:3000", "token").provider(provider.clone());
232        assert!(builder.provider.is_some());
233    }
234
235    #[test]
236    fn builder_concurrency_sets_concurrency() {
237        let builder = WorkerBuilder::new("http://localhost:3000", "token").concurrency(8);
238        assert_eq!(builder.concurrency, 8);
239    }
240
241    #[test]
242    fn builder_concurrency_zero_accepted() {
243        let provider = Arc::new(ClaudeCodeProvider::new());
244        let builder = WorkerBuilder::new("http://localhost:3000", "token")
245            .provider(provider)
246            .concurrency(0);
247        assert_eq!(builder.concurrency, 0);
248    }
249
250    #[test]
251    fn builder_poll_interval_sets_interval() {
252        let interval = Duration::from_secs(5);
253        let builder = WorkerBuilder::new("http://localhost:3000", "token").poll_interval(interval);
254        assert_eq!(builder.poll_interval, interval);
255    }
256
257    #[test]
258    fn builder_build_without_provider_fails() {
259        let builder = WorkerBuilder::new("http://localhost:3000", "token");
260        let result = builder.build();
261        assert!(result.is_err());
262        match result {
263            Err(WorkerError::Internal(msg)) => {
264                assert!(msg.contains("provider is required"));
265            }
266            _ => panic!("expected Internal error about missing provider"),
267        }
268    }
269
270    #[test]
271    fn builder_build_with_provider_succeeds() {
272        let provider = Arc::new(ClaudeCodeProvider::new());
273        let builder = WorkerBuilder::new("http://localhost:3000", "token").provider(provider);
274        let result = builder.build();
275        assert!(result.is_ok());
276    }
277
278    #[test]
279    fn builder_build_creates_worker_with_correct_concurrency() {
280        let provider = Arc::new(ClaudeCodeProvider::new());
281        let builder = WorkerBuilder::new("http://localhost:3000", "token")
282            .provider(provider)
283            .concurrency(16);
284        let worker = builder.build().unwrap();
285        assert_eq!(worker.concurrency, 16);
286    }
287
288    #[test]
289    fn builder_build_creates_worker_with_correct_interval() {
290        let provider = Arc::new(ClaudeCodeProvider::new());
291        let interval = Duration::from_secs(10);
292        let builder = WorkerBuilder::new("http://localhost:3000", "token")
293            .provider(provider)
294            .poll_interval(interval);
295        let worker = builder.build().unwrap();
296        assert_eq!(worker.poll_interval, interval);
297    }
298
299    #[test]
300    fn builder_chaining_works() {
301        let provider = Arc::new(ClaudeCodeProvider::new());
302        let result = WorkerBuilder::new("http://localhost:3000", "token")
303            .provider(provider)
304            .concurrency(4)
305            .poll_interval(Duration::from_secs(3))
306            .build();
307        assert!(result.is_ok());
308        let worker = result.unwrap();
309        assert_eq!(worker.concurrency, 4);
310        assert_eq!(worker.poll_interval, Duration::from_secs(3));
311    }
312
313    #[test]
314    fn builder_empty_api_url_accepted() {
315        let provider = Arc::new(ClaudeCodeProvider::new());
316        let builder = WorkerBuilder::new("", "token").provider(provider);
317        let result = builder.build();
318        assert!(result.is_ok());
319    }
320
321    #[test]
322    fn builder_empty_token_accepted() {
323        let provider = Arc::new(ClaudeCodeProvider::new());
324        let builder = WorkerBuilder::new("http://localhost:3000", "").provider(provider);
325        let result = builder.build();
326        assert!(result.is_ok());
327    }
328}