1use apalis_core::{
2 error::{AbortError, BoxDynError},
3 layers::{Layer, Service},
4 task::{Parts, status::Status},
5 worker::{
6 context::WorkerContext,
7 ext::ack::{Acknowledge, AcknowledgeLayer},
8 },
9};
10use futures::{
11 FutureExt,
12 future::{BoxFuture, Either},
13};
14use serde::Serialize;
15use ulid::Ulid;
16
17use std::sync::Arc;
18
19use crate::{Error, PgContext, PgPool, PgTask, queries};
20
21#[derive(Debug, Clone)]
30pub struct PgAck {
31 pool: PgPool,
32 lease_token: Option<Arc<str>>,
33}
34
35#[cfg(test)]
36mod tests {
37 use std::{
38 future::{Ready, ready},
39 task::{Context, Poll},
40 };
41
42 use apalis_core::{
43 error::BoxDynError,
44 layers::Service,
45 task::{Parts, attempt::Attempt, builder::TaskBuilder, status::Status, task_id::TaskId},
46 worker::ext::ack::Acknowledge,
47 };
48 use diesel::{
49 PgConnection,
50 r2d2::{ConnectionManager, Pool},
51 };
52 use futures::{executor::block_on, task::noop_waker_ref};
53 use lets_expect::{AssertionError, AssertionResult, *};
54
55 use super::*;
56
57 #[derive(Debug, Clone)]
58 enum ReadyState {
59 Ready,
60 Error,
61 Pending,
62 }
63
64 #[derive(Debug, Clone)]
65 struct ReadyService {
66 state: ReadyState,
67 }
68
69 impl Service<PgTask<()>> for ReadyService {
70 type Response = ();
71 type Error = std::io::Error;
72 type Future = Ready<Result<(), Self::Error>>;
73
74 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75 match self.state {
76 ReadyState::Ready => Poll::Ready(Ok(())),
77 ReadyState::Error => Poll::Ready(Err(std::io::Error::other("inner failed"))),
78 ReadyState::Pending => Poll::Pending,
79 }
80 }
81
82 fn call(&mut self, _req: PgTask<()>) -> Self::Future {
83 ready(Ok(()))
84 }
85 }
86
87 fn unchecked_pool() -> PgPool {
88 let manager = ConnectionManager::<PgConnection>::new("postgres://127.0.0.1:1/not-used");
89 Pool::builder()
90 .max_size(1)
91 .connection_timeout(std::time::Duration::from_millis(10))
92 .build_unchecked(manager)
93 }
94
95 fn task_id() -> TaskId<Ulid> {
96 TaskId::new(Ulid::new())
97 }
98
99 fn parts_for_ack(attempts: usize, max_attempts: i32) -> Parts<PgContext, Ulid> {
100 TaskBuilder::new(())
101 .with_task_id(task_id())
102 .with_attempt(Attempt::new_with_value(attempts))
103 .with_ctx(PgContext::new().with_max_attempts(max_attempts))
104 .build()
105 .parts
106 }
107
108 fn box_error(message: &'static str) -> BoxDynError {
109 std::io::Error::other(message).into()
110 }
111
112 fn ack_missing_field(
113 has_task_id: bool,
114 has_lock_by: bool,
115 has_queue: bool,
116 has_lock_at: bool,
117 ) -> Result<(), crate::Error> {
118 block_on(async move {
119 let mut parts = parts_for_ack(1, 3);
120 if !has_task_id {
121 parts.task_id = None;
122 }
123 let mut ctx = parts.ctx.clone();
124 if has_lock_by {
125 ctx = ctx.with_lock_by(Some("ack-worker".to_owned()));
126 }
127 if has_queue {
128 ctx = ctx.with_queue("ack-queue".to_owned());
129 }
130 if has_lock_at {
131 ctx = ctx.with_lock_at(Some(1_700_000_000));
132 }
133 parts.ctx = ctx;
134
135 let mut ack = PgAck::new(unchecked_pool());
136 let result: Result<(), BoxDynError> = Ok(());
137 ack.ack(&result, &parts).await
138 })
139 }
140
141 fn truncated_payload_length(input_len: usize) -> usize {
142 truncate_error_payload("x".repeat(input_len)).len()
143 }
144
145 fn truncated_payload_marker_present(input_len: usize) -> bool {
146 truncate_error_payload("x".repeat(input_len)).ends_with("…[truncated]")
147 }
148
149 fn poll_lock_ready(state: ReadyState) -> Poll<Result<(), BoxDynError>> {
150 let mut service = LockTaskService {
151 inner: ReadyService { state },
152 pool: unchecked_pool(),
153 };
154 let mut cx = Context::from_waker(noop_waker_ref());
155 service.poll_ready(&mut cx)
156 }
157
158 fn layered_service_debug() -> String {
159 let layer = LockTaskLayer::new(unchecked_pool());
160 let service = layer.layer(ReadyService {
161 state: ReadyState::Ready,
162 });
163 format!("{service:?}")
164 }
165
166 fn middleware_auto_ack_enabled(auto_ack: bool) -> bool {
167 PgMiddleware::new(unchecked_pool(), auto_ack).auto_ack()
168 }
169
170 async fn lock_service_call_async(
171 has_worker: bool,
172 has_task_id: bool,
173 ) -> Result<(), BoxDynError> {
174 let mut task = TaskBuilder::new(())
175 .with_ctx(PgContext::new().with_queue("lock-service-unit".to_owned()))
176 .build();
177 if has_worker {
178 task.parts
179 .data
180 .insert(WorkerContext::new::<()>("lock-service-worker"));
181 }
182 if has_task_id {
183 task.parts.task_id = Some(task_id());
184 }
185
186 let mut service = LockTaskService {
187 inner: ReadyService {
188 state: ReadyState::Ready,
189 },
190 pool: unchecked_pool(),
191 };
192 service.call(task).await
193 }
194
195 fn lock_service_call_missing_field(
196 has_worker: bool,
197 has_task_id: bool,
198 ) -> Result<(), BoxDynError> {
199 block_on(lock_service_call_async(has_worker, has_task_id))
200 }
201
202 fn missing_field(field: &'static str) -> impl Fn(&crate::Error) -> AssertionResult {
203 move |error| match error {
204 crate::Error::MissingField(found) if *found == field => Ok(()),
205 other => Err(AssertionError::new(vec![format!(
206 "expected missing field {field}, got {other:?}"
207 )])),
208 }
209 }
210
211 fn poll_ready_ok(result: &Poll<Result<(), BoxDynError>>) -> AssertionResult {
212 match result {
213 Poll::Ready(Ok(())) => Ok(()),
214 other => Err(AssertionError::new(vec![format!(
215 "expected ready ok, got {other:?}"
216 )])),
217 }
218 }
219
220 fn poll_ready_err(result: &Poll<Result<(), BoxDynError>>) -> AssertionResult {
221 match result {
222 Poll::Ready(Err(_)) => Ok(()),
223 other => Err(AssertionError::new(vec![format!(
224 "expected ready error, got {other:?}"
225 )])),
226 }
227 }
228
229 fn poll_pending(result: &Poll<Result<(), BoxDynError>>) -> AssertionResult {
230 match result {
231 Poll::Pending => Ok(()),
232 other => Err(AssertionError::new(vec![format!(
233 "expected pending, got {other:?}"
234 )])),
235 }
236 }
237
238 fn debug_mentions_lock_service(result: &String) -> AssertionResult {
239 if result.contains("LockTaskService") && result.contains("pool") {
240 Ok(())
241 } else {
242 Err(AssertionError::new(vec![format!(
243 "expected lock service debug output, got {result}"
244 )]))
245 }
246 }
247
248 fn abort_contains(expected: &'static str) -> impl Fn(&BoxDynError) -> AssertionResult {
249 move |error| {
250 let message = error.to_string();
251 if message.contains(expected) {
252 Ok(())
253 } else {
254 Err(AssertionError::new(vec![format!(
255 "expected abort containing {expected:?}, got {message:?}"
256 )]))
257 }
258 }
259 }
260
261 lets_expect! {
262 expect(calculate_status(&parts, &result)) {
263 let parts = parts_for_ack(attempts, max_attempts);
264 let result: Result<(), BoxDynError> = Ok(());
265 let attempts = 1;
266 let max_attempts = 3;
267
268 when task_succeeds {
269 to marks_the_task_done { equal(Status::Done) }
270 }
271
272 when task_fails_below_the_attempt_limit {
273 let result: Result<(), BoxDynError> = Err(box_error("retry"));
274 to marks_the_task_failed { equal(Status::Failed) }
275 }
276
277 when task_fails_at_the_attempt_limit {
278 let attempts = 3;
279 let result: Result<(), BoxDynError> = Err(box_error("exact limit"));
280 to kills_the_task { equal(Status::Killed) }
281 }
282
283 when task_fails_above_the_attempt_limit {
284 let attempts = 4;
285 let result: Result<(), BoxDynError> = Err(box_error("above limit"));
286 to kills_the_task { equal(Status::Killed) }
287 }
288
289 when task_fails_with_a_negative_max_attempts_from_a_corrupt_row {
290 let max_attempts = -1;
297 let attempts = 0;
298 let result: Result<(), BoxDynError> = Err(box_error("corrupt row"));
299 to kills_the_task_to_avoid_an_infinite_retry { equal(Status::Killed) }
300 }
301
302 when task_fails_with_zero_max_attempts_on_the_first_attempt {
303 let max_attempts = 0;
306 let attempts = 0;
307 let result: Result<(), BoxDynError> = Err(box_error("no retries"));
308 to kills_the_task { equal(Status::Killed) }
309 }
310 }
311
312 expect(poll_lock_ready(state)) {
313 let state = ReadyState::Ready;
314
315 when inner_service_is_ready {
316 to returns_ready { poll_ready_ok }
317 }
318
319 when inner_service_returns_an_error {
320 let state = ReadyState::Error;
321 to returns_the_error { poll_ready_err }
322 }
323
324 when inner_service_is_pending {
325 let state = ReadyState::Pending;
326 to stays_pending { poll_pending }
327 }
328 }
329
330 expect(layered_service_debug()) {
331 to wraps_the_inner_service_with_the_pool { debug_mentions_lock_service }
332 }
333
334 expect(middleware_auto_ack_enabled(auto_ack)) {
335 let auto_ack = true;
336
337 when config_enables_auto_ack {
338 to installs_the_acknowledgement_layer { equal(true) }
339 }
340
341 when config_disables_auto_ack {
342 let auto_ack = false;
343 to leaves_acknowledgement_to_the_caller { equal(false) }
344 }
345 }
346 }
347
348 lets_expect! {
349 expect(ack_missing_field(has_task_id, has_lock_by, has_queue, has_lock_at)) {
350 let has_task_id = true;
351 let has_lock_by = true;
352 let has_queue = true;
353 let has_lock_at = true;
354
355 when task_id_is_missing {
356 let has_task_id = false;
357 to rejects_before_querying_the_database { be_err_and missing_field("task_id") }
358 }
359
360 when lock_owner_is_missing {
361 let has_lock_by = false;
362 to rejects_before_querying_the_database { be_err_and missing_field("lock_by") }
363 }
364
365 when queue_is_missing {
366 let has_queue = false;
367 to rejects_before_querying_the_database { be_err_and missing_field("queue") }
368 }
369
370 when lock_timestamp_is_missing {
371 let has_lock_at = false;
372 to rejects_before_querying_the_database { be_err_and missing_field("lock_at") }
373 }
374 }
375
376 expect(lock_service_call_missing_field(has_worker, has_task_id)) {
377 let has_worker = true;
378 let has_task_id = true;
379
380 when worker_context_is_missing {
381 let has_worker = false;
382 to aborts_before_locking_the_task { be_err_and abort_contains("worker_context") }
383 }
384
385 when task_id_is_missing {
386 let has_task_id = false;
387 to aborts_before_locking_the_task { be_err_and abort_contains("task_id") }
388 }
389 }
390
391 expect(truncated_payload_length(input_len)) {
392 let input_len = 100;
393
394 when payload_is_shorter_than_the_eight_kib_cap {
395 to leaves_the_payload_length_unchanged { equal(100) }
396 }
397
398 when payload_is_exactly_eight_kib {
399 let input_len = 8 * 1024;
400 to leaves_the_payload_length_unchanged { equal(8 * 1024) }
401 }
402
403 when payload_is_one_byte_above_eight_kib {
404 let input_len = 8 * 1024 + 1;
405 to truncates_to_eight_kib_plus_the_marker_byte_length {
406 equal(8 * 1024 + "…[truncated]".len())
407 }
408 }
409
410 when payload_is_far_above_eight_kib {
411 let input_len = 64 * 1024;
412 to truncates_to_eight_kib_plus_the_marker_byte_length {
413 equal(8 * 1024 + "…[truncated]".len())
414 }
415 }
416 }
417
418 expect(truncated_payload_marker_present(input_len)) {
419 let input_len = 100;
420
421 when payload_is_within_budget {
422 to does_not_append_a_truncation_marker { equal(false) }
423 }
424
425 when payload_overflows_the_budget {
426 let input_len = 8 * 1024 + 1;
427 to appends_the_truncation_marker { equal(true) }
428 }
429 }
430 }
431
432 #[cfg(feature = "tokio")]
433 mod tokio_tests {
434 use super::*;
435 use serde::{Serialize, Serializer, ser};
436
437 async fn ack_with_attempt_overflow() -> Result<(), crate::Error> {
443 let mut parts = parts_for_ack(1, 3);
444 parts.attempt = Attempt::new_with_value(i32::MAX as usize + 1);
446 parts.ctx = parts
447 .ctx
448 .clone()
449 .with_queue("ack-queue".to_owned())
450 .with_lock_by(Some("ack-worker".to_owned()))
451 .with_lock_at(Some(1_700_000_000));
452 let mut ack = PgAck::new(unchecked_pool());
453 let result: Result<(), BoxDynError> = Ok(());
454 ack.ack(&result, &parts).await
455 }
456
457 fn invalid_attempt_overflow(error: &crate::Error) -> AssertionResult {
458 match error {
459 crate::Error::InvalidArgument(msg) if msg.contains("attempt counter") => Ok(()),
460 other => Err(AssertionError::new(vec![format!(
461 "expected InvalidArgument citing attempt counter overflow, got {other:?}"
462 )])),
463 }
464 }
465
466 #[derive(Debug)]
471 struct PoisonOk;
472
473 impl Serialize for PoisonOk {
474 fn serialize<S: Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
475 Err(ser::Error::custom("intentional serialize failure"))
476 }
477 }
478
479 async fn ack_with_unserializable_result() -> Result<(), crate::Error> {
480 let mut parts: Parts<PgContext, Ulid> = TaskBuilder::new(())
481 .with_task_id(task_id())
482 .with_attempt(Attempt::new_with_value(1))
483 .with_ctx(
484 PgContext::new()
485 .with_max_attempts(3)
486 .with_queue("ack-queue".to_owned())
487 .with_lock_by(Some("ack-worker".to_owned()))
488 .with_lock_at(Some(1_700_000_000)),
489 )
490 .build()
491 .parts;
492 let _ = &mut parts;
494 let mut ack = PgAck::new(unchecked_pool());
495 let result: Result<PoisonOk, BoxDynError> = Ok(PoisonOk);
496 ack.ack(&result, &parts).await
497 }
498
499 fn json_serialize_error(error: &crate::Error) -> AssertionResult {
500 match error {
501 crate::Error::Json(_) => Ok(()),
502 other => Err(AssertionError::new(vec![format!(
503 "expected Error::Json from a failing Serialize impl, got {other:?}"
504 )])),
505 }
506 }
507
508 async fn lock_service_call_preclaimed() -> Result<(), BoxDynError> {
514 let mut task = TaskBuilder::new(())
515 .with_task_id(task_id())
516 .with_ctx(
517 PgContext::new()
518 .with_queue("lock-service-unit".to_owned())
519 .with_lock_by(Some("lock-service-worker".to_owned()))
520 .with_lock_at(Some(1_700_000_000)),
521 )
522 .build();
523 task.parts
524 .data
525 .insert(WorkerContext::new::<()>("lock-service-worker"));
526
527 let mut service = LockTaskService {
528 inner: ReadyService {
529 state: ReadyState::Ready,
530 },
531 pool: unchecked_pool(),
532 };
533 service.call(task).await
534 }
535
536 fn lock_service_call_preclaimed_succeeds(
537 result: &Result<(), BoxDynError>,
538 ) -> AssertionResult {
539 match result {
540 Ok(()) => Ok(()),
541 Err(error) => Err(AssertionError::new(vec![format!(
542 "expected the preclaimed branch to bypass lock_task and succeed, got {error}"
543 )])),
544 }
545 }
546
547 lets_expect! { #tokio_test
548 expect(lock_service_call_async(true, true).await) {
549 when task_has_worker_and_id_but_the_database_is_unavailable {
550 to aborts_with_the_lock_error { be_err_and abort_contains("failed to acquire PostgreSQL connection") }
551 }
552 }
553
554 expect(ack_with_attempt_overflow().await) {
555 when the_attempt_counter_exceeds_i32_max {
556 to surfaces_invalid_argument_before_touching_the_database {
557 be_err_and invalid_attempt_overflow
558 }
559 }
560 }
561
562 expect(ack_with_unserializable_result().await) {
563 when the_jobs_ok_payload_fails_to_serialize {
564 to surfaces_an_error_json_before_touching_the_database {
565 be_err_and json_serialize_error
566 }
567 }
568 }
569
570 expect(lock_service_call_preclaimed().await) {
571 when the_task_already_carries_a_matching_lock_by_and_lock_at {
572 to bypasses_the_sql_lock_task_round_trip_and_completes {
573 lock_service_call_preclaimed_succeeds
574 }
575 }
576 }
577 }
578 }
579}
580
581impl PgAck {
582 #[must_use]
589 pub fn new(pool: PgPool) -> Self {
590 Self {
591 pool,
592 lease_token: None,
593 }
594 }
595
596 #[must_use]
603 pub fn with_lease_token(pool: PgPool, lease_token: Arc<str>) -> Self {
604 Self {
605 pool,
606 lease_token: Some(lease_token),
607 }
608 }
609}
610
611const MAX_ERROR_PAYLOAD_LEN: usize = 8 * 1024;
616const TRUNCATION_MARKER: &str = "…[truncated]";
617
618pub(crate) fn truncate_error_payload(mut text: String) -> String {
619 if text.len() > MAX_ERROR_PAYLOAD_LEN {
620 let mut cut = MAX_ERROR_PAYLOAD_LEN;
625 while cut > 0 && !text.is_char_boundary(cut) {
626 cut -= 1;
627 }
628 text.truncate(cut);
629 text.push_str(TRUNCATION_MARKER);
630 }
631 text
632}
633
634impl<Res: Serialize> Acknowledge<Res, PgContext, Ulid> for PgAck {
635 type Error = Error;
636 type Future = BoxFuture<'static, Result<(), Self::Error>>;
637
638 fn ack(
639 &mut self,
640 res: &Result<Res, BoxDynError>,
641 parts: &Parts<PgContext, Ulid>,
642 ) -> Self::Future {
643 let task_id = parts.task_id;
644 let worker_id = parts.ctx.lock_by().clone();
645 let queue = parts.ctx.queue().clone();
646 let lock_at = *parts.ctx.lock_at();
647 let response = serde_json::to_value(
648 res.as_ref()
649 .map_err(|error| truncate_error_payload(error.to_string())),
650 );
651 let status = calculate_status(parts, res);
652 let response = response.map(Some);
661 let attempts_raw = parts.attempt.current();
667 let attempts = i32::try_from(attempts_raw);
668 let pool = self.pool.clone();
669 let lease_token = self.lease_token.clone();
670
671 async move {
672 let attempts = attempts.map_err(|_| {
673 Error::InvalidArgument(format!(
674 "task attempt counter {attempts_raw} exceeds i32::MAX and cannot be stored"
675 ))
676 })?;
677 let started_attempts = attempts.saturating_sub(1);
678 queries::ack_task(
679 pool,
680 queries::AckTaskUpdate {
681 task_id: task_id.ok_or(Error::MissingField("task_id"))?,
682 attempts,
683 started_attempts,
684 result: response?,
685 status,
686 worker_id: worker_id.ok_or(Error::MissingField("lock_by"))?,
687 queue: queue.ok_or(Error::MissingField("queue"))?,
688 lock_at: lock_at.ok_or(Error::MissingField("lock_at"))?,
689 lease_token: lease_token.as_deref().map(str::to_owned),
690 },
691 )
692 .await
693 }
694 .boxed()
695 }
696}
697
698#[must_use]
703pub(crate) fn calculate_status<Res>(
704 parts: &Parts<PgContext, Ulid>,
705 res: &Result<Res, BoxDynError>,
706) -> Status {
707 match res {
708 Ok(_) => Status::Done,
709 Err(_) => match usize::try_from(parts.ctx.max_attempts()) {
710 Ok(max) if max > parts.attempt.current() => Status::Failed,
711 _ => Status::Killed,
712 },
713 }
714}
715
716pub async fn lock_task(pool: &PgPool, task_id: &Ulid, worker_id: &str) -> Result<(), Error> {
730 queries::lock_task(pool.clone(), *task_id, worker_id.to_owned(), None).await
731}
732
733pub async fn lock_task_in_queue(
740 pool: &PgPool,
741 task_id: &Ulid,
742 worker_id: &str,
743 queue: &str,
744) -> Result<(), Error> {
745 queries::lock_task(
746 pool.clone(),
747 *task_id,
748 worker_id.to_owned(),
749 Some(queue.to_owned()),
750 )
751 .await
752}
753
754#[derive(Debug, Clone)]
761pub(crate) struct LockTaskLayer {
762 pool: PgPool,
763}
764
765impl LockTaskLayer {
766 #[must_use]
768 pub(crate) fn new(pool: PgPool) -> Self {
769 Self { pool }
770 }
771}
772
773impl<S> Layer<S> for LockTaskLayer {
774 type Service = LockTaskService<S>;
775
776 fn layer(&self, inner: S) -> Self::Service {
777 LockTaskService {
778 inner,
779 pool: self.pool.clone(),
780 }
781 }
782}
783
784#[derive(Debug, Clone)]
789pub struct PgMiddleware {
790 lock: LockTaskLayer,
791 ack: Option<AcknowledgeLayer<PgAck>>,
792}
793
794impl PgMiddleware {
795 #[must_use]
797 pub fn new(pool: PgPool, auto_ack: bool) -> Self {
798 Self {
799 lock: LockTaskLayer::new(pool.clone()),
800 ack: auto_ack.then(|| AcknowledgeLayer::new(PgAck::new(pool))),
801 }
802 }
803
804 #[must_use]
808 pub fn with_lease_token(pool: PgPool, auto_ack: bool, lease_token: Arc<str>) -> Self {
809 Self {
810 lock: LockTaskLayer::new(pool.clone()),
811 ack: auto_ack
812 .then(|| AcknowledgeLayer::new(PgAck::with_lease_token(pool, lease_token))),
813 }
814 }
815
816 #[must_use]
818 pub fn auto_ack(&self) -> bool {
819 self.ack.is_some()
820 }
821}
822
823impl<S> Layer<S> for PgMiddleware
824where
825 AcknowledgeLayer<PgAck>: Layer<LockTaskService<S>>,
826{
827 type Service = PgMiddlewareService<
828 <AcknowledgeLayer<PgAck> as Layer<LockTaskService<S>>>::Service,
829 LockTaskService<S>,
830 >;
831
832 fn layer(&self, inner: S) -> Self::Service {
833 let locked = LockTaskService {
840 inner,
841 pool: self.lock.pool.clone(),
842 };
843 match &self.ack {
844 Some(ack) => PgMiddlewareService::AutoAck(ack.layer(locked)),
845 None => PgMiddlewareService::ManualAck(locked),
846 }
847 }
848}
849
850#[derive(Debug, Clone)]
852pub enum PgMiddlewareService<AutoAck, ManualAck> {
853 AutoAck(AutoAck),
855 ManualAck(ManualAck),
857}
858
859impl<Req, AutoAck, ManualAck> Service<Req> for PgMiddlewareService<AutoAck, ManualAck>
860where
861 AutoAck: Service<Req>,
862 ManualAck: Service<Req, Response = AutoAck::Response, Error = AutoAck::Error>,
863{
864 type Response = AutoAck::Response;
865 type Error = AutoAck::Error;
866 type Future = Either<AutoAck::Future, ManualAck::Future>;
867
868 fn poll_ready(
869 &mut self,
870 cx: &mut std::task::Context<'_>,
871 ) -> std::task::Poll<Result<(), Self::Error>> {
872 match self {
873 Self::AutoAck(service) => service.poll_ready(cx),
874 Self::ManualAck(service) => service.poll_ready(cx),
875 }
876 }
877
878 fn call(&mut self, req: Req) -> Self::Future {
879 match self {
880 Self::AutoAck(service) => Either::Left(service.call(req)),
881 Self::ManualAck(service) => Either::Right(service.call(req)),
882 }
883 }
884}
885
886#[derive(Debug, Clone)]
888pub struct LockTaskService<S> {
889 inner: S,
890 pool: PgPool,
891}
892
893impl<S, Args> Service<PgTask<Args>> for LockTaskService<S>
894where
895 S: Service<PgTask<Args>> + Clone + Send + 'static,
896 S::Future: Send + 'static,
897 S::Error: Into<BoxDynError>,
898 Args: Send + 'static,
899{
900 type Response = S::Response;
901 type Error = BoxDynError;
902 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
903
904 fn poll_ready(
905 &mut self,
906 cx: &mut std::task::Context<'_>,
907 ) -> std::task::Poll<Result<(), Self::Error>> {
908 self.inner.poll_ready(cx).map_err(Into::into)
909 }
910
911 fn call(&mut self, req: PgTask<Args>) -> Self::Future {
912 let pool = self.pool.clone();
913 let worker_id = req
914 .parts
915 .data
916 .get::<WorkerContext>()
917 .map(|worker| worker.name().to_owned());
918 let queue = req.parts.ctx.queue().clone();
919 let task_id = req.parts.task_id.map(|id| *id.inner());
920 let preclaimed = matches!(
929 (req.parts.ctx.lock_by().as_deref(), worker_id.as_deref()),
930 (Some(stored), Some(current)) if stored == current
931 ) && req.parts.ctx.lock_at().is_some();
932 let clone = self.inner.clone();
939 let mut ready_inner = std::mem::replace(&mut self.inner, clone);
940
941 async move {
942 let worker_id =
943 worker_id.ok_or_else(|| AbortError::new(Error::MissingField("worker_context")))?;
944 let task_id = task_id.ok_or_else(|| AbortError::new(Error::MissingField("task_id")))?;
945 if !preclaimed {
946 queries::lock_task(pool, task_id, worker_id, queue)
947 .await
948 .map_err(AbortError::new)?;
949 }
950 ready_inner.call(req).await.map_err(Into::into)
951 }
952 .boxed()
953 }
954}