ironflow_worker/
worker.rs1use std::sync::Arc;
4use std::time::Duration;
5
6use tokio::spawn;
7use tokio::sync::Semaphore;
8use tokio::time::sleep;
9use tracing::{error, info, warn};
10
11use ironflow_core::provider::AgentProvider;
12use ironflow_engine::engine::Engine;
13use ironflow_engine::handler::WorkflowHandler;
14use ironflow_store::store::RunStore;
15
16use crate::api_store::ApiRunStore;
17use crate::error::WorkerError;
18
19const DEFAULT_CONCURRENCY: usize = 2;
20const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(2);
21
22pub struct WorkerBuilder {
44 api_url: String,
45 worker_token: String,
46 provider: Option<Arc<dyn AgentProvider>>,
47 handlers: Vec<Box<dyn WorkflowHandler>>,
48 concurrency: usize,
49 poll_interval: Duration,
50}
51
52impl WorkerBuilder {
53 pub fn new(api_url: &str, worker_token: &str) -> Self {
55 Self {
56 api_url: api_url.to_string(),
57 worker_token: worker_token.to_string(),
58 provider: None,
59 handlers: Vec::new(),
60 concurrency: DEFAULT_CONCURRENCY,
61 poll_interval: DEFAULT_POLL_INTERVAL,
62 }
63 }
64
65 pub fn provider(mut self, provider: Arc<dyn AgentProvider>) -> Self {
67 self.provider = Some(provider);
68 self
69 }
70
71 pub fn register(mut self, handler: impl WorkflowHandler + 'static) -> Self {
73 self.handlers.push(Box::new(handler));
74 self
75 }
76
77 pub fn concurrency(mut self, n: usize) -> Self {
79 self.concurrency = n;
80 self
81 }
82
83 pub fn poll_interval(mut self, interval: Duration) -> Self {
85 self.poll_interval = interval;
86 self
87 }
88
89 pub fn build(self) -> Result<Worker, WorkerError> {
96 let provider = self
97 .provider
98 .ok_or_else(|| WorkerError::Internal("WorkerBuilder: provider is required".into()))?;
99
100 let store: Arc<dyn RunStore> =
101 Arc::new(ApiRunStore::new(&self.api_url, &self.worker_token));
102
103 let mut engine = Engine::new(store, provider);
104 for handler in self.handlers {
105 engine
106 .register_boxed(handler)
107 .map_err(WorkerError::Engine)?;
108 }
109
110 Ok(Worker {
111 engine: Arc::new(engine),
112 concurrency: self.concurrency,
113 poll_interval: self.poll_interval,
114 })
115 }
116}
117
118pub struct Worker {
120 engine: Arc<Engine>,
121 concurrency: usize,
122 poll_interval: Duration,
123}
124
125impl Worker {
126 pub async fn run(&self) -> Result<(), WorkerError> {
132 let semaphore = Arc::new(Semaphore::new(self.concurrency));
133 let mut idle_streak = 0u32;
134
135 info!(
136 concurrency = self.concurrency,
137 poll_interval_ms = self.poll_interval.as_millis() as u64,
138 "worker started"
139 );
140
141 loop {
142 let run = self.engine.store().pick_next_pending().await;
143
144 match run {
145 Ok(Some(run)) => {
146 let permit = semaphore
147 .clone()
148 .acquire_owned()
149 .await
150 .map_err(|_| WorkerError::Internal("semaphore closed".to_string()))?;
151
152 idle_streak = 0;
153 let engine = self.engine.clone();
154 let run_id = run.id;
155 let workflow = run.workflow_name.clone();
156
157 info!(run_id = %run_id, workflow = %workflow, "executing run");
158
159 let handle = spawn(async move {
160 let _permit = permit;
161 match engine.execute_handler_run(run_id).await {
162 Ok(_) => {
163 info!(run_id = %run_id, workflow = %workflow, "run completed");
164 }
165 Err(e) => {
166 error!(run_id = %run_id, workflow = %workflow, error = %e, "run failed");
167 }
168 }
169 });
170
171 let store = self.engine.store().clone();
173 spawn(async move {
174 if let Err(e) = handle.await {
175 error!(run_id = %run_id, "spawned task panicked: {e}");
176 if let Err(store_err) = store
177 .update_run_status(
178 run_id,
179 ironflow_store::entities::RunStatus::Failed,
180 )
181 .await
182 {
183 error!(run_id = %run_id, error = %store_err, "failed to mark panicked run as failed");
184 }
185 }
186 });
187 }
188 Ok(None) => {
189 idle_streak += 1;
190 let backoff = if idle_streak > 10 {
191 self.poll_interval * 3
192 } else if idle_streak > 5 {
193 self.poll_interval * 2
194 } else {
195 self.poll_interval
196 };
197 sleep(backoff).await;
198 }
199 Err(e) => {
200 warn!(error = %e, "poll error");
201 sleep(self.poll_interval).await;
202 }
203 }
204 }
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use ironflow_core::providers::claude::ClaudeCodeProvider;
212
213 #[test]
214 fn builder_new_creates_default_config() {
215 let builder = WorkerBuilder::new("http://localhost:3000", "my-token");
216 assert_eq!(builder.api_url, "http://localhost:3000");
217 assert_eq!(builder.worker_token, "my-token");
218 assert_eq!(builder.concurrency, DEFAULT_CONCURRENCY);
219 assert_eq!(builder.poll_interval, DEFAULT_POLL_INTERVAL);
220 assert!(builder.provider.is_none());
221 }
222
223 #[test]
224 fn builder_with_trailing_slash_normalized() {
225 let builder = WorkerBuilder::new("http://localhost:3000/", "token");
226 assert_eq!(builder.api_url, "http://localhost:3000/");
227 }
228
229 #[test]
230 fn builder_provider_sets_provider() {
231 let provider = Arc::new(ClaudeCodeProvider::new());
232 let builder =
233 WorkerBuilder::new("http://localhost:3000", "token").provider(provider.clone());
234 assert!(builder.provider.is_some());
235 }
236
237 #[test]
238 fn builder_concurrency_sets_concurrency() {
239 let builder = WorkerBuilder::new("http://localhost:3000", "token").concurrency(8);
240 assert_eq!(builder.concurrency, 8);
241 }
242
243 #[test]
244 fn builder_concurrency_zero_accepted() {
245 let provider = Arc::new(ClaudeCodeProvider::new());
246 let builder = WorkerBuilder::new("http://localhost:3000", "token")
247 .provider(provider)
248 .concurrency(0);
249 assert_eq!(builder.concurrency, 0);
250 }
251
252 #[test]
253 fn builder_poll_interval_sets_interval() {
254 let interval = Duration::from_secs(5);
255 let builder = WorkerBuilder::new("http://localhost:3000", "token").poll_interval(interval);
256 assert_eq!(builder.poll_interval, interval);
257 }
258
259 #[test]
260 fn builder_build_without_provider_fails() {
261 let builder = WorkerBuilder::new("http://localhost:3000", "token");
262 let result = builder.build();
263 assert!(result.is_err());
264 match result {
265 Err(WorkerError::Internal(msg)) => {
266 assert!(msg.contains("provider is required"));
267 }
268 _ => panic!("expected Internal error about missing provider"),
269 }
270 }
271
272 #[test]
273 fn builder_build_with_provider_succeeds() {
274 let provider = Arc::new(ClaudeCodeProvider::new());
275 let builder = WorkerBuilder::new("http://localhost:3000", "token").provider(provider);
276 let result = builder.build();
277 assert!(result.is_ok());
278 }
279
280 #[test]
281 fn builder_build_creates_worker_with_correct_concurrency() {
282 let provider = Arc::new(ClaudeCodeProvider::new());
283 let builder = WorkerBuilder::new("http://localhost:3000", "token")
284 .provider(provider)
285 .concurrency(16);
286 let worker = builder.build().unwrap();
287 assert_eq!(worker.concurrency, 16);
288 }
289
290 #[test]
291 fn builder_build_creates_worker_with_correct_interval() {
292 let provider = Arc::new(ClaudeCodeProvider::new());
293 let interval = Duration::from_secs(10);
294 let builder = WorkerBuilder::new("http://localhost:3000", "token")
295 .provider(provider)
296 .poll_interval(interval);
297 let worker = builder.build().unwrap();
298 assert_eq!(worker.poll_interval, interval);
299 }
300
301 #[test]
302 fn builder_chaining_works() {
303 let provider = Arc::new(ClaudeCodeProvider::new());
304 let result = WorkerBuilder::new("http://localhost:3000", "token")
305 .provider(provider)
306 .concurrency(4)
307 .poll_interval(Duration::from_secs(3))
308 .build();
309 assert!(result.is_ok());
310 let worker = result.unwrap();
311 assert_eq!(worker.concurrency, 4);
312 assert_eq!(worker.poll_interval, Duration::from_secs(3));
313 }
314
315 #[test]
316 fn builder_empty_api_url_accepted() {
317 let provider = Arc::new(ClaudeCodeProvider::new());
318 let builder = WorkerBuilder::new("", "token").provider(provider);
319 let result = builder.build();
320 assert!(result.is_ok());
321 }
322
323 #[test]
324 fn builder_empty_token_accepted() {
325 let provider = Arc::new(ClaudeCodeProvider::new());
326 let builder = WorkerBuilder::new("http://localhost:3000", "").provider(provider);
327 let result = builder.build();
328 assert!(result.is_ok());
329 }
330}