Skip to main content

ironflow_worker/
worker.rs

1//! Worker — polls the API for pending runs and executes them.
2
3use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::spawn;
7use tokio::sync::Semaphore;
8use tokio::time::sleep;
9use tracing::{error, info, warn};
10
11use ironflow_core::provider::AgentProvider;
12use ironflow_engine::engine::Engine;
13use ironflow_engine::handler::WorkflowHandler;
14use ironflow_store::store::RunStore;
15
16use crate::api_store::ApiRunStore;
17use crate::error::WorkerError;
18
19const DEFAULT_CONCURRENCY: usize = 2;
20const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(2);
21
22/// Builder for configuring and creating a [`Worker`].
23///
24/// # Examples
25///
26/// ```no_run
27/// use std::sync::Arc;
28/// use std::time::Duration;
29/// use ironflow_worker::WorkerBuilder;
30/// use ironflow_core::providers::claude::ClaudeCodeProvider;
31///
32/// # async fn example() -> Result<(), ironflow_worker::WorkerError> {
33/// let worker = WorkerBuilder::new("http://localhost:3000", "my-token")
34///     .provider(Arc::new(ClaudeCodeProvider::new()))
35///     .concurrency(4)
36///     .poll_interval(Duration::from_secs(2))
37///     .build()?;
38///
39/// worker.run().await?;
40/// # Ok(())
41/// # }
42/// ```
43pub struct WorkerBuilder {
44    api_url: String,
45    worker_token: String,
46    provider: Option<Arc<dyn AgentProvider>>,
47    handlers: Vec<Box<dyn WorkflowHandler>>,
48    concurrency: usize,
49    poll_interval: Duration,
50}
51
52impl WorkerBuilder {
53    /// Create a new builder targeting the given API server.
54    pub fn new(api_url: &str, worker_token: &str) -> Self {
55        Self {
56            api_url: api_url.to_string(),
57            worker_token: worker_token.to_string(),
58            provider: None,
59            handlers: Vec::new(),
60            concurrency: DEFAULT_CONCURRENCY,
61            poll_interval: DEFAULT_POLL_INTERVAL,
62        }
63    }
64
65    /// Set the agent provider for AI operations.
66    pub fn provider(mut self, provider: Arc<dyn AgentProvider>) -> Self {
67        self.provider = Some(provider);
68        self
69    }
70
71    /// Register a workflow handler.
72    pub fn register(mut self, handler: impl WorkflowHandler + 'static) -> Self {
73        self.handlers.push(Box::new(handler));
74        self
75    }
76
77    /// Set the maximum number of concurrent workflow executions.
78    pub fn concurrency(mut self, n: usize) -> Self {
79        self.concurrency = n;
80        self
81    }
82
83    /// Set the interval between polls for new runs.
84    pub fn poll_interval(mut self, interval: Duration) -> Self {
85        self.poll_interval = interval;
86        self
87    }
88
89    /// Build the worker.
90    ///
91    /// # Errors
92    ///
93    /// Returns [`WorkerError::Internal`] if no provider has been set.
94    /// Returns [`WorkerError::Engine`] if a handler registration fails.
95    pub fn build(self) -> Result<Worker, WorkerError> {
96        let provider = self
97            .provider
98            .ok_or_else(|| WorkerError::Internal("WorkerBuilder: provider is required".into()))?;
99
100        let store: Arc<dyn RunStore> =
101            Arc::new(ApiRunStore::new(&self.api_url, &self.worker_token));
102
103        let mut engine = Engine::new(store, provider);
104        for handler in self.handlers {
105            engine
106                .register_boxed(handler)
107                .map_err(WorkerError::Engine)?;
108        }
109
110        Ok(Worker {
111            engine: Arc::new(engine),
112            concurrency: self.concurrency,
113            poll_interval: self.poll_interval,
114        })
115    }
116}
117
118/// Background worker that polls the API and executes workflows.
119pub struct Worker {
120    engine: Arc<Engine>,
121    concurrency: usize,
122    poll_interval: Duration,
123}
124
125impl Worker {
126    /// Run the worker loop. Blocks until an error occurs or the process exits.
127    ///
128    /// # Errors
129    ///
130    /// Returns [`WorkerError`] if the polling loop encounters an unrecoverable error.
131    pub async fn run(&self) -> Result<(), WorkerError> {
132        let semaphore = Arc::new(Semaphore::new(self.concurrency));
133        let mut idle_streak = 0u32;
134
135        info!(
136            concurrency = self.concurrency,
137            poll_interval_ms = self.poll_interval.as_millis() as u64,
138            "worker started"
139        );
140
141        loop {
142            let run = self.engine.store().pick_next_pending().await;
143
144            match run {
145                Ok(Some(run)) => {
146                    let permit = semaphore
147                        .clone()
148                        .acquire_owned()
149                        .await
150                        .map_err(|_| WorkerError::Internal("semaphore closed".to_string()))?;
151
152                    idle_streak = 0;
153                    let engine = self.engine.clone();
154                    let run_id = run.id;
155                    let workflow = run.workflow_name.clone();
156
157                    info!(run_id = %run_id, workflow = %workflow, "executing run");
158
159                    let handle = spawn(async move {
160                        let _permit = permit;
161                        match engine.execute_handler_run(run_id).await {
162                            Ok(_) => {
163                                info!(run_id = %run_id, workflow = %workflow, "run completed");
164                            }
165                            Err(e) => {
166                                error!(run_id = %run_id, workflow = %workflow, error = %e, "run failed");
167                            }
168                        }
169                    });
170
171                    // Spawn a watcher to catch panics and mark the run as failed
172                    let store = self.engine.store().clone();
173                    spawn(async move {
174                        if let Err(e) = handle.await {
175                            error!(run_id = %run_id, "spawned task panicked: {e}");
176                            if let Err(store_err) = store
177                                .update_run_status(
178                                    run_id,
179                                    ironflow_store::entities::RunStatus::Failed,
180                                )
181                                .await
182                            {
183                                error!(run_id = %run_id, error = %store_err, "failed to mark panicked run as failed");
184                            }
185                        }
186                    });
187                }
188                Ok(None) => {
189                    idle_streak += 1;
190                    let backoff = if idle_streak > 10 {
191                        self.poll_interval * 3
192                    } else if idle_streak > 5 {
193                        self.poll_interval * 2
194                    } else {
195                        self.poll_interval
196                    };
197                    sleep(backoff).await;
198                }
199                Err(e) => {
200                    warn!(error = %e, "poll error");
201                    sleep(self.poll_interval).await;
202                }
203            }
204        }
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use ironflow_core::providers::claude::ClaudeCodeProvider;
212
213    #[test]
214    fn builder_new_creates_default_config() {
215        let builder = WorkerBuilder::new("http://localhost:3000", "my-token");
216        assert_eq!(builder.api_url, "http://localhost:3000");
217        assert_eq!(builder.worker_token, "my-token");
218        assert_eq!(builder.concurrency, DEFAULT_CONCURRENCY);
219        assert_eq!(builder.poll_interval, DEFAULT_POLL_INTERVAL);
220        assert!(builder.provider.is_none());
221    }
222
223    #[test]
224    fn builder_with_trailing_slash_normalized() {
225        let builder = WorkerBuilder::new("http://localhost:3000/", "token");
226        assert_eq!(builder.api_url, "http://localhost:3000/");
227    }
228
229    #[test]
230    fn builder_provider_sets_provider() {
231        let provider = Arc::new(ClaudeCodeProvider::new());
232        let builder =
233            WorkerBuilder::new("http://localhost:3000", "token").provider(provider.clone());
234        assert!(builder.provider.is_some());
235    }
236
237    #[test]
238    fn builder_concurrency_sets_concurrency() {
239        let builder = WorkerBuilder::new("http://localhost:3000", "token").concurrency(8);
240        assert_eq!(builder.concurrency, 8);
241    }
242
243    #[test]
244    fn builder_concurrency_zero_accepted() {
245        let provider = Arc::new(ClaudeCodeProvider::new());
246        let builder = WorkerBuilder::new("http://localhost:3000", "token")
247            .provider(provider)
248            .concurrency(0);
249        assert_eq!(builder.concurrency, 0);
250    }
251
252    #[test]
253    fn builder_poll_interval_sets_interval() {
254        let interval = Duration::from_secs(5);
255        let builder = WorkerBuilder::new("http://localhost:3000", "token").poll_interval(interval);
256        assert_eq!(builder.poll_interval, interval);
257    }
258
259    #[test]
260    fn builder_build_without_provider_fails() {
261        let builder = WorkerBuilder::new("http://localhost:3000", "token");
262        let result = builder.build();
263        assert!(result.is_err());
264        match result {
265            Err(WorkerError::Internal(msg)) => {
266                assert!(msg.contains("provider is required"));
267            }
268            _ => panic!("expected Internal error about missing provider"),
269        }
270    }
271
272    #[test]
273    fn builder_build_with_provider_succeeds() {
274        let provider = Arc::new(ClaudeCodeProvider::new());
275        let builder = WorkerBuilder::new("http://localhost:3000", "token").provider(provider);
276        let result = builder.build();
277        assert!(result.is_ok());
278    }
279
280    #[test]
281    fn builder_build_creates_worker_with_correct_concurrency() {
282        let provider = Arc::new(ClaudeCodeProvider::new());
283        let builder = WorkerBuilder::new("http://localhost:3000", "token")
284            .provider(provider)
285            .concurrency(16);
286        let worker = builder.build().unwrap();
287        assert_eq!(worker.concurrency, 16);
288    }
289
290    #[test]
291    fn builder_build_creates_worker_with_correct_interval() {
292        let provider = Arc::new(ClaudeCodeProvider::new());
293        let interval = Duration::from_secs(10);
294        let builder = WorkerBuilder::new("http://localhost:3000", "token")
295            .provider(provider)
296            .poll_interval(interval);
297        let worker = builder.build().unwrap();
298        assert_eq!(worker.poll_interval, interval);
299    }
300
301    #[test]
302    fn builder_chaining_works() {
303        let provider = Arc::new(ClaudeCodeProvider::new());
304        let result = WorkerBuilder::new("http://localhost:3000", "token")
305            .provider(provider)
306            .concurrency(4)
307            .poll_interval(Duration::from_secs(3))
308            .build();
309        assert!(result.is_ok());
310        let worker = result.unwrap();
311        assert_eq!(worker.concurrency, 4);
312        assert_eq!(worker.poll_interval, Duration::from_secs(3));
313    }
314
315    #[test]
316    fn builder_empty_api_url_accepted() {
317        let provider = Arc::new(ClaudeCodeProvider::new());
318        let builder = WorkerBuilder::new("", "token").provider(provider);
319        let result = builder.build();
320        assert!(result.is_ok());
321    }
322
323    #[test]
324    fn builder_empty_token_accepted() {
325        let provider = Arc::new(ClaudeCodeProvider::new());
326        let builder = WorkerBuilder::new("http://localhost:3000", "").provider(provider);
327        let result = builder.build();
328        assert!(result.is_ok());
329    }
330}