Skip to main content

memoir_core/client/
worker.rs

1//! In-library worker that drains the `memory_jobs` queue.
2//!
3//! The worker is a detached tokio task launched via
4//! [`super::Client::spawn_worker`]. It polls the queue, dispatches each job
5//! to its stage handler, and completes-or-fails the job in the store. The
6//! handler dispatch is a no-op in this scaffolding pass; tickets 0006
7//! (extract) and 0007 (embed) replace the placeholder with real work.
8//!
9//! Shutdown is cooperative. Sending the shutdown signal lets the worker
10//! finish its current job before exiting. A drain timeout caps how long the
11//! caller is willing to wait.
12
13use std::sync::Arc;
14use std::time::Duration;
15
16use tokio::task::JoinHandle;
17use tokio::time::{sleep, timeout};
18use tokio_util::sync::CancellationToken;
19use tracing::{Instrument, Level, event, info_span};
20
21use crate::jobs::{Job, JobKind, JobState, MemoryJobsStore};
22
23use super::{Client, ClientError, ClientInner};
24
25/// Default interval between empty-queue polls.
26///
27/// One second balances responsiveness against idle CPU. Lower values make
28/// newly-enqueued work pick up faster but waste CPU when the queue is idle;
29/// higher values are friendlier to laptops + cheaper hosts.
30pub const DEFAULT_POLL_INTERVAL: Duration = Duration::from_secs(1);
31
32/// Default lease duration; claims older than this get recovered.
33///
34/// Sixty seconds is long enough that a healthy worker's claim won't be
35/// stolen mid-work, short enough that a crashed worker's claim recovers
36/// quickly. Tune up if extraction LLM calls regularly exceed this; tune
37/// down if rapid recovery matters more than tolerating slow operations.
38pub const DEFAULT_LEASE_DURATION: Duration = Duration::from_secs(60);
39
40/// Default max retry count before a job moves to `failed`.
41///
42/// Three attempts catches transient failures (network blips, momentary LLM
43/// rate limits) without amplifying systemic ones. Operators raise this only
44/// when working with provider tiers that have heavy throttling.
45pub const DEFAULT_MAX_ATTEMPTS: i32 = 3;
46
47/// Default cap on graceful drain after `.shutdown()` is called.
48///
49/// Long enough to let a typical extraction job finish (LLM + DB writes
50/// usually <10s), short enough that a hung worker doesn't block server
51/// shutdown indefinitely.
52pub const DEFAULT_DRAIN_TIMEOUT: Duration = Duration::from_secs(30);
53
54/// Per-call builder returned by [`Client::spawn_worker`].
55///
56/// Configure via toggle methods, then call [`Self::start`] to spawn the
57/// worker task. Returns a [`WorkerHandle`] the caller uses to observe and
58/// shut down the worker.
59///
60/// # Examples
61///
62/// ```no_run
63/// # use memoir_core::client::Client;
64/// # async fn example(client: &Client) -> Result<(), Box<dyn std::error::Error>> {
65/// let worker = client.spawn_worker().start().await?;
66/// // ... server runs ...
67/// worker.shutdown().await;
68/// # Ok(())
69/// # }
70/// ```
71#[derive(Debug)]
72#[must_use = "spawn_worker() returns a builder; call .start() to launch the task"]
73pub struct WorkerBuilder<'a> {
74    client: &'a Client,
75    poll_interval: Duration,
76    lease_duration: Duration,
77    max_attempts: i32,
78    drain_timeout: Duration,
79    claimed_by: Option<String>,
80}
81
82impl<'a> WorkerBuilder<'a> {
83    pub(super) fn new(client: &'a Client) -> Self {
84        Self {
85            client,
86            poll_interval: DEFAULT_POLL_INTERVAL,
87            lease_duration: DEFAULT_LEASE_DURATION,
88            max_attempts: DEFAULT_MAX_ATTEMPTS,
89            drain_timeout: DEFAULT_DRAIN_TIMEOUT,
90            claimed_by: None,
91        }
92    }
93
94    /// Interval between polls when the queue is empty. Default 1 second.
95    pub fn poll_interval(mut self, interval: Duration) -> Self {
96        self.poll_interval = interval;
97        self
98    }
99
100    /// Lease duration for in-flight claims. Default 60 seconds.
101    ///
102    /// A worker that crashes before completing a job leaves the row in
103    /// `claimed` state with stale `claimed_at`. The lease-recovery sweep
104    /// re-pends rows older than this duration.
105    pub fn lease_duration(mut self, lease: Duration) -> Self {
106        self.lease_duration = lease;
107        self
108    }
109
110    /// Maximum failed attempts before a job moves to terminal `failed`.
111    /// Default 3.
112    pub fn max_attempts(mut self, max: i32) -> Self {
113        self.max_attempts = max;
114        self
115    }
116
117    /// Maximum time `.shutdown()` waits for the current job to finish.
118    /// Default 30 seconds. After this, the task is aborted.
119    pub fn drain_timeout(mut self, timeout: Duration) -> Self {
120        self.drain_timeout = timeout;
121        self
122    }
123
124    /// Identifier persisted on each claim's `claimed_by` column.
125    ///
126    /// Default `None`. Useful when multiple workers share a queue — e.g.
127    /// `hostname-pid` lets operators identify which process holds a stale
128    /// lease.
129    pub fn claimed_by(mut self, id: impl Into<String>) -> Self {
130        self.claimed_by = Some(id.into());
131        self
132    }
133
134    /// Spawns the worker task and returns a handle.
135    ///
136    /// # Errors
137    ///
138    /// Currently infallible (returns `Ok` unconditionally); the `Result`
139    /// signature reserves room for startup-time validation that downstream
140    /// tickets (0010 LLM config) may add.
141    pub async fn start(self) -> Result<WorkerHandle, ClientError> {
142        let WorkerBuilder {
143            client,
144            poll_interval,
145            lease_duration,
146            max_attempts,
147            drain_timeout,
148            claimed_by,
149        } = self;
150
151        let token = CancellationToken::new();
152        let inner = client.inner.clone();
153        let config = WorkerConfig {
154            poll_interval,
155            lease_duration,
156            max_attempts,
157            claimed_by,
158        };
159
160        let span = info_span!("memoir.worker");
161        let token_for_task = token.clone();
162        let join = tokio::spawn(async move { run_worker(inner, config, token_for_task).await }.instrument(span));
163
164        Ok(WorkerHandle {
165            inner: Arc::new(WorkerHandleInner {
166                join: tokio::sync::Mutex::new(Some(join)),
167                token,
168                drain_timeout,
169            }),
170        })
171    }
172}
173
174/// Handle to a running worker task.
175///
176/// Cheap to clone — internally `Arc`'d so multiple call sites can hold
177/// references. Dropping the last clone does NOT trigger shutdown; callers
178/// should explicitly invoke [`Self::shutdown`] on graceful-stop paths.
179#[derive(Clone)]
180pub struct WorkerHandle {
181    inner: Arc<WorkerHandleInner>,
182}
183
184impl std::fmt::Debug for WorkerHandle {
185    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
186        f.debug_struct("WorkerHandle")
187            .field("is_shutting_down", &self.inner.token.is_cancelled())
188            .field("drain_timeout", &self.inner.drain_timeout)
189            .finish_non_exhaustive()
190    }
191}
192
193struct WorkerHandleInner {
194    join: tokio::sync::Mutex<Option<JoinHandle<()>>>,
195    token: CancellationToken,
196    drain_timeout: Duration,
197}
198
199impl WorkerHandle {
200    /// Returns `true` if the worker has been signaled to stop.
201    #[must_use]
202    pub fn is_shutting_down(&self) -> bool {
203        self.inner.token.is_cancelled()
204    }
205
206    /// Returns a child [`CancellationToken`] tied to the worker's lifecycle.
207    ///
208    /// Child tokens are cancelled when the worker itself is shut down. Useful
209    /// when downstream subtasks want to share the same shutdown semantics.
210    ///
211    /// `CancellationToken` is leaked here from `tokio-util` deliberately:
212    /// it is the de-facto-standard cooperative cancellation primitive in the
213    /// tokio ecosystem, and exposing it gives consumers direct
214    /// interoperability with the rest of their async code.
215    #[must_use]
216    pub fn cancellation_token(&self) -> CancellationToken {
217        self.inner.token.child_token()
218    }
219
220    /// Signals the worker to stop and waits for it to drain.
221    ///
222    /// The worker finishes its current job (if any), declines to claim a new
223    /// one, and exits. If the drain timeout elapses first, the task is
224    /// aborted and any in-flight claim leaks until the lease expires.
225    ///
226    /// Calling `shutdown` more than once is safe — subsequent calls observe
227    /// the already-shut-down state and return immediately.
228    pub async fn shutdown(&self) {
229        self.inner.token.cancel();
230
231        let mut guard = self.inner.join.lock().await;
232        let Some(join) = guard.take() else {
233            return;
234        };
235
236        match timeout(self.inner.drain_timeout, join).await {
237            Ok(Ok(())) => {
238                event!(
239                    name: "memoir.worker.shutdown",
240                    Level::INFO,
241                    outcome = "drained",
242                    "worker shutdown {{outcome}}",
243                );
244            }
245            Ok(Err(err)) => {
246                event!(
247                    name: "memoir.worker.shutdown",
248                    Level::WARN,
249                    outcome = "join_failed",
250                    error.message = %err,
251                    "worker shutdown {{outcome}}: {{error.message}}",
252                );
253            }
254            Err(_) => {
255                event!(
256                    name: "memoir.worker.shutdown",
257                    Level::WARN,
258                    outcome = "timeout",
259                    "worker shutdown {{outcome}} (drain deadline exceeded; task continues until natural exit)",
260                );
261                // Note: we can't abort here because we already took the
262                // JoinHandle out of the Option. The task continues until it
263                // naturally exits or the runtime drops. Consumers who need
264                // hard-abort semantics should call `.abort()` explicitly.
265            }
266        }
267    }
268
269    /// Aborts the worker task without waiting for graceful drain.
270    ///
271    /// In-flight claims leak until their lease expires. Prefer
272    /// [`Self::shutdown`] except in emergency shutdown paths.
273    pub async fn abort(&self) {
274        self.inner.token.cancel();
275        let mut guard = self.inner.join.lock().await;
276        if let Some(join) = guard.take() {
277            join.abort();
278            event!(
279                name: "memoir.worker.aborted",
280                Level::WARN,
281                outcome = "aborted",
282                "worker {{outcome}}",
283            );
284        }
285    }
286}
287
288#[derive(Clone)]
289struct WorkerConfig {
290    poll_interval: Duration,
291    lease_duration: Duration,
292    max_attempts: i32,
293    claimed_by: Option<String>,
294}
295
296async fn run_worker(inner: Arc<ClientInner>, config: WorkerConfig, token: CancellationToken) {
297    // `as_millis()` returns u128; cap at u64::MAX since tracing event fields
298    // accept u64 and durations beyond ~584 million years aren't a real concern.
299    let poll_interval_ms = u64::try_from(config.poll_interval.as_millis()).unwrap_or(u64::MAX);
300    event!(
301        name: "memoir.worker.started",
302        Level::INFO,
303        poll_interval_ms = poll_interval_ms,
304        lease_secs = config.lease_duration.as_secs(),
305        max_attempts = config.max_attempts,
306        "worker started: poll_interval={{poll_interval_ms}}ms lease={{lease_secs}}s max_attempts={{max_attempts}}",
307    );
308
309    while !token.is_cancelled() {
310        let claimed_by = config.claimed_by.as_deref();
311        let claim_result = inner.jobs.claim(claimed_by).await;
312
313        match claim_result {
314            Ok(Some(job)) => {
315                dispatch(&inner, job, config.max_attempts).await;
316            }
317            Ok(None) => {
318                // Queue empty: recover expired leases, then wait.
319                match inner.jobs.reset_expired_leases(config.lease_duration).await {
320                    Ok(0) => {}
321                    Ok(n) => {
322                        event!(
323                            name: "memoir.worker.lease_recovered",
324                            Level::INFO,
325                            count = n,
326                            "recovered {{count}} expired lease(s)",
327                        );
328                    }
329                    Err(err) => {
330                        event!(
331                            name: "memoir.worker.lease_recovery_failed",
332                            Level::WARN,
333                            error.message = %err,
334                            "lease recovery failed: {{error.message}}",
335                        );
336                    }
337                }
338
339                wait_or_cancel(&token, config.poll_interval).await;
340            }
341            Err(err) => {
342                event!(
343                    name: "memoir.worker.claim_failed",
344                    Level::WARN,
345                    error.message = %err,
346                    "claim failed: {{error.message}}; backing off",
347                );
348                wait_or_cancel(&token, config.poll_interval).await;
349            }
350        }
351    }
352
353    event!(
354        name: "memoir.worker.exited",
355        Level::INFO,
356        outcome = "exited",
357        "worker loop {{outcome}}",
358    );
359}
360
361/// Sleeps for `dur` or returns immediately when the token is cancelled.
362async fn wait_or_cancel(token: &CancellationToken, dur: Duration) {
363    tokio::select! {
364        _ = sleep(dur) => {}
365        _ = token.cancelled() => {}
366    }
367}
368
369/// Dispatches one claimed job. No-op in this ticket — completes immediately.
370///
371/// Tickets 0006 (extract) and 0007 (embed) replace this body with real
372/// stage handlers nested under per-job spans.
373async fn dispatch(inner: &Arc<ClientInner>, job: Job, max_attempts: i32) {
374    debug_assert_eq!(job.state, JobState::Claimed);
375
376    let job_span = info_span!(
377        "memoir.worker.job",
378        job_id = job.id,
379        kind = %job.kind,
380        source_pid = %job.source_pid,
381    );
382    let _enter = job_span.enter();
383
384    event!(
385        name: "memoir.worker.job_started",
386        Level::DEBUG,
387        outcome = "claimed",
388        "job {{outcome}}",
389    );
390
391    let result: Result<(), String> = match job.kind {
392        JobKind::Extract => inner.run_extract(job.clone()).await.map_err(|err| err.to_string()),
393        JobKind::Embed => inner
394            .run_embed_job(&job.source_pid)
395            .await
396            .map_err(|err| err.to_string()),
397        JobKind::Categorize => inner.run_categorize(job.clone()).await.map_err(|err| err.to_string()),
398        JobKind::Reprocess => inner.run_reprocess(job.clone()).await.map_err(|err| err.to_string()),
399        #[cfg(feature = "knowledge-graph")]
400        JobKind::RelationalExtract => inner
401            .run_relational_extract(job.clone())
402            .await
403            .map_err(|err| err.to_string()),
404        #[cfg(feature = "knowledge-graph")]
405        JobKind::Synthesize => inner.run_synthesize(job.clone()).await.map_err(|err| err.to_string()),
406        // A vector-only build never enqueues these kinds; a row written by a
407        // graph-enabled build is completed as a no-op rather than failing.
408        #[cfg(not(feature = "knowledge-graph"))]
409        JobKind::RelationalExtract | JobKind::Synthesize => Ok(()),
410    };
411
412    match result {
413        Ok(()) => match inner.jobs.complete(job.id).await {
414            Ok(()) => event!(
415                name: "memoir.worker.job_succeeded",
416                Level::INFO,
417                outcome = "succeeded",
418                "job {{outcome}}",
419            ),
420            Err(err) => event!(
421                name: "memoir.worker.complete_failed",
422                Level::WARN,
423                error.message = %err,
424                "complete failed after successful dispatch: {{error.message}}",
425            ),
426        },
427        Err(reason) => {
428            if let Err(fail_err) = inner.jobs.fail(job.id, reason.clone(), max_attempts).await {
429                event!(
430                    name: "memoir.worker.fail_failed",
431                    Level::WARN,
432                    error.message = %fail_err,
433                    "fail call itself failed: {{error.message}}",
434                );
435            } else {
436                event!(
437                    name: "memoir.worker.job_failed",
438                    Level::WARN,
439                    error.message = %reason,
440                    "job failed: {{error.message}}",
441                );
442            }
443        }
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450
451    // M-TYPES-SEND: public types must be `Send` so they compose with tokio.
452    const fn assert_send<T: Send>() {}
453    const _: () = assert_send::<WorkerHandle>();
454
455    #[test]
456    fn should_use_default_constants_for_builder() {
457        // Sanity-check the defaults are sensible — fast enough for tests,
458        // not so fast they pin a CPU.
459        assert_eq!(DEFAULT_POLL_INTERVAL, Duration::from_secs(1));
460        assert_eq!(DEFAULT_LEASE_DURATION, Duration::from_secs(60));
461        assert_eq!(DEFAULT_MAX_ATTEMPTS, 3);
462        assert_eq!(DEFAULT_DRAIN_TIMEOUT, Duration::from_secs(30));
463    }
464
465    #[tokio::test(flavor = "current_thread")]
466    async fn should_wait_or_cancel_complete_when_uncancelled() {
467        let token = CancellationToken::new();
468        let start = std::time::Instant::now();
469        wait_or_cancel(&token, Duration::from_millis(10)).await;
470        assert!(
471            start.elapsed() >= Duration::from_millis(10),
472            "expected ~10ms sleep without cancellation"
473        );
474        assert!(!token.is_cancelled());
475    }
476
477    #[tokio::test(flavor = "current_thread")]
478    async fn should_wait_or_cancel_return_immediately_when_cancelled() {
479        let token = CancellationToken::new();
480        token.cancel();
481
482        let start = std::time::Instant::now();
483        // Long timeout: would block forever if cancellation wasn't observed.
484        wait_or_cancel(&token, Duration::from_secs(60)).await;
485        assert!(
486            start.elapsed() < Duration::from_millis(100),
487            "cancellation should wake us nearly instantly"
488        );
489    }
490
491    #[tokio::test(flavor = "current_thread")]
492    async fn should_worker_handle_track_shutdown_state() {
493        let token = CancellationToken::new();
494        let join = tokio::spawn(async {});
495        let handle = WorkerHandle {
496            inner: Arc::new(WorkerHandleInner {
497                join: tokio::sync::Mutex::new(Some(join)),
498                token: token.clone(),
499                drain_timeout: Duration::from_secs(1),
500            }),
501        };
502
503        assert!(!handle.is_shutting_down());
504        token.cancel();
505        assert!(handle.is_shutting_down());
506    }
507
508    #[tokio::test(flavor = "current_thread")]
509    async fn should_child_token_inherit_cancellation_from_parent() {
510        let token = CancellationToken::new();
511        let join = tokio::spawn(async {});
512        let handle = WorkerHandle {
513            inner: Arc::new(WorkerHandleInner {
514                join: tokio::sync::Mutex::new(Some(join)),
515                token: token.clone(),
516                drain_timeout: Duration::from_secs(1),
517            }),
518        };
519
520        let child = handle.cancellation_token();
521        assert!(!child.is_cancelled());
522        token.cancel();
523        assert!(child.is_cancelled(), "child should observe parent cancellation");
524    }
525}