1use std::sync::Arc;
14use std::time::Duration;
15
16use tokio::task::JoinHandle;
17use tokio::time::{sleep, timeout};
18use tokio_util::sync::CancellationToken;
19use tracing::{Instrument, Level, event, info_span};
20
21use crate::jobs::{Job, JobKind, JobState, MemoryJobsStore};
22
23use super::{Client, ClientError, ClientInner};
24
25pub const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(1);
31
32pub const DEFAULT_LEASE_DURATION: Duration = Duration::from_secs(60);
39
40pub const DEFAULT_MAX_ATTEMPTS: i32 = 3;
46
47pub const DEFAULT_DRAIN_TIMEOUT: Duration = Duration::from_secs(30);
53
54#[derive(Debug)]
72#[must_use = "spawn_worker() returns a builder; call .start() to launch the task"]
73pub struct WorkerBuilder<'a> {
74 client: &'a Client,
75 poll_interval: Duration,
76 lease_duration: Duration,
77 max_attempts: i32,
78 drain_timeout: Duration,
79 claimed_by: Option<String>,
80}
81
82impl<'a> WorkerBuilder<'a> {
83 pub(super) fn new(client: &'a Client) -> Self {
84 Self {
85 client,
86 poll_interval: DEFAULT_POLL_INTERVAL,
87 lease_duration: DEFAULT_LEASE_DURATION,
88 max_attempts: DEFAULT_MAX_ATTEMPTS,
89 drain_timeout: DEFAULT_DRAIN_TIMEOUT,
90 claimed_by: None,
91 }
92 }
93
94 pub fn poll_interval(mut self, interval: Duration) -> Self {
96 self.poll_interval = interval;
97 self
98 }
99
100 pub fn lease_duration(mut self, lease: Duration) -> Self {
106 self.lease_duration = lease;
107 self
108 }
109
110 pub fn max_attempts(mut self, max: i32) -> Self {
113 self.max_attempts = max;
114 self
115 }
116
117 pub fn drain_timeout(mut self, timeout: Duration) -> Self {
120 self.drain_timeout = timeout;
121 self
122 }
123
124 pub fn claimed_by(mut self, id: impl Into<String>) -> Self {
130 self.claimed_by = Some(id.into());
131 self
132 }
133
134 pub async fn start(self) -> Result<WorkerHandle, ClientError> {
142 let WorkerBuilder {
143 client,
144 poll_interval,
145 lease_duration,
146 max_attempts,
147 drain_timeout,
148 claimed_by,
149 } = self;
150
151 let token = CancellationToken::new();
152 let inner = client.inner.clone();
153 let config = WorkerConfig {
154 poll_interval,
155 lease_duration,
156 max_attempts,
157 claimed_by,
158 };
159
160 let span = info_span!("memoir.worker");
161 let token_for_task = token.clone();
162 let join = tokio::spawn(async move { run_worker(inner, config, token_for_task).await }.instrument(span));
163
164 Ok(WorkerHandle {
165 inner: Arc::new(WorkerHandleInner {
166 join: tokio::sync::Mutex::new(Some(join)),
167 token,
168 drain_timeout,
169 }),
170 })
171 }
172}
173
174#[derive(Clone)]
180pub struct WorkerHandle {
181 inner: Arc<WorkerHandleInner>,
182}
183
184impl std::fmt::Debug for WorkerHandle {
185 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186 f.debug_struct("WorkerHandle")
187 .field("is_shutting_down", &self.inner.token.is_cancelled())
188 .field("drain_timeout", &self.inner.drain_timeout)
189 .finish_non_exhaustive()
190 }
191}
192
193struct WorkerHandleInner {
194 join: tokio::sync::Mutex<Option<JoinHandle<()>>>,
195 token: CancellationToken,
196 drain_timeout: Duration,
197}
198
199impl WorkerHandle {
200 #[must_use]
202 pub fn is_shutting_down(&self) -> bool {
203 self.inner.token.is_cancelled()
204 }
205
206 #[must_use]
216 pub fn cancellation_token(&self) -> CancellationToken {
217 self.inner.token.child_token()
218 }
219
220 pub async fn shutdown(&self) {
229 self.inner.token.cancel();
230
231 let mut guard = self.inner.join.lock().await;
232 let Some(join) = guard.take() else {
233 return;
234 };
235
236 match timeout(self.inner.drain_timeout, join).await {
237 Ok(Ok(())) => {
238 event!(
239 name: "memoir.worker.shutdown",
240 Level::INFO,
241 outcome = "drained",
242 "worker shutdown {{outcome}}",
243 );
244 }
245 Ok(Err(err)) => {
246 event!(
247 name: "memoir.worker.shutdown",
248 Level::WARN,
249 outcome = "join_failed",
250 error.message = %err,
251 "worker shutdown {{outcome}}: {{error.message}}",
252 );
253 }
254 Err(_) => {
255 event!(
256 name: "memoir.worker.shutdown",
257 Level::WARN,
258 outcome = "timeout",
259 "worker shutdown {{outcome}} (drain deadline exceeded; task continues until natural exit)",
260 );
261 }
266 }
267 }
268
269 pub async fn abort(&self) {
274 self.inner.token.cancel();
275 let mut guard = self.inner.join.lock().await;
276 if let Some(join) = guard.take() {
277 join.abort();
278 event!(
279 name: "memoir.worker.aborted",
280 Level::WARN,
281 outcome = "aborted",
282 "worker {{outcome}}",
283 );
284 }
285 }
286}
287
288#[derive(Clone)]
289struct WorkerConfig {
290 poll_interval: Duration,
291 lease_duration: Duration,
292 max_attempts: i32,
293 claimed_by: Option<String>,
294}
295
296async fn run_worker(inner: Arc<ClientInner>, config: WorkerConfig, token: CancellationToken) {
297 let poll_interval_ms = u64::try_from(config.poll_interval.as_millis()).unwrap_or(u64::MAX);
300 event!(
301 name: "memoir.worker.started",
302 Level::INFO,
303 poll_interval_ms = poll_interval_ms,
304 lease_secs = config.lease_duration.as_secs(),
305 max_attempts = config.max_attempts,
306 "worker started: poll_interval={{poll_interval_ms}}ms lease={{lease_secs}}s max_attempts={{max_attempts}}",
307 );
308
309 while !token.is_cancelled() {
310 let claimed_by = config.claimed_by.as_deref();
311 let claim_result = inner.jobs.claim(claimed_by).await;
312
313 match claim_result {
314 Ok(Some(job)) => {
315 dispatch(&inner, job, config.max_attempts).await;
316 }
317 Ok(None) => {
318 match inner.jobs.reset_expired_leases(config.lease_duration).await {
320 Ok(0) => {}
321 Ok(n) => {
322 event!(
323 name: "memoir.worker.lease_recovered",
324 Level::INFO,
325 count = n,
326 "recovered {{count}} expired lease(s)",
327 );
328 }
329 Err(err) => {
330 event!(
331 name: "memoir.worker.lease_recovery_failed",
332 Level::WARN,
333 error.message = %err,
334 "lease recovery failed: {{error.message}}",
335 );
336 }
337 }
338
339 wait_or_cancel(&token, config.poll_interval).await;
340 }
341 Err(err) => {
342 event!(
343 name: "memoir.worker.claim_failed",
344 Level::WARN,
345 error.message = %err,
346 "claim failed: {{error.message}}; backing off",
347 );
348 wait_or_cancel(&token, config.poll_interval).await;
349 }
350 }
351 }
352
353 event!(
354 name: "memoir.worker.exited",
355 Level::INFO,
356 outcome = "exited",
357 "worker loop {{outcome}}",
358 );
359}
360
361async fn wait_or_cancel(token: &CancellationToken, dur: Duration) {
363 tokio::select! {
364 _ = sleep(dur) => {}
365 _ = token.cancelled() => {}
366 }
367}
368
369async fn dispatch(inner: &Arc<ClientInner>, job: Job, max_attempts: i32) {
374 debug_assert_eq!(job.state, JobState::Claimed);
375
376 let job_span = info_span!(
377 "memoir.worker.job",
378 job_id = job.id,
379 kind = %job.kind,
380 source_pid = %job.source_pid,
381 );
382 let _enter = job_span.enter();
383
384 event!(
385 name: "memoir.worker.job_started",
386 Level::DEBUG,
387 outcome = "claimed",
388 "job {{outcome}}",
389 );
390
391 let result: Result<(), String> = match job.kind {
392 JobKind::Extract => inner.run_extract(job.clone()).await.map_err(|err| err.to_string()),
393 JobKind::Embed => inner
394 .run_embed_job(&job.source_pid)
395 .await
396 .map_err(|err| err.to_string()),
397 JobKind::Categorize => inner.run_categorize(job.clone()).await.map_err(|err| err.to_string()),
398 JobKind::Reprocess => inner.run_reprocess(job.clone()).await.map_err(|err| err.to_string()),
399 #[cfg(feature = "knowledge-graph")]
400 JobKind::RelationalExtract => inner
401 .run_relational_extract(job.clone())
402 .await
403 .map_err(|err| err.to_string()),
404 #[cfg(feature = "knowledge-graph")]
405 JobKind::Synthesize => inner.run_synthesize(job.clone()).await.map_err(|err| err.to_string()),
406 #[cfg(not(feature = "knowledge-graph"))]
409 JobKind::RelationalExtract | JobKind::Synthesize => Ok(()),
410 };
411
412 match result {
413 Ok(()) => match inner.jobs.complete(job.id).await {
414 Ok(()) => event!(
415 name: "memoir.worker.job_succeeded",
416 Level::INFO,
417 outcome = "succeeded",
418 "job {{outcome}}",
419 ),
420 Err(err) => event!(
421 name: "memoir.worker.complete_failed",
422 Level::WARN,
423 error.message = %err,
424 "complete failed after successful dispatch: {{error.message}}",
425 ),
426 },
427 Err(reason) => {
428 if let Err(fail_err) = inner.jobs.fail(job.id, reason.clone(), max_attempts).await {
429 event!(
430 name: "memoir.worker.fail_failed",
431 Level::WARN,
432 error.message = %fail_err,
433 "fail call itself failed: {{error.message}}",
434 );
435 } else {
436 event!(
437 name: "memoir.worker.job_failed",
438 Level::WARN,
439 error.message = %reason,
440 "job failed: {{error.message}}",
441 );
442 }
443 }
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 const fn assert_send<T: Send>() {}
453 const _: () = assert_send::<WorkerHandle>();
454
455 #[test]
456 fn should_use_default_constants_for_builder() {
457 assert_eq!(DEFAULT_POLL_INTERVAL, Duration::from_secs(1));
460 assert_eq!(DEFAULT_LEASE_DURATION, Duration::from_secs(60));
461 assert_eq!(DEFAULT_MAX_ATTEMPTS, 3);
462 assert_eq!(DEFAULT_DRAIN_TIMEOUT, Duration::from_secs(30));
463 }
464
465 #[tokio::test(flavor = "current_thread")]
466 async fn should_wait_or_cancel_complete_when_uncancelled() {
467 let token = CancellationToken::new();
468 let start = std::time::Instant::now();
469 wait_or_cancel(&token, Duration::from_millis(10)).await;
470 assert!(
471 start.elapsed() >= Duration::from_millis(10),
472 "expected ~10ms sleep without cancellation"
473 );
474 assert!(!token.is_cancelled());
475 }
476
477 #[tokio::test(flavor = "current_thread")]
478 async fn should_wait_or_cancel_return_immediately_when_cancelled() {
479 let token = CancellationToken::new();
480 token.cancel();
481
482 let start = std::time::Instant::now();
483 wait_or_cancel(&token, Duration::from_secs(60)).await;
485 assert!(
486 start.elapsed() < Duration::from_millis(100),
487 "cancellation should wake us nearly instantly"
488 );
489 }
490
491 #[tokio::test(flavor = "current_thread")]
492 async fn should_worker_handle_track_shutdown_state() {
493 let token = CancellationToken::new();
494 let join = tokio::spawn(async {});
495 let handle = WorkerHandle {
496 inner: Arc::new(WorkerHandleInner {
497 join: tokio::sync::Mutex::new(Some(join)),
498 token: token.clone(),
499 drain_timeout: Duration::from_secs(1),
500 }),
501 };
502
503 assert!(!handle.is_shutting_down());
504 token.cancel();
505 assert!(handle.is_shutting_down());
506 }
507
508 #[tokio::test(flavor = "current_thread")]
509 async fn should_child_token_inherit_cancellation_from_parent() {
510 let token = CancellationToken::new();
511 let join = tokio::spawn(async {});
512 let handle = WorkerHandle {
513 inner: Arc::new(WorkerHandleInner {
514 join: tokio::sync::Mutex::new(Some(join)),
515 token: token.clone(),
516 drain_timeout: Duration::from_secs(1),
517 }),
518 };
519
520 let child = handle.cancellation_token();
521 assert!(!child.is_cancelled());
522 token.cancel();
523 assert!(child.is_cancelled(), "child should observe parent cancellation");
524 }
525}