1use std::collections::BTreeSet;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::Poll;
8
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use tracing::{error, info, warn};
12
13use crate::activity::{ActivityRegistry, HandlerFuture};
14use crate::config::WorkerConfig;
15use crate::context::ActivityContext;
16use crate::error::WorkerError;
17use crate::protocol::reconnect::{
18 ReconnectBackoff, UnackedResultTracker, re_report_unacked, reconnect_with_backoff,
19 register_connected_session,
20};
21use crate::protocol::{GrpcWorkerSession, WorkerSession};
22use crate::runtime::{
23 NoShutdown, ServeEnd, SessionHealth, serve_activity_tasks, serve_activity_tasks_until,
24};
25
26#[must_use]
28pub struct WorkerBuilder {
29 config: WorkerConfig,
30 activities: ActivityRegistry,
31}
32
33impl WorkerBuilder {
34 pub fn new(config: WorkerConfig) -> Self {
36 Self {
37 config,
38 activities: ActivityRegistry::new(),
39 }
40 }
41
42 pub fn register_activity<Input, Output, Handler>(
48 mut self,
49 activity_type: impl Into<String>,
50 handler: Handler,
51 ) -> Result<Self, WorkerError>
52 where
53 Input: Serialize + DeserializeOwned + Send + Sync + 'static,
54 Output: Serialize + Send + Sync + 'static,
55 Handler: for<'context> Fn(Input, &'context ActivityContext) -> HandlerFuture<'context, Output>
56 + Send
57 + Sync
58 + 'static,
59 {
60 self.activities = self.activities.register_activity(activity_type, handler)?;
61 Ok(self)
62 }
63
64 pub fn build(self) -> Result<Worker, WorkerError> {
70 if self.activities.is_empty() {
71 return Err(WorkerError::registration(EmptyActivitySet));
72 }
73 let available_handlers = self.activities.activity_types();
74 let activity_types = available_handlers.iter().cloned().collect();
75 Ok(Worker {
76 config: self.config,
77 activity_types,
78 available_handlers,
79 activities: Arc::new(self.activities),
80 })
81 }
82}
83
84#[must_use]
86pub struct Worker {
87 config: WorkerConfig,
88 activity_types: Vec<String>,
89 available_handlers: BTreeSet<String>,
90 activities: Arc<ActivityRegistry>,
91}
92
93impl Worker {
94 pub fn builder(config: WorkerConfig) -> WorkerBuilder {
96 WorkerBuilder::new(config)
97 }
98
99 #[must_use]
101 pub fn activity_types(&self) -> &[String] {
102 &self.activity_types
103 }
104
105 #[must_use]
107 pub fn available_handlers(&self) -> &BTreeSet<String> {
108 &self.available_handlers
109 }
110
111 pub async fn run(self) -> Result<(), WorkerError> {
132 self.run_until(std::future::pending::<()>()).await
133 }
134
135 pub async fn run_until<Shutdown>(self, shutdown: Shutdown) -> Result<(), WorkerError>
147 where
148 Shutdown: Future<Output = ()> + Send,
149 {
150 let config = self.config.clone();
151 self.run_with_connector_until(move || GrpcWorkerSession::connect(config.clone()), shutdown)
152 .await
153 }
154
155 pub async fn run_with_connector_until<S, F, Fut, Shutdown>(
196 self,
197 mut connect: F,
198 shutdown: Shutdown,
199 ) -> Result<(), WorkerError>
200 where
201 S: WorkerSession,
202 F: FnMut() -> Fut,
203 Fut: Future<Output = Result<S, WorkerError>>,
204 Shutdown: Future<Output = ()> + Send,
205 {
206 let backoff = ReconnectBackoff::from_config(&self.config)?;
207 let mut tracker = UnackedResultTracker::new();
208 tokio::pin!(shutdown);
209 let mut shutdown = SharedShutdown::new(shutdown);
210 let mut drop_failures = 0_usize;
211 let mut recovery_error: Option<WorkerError> = None;
212
213 loop {
214 let connected = tokio::select! {
215 biased;
216 () = shutdown.wait() => {
217 return recovery_error.take().map_or(Ok(()), Err);
218 }
219 result = reconnect_with_backoff(
220 &self.config,
221 self.activity_types.clone(),
222 &self.available_handlers,
223 &mut connect,
224 ) => result,
225 };
226 let mut session = connected?;
227 let session_started = tokio::time::Instant::now();
228 let mut health = SessionHealth::default();
229 let replay = tokio::select! {
234 biased;
235 () = shutdown.wait() => None,
236 result = re_report_unacked(&tracker, &mut session) => Some(result),
237 };
238 let Some(replay_result) = replay else {
239 return Ok(());
240 };
241 let served = match replay_result {
242 Ok(()) => {
243 serve_activity_tasks_until(
244 &self.config,
245 &mut session,
246 Arc::clone(&self.activities),
247 &mut tracker,
248 &mut health,
249 shutdown.wait(),
250 )
251 .await
252 }
253 Err(report_error) => Err(report_error),
254 };
255 drop(session);
256 let cause = match classify_serve_outcome(served, &health, shutdown.fired()) {
257 ServeClassification::End(result) => return result,
258 ServeClassification::Drop(cause) => cause,
259 };
260 let connected_for = health
267 .stream_ended_at
268 .unwrap_or_else(tokio::time::Instant::now)
269 .saturating_duration_since(session_started);
270 let proved_healthy = health.tasks_reported > 0 || connected_for > backoff.max_delay();
271 if proved_healthy && drop_failures > 0 {
272 info!(
273 drop_failures,
274 tasks_reported = health.tasks_reported,
275 "worker session proved healthy; drop budget reset"
276 );
277 drop_failures = 0;
278 }
279 let delay = if matches!(cause, DropCause::Drain) {
284 self.config.reconnect.initial_backoff
285 } else {
286 drop_failures += 1;
287 if drop_failures >= backoff.attempts() {
288 let error = cause.into_exhaustion_error();
289 error!(
290 drop_failures,
291 error = %error,
292 "worker session drop budget exhausted; not reconnecting"
293 );
294 return Err(error);
295 }
296 backoff.delay_for_attempt(drop_failures)
297 };
298 warn!(
299 drop_failures,
300 delay_ms = delay.as_millis(),
301 cause = %cause,
302 "worker session dropped; reconnecting after backoff"
303 );
304 let shutdown_won = tokio::select! {
305 biased;
306 () = shutdown.wait() => true,
307 () = tokio::time::sleep(delay) => false,
308 };
309 if shutdown_won {
310 return cause.into_shutdown_result();
311 }
312 recovery_error = cause.into_recovery_error();
313 }
314 }
315
316 pub async fn run_with_session<S>(self, session: S) -> Result<S, WorkerError>
322 where
323 S: WorkerSession,
324 {
325 self.run_with_session_until(session, std::future::pending::<()>())
326 .await
327 }
328
329 pub async fn run_with_session_until<S, Shutdown>(
335 self,
336 session: S,
337 shutdown: Shutdown,
338 ) -> Result<S, WorkerError>
339 where
340 S: WorkerSession,
341 Shutdown: Future<Output = ()> + Send,
342 {
343 let mut session = register_connected_session(
344 session,
345 &self.config,
346 self.activity_types.clone(),
347 &self.available_handlers,
348 )
349 .await?;
350 let mut tracker = UnackedResultTracker::new();
351 let mut health = SessionHealth::default();
352 serve_activity_tasks_until(
353 &self.config,
354 &mut session,
355 self.activities,
356 &mut tracker,
357 &mut health,
358 shutdown,
359 )
360 .await?;
361 Ok(session)
362 }
363}
364
365enum ServeClassification {
367 End(Result<(), WorkerError>),
369 Drop(DropCause),
371}
372
373fn classify_serve_outcome(
380 served: Result<ServeEnd, WorkerError>,
381 health: &SessionHealth,
382 shutdown_fired: bool,
383) -> ServeClassification {
384 match served {
385 Ok(ServeEnd::Shutdown) => ServeClassification::End(Ok(())),
386 Ok(ServeEnd::Drained) => {
387 if shutdown_fired {
388 return ServeClassification::End(Ok(()));
389 }
390 ServeClassification::Drop(DropCause::Drain)
391 }
392 Ok(ServeEnd::StreamClosed) => {
393 if shutdown_fired {
394 return ServeClassification::End(Ok(()));
395 }
396 ServeClassification::Drop(DropCause::CleanClose)
397 }
398 Err(error) if !error.is_retryable() => {
399 error!(error = %error, "worker session denied by server; not reconnecting");
400 ServeClassification::End(Err(error))
401 }
402 Err(error) if health.drain_received => {
403 warn!(
407 error = %error,
408 "session error after server drain; classified as drain drop"
409 );
410 if shutdown_fired {
411 return ServeClassification::End(Ok(()));
412 }
413 ServeClassification::Drop(DropCause::Drain)
414 }
415 Err(error) => {
416 if shutdown_fired {
417 return ServeClassification::End(Err(error));
418 }
419 ServeClassification::Drop(DropCause::Failure(error))
420 }
421 }
422}
423
424enum DropCause {
426 Failure(WorkerError),
428 CleanClose,
430 Drain,
433}
434
435impl DropCause {
436 fn into_exhaustion_error(self) -> WorkerError {
442 match self {
443 Self::Failure(error) => error,
444 Self::CleanClose | Self::Drain => WorkerError::CleanCloseExhausted,
445 }
446 }
447
448 fn into_shutdown_result(self) -> Result<(), WorkerError> {
451 match self {
452 Self::Failure(error) => Err(error),
453 Self::CleanClose | Self::Drain => Ok(()),
454 }
455 }
456
457 fn into_recovery_error(self) -> Option<WorkerError> {
459 match self {
460 Self::Failure(error) => Some(error),
461 Self::CleanClose | Self::Drain => None,
462 }
463 }
464}
465
466impl std::fmt::Display for DropCause {
467 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
468 match self {
469 Self::Failure(error) => write!(formatter, "{error}"),
470 Self::CleanClose => write!(formatter, "server closed the worker stream cleanly"),
471 Self::Drain => write!(formatter, "server drained the worker stream"),
472 }
473 }
474}
475
476struct SharedShutdown<'a, S> {
485 inner: Pin<&'a mut S>,
486 fired: bool,
487}
488
489impl<'a, S> SharedShutdown<'a, S>
490where
491 S: Future<Output = ()> + Send,
492{
493 const fn new(inner: Pin<&'a mut S>) -> Self {
494 Self {
495 inner,
496 fired: false,
497 }
498 }
499
500 const fn fired(&self) -> bool {
502 self.fired
503 }
504
505 fn wait(&mut self) -> impl Future<Output = ()> + Send {
507 std::future::poll_fn(|context| {
508 if self.fired {
509 return Poll::Ready(());
510 }
511 match self.inner.as_mut().poll(context) {
512 Poll::Ready(()) => {
513 self.fired = true;
514 Poll::Ready(())
515 }
516 Poll::Pending => Poll::Pending,
517 }
518 })
519 }
520}
521
522pub async fn run_worker_with_session<S>(worker: Worker, session: S) -> Result<S, WorkerError>
528where
529 S: WorkerSession,
530{
531 worker.run_with_session(session).await
532}
533
534#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
536#[error("worker must register at least one activity handler")]
537pub struct EmptyActivitySet;
538
539fn _assert_live_session_type() {
540 let _ = std::mem::size_of::<GrpcWorkerSession>();
541 let _ = std::mem::size_of::<NoShutdown>();
542 let _ = serve_activity_tasks::<GrpcWorkerSession, ActivityRegistry>;
543}
544
545#[cfg(test)]
546mod tests {
547 use std::collections::BTreeSet;
548 use std::sync::Arc;
549 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
550 use std::time::Duration;
551
552 use aion_core::{ActivityError, ActivityId, ContentType, Payload, WorkflowId};
553 use aion_proto::{ProtoActivityId, ProtoActivityTask, ProtoPayload, ProtoWorkflowId};
554 use async_trait::async_trait;
555 use futures::StreamExt as _;
556 use futures::stream;
557 use serde::{Deserialize, Serialize};
558 use tokio::sync::{Notify, mpsc};
559
560 use super::{Worker, WorkerBuilder};
561 use crate::config::{ReconnectConfig, WorkerConfig};
562 use crate::context::ActivityContext;
563 use crate::error::WorkerError;
564 use crate::protocol::{
565 WorkerSession, WorkerSessionEvent, WorkerTaskStream, validate_activity_handlers,
566 };
567
568 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
569 struct TestInput {
570 value: i32,
571 }
572
573 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
574 struct TestOutput {
575 value: i32,
576 }
577
578 struct ChannelSession {
579 receiver: Option<mpsc::Receiver<Result<WorkerSessionEvent, WorkerError>>>,
580 reports: Vec<RecordedReport>,
581 registered: Vec<String>,
582 }
583
584 #[derive(Clone, Debug, PartialEq, Eq)]
585 enum RecordedReport {
586 Completed(WorkflowId, ActivityId, Payload),
587 Failed(WorkflowId, ActivityId, ActivityError),
588 }
589
590 #[async_trait]
591 impl WorkerSession for ChannelSession {
592 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
593 drop(config.clone());
594 Ok(())
595 }
596
597 async fn register(
598 &mut self,
599 activity_types: Vec<String>,
600 available_handlers: &BTreeSet<String>,
601 ) -> Result<(), WorkerError> {
602 validate_activity_handlers(&activity_types, available_handlers)?;
603 self.registered = activity_types;
604 Ok(())
605 }
606
607 fn receive_tasks(&mut self) -> WorkerTaskStream {
608 match self.receiver.take() {
609 Some(receiver) => Box::pin(tokio_stream::wrappers::ReceiverStream::new(receiver)),
610 None => Box::pin(stream::empty()),
611 }
612 }
613
614 async fn report_result(
615 &mut self,
616 workflow_id: WorkflowId,
617 activity_id: ActivityId,
618 result: Payload,
619 ) -> Result<(), WorkerError> {
620 self.reports
621 .push(RecordedReport::Completed(workflow_id, activity_id, result));
622 Ok(())
623 }
624
625 async fn report_failure(
626 &mut self,
627 workflow_id: WorkflowId,
628 activity_id: ActivityId,
629 failure: ActivityError,
630 ) -> Result<(), WorkerError> {
631 self.reports
632 .push(RecordedReport::Failed(workflow_id, activity_id, failure));
633 Ok(())
634 }
635
636 async fn send_heartbeat(
637 &mut self,
638 workflow_id: WorkflowId,
639 activity_id: ActivityId,
640 progress: Option<Payload>,
641 ) -> Result<(), WorkerError> {
642 drop((workflow_id, activity_id, progress));
643 Ok(())
644 }
645 }
646
647 struct HungReportSession {
650 log: mpsc::UnboundedSender<SessionLog>,
651 index: usize,
652 }
653
654 #[async_trait]
655 impl WorkerSession for HungReportSession {
656 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
657 drop(config.clone());
658 Ok(())
659 }
660
661 async fn register(
662 &mut self,
663 activity_types: Vec<String>,
664 available_handlers: &BTreeSet<String>,
665 ) -> Result<(), WorkerError> {
666 validate_activity_handlers(&activity_types, available_handlers)?;
667 self.log
668 .send(SessionLog::Registered(self.index, activity_types))
669 .map_err(WorkerError::decode)
670 }
671
672 fn receive_tasks(&mut self) -> WorkerTaskStream {
673 Box::pin(stream::pending())
674 }
675
676 async fn report_result(
677 &mut self,
678 _workflow_id: WorkflowId,
679 _activity_id: ActivityId,
680 _result: Payload,
681 ) -> Result<(), WorkerError> {
682 std::future::pending::<()>().await;
683 Ok(())
684 }
685
686 async fn report_failure(
687 &mut self,
688 _workflow_id: WorkflowId,
689 _activity_id: ActivityId,
690 _failure: ActivityError,
691 ) -> Result<(), WorkerError> {
692 std::future::pending::<()>().await;
693 Ok(())
694 }
695
696 async fn send_heartbeat(
697 &mut self,
698 _workflow_id: WorkflowId,
699 _activity_id: ActivityId,
700 _progress: Option<Payload>,
701 ) -> Result<(), WorkerError> {
702 Ok(())
703 }
704 }
705
706 enum SessionKind {
707 Scripted(ScriptedSession),
708 Hung(HungReportSession),
709 }
710
711 #[async_trait]
712 impl WorkerSession for SessionKind {
713 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
714 match self {
715 Self::Scripted(session) => session.handshake(config).await,
716 Self::Hung(session) => session.handshake(config).await,
717 }
718 }
719
720 async fn register(
721 &mut self,
722 activity_types: Vec<String>,
723 available_handlers: &BTreeSet<String>,
724 ) -> Result<(), WorkerError> {
725 match self {
726 Self::Scripted(session) => {
727 session.register(activity_types, available_handlers).await
728 }
729 Self::Hung(session) => session.register(activity_types, available_handlers).await,
730 }
731 }
732
733 fn receive_tasks(&mut self) -> WorkerTaskStream {
734 match self {
735 Self::Scripted(session) => session.receive_tasks(),
736 Self::Hung(session) => session.receive_tasks(),
737 }
738 }
739
740 async fn report_result(
741 &mut self,
742 workflow_id: WorkflowId,
743 activity_id: ActivityId,
744 result: Payload,
745 ) -> Result<(), WorkerError> {
746 match self {
747 Self::Scripted(session) => {
748 session
749 .report_result(workflow_id, activity_id, result)
750 .await
751 }
752 Self::Hung(session) => {
753 session
754 .report_result(workflow_id, activity_id, result)
755 .await
756 }
757 }
758 }
759
760 async fn report_failure(
761 &mut self,
762 workflow_id: WorkflowId,
763 activity_id: ActivityId,
764 failure: ActivityError,
765 ) -> Result<(), WorkerError> {
766 match self {
767 Self::Scripted(session) => {
768 session
769 .report_failure(workflow_id, activity_id, failure)
770 .await
771 }
772 Self::Hung(session) => {
773 session
774 .report_failure(workflow_id, activity_id, failure)
775 .await
776 }
777 }
778 }
779
780 async fn send_heartbeat(
781 &mut self,
782 workflow_id: WorkflowId,
783 activity_id: ActivityId,
784 progress: Option<Payload>,
785 ) -> Result<(), WorkerError> {
786 match self {
787 Self::Scripted(session) => {
788 session
789 .send_heartbeat(workflow_id, activity_id, progress)
790 .await
791 }
792 Self::Hung(session) => {
793 session
794 .send_heartbeat(workflow_id, activity_id, progress)
795 .await
796 }
797 }
798 }
799 }
800
801 struct DrainLatchSession {
805 events: Vec<Result<WorkerSessionEvent, WorkerError>>,
806 fail_id: ActivityId,
807 }
808
809 #[async_trait]
810 impl WorkerSession for DrainLatchSession {
811 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
812 drop(config.clone());
813 Ok(())
814 }
815
816 async fn register(
817 &mut self,
818 activity_types: Vec<String>,
819 available_handlers: &BTreeSet<String>,
820 ) -> Result<(), WorkerError> {
821 validate_activity_handlers(&activity_types, available_handlers)
822 }
823
824 fn receive_tasks(&mut self) -> WorkerTaskStream {
825 Box::pin(stream::iter(std::mem::take(&mut self.events)))
826 }
827
828 async fn report_result(
829 &mut self,
830 _workflow_id: WorkflowId,
831 activity_id: ActivityId,
832 _result: Payload,
833 ) -> Result<(), WorkerError> {
834 if activity_id == self.fail_id {
835 return Err(WorkerError::Transport {
836 source: tonic::Status::unavailable(
837 "stream broke abruptly after the drain frame",
838 ),
839 });
840 }
841 Ok(())
842 }
843
844 async fn report_failure(
845 &mut self,
846 _workflow_id: WorkflowId,
847 _activity_id: ActivityId,
848 _failure: ActivityError,
849 ) -> Result<(), WorkerError> {
850 Ok(())
851 }
852
853 async fn send_heartbeat(
854 &mut self,
855 _workflow_id: WorkflowId,
856 _activity_id: ActivityId,
857 _progress: Option<Payload>,
858 ) -> Result<(), WorkerError> {
859 Ok(())
860 }
861 }
862
863 enum LatchKind {
864 Latch(DrainLatchSession),
865 Deny(ScriptedSession),
866 }
867
868 #[async_trait]
869 impl WorkerSession for LatchKind {
870 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
871 match self {
872 Self::Latch(session) => session.handshake(config).await,
873 Self::Deny(session) => session.handshake(config).await,
874 }
875 }
876
877 async fn register(
878 &mut self,
879 activity_types: Vec<String>,
880 available_handlers: &BTreeSet<String>,
881 ) -> Result<(), WorkerError> {
882 match self {
883 Self::Latch(session) => session.register(activity_types, available_handlers).await,
884 Self::Deny(session) => session.register(activity_types, available_handlers).await,
885 }
886 }
887
888 fn receive_tasks(&mut self) -> WorkerTaskStream {
889 match self {
890 Self::Latch(session) => session.receive_tasks(),
891 Self::Deny(session) => session.receive_tasks(),
892 }
893 }
894
895 async fn report_result(
896 &mut self,
897 workflow_id: WorkflowId,
898 activity_id: ActivityId,
899 result: Payload,
900 ) -> Result<(), WorkerError> {
901 match self {
902 Self::Latch(session) => {
903 session
904 .report_result(workflow_id, activity_id, result)
905 .await
906 }
907 Self::Deny(session) => {
908 session
909 .report_result(workflow_id, activity_id, result)
910 .await
911 }
912 }
913 }
914
915 async fn report_failure(
916 &mut self,
917 workflow_id: WorkflowId,
918 activity_id: ActivityId,
919 failure: ActivityError,
920 ) -> Result<(), WorkerError> {
921 match self {
922 Self::Latch(session) => {
923 session
924 .report_failure(workflow_id, activity_id, failure)
925 .await
926 }
927 Self::Deny(session) => {
928 session
929 .report_failure(workflow_id, activity_id, failure)
930 .await
931 }
932 }
933 }
934
935 async fn send_heartbeat(
936 &mut self,
937 workflow_id: WorkflowId,
938 activity_id: ActivityId,
939 progress: Option<Payload>,
940 ) -> Result<(), WorkerError> {
941 match self {
942 Self::Latch(session) => {
943 session
944 .send_heartbeat(workflow_id, activity_id, progress)
945 .await
946 }
947 Self::Deny(session) => {
948 session
949 .send_heartbeat(workflow_id, activity_id, progress)
950 .await
951 }
952 }
953 }
954 }
955
956 #[test]
957 fn empty_worker_is_rejected() {
958 let error = WorkerBuilder::new(test_config()).build().err();
959
960 assert!(error.is_some_and(|error| error.to_string().contains("at least one activity")));
961 }
962
963 #[test]
964 fn worker_collects_two_activity_registration_names() -> Result<(), WorkerError> {
965 let worker = two_activity_worker()?;
966 let expected = [String::from("double"), String::from("increment")]
967 .into_iter()
968 .collect::<BTreeSet<_>>();
969
970 assert_eq!(worker.available_handlers(), &expected);
971 assert_eq!(
972 worker.activity_types(),
973 &[String::from("double"), String::from("increment")]
974 );
975 Ok(())
976 }
977
978 #[tokio::test]
979 async fn worker_registers_names_with_session() -> Result<(), WorkerError> {
980 let worker = two_activity_worker()?;
981 let session = worker
982 .run_with_session(ChannelSession {
983 receiver: None,
984 reports: Vec::new(),
985 registered: Vec::new(),
986 })
987 .await?;
988
989 assert_eq!(
990 session.registered,
991 vec![String::from("double"), String::from("increment")]
992 );
993 Ok(())
994 }
995
996 #[tokio::test]
997 async fn shutdown_waits_for_slow_in_flight_activity() -> Result<(), WorkerError> {
998 let workflow_id = WorkflowId::new_v4();
999 let activity_id = ActivityId::from_sequence_position(7);
1000 let (sender, receiver) = mpsc::channel(2);
1001 sender
1002 .send(Ok(WorkerSessionEvent::Task(proto_task(
1003 workflow_id,
1004 activity_id.clone(),
1005 "slow",
1006 0,
1007 ))))
1008 .await
1009 .map_err(WorkerError::decode)?;
1010 let release = Arc::new(AtomicBool::new(false));
1011 let started = Arc::new(AtomicUsize::new(0));
1012 let worker = Worker::builder(test_config())
1013 .register_activity("slow", {
1014 let release = Arc::clone(&release);
1015 let started = Arc::clone(&started);
1016 move |input: TestInput, context: &ActivityContext| {
1017 let release = Arc::clone(&release);
1018 let started = Arc::clone(&started);
1019 Box::pin(async move {
1020 let _ = input;
1021 started.fetch_add(1, Ordering::SeqCst);
1022 context.cancelled().await;
1023 while !release.load(Ordering::SeqCst) {
1024 tokio::time::sleep(Duration::from_millis(1)).await;
1025 }
1026 Ok(TestOutput { value: 1 })
1027 })
1028 }
1029 })?
1030 .build()?;
1031 let (shutdown_sender, shutdown_receiver) = tokio::sync::oneshot::channel::<()>();
1032 let session = ChannelSession {
1033 receiver: Some(receiver),
1034 reports: Vec::new(),
1035 registered: Vec::new(),
1036 };
1037 let handle = tokio::spawn(async move {
1038 worker
1039 .run_with_session_until(session, async {
1040 let _ = shutdown_receiver.await;
1041 })
1042 .await
1043 });
1044
1045 wait_until_started(&started).await;
1046 shutdown_sender
1047 .send(())
1048 .map_err(|()| WorkerError::decode(SendFailed))?;
1049 tokio::time::sleep(Duration::from_millis(20)).await;
1050 assert!(!handle.is_finished());
1051 release.store(true, Ordering::SeqCst);
1052 drop(sender);
1053 let session = handle.await.map_err(WorkerError::decode)??;
1054
1055 assert_eq!(session.reports.len(), 1);
1056 assert!(matches!(
1057 &session.reports[0],
1058 RecordedReport::Completed(_, reported_id, _) if reported_id == &activity_id
1059 ));
1060 Ok(())
1061 }
1062
1063 fn two_activity_worker() -> Result<Worker, WorkerError> {
1064 two_activity_worker_with(test_config())
1065 }
1066
1067 fn two_activity_worker_with(config: WorkerConfig) -> Result<Worker, WorkerError> {
1068 Worker::builder(config)
1069 .register_activity("double", |input: TestInput, context| {
1070 Box::pin(async move {
1071 let _ = context;
1072 Ok(TestOutput {
1073 value: input.value * 2,
1074 })
1075 })
1076 })?
1077 .register_activity("increment", |input: TestInput, context| {
1078 Box::pin(async move {
1079 let _ = context;
1080 Ok(TestOutput {
1081 value: input.value + 1,
1082 })
1083 })
1084 })?
1085 .build()
1086 }
1087
1088 fn proto_task(
1089 workflow_id: WorkflowId,
1090 activity_id: ActivityId,
1091 activity_type: &str,
1092 value: i32,
1093 ) -> ProtoActivityTask {
1094 ProtoActivityTask {
1095 workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
1096 activity_id: Some(ProtoActivityId::from(activity_id)),
1097 activity_type: activity_type.to_owned(),
1098 input: Some(ProtoPayload::from(Payload::new(
1099 ContentType::Json,
1100 format!("{{\"value\":{value}}}").into_bytes(),
1101 ))),
1102 attempt: 1,
1103 }
1104 }
1105
1106 async fn wait_until_started(started: &AtomicUsize) {
1107 while started.load(Ordering::SeqCst) == 0 {
1108 tokio::time::sleep(Duration::from_millis(1)).await;
1109 }
1110 }
1111
1112 fn test_config() -> WorkerConfig {
1113 test_config_with(ReconnectConfig::new(
1114 Duration::from_millis(5),
1115 Duration::from_millis(20),
1116 3,
1117 ))
1118 }
1119
1120 fn test_config_with(reconnect: ReconnectConfig) -> WorkerConfig {
1121 WorkerConfig::new(
1122 "http://127.0.0.1:50051",
1123 "payments",
1124 "worker-a",
1125 1,
1126 reconnect,
1127 None,
1128 )
1129 }
1130
1131 fn slow_reconnect_config() -> WorkerConfig {
1132 test_config_with(ReconnectConfig::new(
1133 Duration::from_secs(5),
1134 Duration::from_secs(10),
1135 5,
1136 ))
1137 }
1138
1139 #[derive(Debug, thiserror::Error)]
1140 #[error("failed to send shutdown signal")]
1141 struct SendFailed;
1142
1143 #[derive(Debug, thiserror::Error)]
1144 #[error("expected the worker run to fail")]
1145 struct UnexpectedSuccess;
1146
1147 #[derive(Debug, thiserror::Error)]
1148 #[error("expected a completed activity report")]
1149 struct UnexpectedReportShape;
1150
1151 #[derive(Debug)]
1153 enum SessionLog {
1154 Registered(usize, Vec<String>),
1155 Reported(usize, RecordedReport),
1156 }
1157
1158 struct ScriptedSession {
1161 index: usize,
1162 log: mpsc::UnboundedSender<SessionLog>,
1163 events: Vec<Result<WorkerSessionEvent, WorkerError>>,
1164 fail_reports: bool,
1165 register_denial: Option<tonic::Status>,
1166 delay_stream: Option<Duration>,
1169 }
1170
1171 #[async_trait]
1172 impl WorkerSession for ScriptedSession {
1173 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
1174 drop(config.clone());
1175 Ok(())
1176 }
1177
1178 async fn register(
1179 &mut self,
1180 activity_types: Vec<String>,
1181 available_handlers: &BTreeSet<String>,
1182 ) -> Result<(), WorkerError> {
1183 validate_activity_handlers(&activity_types, available_handlers)?;
1184 if let Some(denial) = self.register_denial.take() {
1185 return Err(WorkerError::Registration {
1186 source: Box::new(denial),
1187 });
1188 }
1189 self.log
1190 .send(SessionLog::Registered(self.index, activity_types))
1191 .map_err(WorkerError::decode)
1192 }
1193
1194 fn receive_tasks(&mut self) -> WorkerTaskStream {
1195 let events = std::mem::take(&mut self.events);
1196 match self.delay_stream.take() {
1197 Some(delay) => Box::pin(
1198 stream::once(async move {
1199 tokio::time::sleep(delay).await;
1200 stream::iter(events)
1201 })
1202 .flatten(),
1203 ),
1204 None => Box::pin(stream::iter(events)),
1205 }
1206 }
1207
1208 async fn report_result(
1209 &mut self,
1210 workflow_id: WorkflowId,
1211 activity_id: ActivityId,
1212 result: Payload,
1213 ) -> Result<(), WorkerError> {
1214 if self.fail_reports {
1215 return Err(WorkerError::Transport {
1216 source: tonic::Status::unavailable("transport dropped before result ack"),
1217 });
1218 }
1219 self.log
1220 .send(SessionLog::Reported(
1221 self.index,
1222 RecordedReport::Completed(workflow_id, activity_id, result),
1223 ))
1224 .map_err(WorkerError::decode)
1225 }
1226
1227 async fn report_failure(
1228 &mut self,
1229 workflow_id: WorkflowId,
1230 activity_id: ActivityId,
1231 failure: ActivityError,
1232 ) -> Result<(), WorkerError> {
1233 if self.fail_reports {
1234 return Err(WorkerError::Transport {
1235 source: tonic::Status::unavailable("transport dropped before failure ack"),
1236 });
1237 }
1238 self.log
1239 .send(SessionLog::Reported(
1240 self.index,
1241 RecordedReport::Failed(workflow_id, activity_id, failure),
1242 ))
1243 .map_err(WorkerError::decode)
1244 }
1245
1246 async fn send_heartbeat(
1247 &mut self,
1248 workflow_id: WorkflowId,
1249 activity_id: ActivityId,
1250 progress: Option<Payload>,
1251 ) -> Result<(), WorkerError> {
1252 drop((workflow_id, activity_id, progress));
1253 Ok(())
1254 }
1255 }
1256
1257 #[tokio::test]
1258 async fn establishment_retries_transient_failures_until_attempts_exhausted()
1259 -> Result<(), WorkerError> {
1260 let worker = two_activity_worker()?;
1261 let attempts = Arc::new(AtomicUsize::new(0));
1262 let connect = {
1263 let attempts = Arc::clone(&attempts);
1264 move || {
1265 attempts.fetch_add(1, Ordering::SeqCst);
1266 async move {
1267 Err::<ScriptedSession, _>(WorkerError::Transport {
1268 source: tonic::Status::unavailable("engine unreachable"),
1269 })
1270 }
1271 }
1272 };
1273
1274 let result = worker
1275 .run_with_connector_until(connect, std::future::pending::<()>())
1276 .await;
1277
1278 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1279 let Err(error) = result else {
1280 return Err(WorkerError::decode(UnexpectedSuccess));
1281 };
1282 assert!(error.is_retryable());
1283 assert!(matches!(
1284 error.grpc_status().map(tonic::Status::code),
1285 Some(tonic::Code::Unavailable)
1286 ));
1287 Ok(())
1288 }
1289
1290 #[tokio::test]
1291 async fn establishment_denial_surfaces_after_one_attempt() -> Result<(), WorkerError> {
1292 let worker = two_activity_worker()?;
1293 let attempts = Arc::new(AtomicUsize::new(0));
1294 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1295 let connect = {
1296 let attempts = Arc::clone(&attempts);
1297 move || {
1298 attempts.fetch_add(1, Ordering::SeqCst);
1299 let log = log_sender.clone();
1300 async move {
1301 Ok(ScriptedSession {
1302 index: 1,
1303 log,
1304 events: Vec::new(),
1305 fail_reports: false,
1306 register_denial: Some(tonic::Status::permission_denied(
1307 "namespace `payments` is not granted to subject `worker-a`",
1308 )),
1309 delay_stream: None,
1310 })
1311 }
1312 }
1313 };
1314
1315 let result = worker
1316 .run_with_connector_until(connect, std::future::pending::<()>())
1317 .await;
1318
1319 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1320 let Err(error) = result else {
1321 return Err(WorkerError::decode(UnexpectedSuccess));
1322 };
1323 assert!(!error.is_retryable());
1324 assert!(matches!(
1325 error.grpc_status().map(tonic::Status::code),
1326 Some(tonic::Code::PermissionDenied)
1327 ));
1328 assert_eq!(
1329 error.grpc_status().map(tonic::Status::message),
1330 Some("namespace `payments` is not granted to subject `worker-a`")
1331 );
1332 drop(log_receiver);
1333 Ok(())
1334 }
1335
1336 #[tokio::test]
1337 async fn mid_run_drop_reconnects_re_registers_and_re_reports_unacked() -> Result<(), WorkerError>
1338 {
1339 let workflow_id = WorkflowId::new_v4();
1340 let activity_id = ActivityId::from_sequence_position(3);
1341 let worker = two_activity_worker()?;
1342 let attempts = Arc::new(AtomicUsize::new(0));
1343 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1344 let connect = {
1345 let attempts = Arc::clone(&attempts);
1346 let log_sender = log_sender.clone();
1347 let workflow_id = workflow_id.clone();
1348 let activity_id = activity_id.clone();
1349 move || {
1350 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1351 let log = log_sender.clone();
1352 let task = proto_task(workflow_id.clone(), activity_id.clone(), "double", 21);
1353 async move {
1354 if attempt == 1 {
1355 Ok(ScriptedSession {
1356 index: 1,
1357 log,
1358 events: vec![Ok(WorkerSessionEvent::Task(task))],
1359 fail_reports: true,
1360 register_denial: None,
1361 delay_stream: None,
1362 })
1363 } else if attempt == 2 {
1364 Ok(ScriptedSession {
1365 index: attempt,
1366 log,
1367 events: Vec::new(),
1368 fail_reports: false,
1369 register_denial: None,
1370 delay_stream: None,
1371 })
1372 } else {
1373 Ok(ScriptedSession {
1376 index: attempt,
1377 log,
1378 events: Vec::new(),
1379 fail_reports: false,
1380 register_denial: Some(tonic::Status::permission_denied(
1381 "namespace `payments` revoked for subject `worker-a`",
1382 )),
1383 delay_stream: None,
1384 })
1385 }
1386 }
1387 }
1388 };
1389
1390 let result = worker
1391 .run_with_connector_until(connect, std::future::pending::<()>())
1392 .await;
1393
1394 drop(log_sender);
1395 let mut registrations = Vec::new();
1396 let mut reports = Vec::new();
1397 while let Some(entry) = log_receiver.recv().await {
1398 match entry {
1399 SessionLog::Registered(index, types) => registrations.push((index, types)),
1400 SessionLog::Reported(index, report) => reports.push((index, report)),
1401 }
1402 }
1403 let Err(error) = result else {
1404 return Err(WorkerError::decode(UnexpectedSuccess));
1405 };
1406 assert!(!error.is_retryable());
1407 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1408 let expected_types = vec![String::from("double"), String::from("increment")];
1409 assert_eq!(
1410 registrations,
1411 vec![(1, expected_types.clone()), (2, expected_types)]
1412 );
1413 assert_eq!(reports.len(), 1);
1414 let (session_index, report) = &reports[0];
1415 assert_eq!(*session_index, 2);
1416 let RecordedReport::Completed(reported_workflow, reported_id, payload) = report else {
1417 return Err(WorkerError::decode(UnexpectedReportShape));
1418 };
1419 assert_eq!(reported_workflow, &workflow_id);
1420 assert_eq!(reported_id, &activity_id);
1421 let output: TestOutput =
1422 serde_json::from_slice(payload.bytes()).map_err(WorkerError::decode)?;
1423 assert_eq!(output.value, 42);
1424 Ok(())
1425 }
1426
1427 #[tokio::test]
1428 async fn mid_run_drop_re_reports_unacked_results_for_all_workflows() -> Result<(), WorkerError>
1429 {
1430 let first_workflow = WorkflowId::new_v4();
1431 let second_workflow = WorkflowId::new_v4();
1432 let activity_id = ActivityId::from_sequence_position(3);
1433 let worker = two_activity_worker()?;
1434 let attempts = Arc::new(AtomicUsize::new(0));
1435 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1436 let connect = {
1437 let attempts = Arc::clone(&attempts);
1438 let log_sender = log_sender.clone();
1439 let first_workflow = first_workflow.clone();
1440 let second_workflow = second_workflow.clone();
1441 let activity_id = activity_id.clone();
1442 move || {
1443 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1444 let log = log_sender.clone();
1445 let first_task =
1446 proto_task(first_workflow.clone(), activity_id.clone(), "double", 10);
1447 let second_task =
1448 proto_task(second_workflow.clone(), activity_id.clone(), "double", 20);
1449 async move {
1450 if attempt == 1 {
1451 Ok(ScriptedSession {
1452 index: 1,
1453 log,
1454 events: vec![
1455 Ok(WorkerSessionEvent::Task(first_task)),
1456 Ok(WorkerSessionEvent::Task(second_task)),
1457 ],
1458 fail_reports: true,
1459 register_denial: None,
1460 delay_stream: None,
1461 })
1462 } else if attempt == 2 {
1463 Ok(ScriptedSession {
1464 index: attempt,
1465 log,
1466 events: Vec::new(),
1467 fail_reports: false,
1468 register_denial: None,
1469 delay_stream: None,
1470 })
1471 } else {
1472 Ok(ScriptedSession {
1475 index: attempt,
1476 log,
1477 events: Vec::new(),
1478 fail_reports: false,
1479 register_denial: Some(tonic::Status::permission_denied(
1480 "namespace `payments` revoked for subject `worker-a`",
1481 )),
1482 delay_stream: None,
1483 })
1484 }
1485 }
1486 }
1487 };
1488
1489 let result = worker
1490 .run_with_connector_until(connect, std::future::pending::<()>())
1491 .await;
1492
1493 drop(log_sender);
1494 let mut reports = Vec::new();
1495 while let Some(entry) = log_receiver.recv().await {
1496 if let SessionLog::Reported(index, report) = entry {
1497 reports.push((index, report));
1498 }
1499 }
1500 let Err(error) = result else {
1501 return Err(WorkerError::decode(UnexpectedSuccess));
1502 };
1503 assert!(!error.is_retryable());
1504 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1505 assert_eq!(
1506 reports.len(),
1507 2,
1508 "both workflows' colliding sequence-position results must be re-reported"
1509 );
1510 let mut reported_workflows = Vec::new();
1511 for (session_index, report) in &reports {
1512 assert_eq!(*session_index, 2, "re-reports must land on the new session");
1513 let RecordedReport::Completed(reported_workflow, reported_id, _) = report else {
1514 return Err(WorkerError::decode(UnexpectedReportShape));
1515 };
1516 assert_eq!(reported_id, &activity_id);
1517 reported_workflows.push(reported_workflow.clone());
1518 }
1519 assert!(reported_workflows.contains(&first_workflow));
1520 assert!(reported_workflows.contains(&second_workflow));
1521 Ok(())
1522 }
1523
1524 #[tokio::test]
1525 async fn shutdown_during_recovery_establishment_returns_original_drop_error()
1526 -> Result<(), WorkerError> {
1527 let worker = two_activity_worker()?;
1528 let attempts = Arc::new(AtomicUsize::new(0));
1529 let notify = Arc::new(Notify::new());
1530 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1531 let connect = {
1532 let attempts = Arc::clone(&attempts);
1533 let notify = Arc::clone(¬ify);
1534 move || {
1535 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1536 let notify = Arc::clone(¬ify);
1537 let log = log_sender.clone();
1538 async move {
1539 if attempt == 1 {
1540 Ok(ScriptedSession {
1541 index: 1,
1542 log,
1543 events: vec![Err(WorkerError::Transport {
1544 source: tonic::Status::unavailable("stream reset by peer"),
1545 })],
1546 fail_reports: false,
1547 register_denial: None,
1548 delay_stream: None,
1549 })
1550 } else {
1551 notify.notify_one();
1555 std::future::pending::<()>().await;
1556 Err(WorkerError::Transport {
1557 source: tonic::Status::unavailable("unreachable"),
1558 })
1559 }
1560 }
1561 }
1562 };
1563 let shutdown = {
1564 let notify = Arc::clone(¬ify);
1565 async move {
1566 notify.notified().await;
1567 }
1568 };
1569
1570 let run = worker.run_with_connector_until(connect, shutdown);
1571 let result = tokio::time::timeout(Duration::from_secs(5), run)
1572 .await
1573 .map_err(WorkerError::decode)?;
1574
1575 assert_eq!(attempts.load(Ordering::SeqCst), 2);
1576 let Err(error) = result else {
1577 return Err(WorkerError::decode(UnexpectedSuccess));
1578 };
1579 assert!(matches!(
1580 error.grpc_status().map(tonic::Status::code),
1581 Some(tonic::Code::Unavailable)
1582 ));
1583 assert_eq!(
1584 error.grpc_status().map(tonic::Status::message),
1585 Some("stream reset by peer"),
1586 "shutdown during recovery establishment must surface the original drop error"
1587 );
1588 drop(log_receiver);
1589 Ok(())
1590 }
1591
1592 #[tokio::test(start_paused = true)]
1596 async fn mid_run_drop_budget_exhaustion_surfaces_last_drop_error() -> Result<(), WorkerError> {
1597 let worker = two_activity_worker()?;
1598 let attempts = Arc::new(AtomicUsize::new(0));
1599 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1600 let connect = {
1601 let attempts = Arc::clone(&attempts);
1602 move || {
1603 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1604 let log = log_sender.clone();
1605 async move {
1606 Ok(ScriptedSession {
1607 index: attempt,
1608 log,
1609 events: vec![Err(WorkerError::Transport {
1610 source: tonic::Status::unavailable("stream reset by peer"),
1611 })],
1612 fail_reports: false,
1613 register_denial: None,
1614 delay_stream: None,
1615 })
1616 }
1617 }
1618 };
1619
1620 let run = worker.run_with_connector_until(connect, std::future::pending::<()>());
1621 let result = tokio::time::timeout(Duration::from_secs(5), run)
1622 .await
1623 .map_err(WorkerError::decode)?;
1624
1625 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1628 let Err(error) = result else {
1629 return Err(WorkerError::decode(UnexpectedSuccess));
1630 };
1631 assert!(error.is_retryable());
1632 assert!(matches!(
1633 error.grpc_status().map(tonic::Status::code),
1634 Some(tonic::Code::Unavailable)
1635 ));
1636 assert_eq!(
1637 error.grpc_status().map(tonic::Status::message),
1638 Some("stream reset by peer")
1639 );
1640 drop(log_receiver);
1641 Ok(())
1642 }
1643
1644 #[tokio::test]
1645 async fn mid_run_denial_surfaces_without_reconnect() -> Result<(), WorkerError> {
1646 let worker = two_activity_worker()?;
1647 let attempts = Arc::new(AtomicUsize::new(0));
1648 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1649 let connect = {
1650 let attempts = Arc::clone(&attempts);
1651 move || {
1652 attempts.fetch_add(1, Ordering::SeqCst);
1653 let log = log_sender.clone();
1654 async move {
1655 Ok(ScriptedSession {
1656 index: 1,
1657 log,
1658 events: vec![Err(WorkerError::Transport {
1659 source: tonic::Status::permission_denied(
1660 "namespace `payments` revoked for subject `worker-a`",
1661 ),
1662 })],
1663 fail_reports: false,
1664 register_denial: None,
1665 delay_stream: None,
1666 })
1667 }
1668 }
1669 };
1670
1671 let result = worker
1672 .run_with_connector_until(connect, std::future::pending::<()>())
1673 .await;
1674
1675 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1676 let Err(error) = result else {
1677 return Err(WorkerError::decode(UnexpectedSuccess));
1678 };
1679 assert!(!error.is_retryable());
1680 assert!(matches!(
1681 error.grpc_status().map(tonic::Status::code),
1682 Some(tonic::Code::PermissionDenied)
1683 ));
1684 assert_eq!(
1685 error.grpc_status().map(tonic::Status::message),
1686 Some("namespace `payments` revoked for subject `worker-a`")
1687 );
1688 drop(log_receiver);
1689 Ok(())
1690 }
1691
1692 #[tokio::test]
1693 async fn shutdown_during_establishment_backoff_returns_promptly() -> Result<(), WorkerError> {
1694 let worker = two_activity_worker_with(slow_reconnect_config())?;
1695 let attempts = Arc::new(AtomicUsize::new(0));
1696 let notify = Arc::new(Notify::new());
1697 let connect = {
1698 let attempts = Arc::clone(&attempts);
1699 let notify = Arc::clone(¬ify);
1700 move || {
1701 attempts.fetch_add(1, Ordering::SeqCst);
1702 notify.notify_one();
1703 async move {
1704 Err::<ScriptedSession, _>(WorkerError::Transport {
1705 source: tonic::Status::unavailable("engine unreachable"),
1706 })
1707 }
1708 }
1709 };
1710 let shutdown = {
1711 let notify = Arc::clone(¬ify);
1712 async move {
1713 notify.notified().await;
1714 }
1715 };
1716
1717 let run = worker.run_with_connector_until(connect, shutdown);
1718 tokio::time::timeout(Duration::from_millis(500), run)
1719 .await
1720 .map_err(WorkerError::decode)??;
1721
1722 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1723 Ok(())
1724 }
1725
1726 #[tokio::test]
1727 async fn shutdown_during_mid_run_drop_backoff_returns_promptly() -> Result<(), WorkerError> {
1728 let worker = two_activity_worker_with(slow_reconnect_config())?;
1729 let attempts = Arc::new(AtomicUsize::new(0));
1730 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1731 let connect = {
1732 let attempts = Arc::clone(&attempts);
1733 move || {
1734 attempts.fetch_add(1, Ordering::SeqCst);
1735 let log = log_sender.clone();
1736 async move {
1737 Ok(ScriptedSession {
1738 index: 1,
1739 log,
1740 events: vec![Err(WorkerError::Transport {
1741 source: tonic::Status::unavailable("stream reset by peer"),
1742 })],
1743 fail_reports: false,
1744 register_denial: None,
1745 delay_stream: None,
1746 })
1747 }
1748 }
1749 };
1750 let shutdown = async {
1751 tokio::time::sleep(Duration::from_millis(50)).await;
1752 };
1753
1754 let run = worker.run_with_connector_until(connect, shutdown);
1755 let result = tokio::time::timeout(Duration::from_millis(500), run)
1756 .await
1757 .map_err(WorkerError::decode)?;
1758
1759 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1760 let Err(error) = result else {
1761 return Err(WorkerError::decode(UnexpectedSuccess));
1762 };
1763 assert!(error.is_retryable());
1764 assert!(matches!(
1765 error.grpc_status().map(tonic::Status::code),
1766 Some(tonic::Code::Unavailable)
1767 ));
1768 drop(log_receiver);
1769 Ok(())
1770 }
1771
1772 #[tokio::test]
1773 async fn served_tasks_reset_drop_budget_across_cycles() -> Result<(), WorkerError> {
1774 let workflow_id = WorkflowId::new_v4();
1775 let activity_id = ActivityId::from_sequence_position(7);
1776 let worker = two_activity_worker_with(test_config_with(ReconnectConfig::new(
1779 Duration::from_millis(1),
1780 Duration::from_secs(3600),
1781 2,
1782 )))?;
1783 let attempts = Arc::new(AtomicUsize::new(0));
1784 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1785 let connect = {
1786 let attempts = Arc::clone(&attempts);
1787 let log_sender = log_sender.clone();
1788 let workflow_id = workflow_id.clone();
1789 let activity_id = activity_id.clone();
1790 move || {
1791 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1792 let log = log_sender.clone();
1793 let task = proto_task(workflow_id.clone(), activity_id.clone(), "double", 21);
1794 async move {
1795 if attempt <= 4 {
1796 Ok(ScriptedSession {
1797 index: attempt,
1798 log,
1799 events: vec![
1800 Ok(WorkerSessionEvent::Task(task)),
1801 Err(WorkerError::Transport {
1802 source: tonic::Status::unavailable("stream reset by peer"),
1803 }),
1804 ],
1805 fail_reports: false,
1806 register_denial: None,
1807 delay_stream: None,
1808 })
1809 } else {
1810 Ok(ScriptedSession {
1811 index: attempt,
1812 log,
1813 events: Vec::new(),
1814 fail_reports: false,
1815 register_denial: Some(tonic::Status::permission_denied(
1816 "namespace `payments` revoked for subject `worker-a`",
1817 )),
1818 delay_stream: None,
1819 })
1820 }
1821 }
1822 }
1823 };
1824
1825 let run = worker.run_with_connector_until(connect, std::future::pending::<()>());
1826 let result = tokio::time::timeout(Duration::from_secs(5), run)
1827 .await
1828 .map_err(WorkerError::decode)?;
1829
1830 drop(log_sender);
1831 let mut registrations = 0_usize;
1832 while let Some(entry) = log_receiver.recv().await {
1833 if let SessionLog::Registered(..) = entry {
1834 registrations += 1;
1835 }
1836 }
1837 assert_eq!(attempts.load(Ordering::SeqCst), 5);
1842 assert_eq!(registrations, 4);
1843 let Err(error) = result else {
1844 return Err(WorkerError::decode(UnexpectedSuccess));
1845 };
1846 assert!(!error.is_retryable());
1847 assert!(matches!(
1848 error.grpc_status().map(tonic::Status::code),
1849 Some(tonic::Code::PermissionDenied)
1850 ));
1851 Ok(())
1852 }
1853
1854 #[tokio::test(start_paused = true)]
1855 async fn session_outliving_max_backoff_resets_drop_budget() -> Result<(), WorkerError> {
1856 let worker = two_activity_worker_with(test_config_with(ReconnectConfig::new(
1857 Duration::from_millis(5),
1858 Duration::from_millis(20),
1859 2,
1860 )))?;
1861 let attempts = Arc::new(AtomicUsize::new(0));
1862 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1863 let connect = {
1864 let attempts = Arc::clone(&attempts);
1865 move || {
1866 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1867 let log = log_sender.clone();
1868 async move {
1869 Ok(ScriptedSession {
1870 index: attempt,
1871 log,
1872 events: vec![Err(WorkerError::Transport {
1873 source: tonic::Status::unavailable("stream reset by peer"),
1874 })],
1875 fail_reports: false,
1876 register_denial: None,
1877 delay_stream: (attempt == 2).then_some(Duration::from_millis(30)),
1881 })
1882 }
1883 }
1884 };
1885
1886 let run = worker.run_with_connector_until(connect, std::future::pending::<()>());
1887 let result = tokio::time::timeout(Duration::from_secs(5), run)
1888 .await
1889 .map_err(WorkerError::decode)?;
1890
1891 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1898 let Err(error) = result else {
1899 return Err(WorkerError::decode(UnexpectedSuccess));
1900 };
1901 assert!(error.is_retryable());
1902 assert!(matches!(
1903 error.grpc_status().map(tonic::Status::code),
1904 Some(tonic::Code::Unavailable)
1905 ));
1906 drop(log_receiver);
1907 Ok(())
1908 }
1909
1910 #[tokio::test(start_paused = true)]
1917 async fn post_drop_drain_time_does_not_reset_drop_budget() -> Result<(), WorkerError> {
1918 let workflow_id = WorkflowId::new_v4();
1919 let activity_id = ActivityId::from_sequence_position(9);
1920 let config = WorkerConfig::new(
1923 "http://127.0.0.1:50051",
1924 "payments",
1925 "worker-a",
1926 2,
1927 ReconnectConfig::new(Duration::from_millis(5), Duration::from_millis(20), 2),
1928 None,
1929 );
1930 let worker = Worker::builder(config)
1931 .register_activity("slow", |input: TestInput, context: &ActivityContext| {
1932 let _ = (input, context);
1933 Box::pin(async move {
1934 tokio::time::sleep(Duration::from_millis(60)).await;
1937 Ok(TestOutput { value: 1 })
1938 })
1939 })?
1940 .build()?;
1941 let attempts = Arc::new(AtomicUsize::new(0));
1942 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1943 let connect = {
1944 let attempts = Arc::clone(&attempts);
1945 let workflow_id = workflow_id.clone();
1946 let activity_id = activity_id.clone();
1947 move || {
1948 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1949 let log = log_sender.clone();
1950 let task = proto_task(workflow_id.clone(), activity_id.clone(), "slow", 1);
1951 async move {
1952 if attempt == 1 {
1953 Ok(ScriptedSession {
1957 index: 1,
1958 log,
1959 events: vec![Err(WorkerError::Transport {
1960 source: tonic::Status::unavailable("stream reset by peer"),
1961 })],
1962 fail_reports: false,
1963 register_denial: None,
1964 delay_stream: None,
1965 })
1966 } else {
1967 Ok(ScriptedSession {
1972 index: attempt,
1973 log,
1974 events: vec![
1975 Ok(WorkerSessionEvent::Task(task)),
1976 Err(WorkerError::Transport {
1977 source: tonic::Status::unavailable("stream reset by peer"),
1978 }),
1979 ],
1980 fail_reports: true,
1981 register_denial: None,
1982 delay_stream: None,
1983 })
1984 }
1985 }
1986 }
1987 };
1988
1989 let run = worker.run_with_connector_until(connect, std::future::pending::<()>());
1990 let result = tokio::time::timeout(Duration::from_secs(5), run)
1991 .await
1992 .map_err(WorkerError::decode)?;
1993
1994 assert_eq!(attempts.load(Ordering::SeqCst), 2);
2001 let Err(error) = result else {
2002 return Err(WorkerError::decode(UnexpectedSuccess));
2003 };
2004 assert!(error.is_retryable());
2005 assert!(matches!(
2006 error.grpc_status().map(tonic::Status::code),
2007 Some(tonic::Code::Unavailable)
2008 ));
2009 drop(log_receiver);
2010 Ok(())
2011 }
2012
2013 #[tokio::test]
2014 async fn clean_close_reconnects_re_registers_and_keeps_serving() -> Result<(), WorkerError> {
2015 let workflow_id = WorkflowId::new_v4();
2016 let first_activity = ActivityId::from_sequence_position(1);
2017 let second_activity = ActivityId::from_sequence_position(2);
2018 let worker = two_activity_worker()?;
2019 let attempts = Arc::new(AtomicUsize::new(0));
2020 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
2021 let connect = {
2022 let attempts = Arc::clone(&attempts);
2023 let log_sender = log_sender.clone();
2024 let workflow_id = workflow_id.clone();
2025 let first_activity = first_activity.clone();
2026 let second_activity = second_activity.clone();
2027 move || {
2028 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
2029 let log = log_sender.clone();
2030 let first_task =
2031 proto_task(workflow_id.clone(), first_activity.clone(), "double", 10);
2032 let second_task =
2033 proto_task(workflow_id.clone(), second_activity.clone(), "double", 20);
2034 async move {
2035 match attempt {
2036 1 => Ok(ScriptedSession {
2039 index: 1,
2040 log,
2041 events: vec![Ok(WorkerSessionEvent::Task(first_task))],
2042 fail_reports: false,
2043 register_denial: None,
2044 delay_stream: None,
2045 }),
2046 2 => Ok(ScriptedSession {
2047 index: 2,
2048 log,
2049 events: vec![Ok(WorkerSessionEvent::Task(second_task))],
2050 fail_reports: false,
2051 register_denial: None,
2052 delay_stream: None,
2053 }),
2054 _ => Ok(ScriptedSession {
2055 index: attempt,
2056 log,
2057 events: Vec::new(),
2058 fail_reports: false,
2059 register_denial: Some(tonic::Status::permission_denied(
2060 "namespace `payments` revoked for subject `worker-a`",
2061 )),
2062 delay_stream: None,
2063 }),
2064 }
2065 }
2066 }
2067 };
2068
2069 let run = worker.run_with_connector_until(connect, std::future::pending::<()>());
2070 let result = tokio::time::timeout(Duration::from_secs(5), run)
2071 .await
2072 .map_err(WorkerError::decode)?;
2073
2074 drop(log_sender);
2075 let mut registrations = Vec::new();
2076 let mut reports = Vec::new();
2077 while let Some(entry) = log_receiver.recv().await {
2078 match entry {
2079 SessionLog::Registered(index, types) => registrations.push((index, types)),
2080 SessionLog::Reported(index, report) => reports.push((index, report)),
2081 }
2082 }
2083 assert_eq!(attempts.load(Ordering::SeqCst), 3);
2087 let expected_types = vec![String::from("double"), String::from("increment")];
2088 assert_eq!(
2089 registrations,
2090 vec![(1, expected_types.clone()), (2, expected_types)]
2091 );
2092 assert_eq!(reports.len(), 3);
2093 assert!(matches!(
2094 &reports[0],
2095 (1, RecordedReport::Completed(_, id, _)) if id == &first_activity
2096 ));
2097 assert!(matches!(
2098 &reports[1],
2099 (2, RecordedReport::Completed(_, id, _)) if id == &first_activity
2100 ));
2101 assert!(matches!(
2102 &reports[2],
2103 (2, RecordedReport::Completed(_, id, _)) if id == &second_activity
2104 ));
2105 let Err(error) = result else {
2106 return Err(WorkerError::decode(UnexpectedSuccess));
2107 };
2108 assert!(!error.is_retryable());
2109 assert!(matches!(
2110 error.grpc_status().map(tonic::Status::code),
2111 Some(tonic::Code::PermissionDenied)
2112 ));
2113 Ok(())
2114 }
2115
2116 #[tokio::test(start_paused = true)]
2117 async fn clean_close_loop_exhausts_drop_budget_with_classified_error() -> Result<(), WorkerError>
2118 {
2119 let worker = two_activity_worker()?;
2120 let attempts = Arc::new(AtomicUsize::new(0));
2121 let (log_sender, log_receiver) = mpsc::unbounded_channel();
2122 let connect = {
2123 let attempts = Arc::clone(&attempts);
2124 move || {
2125 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
2126 let log = log_sender.clone();
2127 async move {
2128 Ok(ScriptedSession {
2129 index: attempt,
2130 log,
2131 events: Vec::new(),
2132 fail_reports: false,
2133 register_denial: None,
2134 delay_stream: None,
2135 })
2136 }
2137 }
2138 };
2139
2140 let run = worker.run_with_connector_until(connect, std::future::pending::<()>());
2141 let result = tokio::time::timeout(Duration::from_secs(5), run)
2142 .await
2143 .map_err(WorkerError::decode)?;
2144
2145 assert_eq!(attempts.load(Ordering::SeqCst), 3);
2150 let Err(error) = result else {
2151 return Err(WorkerError::decode(UnexpectedSuccess));
2152 };
2153 assert!(matches!(error, WorkerError::CleanCloseExhausted));
2154 assert!(error.to_string().contains("closed the stream cleanly"));
2155 drop(log_receiver);
2156 Ok(())
2157 }
2158
2159 #[tokio::test]
2160 async fn shutdown_during_clean_close_backoff_returns_ok_promptly() -> Result<(), WorkerError> {
2161 let worker = two_activity_worker_with(slow_reconnect_config())?;
2162 let attempts = Arc::new(AtomicUsize::new(0));
2163 let (log_sender, log_receiver) = mpsc::unbounded_channel();
2164 let connect = {
2165 let attempts = Arc::clone(&attempts);
2166 move || {
2167 attempts.fetch_add(1, Ordering::SeqCst);
2168 let log = log_sender.clone();
2169 async move {
2170 Ok(ScriptedSession {
2171 index: 1,
2172 log,
2173 events: Vec::new(),
2174 fail_reports: false,
2175 register_denial: None,
2176 delay_stream: None,
2177 })
2178 }
2179 }
2180 };
2181 let shutdown = async {
2182 tokio::time::sleep(Duration::from_millis(50)).await;
2183 };
2184
2185 let run = worker.run_with_connector_until(connect, shutdown);
2188 tokio::time::timeout(Duration::from_millis(500), run)
2189 .await
2190 .map_err(WorkerError::decode)??;
2191
2192 assert_eq!(attempts.load(Ordering::SeqCst), 1);
2193 drop(log_receiver);
2194 Ok(())
2195 }
2196
2197 #[tokio::test]
2201 async fn result_ack_clears_exactly_its_tracker_entry() -> Result<(), WorkerError> {
2202 use crate::protocol::reconnect::{PendingActivityReport, UnackedResultTracker};
2203 use crate::runtime::loop_::{SessionHealth, serve_activity_tasks_until};
2204
2205 let workflow_a = WorkflowId::new_v4();
2206 let workflow_b = WorkflowId::new_v4();
2207 let position = ActivityId::from_sequence_position(5);
2208 let mut tracker = UnackedResultTracker::new();
2209 for workflow in [&workflow_a, &workflow_b] {
2210 tracker.record(PendingActivityReport::Completed {
2211 workflow_id: workflow.clone(),
2212 activity_id: position.clone(),
2213 output: Payload::new(ContentType::Json, b"{\"value\":1}".to_vec()),
2214 });
2215 }
2216
2217 let worker = two_activity_worker()?;
2218 let mut session = ChannelSession {
2219 receiver: None,
2220 reports: Vec::new(),
2221 registered: Vec::new(),
2222 };
2223 let (sender, receiver) = mpsc::channel(4);
2224 sender
2225 .send(Ok(WorkerSessionEvent::ResultAck {
2226 workflow_id: workflow_a.clone(),
2227 activity_id: position.clone(),
2228 }))
2229 .await
2230 .map_err(WorkerError::decode)?;
2231 sender
2233 .send(Ok(WorkerSessionEvent::ResultAck {
2234 workflow_id: WorkflowId::new_v4(),
2235 activity_id: ActivityId::from_sequence_position(99),
2236 }))
2237 .await
2238 .map_err(WorkerError::decode)?;
2239 drop(sender);
2240 session.receiver = Some(receiver);
2241
2242 let mut health = SessionHealth::default();
2243 serve_activity_tasks_until(
2244 &test_config(),
2245 &mut session,
2246 Arc::new(crate::activity::ActivityRegistry::new()),
2247 &mut tracker,
2248 &mut health,
2249 std::future::pending(),
2250 )
2251 .await?;
2252
2253 assert_eq!(tracker.len(), 1, "exactly the acked entry must clear");
2254 assert!(tracker.get(&workflow_a, &position).is_none());
2255 assert!(tracker.get(&workflow_b, &position).is_some());
2256 drop(worker);
2257 Ok(())
2258 }
2259
2260 #[tokio::test]
2264 async fn acked_results_decay_out_of_the_reconnect_replay() -> Result<(), WorkerError> {
2265 use crate::protocol::re_report_unacked;
2266 use crate::protocol::reconnect::{PendingActivityReport, UnackedResultTracker};
2267 use crate::runtime::loop_::{SessionHealth, serve_activity_tasks_until};
2268
2269 let workflow_id = WorkflowId::new_v4();
2270 let acked_id = ActivityId::from_sequence_position(1);
2271 let unacked_id = ActivityId::from_sequence_position(2);
2272 let mut tracker = UnackedResultTracker::new();
2273 for id in [&acked_id, &unacked_id] {
2274 tracker.record(PendingActivityReport::Completed {
2275 workflow_id: workflow_id.clone(),
2276 activity_id: id.clone(),
2277 output: Payload::new(ContentType::Json, b"{\"value\":2}".to_vec()),
2278 });
2279 }
2280
2281 let mut session = ChannelSession {
2284 receiver: None,
2285 reports: Vec::new(),
2286 registered: Vec::new(),
2287 };
2288 let (sender, receiver) = mpsc::channel(2);
2289 sender
2290 .send(Ok(WorkerSessionEvent::ResultAck {
2291 workflow_id: workflow_id.clone(),
2292 activity_id: acked_id.clone(),
2293 }))
2294 .await
2295 .map_err(WorkerError::decode)?;
2296 drop(sender);
2297 session.receiver = Some(receiver);
2298 let mut health = SessionHealth::default();
2299 serve_activity_tasks_until(
2300 &test_config(),
2301 &mut session,
2302 Arc::new(crate::activity::ActivityRegistry::new()),
2303 &mut tracker,
2304 &mut health,
2305 std::future::pending(),
2306 )
2307 .await?;
2308
2309 let mut replay_session = ChannelSession {
2311 receiver: None,
2312 reports: Vec::new(),
2313 registered: Vec::new(),
2314 };
2315 re_report_unacked(&tracker, &mut replay_session).await?;
2316 assert_eq!(
2317 replay_session.reports.len(),
2318 1,
2319 "only the un-acked result may be re-reported"
2320 );
2321 assert!(matches!(
2322 &replay_session.reports[0],
2323 RecordedReport::Completed(_, id, _) if id == &unacked_id
2324 ));
2325
2326 let (sender, receiver) = mpsc::channel(2);
2329 sender
2330 .send(Ok(WorkerSessionEvent::ResultAck {
2331 workflow_id: workflow_id.clone(),
2332 activity_id: unacked_id.clone(),
2333 }))
2334 .await
2335 .map_err(WorkerError::decode)?;
2336 drop(sender);
2337 replay_session.receiver = Some(receiver);
2338 let mut health = SessionHealth::default();
2339 serve_activity_tasks_until(
2340 &test_config(),
2341 &mut replay_session,
2342 Arc::new(crate::activity::ActivityRegistry::new()),
2343 &mut tracker,
2344 &mut health,
2345 std::future::pending(),
2346 )
2347 .await?;
2348 assert!(tracker.is_empty(), "acks must drain the tracker");
2349
2350 let mut decayed_session = ChannelSession {
2351 receiver: None,
2352 reports: Vec::new(),
2353 registered: Vec::new(),
2354 };
2355 re_report_unacked(&tracker, &mut decayed_session).await?;
2356 assert!(
2357 decayed_session.reports.is_empty(),
2358 "steady-state replay must send nothing"
2359 );
2360 Ok(())
2361 }
2362
2363 #[tokio::test(start_paused = true)]
2366 async fn shutdown_interrupts_hung_unacked_replay_promptly() -> Result<(), WorkerError> {
2367 let workflow_id = WorkflowId::new_v4();
2370 let activity_id = ActivityId::from_sequence_position(3);
2371 let worker = two_activity_worker()?;
2372 let attempts = Arc::new(AtomicUsize::new(0));
2373 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
2374 let (registered_2_tx, registered_2_rx) = tokio::sync::oneshot::channel::<()>();
2375 let registered_2_tx = std::sync::Mutex::new(Some(registered_2_tx));
2376 let connect = {
2377 let log_sender = log_sender.clone();
2378 let workflow_id = workflow_id.clone();
2379 let activity_id = activity_id.clone();
2380 move |attempt_override: usize| {
2381 let log = log_sender.clone();
2382 let task = proto_task(workflow_id.clone(), activity_id.clone(), "double", 21);
2383 let notify = if attempt_override == 2 {
2384 registered_2_tx
2385 .lock()
2386 .ok()
2387 .and_then(|mut guard| guard.take())
2388 } else {
2389 None
2390 };
2391 async move {
2392 if attempt_override == 1 {
2393 Ok(SessionKind::Scripted(ScriptedSession {
2394 index: 1,
2395 log,
2396 events: vec![Ok(WorkerSessionEvent::Task(task))],
2397 fail_reports: true,
2398 register_denial: None,
2399 delay_stream: None,
2400 }))
2401 } else {
2402 if let Some(notify) = notify {
2403 let _ = notify.send(());
2404 }
2405 Ok(SessionKind::Hung(HungReportSession { index: 2, log }))
2406 }
2407 }
2408 }
2409 };
2410
2411 let attempts_for_connect = Arc::clone(&attempts);
2412 let run = worker.run_with_connector_until(
2413 move || {
2414 let attempt = attempts_for_connect.fetch_add(1, Ordering::SeqCst) + 1;
2415 connect(attempt)
2416 },
2417 async move {
2418 let _ = registered_2_rx.await;
2419 },
2420 );
2421
2422 tokio::time::timeout(Duration::from_secs(60), run)
2425 .await
2426 .map_err(WorkerError::decode)??;
2427
2428 drop(log_sender);
2429 let mut hung_session_reports = 0_usize;
2430 while let Some(entry) = log_receiver.recv().await {
2431 if let SessionLog::Reported(2, _) = entry {
2432 hung_session_reports += 1;
2433 }
2434 }
2435 assert_eq!(
2436 hung_session_reports, 0,
2437 "the hung replay must not have produced a report"
2438 );
2439 assert_eq!(attempts.load(Ordering::SeqCst), 2);
2440 Ok(())
2441 }
2442
2443 #[tokio::test(start_paused = true)]
2447 async fn drain_cycles_reconnect_without_consuming_drop_budget() -> Result<(), WorkerError> {
2448 let worker = two_activity_worker_with(test_config_with(ReconnectConfig::new(
2449 Duration::from_millis(5),
2450 Duration::from_millis(20),
2451 2,
2452 )))?;
2453 let attempts = Arc::new(AtomicUsize::new(0));
2454 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
2455 let connect = {
2456 let attempts = Arc::clone(&attempts);
2457 move || {
2458 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
2459 let log = log_sender.clone();
2460 async move {
2461 if attempt <= 3 {
2462 Ok(ScriptedSession {
2463 index: attempt,
2464 log,
2465 events: vec![Ok(WorkerSessionEvent::Drain)],
2466 fail_reports: false,
2467 register_denial: None,
2468 delay_stream: None,
2469 })
2470 } else {
2471 Ok(ScriptedSession {
2472 index: attempt,
2473 log,
2474 events: Vec::new(),
2475 fail_reports: false,
2476 register_denial: Some(tonic::Status::permission_denied(
2477 "namespace `payments` revoked for subject `worker-a`",
2478 )),
2479 delay_stream: None,
2480 })
2481 }
2482 }
2483 }
2484 };
2485
2486 let result = worker
2487 .run_with_connector_until(connect, std::future::pending::<()>())
2488 .await;
2489
2490 assert_eq!(attempts.load(Ordering::SeqCst), 4);
2494 let Err(error) = result else {
2495 return Err(WorkerError::decode(UnexpectedSuccess));
2496 };
2497 assert!(matches!(
2498 error.grpc_status().map(tonic::Status::code),
2499 Some(tonic::Code::PermissionDenied)
2500 ));
2501 let mut registrations = 0_usize;
2502 while let Some(entry) = log_receiver.recv().await {
2503 if matches!(entry, SessionLog::Registered(..)) {
2504 registrations += 1;
2505 }
2506 }
2507 assert_eq!(registrations, 3, "every drain cycle must re-register");
2508 Ok(())
2509 }
2510
2511 #[tokio::test(start_paused = true)]
2517 async fn drain_latch_keeps_abrupt_post_drain_failures_unbudgeted() -> Result<(), WorkerError> {
2518 let workflow_id = WorkflowId::new_v4();
2519 let worker = Worker::builder(test_config_with(ReconnectConfig::new(
2524 Duration::from_millis(5),
2525 Duration::from_millis(20),
2526 2,
2527 )))
2528 .register_activity("slow_double", |input: TestInput, context| {
2529 Box::pin(async move {
2530 let _ = context;
2531 tokio::time::sleep(Duration::from_millis(1)).await;
2532 Ok(TestOutput {
2533 value: input.value * 2,
2534 })
2535 })
2536 })?
2537 .build()?;
2538 let attempts = Arc::new(AtomicUsize::new(0));
2539 let (log_sender, log_receiver) = mpsc::unbounded_channel();
2540 let connect = {
2541 let attempts = Arc::clone(&attempts);
2542 let workflow_id = workflow_id.clone();
2543 move || {
2544 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
2545 let log = log_sender.clone();
2546 let attempt_u64 = u64::try_from(attempt).unwrap_or(u64::MAX);
2547 let activity_id = ActivityId::from_sequence_position(attempt_u64);
2548 let task = proto_task(workflow_id.clone(), activity_id.clone(), "slow_double", 21);
2549 async move {
2550 if attempt <= 3 {
2551 Ok(LatchKind::Latch(DrainLatchSession {
2552 events: vec![
2553 Ok(WorkerSessionEvent::Task(task)),
2554 Ok(WorkerSessionEvent::Drain),
2555 ],
2556 fail_id: activity_id,
2557 }))
2558 } else {
2559 Ok(LatchKind::Deny(ScriptedSession {
2560 index: attempt,
2561 log,
2562 events: Vec::new(),
2563 fail_reports: false,
2564 register_denial: Some(tonic::Status::permission_denied(
2565 "namespace `payments` revoked for subject `worker-a`",
2566 )),
2567 delay_stream: None,
2568 }))
2569 }
2570 }
2571 }
2572 };
2573
2574 let result = worker
2575 .run_with_connector_until(connect, std::future::pending::<()>())
2576 .await;
2577
2578 assert_eq!(attempts.load(Ordering::SeqCst), 4);
2581 let Err(error) = result else {
2582 return Err(WorkerError::decode(UnexpectedSuccess));
2583 };
2584 assert!(matches!(
2585 error.grpc_status().map(tonic::Status::code),
2586 Some(tonic::Code::PermissionDenied)
2587 ));
2588 drop(log_receiver);
2589 Ok(())
2590 }
2591
2592 #[tokio::test]
2596 async fn shutdown_during_post_drain_backoff_returns_ok_promptly() -> Result<(), WorkerError> {
2597 let worker = two_activity_worker_with(test_config_with(ReconnectConfig::new(
2598 Duration::from_secs(5),
2599 Duration::from_secs(10),
2600 5,
2601 )))?;
2602 let attempts = Arc::new(AtomicUsize::new(0));
2603 let (log_sender, log_receiver) = mpsc::unbounded_channel();
2604 let connect = {
2605 let attempts = Arc::clone(&attempts);
2606 move || {
2607 attempts.fetch_add(1, Ordering::SeqCst);
2608 let log = log_sender.clone();
2609 async move {
2610 Ok(ScriptedSession {
2611 index: 1,
2612 log,
2613 events: vec![Ok(WorkerSessionEvent::Drain)],
2614 fail_reports: false,
2615 register_denial: None,
2616 delay_stream: None,
2617 })
2618 }
2619 }
2620 };
2621 let shutdown = async {
2622 tokio::time::sleep(Duration::from_millis(50)).await;
2623 };
2624
2625 let run = worker.run_with_connector_until(connect, shutdown);
2628 tokio::time::timeout(Duration::from_millis(500), run)
2629 .await
2630 .map_err(WorkerError::decode)??;
2631
2632 assert_eq!(attempts.load(Ordering::SeqCst), 1);
2633 drop(log_receiver);
2634 Ok(())
2635 }
2636}