Skip to main content

apalis_diesel_postgres/
ack.rs

1use apalis_core::{
2    error::{AbortError, BoxDynError},
3    layers::{Layer, Service},
4    task::{Parts, status::Status},
5    worker::{
6        context::WorkerContext,
7        ext::ack::{Acknowledge, AcknowledgeLayer},
8    },
9};
10use futures::{
11    FutureExt,
12    future::{BoxFuture, Either},
13};
14use serde::Serialize;
15use ulid::Ulid;
16
17use std::sync::Arc;
18
19use crate::{Error, PgContext, PgPool, PgTask, queries};
20
21/// Acknowledges task completion by updating `apalis.jobs`.
22///
23/// When constructed via [`PgAck::with_lease_token`], the acknowledge SQL is
24/// additionally bound to the worker's `lease_token` so the per-process secret
25/// that already protects heartbeat refreshes (migration `20260521000002`) also
26/// guards ack writes. Callers that hold only `(task_id, queue, worker_id,
27/// lock_at, attempts)` — values that appear in dashboards and admin payloads —
28/// cannot forge an ack without also possessing the token.
29#[derive(Debug, Clone)]
30pub struct PgAck {
31    pool: PgPool,
32    lease_token: Option<Arc<str>>,
33}
34
35#[cfg(test)]
36mod tests {
37    use std::{
38        future::{Ready, ready},
39        task::{Context, Poll},
40    };
41
42    use apalis_core::{
43        error::BoxDynError,
44        layers::Service,
45        task::{Parts, attempt::Attempt, builder::TaskBuilder, status::Status, task_id::TaskId},
46        worker::ext::ack::Acknowledge,
47    };
48    use diesel::{
49        PgConnection,
50        r2d2::{ConnectionManager, Pool},
51    };
52    use futures::{executor::block_on, task::noop_waker_ref};
53    use lets_expect::{AssertionError, AssertionResult, *};
54
55    use super::*;
56
57    #[derive(Debug, Clone)]
58    enum ReadyState {
59        Ready,
60        Error,
61        Pending,
62    }
63
64    #[derive(Debug, Clone)]
65    struct ReadyService {
66        state: ReadyState,
67    }
68
69    impl Service<PgTask<()>> for ReadyService {
70        type Response = ();
71        type Error = std::io::Error;
72        type Future = Ready<Result<(), Self::Error>>;
73
74        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75            match self.state {
76                ReadyState::Ready => Poll::Ready(Ok(())),
77                ReadyState::Error => Poll::Ready(Err(std::io::Error::other("inner failed"))),
78                ReadyState::Pending => Poll::Pending,
79            }
80        }
81
82        fn call(&mut self, _req: PgTask<()>) -> Self::Future {
83            ready(Ok(()))
84        }
85    }
86
87    fn unchecked_pool() -> PgPool {
88        let manager = ConnectionManager::<PgConnection>::new("postgres://127.0.0.1:1/not-used");
89        Pool::builder()
90            .max_size(1)
91            .connection_timeout(std::time::Duration::from_millis(10))
92            .build_unchecked(manager)
93    }
94
95    fn task_id() -> TaskId<Ulid> {
96        TaskId::new(Ulid::new())
97    }
98
99    fn parts_for_ack(attempts: usize, max_attempts: i32) -> Parts<PgContext, Ulid> {
100        TaskBuilder::new(())
101            .with_task_id(task_id())
102            .with_attempt(Attempt::new_with_value(attempts))
103            .with_ctx(PgContext::new().with_max_attempts(max_attempts))
104            .build()
105            .parts
106    }
107
108    fn box_error(message: &'static str) -> BoxDynError {
109        std::io::Error::other(message).into()
110    }
111
112    fn ack_missing_field(
113        has_task_id: bool,
114        has_lock_by: bool,
115        has_queue: bool,
116        has_lock_at: bool,
117    ) -> Result<(), crate::Error> {
118        block_on(async move {
119            let mut parts = parts_for_ack(1, 3);
120            if !has_task_id {
121                parts.task_id = None;
122            }
123            let mut ctx = parts.ctx.clone();
124            if has_lock_by {
125                ctx = ctx.with_lock_by(Some("ack-worker".to_owned()));
126            }
127            if has_queue {
128                ctx = ctx.with_queue("ack-queue".to_owned());
129            }
130            if has_lock_at {
131                ctx = ctx.with_lock_at(Some(1_700_000_000));
132            }
133            parts.ctx = ctx;
134
135            let mut ack = PgAck::new(unchecked_pool());
136            let result: Result<(), BoxDynError> = Ok(());
137            ack.ack(&result, &parts).await
138        })
139    }
140
141    fn truncated_payload_length(input_len: usize) -> usize {
142        truncate_error_payload("x".repeat(input_len)).len()
143    }
144
145    fn truncated_payload_marker_present(input_len: usize) -> bool {
146        truncate_error_payload("x".repeat(input_len)).ends_with("…[truncated]")
147    }
148
149    fn poll_lock_ready(state: ReadyState) -> Poll<Result<(), BoxDynError>> {
150        let mut service = LockTaskService {
151            inner: ReadyService { state },
152            pool: unchecked_pool(),
153        };
154        let mut cx = Context::from_waker(noop_waker_ref());
155        service.poll_ready(&mut cx)
156    }
157
158    fn layered_service_debug() -> String {
159        let layer = LockTaskLayer::new(unchecked_pool());
160        let service = layer.layer(ReadyService {
161            state: ReadyState::Ready,
162        });
163        format!("{service:?}")
164    }
165
166    fn middleware_auto_ack_enabled(auto_ack: bool) -> bool {
167        PgMiddleware::new(unchecked_pool(), auto_ack).auto_ack()
168    }
169
170    async fn lock_service_call_async(
171        has_worker: bool,
172        has_task_id: bool,
173    ) -> Result<(), BoxDynError> {
174        let mut task = TaskBuilder::new(())
175            .with_ctx(PgContext::new().with_queue("lock-service-unit".to_owned()))
176            .build();
177        if has_worker {
178            task.parts
179                .data
180                .insert(WorkerContext::new::<()>("lock-service-worker"));
181        }
182        if has_task_id {
183            task.parts.task_id = Some(task_id());
184        }
185
186        let mut service = LockTaskService {
187            inner: ReadyService {
188                state: ReadyState::Ready,
189            },
190            pool: unchecked_pool(),
191        };
192        service.call(task).await
193    }
194
195    fn lock_service_call_missing_field(
196        has_worker: bool,
197        has_task_id: bool,
198    ) -> Result<(), BoxDynError> {
199        block_on(lock_service_call_async(has_worker, has_task_id))
200    }
201
202    fn missing_field(field: &'static str) -> impl Fn(&crate::Error) -> AssertionResult {
203        move |error| match error {
204            crate::Error::MissingField(found) if *found == field => Ok(()),
205            other => Err(AssertionError::new(vec![format!(
206                "expected missing field {field}, got {other:?}"
207            )])),
208        }
209    }
210
211    fn poll_ready_ok(result: &Poll<Result<(), BoxDynError>>) -> AssertionResult {
212        match result {
213            Poll::Ready(Ok(())) => Ok(()),
214            other => Err(AssertionError::new(vec![format!(
215                "expected ready ok, got {other:?}"
216            )])),
217        }
218    }
219
220    fn poll_ready_err(result: &Poll<Result<(), BoxDynError>>) -> AssertionResult {
221        match result {
222            Poll::Ready(Err(_)) => Ok(()),
223            other => Err(AssertionError::new(vec![format!(
224                "expected ready error, got {other:?}"
225            )])),
226        }
227    }
228
229    fn poll_pending(result: &Poll<Result<(), BoxDynError>>) -> AssertionResult {
230        match result {
231            Poll::Pending => Ok(()),
232            other => Err(AssertionError::new(vec![format!(
233                "expected pending, got {other:?}"
234            )])),
235        }
236    }
237
238    fn debug_mentions_lock_service(result: &String) -> AssertionResult {
239        if result.contains("LockTaskService") && result.contains("pool") {
240            Ok(())
241        } else {
242            Err(AssertionError::new(vec![format!(
243                "expected lock service debug output, got {result}"
244            )]))
245        }
246    }
247
248    fn abort_contains(expected: &'static str) -> impl Fn(&BoxDynError) -> AssertionResult {
249        move |error| {
250            let message = error.to_string();
251            if message.contains(expected) {
252                Ok(())
253            } else {
254                Err(AssertionError::new(vec![format!(
255                    "expected abort containing {expected:?}, got {message:?}"
256                )]))
257            }
258        }
259    }
260
261    lets_expect! {
262        expect(calculate_status(&parts, &result)) {
263            let parts = parts_for_ack(attempts, max_attempts);
264            let result: Result<(), BoxDynError> = Ok(());
265            let attempts = 1;
266            let max_attempts = 3;
267
268            when task_succeeds {
269                to marks_the_task_done { equal(Status::Done) }
270            }
271
272            when task_fails_below_the_attempt_limit {
273                let result: Result<(), BoxDynError> = Err(box_error("retry"));
274                to marks_the_task_failed { equal(Status::Failed) }
275            }
276
277            when task_fails_at_the_attempt_limit {
278                let attempts = 3;
279                let result: Result<(), BoxDynError> = Err(box_error("exact limit"));
280                to kills_the_task { equal(Status::Killed) }
281            }
282
283            when task_fails_above_the_attempt_limit {
284                let attempts = 4;
285                let result: Result<(), BoxDynError> = Err(box_error("above limit"));
286                to kills_the_task { equal(Status::Killed) }
287            }
288
289            when task_fails_with_a_negative_max_attempts_from_a_corrupt_row {
290                // Documents the doc-comment contract on `calculate_status`:
291                // negative `max_attempts` (which the schema rejects but a
292                // hand-crafted row could carry) is treated as terminal so a
293                // corrupt row cannot drive an infinite retry loop.
294                // `usize::try_from(-1)` returns Err, falling into the `_ =>
295                // Killed` arm.
296                let max_attempts = -1;
297                let attempts = 0;
298                let result: Result<(), BoxDynError> = Err(box_error("corrupt row"));
299                to kills_the_task_to_avoid_an_infinite_retry { equal(Status::Killed) }
300            }
301
302            when task_fails_with_zero_max_attempts_on_the_first_attempt {
303                // Boundary case: max=0 means no retries are allowed. The
304                // first failure (attempts >= max=0) must terminate.
305                let max_attempts = 0;
306                let attempts = 0;
307                let result: Result<(), BoxDynError> = Err(box_error("no retries"));
308                to kills_the_task { equal(Status::Killed) }
309            }
310        }
311
312        expect(poll_lock_ready(state)) {
313            let state = ReadyState::Ready;
314
315            when inner_service_is_ready {
316                to returns_ready { poll_ready_ok }
317            }
318
319            when inner_service_returns_an_error {
320                let state = ReadyState::Error;
321                to returns_the_error { poll_ready_err }
322            }
323
324            when inner_service_is_pending {
325                let state = ReadyState::Pending;
326                to stays_pending { poll_pending }
327            }
328        }
329
330        expect(layered_service_debug()) {
331            to wraps_the_inner_service_with_the_pool { debug_mentions_lock_service }
332        }
333
334        expect(middleware_auto_ack_enabled(auto_ack)) {
335            let auto_ack = true;
336
337            when config_enables_auto_ack {
338                to installs_the_acknowledgement_layer { equal(true) }
339            }
340
341            when config_disables_auto_ack {
342                let auto_ack = false;
343                to leaves_acknowledgement_to_the_caller { equal(false) }
344            }
345        }
346    }
347
348    lets_expect! {
349        expect(ack_missing_field(has_task_id, has_lock_by, has_queue, has_lock_at)) {
350            let has_task_id = true;
351            let has_lock_by = true;
352            let has_queue = true;
353            let has_lock_at = true;
354
355            when task_id_is_missing {
356                let has_task_id = false;
357                to rejects_before_querying_the_database { be_err_and missing_field("task_id") }
358            }
359
360            when lock_owner_is_missing {
361                let has_lock_by = false;
362                to rejects_before_querying_the_database { be_err_and missing_field("lock_by") }
363            }
364
365            when queue_is_missing {
366                let has_queue = false;
367                to rejects_before_querying_the_database { be_err_and missing_field("queue") }
368            }
369
370            when lock_timestamp_is_missing {
371                let has_lock_at = false;
372                to rejects_before_querying_the_database { be_err_and missing_field("lock_at") }
373            }
374        }
375
376        expect(lock_service_call_missing_field(has_worker, has_task_id)) {
377            let has_worker = true;
378            let has_task_id = true;
379
380            when worker_context_is_missing {
381                let has_worker = false;
382                to aborts_before_locking_the_task { be_err_and abort_contains("worker_context") }
383            }
384
385            when task_id_is_missing {
386                let has_task_id = false;
387                to aborts_before_locking_the_task { be_err_and abort_contains("task_id") }
388            }
389        }
390
391        expect(truncated_payload_length(input_len)) {
392            let input_len = 100;
393
394            when payload_is_shorter_than_the_eight_kib_cap {
395                to leaves_the_payload_length_unchanged { equal(100) }
396            }
397
398            when payload_is_exactly_eight_kib {
399                let input_len = 8 * 1024;
400                to leaves_the_payload_length_unchanged { equal(8 * 1024) }
401            }
402
403            when payload_is_one_byte_above_eight_kib {
404                let input_len = 8 * 1024 + 1;
405                to truncates_to_eight_kib_plus_the_marker_byte_length {
406                    equal(8 * 1024 + "…[truncated]".len())
407                }
408            }
409
410            when payload_is_far_above_eight_kib {
411                let input_len = 64 * 1024;
412                to truncates_to_eight_kib_plus_the_marker_byte_length {
413                    equal(8 * 1024 + "…[truncated]".len())
414                }
415            }
416        }
417
418        expect(truncated_payload_marker_present(input_len)) {
419            let input_len = 100;
420
421            when payload_is_within_budget {
422                to does_not_append_a_truncation_marker { equal(false) }
423            }
424
425            when payload_overflows_the_budget {
426                let input_len = 8 * 1024 + 1;
427                to appends_the_truncation_marker { equal(true) }
428            }
429        }
430    }
431
432    #[cfg(feature = "tokio")]
433    mod tokio_tests {
434        use super::*;
435        use serde::{Serialize, Serializer, ser};
436
437        /// Drive `PgAck::ack` with an oversized attempt counter. The bounds
438        /// check on `i32::try_from(attempts_raw)` (src/ack.rs:532) returns
439        /// `Error::InvalidArgument`; without this branch a saturated cast
440        /// would silently mismatch the row's `attempts` column and surface
441        /// as a spurious `StaleAcknowledgement`.
442        async fn ack_with_attempt_overflow() -> Result<(), crate::Error> {
443            let mut parts = parts_for_ack(1, 3);
444            // Force an overflow regardless of host pointer width.
445            parts.attempt = Attempt::new_with_value(i32::MAX as usize + 1);
446            parts.ctx = parts
447                .ctx
448                .clone()
449                .with_queue("ack-queue".to_owned())
450                .with_lock_by(Some("ack-worker".to_owned()))
451                .with_lock_at(Some(1_700_000_000));
452            let mut ack = PgAck::new(unchecked_pool());
453            let result: Result<(), BoxDynError> = Ok(());
454            ack.ack(&result, &parts).await
455        }
456
457        fn invalid_attempt_overflow(error: &crate::Error) -> AssertionResult {
458            match error {
459                crate::Error::InvalidArgument(msg) if msg.contains("attempt counter") => Ok(()),
460                other => Err(AssertionError::new(vec![format!(
461                    "expected InvalidArgument citing attempt counter overflow, got {other:?}"
462                )])),
463            }
464        }
465
466        /// Custom type that fails to serialize — drives the
467        /// `serde_json::to_value(result)?` arm in `PgAck::ack`
468        /// (src/ack.rs:512,549). Reachable for any job that returns a custom
469        /// `Ok` payload with a fallible `Serialize` impl.
470        #[derive(Debug)]
471        struct PoisonOk;
472
473        impl Serialize for PoisonOk {
474            fn serialize<S: Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
475                Err(ser::Error::custom("intentional serialize failure"))
476            }
477        }
478
479        async fn ack_with_unserializable_result() -> Result<(), crate::Error> {
480            let mut parts: Parts<PgContext, Ulid> = TaskBuilder::new(())
481                .with_task_id(task_id())
482                .with_attempt(Attempt::new_with_value(1))
483                .with_ctx(
484                    PgContext::new()
485                        .with_max_attempts(3)
486                        .with_queue("ack-queue".to_owned())
487                        .with_lock_by(Some("ack-worker".to_owned()))
488                        .with_lock_at(Some(1_700_000_000)),
489                )
490                .build()
491                .parts;
492            // Need a Parts whose payload-channel type matches PoisonOk.
493            let _ = &mut parts;
494            let mut ack = PgAck::new(unchecked_pool());
495            let result: Result<PoisonOk, BoxDynError> = Ok(PoisonOk);
496            ack.ack(&result, &parts).await
497        }
498
499        fn json_serialize_error(error: &crate::Error) -> AssertionResult {
500            match error {
501                crate::Error::Json(_) => Ok(()),
502                other => Err(AssertionError::new(vec![format!(
503                    "expected Error::Json from a failing Serialize impl, got {other:?}"
504                )])),
505            }
506        }
507
508        /// Drive `LockTaskService::call` with a task whose `lock_by` already
509        /// matches the worker context and `lock_at` is populated. The
510        /// `preclaimed` branch at src/ack.rs:793-810 must bypass the SQL
511        /// `lock_task` call entirely, so this exercise succeeds even with a
512        /// pool that cannot connect.
513        async fn lock_service_call_preclaimed() -> Result<(), BoxDynError> {
514            let mut task = TaskBuilder::new(())
515                .with_task_id(task_id())
516                .with_ctx(
517                    PgContext::new()
518                        .with_queue("lock-service-unit".to_owned())
519                        .with_lock_by(Some("lock-service-worker".to_owned()))
520                        .with_lock_at(Some(1_700_000_000)),
521                )
522                .build();
523            task.parts
524                .data
525                .insert(WorkerContext::new::<()>("lock-service-worker"));
526
527            let mut service = LockTaskService {
528                inner: ReadyService {
529                    state: ReadyState::Ready,
530                },
531                pool: unchecked_pool(),
532            };
533            service.call(task).await
534        }
535
536        fn lock_service_call_preclaimed_succeeds(
537            result: &Result<(), BoxDynError>,
538        ) -> AssertionResult {
539            match result {
540                Ok(()) => Ok(()),
541                Err(error) => Err(AssertionError::new(vec![format!(
542                    "expected the preclaimed branch to bypass lock_task and succeed, got {error}"
543                )])),
544            }
545        }
546
547        lets_expect! { #tokio_test
548            expect(lock_service_call_async(true, true).await) {
549                when task_has_worker_and_id_but_the_database_is_unavailable {
550                    to aborts_with_the_lock_error { be_err_and abort_contains("failed to acquire PostgreSQL connection") }
551                }
552            }
553
554            expect(ack_with_attempt_overflow().await) {
555                when the_attempt_counter_exceeds_i32_max {
556                    to surfaces_invalid_argument_before_touching_the_database {
557                        be_err_and invalid_attempt_overflow
558                    }
559                }
560            }
561
562            expect(ack_with_unserializable_result().await) {
563                when the_jobs_ok_payload_fails_to_serialize {
564                    to surfaces_an_error_json_before_touching_the_database {
565                        be_err_and json_serialize_error
566                    }
567                }
568            }
569
570            expect(lock_service_call_preclaimed().await) {
571                when the_task_already_carries_a_matching_lock_by_and_lock_at {
572                    to bypasses_the_sql_lock_task_round_trip_and_completes {
573                        lock_service_call_preclaimed_succeeds
574                    }
575                }
576            }
577        }
578    }
579}
580
581impl PgAck {
582    /// Create a PostgreSQL acknowledger without lease-token binding.
583    ///
584    /// Ack writes are gated only by `(lock_by, lock_at, attempts)`; prefer
585    /// [`PgAck::with_lease_token`] for the defense-in-depth variant that also
586    /// checks the per-process token. This constructor exists for test harnesses
587    /// and admin tooling that do not own a lease token.
588    #[must_use]
589    pub fn new(pool: PgPool) -> Self {
590        Self {
591            pool,
592            lease_token: None,
593        }
594    }
595
596    /// Create a PostgreSQL acknowledger bound to a specific worker lease token.
597    ///
598    /// The token is added to the ack SQL as an `EXISTS` check against
599    /// `apalis.workers.lease_token`, mirroring the heartbeat path. A storage
600    /// handle's `middleware()` wires this automatically; manual callers should
601    /// reuse the token they passed to `initial_heartbeat`/`keep_alive`.
602    #[must_use]
603    pub fn with_lease_token(pool: PgPool, lease_token: Arc<str>) -> Self {
604        Self {
605            pool,
606            lease_token: Some(lease_token),
607        }
608    }
609}
610
611// Cap persisted error strings so a misbehaving job that returns a
612// multi-megabyte `Display` cannot balloon `apalis.jobs.last_result` (a
613// JSONB column) and exhaust storage. 8 KiB preserves diagnostic value
614// without unbounded growth; truncated strings get a clear marker.
615const MAX_ERROR_PAYLOAD_LEN: usize = 8 * 1024;
616const TRUNCATION_MARKER: &str = "…[truncated]";
617
618pub(crate) fn truncate_error_payload(mut text: String) -> String {
619    if text.len() > MAX_ERROR_PAYLOAD_LEN {
620        // `String::truncate` panics if the cut index is not at a UTF-8 char
621        // boundary; walk back to the nearest boundary so multi-byte sequences
622        // are never split mid-codepoint. `str::floor_char_boundary` would
623        // replace this loop but is only stable since 1.91 (crate MSRV 1.88).
624        let mut cut = MAX_ERROR_PAYLOAD_LEN;
625        while cut > 0 && !text.is_char_boundary(cut) {
626            cut -= 1;
627        }
628        text.truncate(cut);
629        text.push_str(TRUNCATION_MARKER);
630    }
631    text
632}
633
634impl<Res: Serialize> Acknowledge<Res, PgContext, Ulid> for PgAck {
635    type Error = Error;
636    type Future = BoxFuture<'static, Result<(), Self::Error>>;
637
638    fn ack(
639        &mut self,
640        res: &Result<Res, BoxDynError>,
641        parts: &Parts<PgContext, Ulid>,
642    ) -> Self::Future {
643        let task_id = parts.task_id;
644        let worker_id = parts.ctx.lock_by().clone();
645        let queue = parts.ctx.queue().clone();
646        let lock_at = *parts.ctx.lock_at();
647        let response = serde_json::to_value(
648            res.as_ref()
649                .map_err(|error| truncate_error_payload(error.to_string())),
650        );
651        let status = calculate_status(parts, res);
652        // `last_result` is always persisted as the externally-tagged
653        // `Result<O, String>` JSON (`{"Ok": ...}` or `{"Err": "..."}`).
654        // `WaitForCompletion::wait_for` reads it back with `serde_json::from_value`
655        // and the spec (queries/mod.rs tests::last_result_is_missing) requires
656        // a SQL NULL to surface as `MissingField("last_result")`. So we wrap
657        // every serialized value in `Some(...)` rather than collapsing the
658        // trivial-Ok case to SQL NULL, which would make completed Ok(()) jobs
659        // appear unread to consumers of `WaitForCompletion`.
660        let response = response.map(Some);
661        // Silent saturation would corrupt the ack's lock-check predicate:
662        // `ack_task` matches on `attempts = $started_attempts`, so a capped
663        // value would silently mismatch the stored row and the ack would be
664        // reported as `StaleAcknowledgement` for a non-stale task. Surface
665        // overflow as `InvalidArgument` instead.
666        let attempts_raw = parts.attempt.current();
667        let attempts = i32::try_from(attempts_raw);
668        let pool = self.pool.clone();
669        let lease_token = self.lease_token.clone();
670
671        async move {
672            let attempts = attempts.map_err(|_| {
673                Error::InvalidArgument(format!(
674                    "task attempt counter {attempts_raw} exceeds i32::MAX and cannot be stored"
675                ))
676            })?;
677            let started_attempts = attempts.saturating_sub(1);
678            queries::ack_task(
679                pool,
680                queries::AckTaskUpdate {
681                    task_id: task_id.ok_or(Error::MissingField("task_id"))?,
682                    attempts,
683                    started_attempts,
684                    result: response?,
685                    status,
686                    worker_id: worker_id.ok_or(Error::MissingField("lock_by"))?,
687                    queue: queue.ok_or(Error::MissingField("queue"))?,
688                    lock_at: lock_at.ok_or(Error::MissingField("lock_at"))?,
689                    lease_token: lease_token.as_deref().map(str::to_owned),
690                },
691            )
692            .await
693        }
694        .boxed()
695    }
696}
697
698/// Calculate the persisted task status from a task execution result.
699///
700/// Negative `max_attempts` values (which the database schema rejects) are
701/// treated as terminal so a corrupt row cannot drive an infinite retry loop.
702#[must_use]
703pub(crate) fn calculate_status<Res>(
704    parts: &Parts<PgContext, Ulid>,
705    res: &Result<Res, BoxDynError>,
706) -> Status {
707    match res {
708        Ok(_) => Status::Done,
709        Err(_) => match usize::try_from(parts.ctx.max_attempts()) {
710            Ok(max) if max > parts.attempt.current() => Status::Failed,
711            _ => Status::Killed,
712        },
713    }
714}
715
716/// Lock a due task for a worker.
717///
718/// The worker must already be registered for the task queue. The task must be
719/// due and in a lockable state: `Pending`, retryable `Failed`, or `Queued` by
720/// the same worker.
721///
722/// # Cross-queue semantics
723///
724/// This entry point does **not** filter by `job_type`: a caller holding a
725/// task's `Ulid` can lock it regardless of which queue it belongs to. Prefer
726/// [`lock_task_in_queue`] which scopes the lock to a specific queue and
727/// prevents a caller that learned a `Ulid` from logs or dashboards from
728/// claiming it under an unrelated queue.
729pub async fn lock_task(pool: &PgPool, task_id: &Ulid, worker_id: &str) -> Result<(), Error> {
730    queries::lock_task(pool.clone(), *task_id, worker_id.to_owned(), None).await
731}
732
733/// Lock a due task scoped to a specific queue.
734///
735/// Like [`lock_task`] but restricts the lock to `queue` so admin tooling that
736/// knows the task's `Ulid` cannot accidentally (or maliciously) lock a task
737/// belonging to another queue. Use this in any code path that does not derive
738/// the queue from a trusted `WorkerContext`.
739pub async fn lock_task_in_queue(
740    pool: &PgPool,
741    task_id: &Ulid,
742    worker_id: &str,
743    queue: &str,
744) -> Result<(), Error> {
745    queries::lock_task(
746        pool.clone(),
747        *task_id,
748        worker_id.to_owned(),
749        Some(queue.to_owned()),
750    )
751    .await
752}
753
754/// Middleware layer that transitions queued jobs to `Running` before execution.
755///
756/// Crate-private: external callers use [`PgMiddleware`], which composes this
757/// layer with the optional auto-ack layer. Exposed only to the crate so the
758/// `Layer<S>` impl on `PgMiddleware` can reference its `Service` type without
759/// leaking via a public trait bound.
760#[derive(Debug, Clone)]
761pub(crate) struct LockTaskLayer {
762    pool: PgPool,
763}
764
765impl LockTaskLayer {
766    /// Create a lock middleware layer.
767    #[must_use]
768    pub(crate) fn new(pool: PgPool) -> Self {
769        Self { pool }
770    }
771}
772
773impl<S> Layer<S> for LockTaskLayer {
774    type Service = LockTaskService<S>;
775
776    fn layer(&self, inner: S) -> Self::Service {
777        LockTaskService {
778            inner,
779            pool: self.pool.clone(),
780        }
781    }
782}
783
784/// Middleware layer used by the PostgreSQL backend.
785///
786/// The lock step always runs. The acknowledge step is installed only when the
787/// queue config has automatic acknowledgement enabled.
788#[derive(Debug, Clone)]
789pub struct PgMiddleware {
790    lock: LockTaskLayer,
791    ack: Option<AcknowledgeLayer<PgAck>>,
792}
793
794impl PgMiddleware {
795    /// Create the PostgreSQL backend middleware.
796    #[must_use]
797    pub fn new(pool: PgPool, auto_ack: bool) -> Self {
798        Self {
799            lock: LockTaskLayer::new(pool.clone()),
800            ack: auto_ack.then(|| AcknowledgeLayer::new(PgAck::new(pool))),
801        }
802    }
803
804    /// Create the PostgreSQL backend middleware with lease-token binding for
805    /// the auto-ack path. Used by [`crate::PostgresStorage`] so completed jobs
806    /// can only be acknowledged by a worker possessing the per-storage token.
807    #[must_use]
808    pub fn with_lease_token(pool: PgPool, auto_ack: bool, lease_token: Arc<str>) -> Self {
809        Self {
810            lock: LockTaskLayer::new(pool.clone()),
811            ack: auto_ack
812                .then(|| AcknowledgeLayer::new(PgAck::with_lease_token(pool, lease_token))),
813        }
814    }
815
816    /// Return whether this middleware will acknowledge tasks after execution.
817    #[must_use]
818    pub fn auto_ack(&self) -> bool {
819        self.ack.is_some()
820    }
821}
822
823impl<S> Layer<S> for PgMiddleware
824where
825    AcknowledgeLayer<PgAck>: Layer<LockTaskService<S>>,
826{
827    type Service = PgMiddlewareService<
828        <AcknowledgeLayer<PgAck> as Layer<LockTaskService<S>>>::Service,
829        LockTaskService<S>,
830    >;
831
832    fn layer(&self, inner: S) -> Self::Service {
833        // Construct `LockTaskService` directly rather than going through
834        // `LockTaskLayer::layer` so the where-bound on this public `Layer<S>`
835        // impl does not reference the crate-private `LockTaskLayer` type
836        // (which would trigger E0446). The `lock` field is kept on
837        // `PgMiddleware` to centralise pool ownership and avoid duplicating
838        // the constructor's pool-clone logic.
839        let locked = LockTaskService {
840            inner,
841            pool: self.lock.pool.clone(),
842        };
843        match &self.ack {
844            Some(ack) => PgMiddlewareService::AutoAck(ack.layer(locked)),
845            None => PgMiddlewareService::ManualAck(locked),
846        }
847    }
848}
849
850/// Service produced by [`PgMiddleware`].
851#[derive(Debug, Clone)]
852pub enum PgMiddlewareService<AutoAck, ManualAck> {
853    /// Lock tasks and acknowledge them automatically.
854    AutoAck(AutoAck),
855    /// Lock tasks only, leaving acknowledgement to the caller.
856    ManualAck(ManualAck),
857}
858
859impl<Req, AutoAck, ManualAck> Service<Req> for PgMiddlewareService<AutoAck, ManualAck>
860where
861    AutoAck: Service<Req>,
862    ManualAck: Service<Req, Response = AutoAck::Response, Error = AutoAck::Error>,
863{
864    type Response = AutoAck::Response;
865    type Error = AutoAck::Error;
866    type Future = Either<AutoAck::Future, ManualAck::Future>;
867
868    fn poll_ready(
869        &mut self,
870        cx: &mut std::task::Context<'_>,
871    ) -> std::task::Poll<Result<(), Self::Error>> {
872        match self {
873            Self::AutoAck(service) => service.poll_ready(cx),
874            Self::ManualAck(service) => service.poll_ready(cx),
875        }
876    }
877
878    fn call(&mut self, req: Req) -> Self::Future {
879        match self {
880            Self::AutoAck(service) => Either::Left(service.call(req)),
881            Self::ManualAck(service) => Either::Right(service.call(req)),
882        }
883    }
884}
885
886/// Service produced by [`LockTaskLayer`].
887#[derive(Debug, Clone)]
888pub struct LockTaskService<S> {
889    inner: S,
890    pool: PgPool,
891}
892
893impl<S, Args> Service<PgTask<Args>> for LockTaskService<S>
894where
895    S: Service<PgTask<Args>> + Clone + Send + 'static,
896    S::Future: Send + 'static,
897    S::Error: Into<BoxDynError>,
898    Args: Send + 'static,
899{
900    type Response = S::Response;
901    type Error = BoxDynError;
902    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
903
904    fn poll_ready(
905        &mut self,
906        cx: &mut std::task::Context<'_>,
907    ) -> std::task::Poll<Result<(), Self::Error>> {
908        self.inner.poll_ready(cx).map_err(Into::into)
909    }
910
911    fn call(&mut self, req: PgTask<Args>) -> Self::Future {
912        let pool = self.pool.clone();
913        let worker_id = req
914            .parts
915            .data
916            .get::<WorkerContext>()
917            .map(|worker| worker.name().to_owned());
918        let queue = req.parts.ctx.queue().clone();
919        let task_id = req.parts.task_id.map(|id| *id.inner());
920        // Skip the lock_task round-trip for tasks that the fetcher already
921        // transitioned to `Running` and locked to this worker (`fetch_next`
922        // and `queue_by_id` set both `lock_by` and `lock_at` in the dequeue
923        // UPDATE). In that case the SQL `lock_task` would only rewrite the
924        // same values, paying a full per-job round-trip + HOT-tuple write
925        // for nothing. External `lock_task` callers (and any future fetcher
926        // that does not pre-lock) still go through the SQL path because they
927        // arrive without `lock_by`/`lock_at` populated in the context.
928        let preclaimed = matches!(
929            (req.parts.ctx.lock_by().as_deref(), worker_id.as_deref()),
930            (Some(stored), Some(current)) if stored == current
931        ) && req.parts.ctx.lock_at().is_some();
932        // Tower service contract: `poll_ready` reserves capacity on
933        // `self.inner`; that exact instance MUST be the one that consumes the
934        // reservation via `call`. Take ownership of the ready instance and
935        // leave a clone behind so subsequent `poll_ready`/`call` cycles work.
936        // The clone is treated as a fresh, not-yet-ready instance — the caller
937        // will `poll_ready` it again before sending the next request.
938        let clone = self.inner.clone();
939        let mut ready_inner = std::mem::replace(&mut self.inner, clone);
940
941        async move {
942            let worker_id =
943                worker_id.ok_or_else(|| AbortError::new(Error::MissingField("worker_context")))?;
944            let task_id = task_id.ok_or_else(|| AbortError::new(Error::MissingField("task_id")))?;
945            if !preclaimed {
946                queries::lock_task(pool, task_id, worker_id, queue)
947                    .await
948                    .map_err(AbortError::new)?;
949            }
950            ready_inner.call(req).await.map_err(Into::into)
951        }
952        .boxed()
953    }
954}