ironflow_worker/
worker.rs1use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::sync::Semaphore;
7use tracing::{error, info, warn};
8
9use ironflow_core::provider::AgentProvider;
10use ironflow_engine::engine::Engine;
11use ironflow_engine::handler::WorkflowHandler;
12use ironflow_store::store::RunStore;
13
14use crate::api_store::ApiRunStore;
15use crate::error::WorkerError;
16
17const DEFAULT_CONCURRENCY: usize = 2;
18const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(2);
19
20pub struct WorkerBuilder {
42 api_url: String,
43 worker_token: String,
44 provider: Option<Arc<dyn AgentProvider>>,
45 handlers: Vec<Box<dyn WorkflowHandler>>,
46 concurrency: usize,
47 poll_interval: Duration,
48}
49
50impl WorkerBuilder {
51 pub fn new(api_url: &str, worker_token: &str) -> Self {
53 Self {
54 api_url: api_url.to_string(),
55 worker_token: worker_token.to_string(),
56 provider: None,
57 handlers: Vec::new(),
58 concurrency: DEFAULT_CONCURRENCY,
59 poll_interval: DEFAULT_POLL_INTERVAL,
60 }
61 }
62
63 pub fn provider(mut self, provider: Arc<dyn AgentProvider>) -> Self {
65 self.provider = Some(provider);
66 self
67 }
68
69 pub fn register(mut self, handler: impl WorkflowHandler + 'static) -> Self {
71 self.handlers.push(Box::new(handler));
72 self
73 }
74
75 pub fn concurrency(mut self, n: usize) -> Self {
77 self.concurrency = n;
78 self
79 }
80
81 pub fn poll_interval(mut self, interval: Duration) -> Self {
83 self.poll_interval = interval;
84 self
85 }
86
87 pub fn build(self) -> Result<Worker, WorkerError> {
94 let provider = self
95 .provider
96 .ok_or_else(|| WorkerError::Internal("WorkerBuilder: provider is required".into()))?;
97
98 let store: Arc<dyn RunStore> =
99 Arc::new(ApiRunStore::new(&self.api_url, &self.worker_token));
100
101 let mut engine = Engine::new(store, provider);
102 for handler in self.handlers {
103 engine
104 .register_boxed(handler)
105 .map_err(WorkerError::Engine)?;
106 }
107
108 Ok(Worker {
109 engine: Arc::new(engine),
110 concurrency: self.concurrency,
111 poll_interval: self.poll_interval,
112 })
113 }
114}
115
116pub struct Worker {
118 engine: Arc<Engine>,
119 concurrency: usize,
120 poll_interval: Duration,
121}
122
123impl Worker {
124 pub async fn run(&self) -> Result<(), WorkerError> {
130 let semaphore = Arc::new(Semaphore::new(self.concurrency));
131 let mut idle_streak = 0u32;
132
133 info!(
134 concurrency = self.concurrency,
135 poll_interval_ms = self.poll_interval.as_millis() as u64,
136 "worker started"
137 );
138
139 loop {
140 let run = self.engine.store().pick_next_pending().await;
141
142 match run {
143 Ok(Some(run)) => {
144 let permit = semaphore
145 .clone()
146 .acquire_owned()
147 .await
148 .map_err(|_| WorkerError::Internal("semaphore closed".to_string()))?;
149
150 idle_streak = 0;
151 let engine = self.engine.clone();
152 let run_id = run.id;
153 let workflow = run.workflow_name.clone();
154
155 info!(run_id = %run_id, workflow = %workflow, "executing run");
156
157 let handle = tokio::spawn(async move {
158 let _permit = permit;
159 match engine.execute_handler_run(run_id).await {
160 Ok(_) => {
161 info!(run_id = %run_id, workflow = %workflow, "run completed");
162 }
163 Err(e) => {
164 error!(run_id = %run_id, workflow = %workflow, error = %e, "run failed");
165 }
166 }
167 });
168
169 let store = self.engine.store().clone();
171 tokio::spawn(async move {
172 if let Err(e) = handle.await {
173 error!(run_id = %run_id, "spawned task panicked: {e}");
174 if let Err(store_err) = store
175 .update_run_status(
176 run_id,
177 ironflow_store::entities::RunStatus::Failed,
178 )
179 .await
180 {
181 error!(run_id = %run_id, error = %store_err, "failed to mark panicked run as failed");
182 }
183 }
184 });
185 }
186 Ok(None) => {
187 idle_streak += 1;
188 let backoff = if idle_streak > 10 {
189 self.poll_interval * 3
190 } else if idle_streak > 5 {
191 self.poll_interval * 2
192 } else {
193 self.poll_interval
194 };
195 tokio::time::sleep(backoff).await;
196 }
197 Err(e) => {
198 warn!(error = %e, "poll error");
199 tokio::time::sleep(self.poll_interval).await;
200 }
201 }
202 }
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use ironflow_core::providers::claude::ClaudeCodeProvider;
210
211 #[test]
212 fn builder_new_creates_default_config() {
213 let builder = WorkerBuilder::new("http://localhost:3000", "my-token");
214 assert_eq!(builder.api_url, "http://localhost:3000");
215 assert_eq!(builder.worker_token, "my-token");
216 assert_eq!(builder.concurrency, DEFAULT_CONCURRENCY);
217 assert_eq!(builder.poll_interval, DEFAULT_POLL_INTERVAL);
218 assert!(builder.provider.is_none());
219 }
220
221 #[test]
222 fn builder_with_trailing_slash_normalized() {
223 let builder = WorkerBuilder::new("http://localhost:3000/", "token");
224 assert_eq!(builder.api_url, "http://localhost:3000/");
225 }
226
227 #[test]
228 fn builder_provider_sets_provider() {
229 let provider = Arc::new(ClaudeCodeProvider::new());
230 let builder =
231 WorkerBuilder::new("http://localhost:3000", "token").provider(provider.clone());
232 assert!(builder.provider.is_some());
233 }
234
235 #[test]
236 fn builder_concurrency_sets_concurrency() {
237 let builder = WorkerBuilder::new("http://localhost:3000", "token").concurrency(8);
238 assert_eq!(builder.concurrency, 8);
239 }
240
241 #[test]
242 fn builder_concurrency_zero_accepted() {
243 let provider = Arc::new(ClaudeCodeProvider::new());
244 let builder = WorkerBuilder::new("http://localhost:3000", "token")
245 .provider(provider)
246 .concurrency(0);
247 assert_eq!(builder.concurrency, 0);
248 }
249
250 #[test]
251 fn builder_poll_interval_sets_interval() {
252 let interval = Duration::from_secs(5);
253 let builder = WorkerBuilder::new("http://localhost:3000", "token").poll_interval(interval);
254 assert_eq!(builder.poll_interval, interval);
255 }
256
257 #[test]
258 fn builder_build_without_provider_fails() {
259 let builder = WorkerBuilder::new("http://localhost:3000", "token");
260 let result = builder.build();
261 assert!(result.is_err());
262 match result {
263 Err(WorkerError::Internal(msg)) => {
264 assert!(msg.contains("provider is required"));
265 }
266 _ => panic!("expected Internal error about missing provider"),
267 }
268 }
269
270 #[test]
271 fn builder_build_with_provider_succeeds() {
272 let provider = Arc::new(ClaudeCodeProvider::new());
273 let builder = WorkerBuilder::new("http://localhost:3000", "token").provider(provider);
274 let result = builder.build();
275 assert!(result.is_ok());
276 }
277
278 #[test]
279 fn builder_build_creates_worker_with_correct_concurrency() {
280 let provider = Arc::new(ClaudeCodeProvider::new());
281 let builder = WorkerBuilder::new("http://localhost:3000", "token")
282 .provider(provider)
283 .concurrency(16);
284 let worker = builder.build().unwrap();
285 assert_eq!(worker.concurrency, 16);
286 }
287
288 #[test]
289 fn builder_build_creates_worker_with_correct_interval() {
290 let provider = Arc::new(ClaudeCodeProvider::new());
291 let interval = Duration::from_secs(10);
292 let builder = WorkerBuilder::new("http://localhost:3000", "token")
293 .provider(provider)
294 .poll_interval(interval);
295 let worker = builder.build().unwrap();
296 assert_eq!(worker.poll_interval, interval);
297 }
298
299 #[test]
300 fn builder_chaining_works() {
301 let provider = Arc::new(ClaudeCodeProvider::new());
302 let result = WorkerBuilder::new("http://localhost:3000", "token")
303 .provider(provider)
304 .concurrency(4)
305 .poll_interval(Duration::from_secs(3))
306 .build();
307 assert!(result.is_ok());
308 let worker = result.unwrap();
309 assert_eq!(worker.concurrency, 4);
310 assert_eq!(worker.poll_interval, Duration::from_secs(3));
311 }
312
313 #[test]
314 fn builder_empty_api_url_accepted() {
315 let provider = Arc::new(ClaudeCodeProvider::new());
316 let builder = WorkerBuilder::new("", "token").provider(provider);
317 let result = builder.build();
318 assert!(result.is_ok());
319 }
320
321 #[test]
322 fn builder_empty_token_accepted() {
323 let provider = Arc::new(ClaudeCodeProvider::new());
324 let builder = WorkerBuilder::new("http://localhost:3000", "").provider(provider);
325 let result = builder.build();
326 assert!(result.is_ok());
327 }
328}