1use std::collections::BTreeSet;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::Arc;
7use std::task::Poll;
8
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use tracing::{error, info, warn};
12
13use crate::activity::{ActivityRegistry, HandlerFuture};
14use crate::config::WorkerConfig;
15use crate::context::ActivityContext;
16use crate::error::WorkerError;
17use crate::protocol::reconnect::{
18 ReconnectBackoff, UnackedResultTracker, re_report_unacked, reconnect_with_backoff,
19 register_connected_session,
20};
21use crate::protocol::{GrpcWorkerSession, WorkerSession};
22use crate::runtime::{
23 NoShutdown, ServeEnd, SessionHealth, serve_activity_tasks, serve_activity_tasks_until,
24};
25
26#[must_use]
28pub struct WorkerBuilder {
29 config: WorkerConfig,
30 activities: ActivityRegistry,
31}
32
33impl WorkerBuilder {
34 pub fn new(config: WorkerConfig) -> Self {
36 Self {
37 config,
38 activities: ActivityRegistry::new(),
39 }
40 }
41
42 pub fn register_activity<Input, Output, Handler>(
48 mut self,
49 activity_type: impl Into<String>,
50 handler: Handler,
51 ) -> Result<Self, WorkerError>
52 where
53 Input: Serialize + DeserializeOwned + Send + Sync + 'static,
54 Output: Serialize + Send + Sync + 'static,
55 Handler: for<'context> Fn(Input, &'context ActivityContext) -> HandlerFuture<'context, Output>
56 + Send
57 + Sync
58 + 'static,
59 {
60 self.activities = self.activities.register_activity(activity_type, handler)?;
61 Ok(self)
62 }
63
64 pub fn build(self) -> Result<Worker, WorkerError> {
70 if self.activities.is_empty() {
71 return Err(WorkerError::registration(EmptyActivitySet));
72 }
73 let available_handlers = self.activities.activity_types();
74 let activity_types = available_handlers.iter().cloned().collect();
75 Ok(Worker {
76 config: self.config,
77 activity_types,
78 available_handlers,
79 activities: Arc::new(self.activities),
80 })
81 }
82}
83
84#[must_use]
86pub struct Worker {
87 config: WorkerConfig,
88 activity_types: Vec<String>,
89 available_handlers: BTreeSet<String>,
90 activities: Arc<ActivityRegistry>,
91}
92
93impl Worker {
94 pub fn builder(config: WorkerConfig) -> WorkerBuilder {
96 WorkerBuilder::new(config)
97 }
98
99 #[must_use]
101 pub fn activity_types(&self) -> &[String] {
102 &self.activity_types
103 }
104
105 #[must_use]
107 pub fn available_handlers(&self) -> &BTreeSet<String> {
108 &self.available_handlers
109 }
110
111 fn log_session_established(&self) {
115 info!(
116 identity = %self.config.identity,
117 endpoint = %self.config.endpoint,
118 activity_types = ?self.activity_types,
119 "worker session established; serving activities"
120 );
121 }
122
123 pub async fn run(self) -> Result<(), WorkerError> {
144 self.run_until(std::future::pending::<()>()).await
145 }
146
147 pub async fn run_until<Shutdown>(self, shutdown: Shutdown) -> Result<(), WorkerError>
159 where
160 Shutdown: Future<Output = ()> + Send,
161 {
162 let config = self.config.clone();
163 self.run_with_connector_until(move || GrpcWorkerSession::connect(config.clone()), shutdown)
164 .await
165 }
166
167 pub async fn run_with_connector_until<S, F, Fut, Shutdown>(
208 self,
209 mut connect: F,
210 shutdown: Shutdown,
211 ) -> Result<(), WorkerError>
212 where
213 S: WorkerSession,
214 F: FnMut() -> Fut,
215 Fut: Future<Output = Result<S, WorkerError>>,
216 Shutdown: Future<Output = ()> + Send,
217 {
218 let backoff = ReconnectBackoff::from_config(&self.config)?;
219 let mut tracker = UnackedResultTracker::new();
220 tokio::pin!(shutdown);
221 let mut shutdown = SharedShutdown::new(shutdown);
222 let mut drop_failures = 0_usize;
223 let mut recovery_error: Option<WorkerError> = None;
224
225 loop {
226 let connected = tokio::select! {
227 biased;
228 () = shutdown.wait() => {
229 return recovery_error.take().map_or(Ok(()), Err);
230 }
231 result = reconnect_with_backoff(
232 &self.config,
233 self.activity_types.clone(),
234 &self.available_handlers,
235 &mut connect,
236 ) => result,
237 };
238 let mut session = connected?;
239 self.log_session_established();
240 let session_started = tokio::time::Instant::now();
241 let mut health = SessionHealth::default();
242 let replay = tokio::select! {
247 biased;
248 () = shutdown.wait() => None,
249 result = re_report_unacked(&tracker, &mut session) => Some(result),
250 };
251 let Some(replay_result) = replay else {
252 return Ok(());
253 };
254 let served = match replay_result {
255 Ok(()) => {
256 serve_activity_tasks_until(
257 &self.config,
258 &mut session,
259 Arc::clone(&self.activities),
260 &mut tracker,
261 &mut health,
262 shutdown.wait(),
263 )
264 .await
265 }
266 Err(report_error) => Err(report_error),
267 };
268 drop(session);
269 let cause = match classify_serve_outcome(served, &health, shutdown.fired()) {
270 ServeClassification::End(result) => return result,
271 ServeClassification::Drop(cause) => cause,
272 };
273 let connected_for = health
280 .stream_ended_at
281 .unwrap_or_else(tokio::time::Instant::now)
282 .saturating_duration_since(session_started);
283 let proved_healthy = health.tasks_reported > 0 || connected_for > backoff.max_delay();
284 if proved_healthy && drop_failures > 0 {
285 info!(
286 drop_failures,
287 tasks_reported = health.tasks_reported,
288 "worker session proved healthy; drop budget reset"
289 );
290 drop_failures = 0;
291 }
292 let delay = if matches!(cause, DropCause::Drain) {
297 self.config.reconnect.initial_backoff
298 } else {
299 drop_failures += 1;
300 if drop_failures >= backoff.attempts() {
301 let error = cause.into_exhaustion_error();
302 error!(
303 drop_failures,
304 error = %error,
305 "worker session drop budget exhausted; not reconnecting"
306 );
307 return Err(error);
308 }
309 backoff.delay_for_attempt(drop_failures)
310 };
311 warn!(
312 drop_failures,
313 delay_ms = delay.as_millis(),
314 cause = %cause,
315 "worker session dropped; reconnecting after backoff"
316 );
317 let shutdown_won = tokio::select! {
318 biased;
319 () = shutdown.wait() => true,
320 () = tokio::time::sleep(delay) => false,
321 };
322 if shutdown_won {
323 return cause.into_shutdown_result();
324 }
325 recovery_error = cause.into_recovery_error();
326 }
327 }
328
329 pub async fn run_with_session<S>(self, session: S) -> Result<S, WorkerError>
335 where
336 S: WorkerSession,
337 {
338 self.run_with_session_until(session, std::future::pending::<()>())
339 .await
340 }
341
342 pub async fn run_with_session_until<S, Shutdown>(
348 self,
349 session: S,
350 shutdown: Shutdown,
351 ) -> Result<S, WorkerError>
352 where
353 S: WorkerSession,
354 Shutdown: Future<Output = ()> + Send,
355 {
356 let mut session = register_connected_session(
357 session,
358 &self.config,
359 self.activity_types.clone(),
360 &self.available_handlers,
361 )
362 .await?;
363 let mut tracker = UnackedResultTracker::new();
364 let mut health = SessionHealth::default();
365 serve_activity_tasks_until(
366 &self.config,
367 &mut session,
368 self.activities,
369 &mut tracker,
370 &mut health,
371 shutdown,
372 )
373 .await?;
374 Ok(session)
375 }
376}
377
378enum ServeClassification {
380 End(Result<(), WorkerError>),
382 Drop(DropCause),
384}
385
386fn classify_serve_outcome(
393 served: Result<ServeEnd, WorkerError>,
394 health: &SessionHealth,
395 shutdown_fired: bool,
396) -> ServeClassification {
397 match served {
398 Ok(ServeEnd::Shutdown) => ServeClassification::End(Ok(())),
399 Ok(ServeEnd::Drained) => {
400 if shutdown_fired {
401 return ServeClassification::End(Ok(()));
402 }
403 ServeClassification::Drop(DropCause::Drain)
404 }
405 Ok(ServeEnd::StreamClosed) => {
406 if shutdown_fired {
407 return ServeClassification::End(Ok(()));
408 }
409 ServeClassification::Drop(DropCause::CleanClose)
410 }
411 Err(error) if !error.is_retryable() => {
412 error!(error = %error, "worker session denied by server; not reconnecting");
413 ServeClassification::End(Err(error))
414 }
415 Err(error) if health.drain_received => {
416 warn!(
420 error = %error,
421 "session error after server drain; classified as drain drop"
422 );
423 if shutdown_fired {
424 return ServeClassification::End(Ok(()));
425 }
426 ServeClassification::Drop(DropCause::Drain)
427 }
428 Err(error) => {
429 if shutdown_fired {
430 return ServeClassification::End(Err(error));
431 }
432 ServeClassification::Drop(DropCause::Failure(error))
433 }
434 }
435}
436
437enum DropCause {
439 Failure(WorkerError),
441 CleanClose,
443 Drain,
446}
447
448impl DropCause {
449 fn into_exhaustion_error(self) -> WorkerError {
455 match self {
456 Self::Failure(error) => error,
457 Self::CleanClose | Self::Drain => WorkerError::CleanCloseExhausted,
458 }
459 }
460
461 fn into_shutdown_result(self) -> Result<(), WorkerError> {
464 match self {
465 Self::Failure(error) => Err(error),
466 Self::CleanClose | Self::Drain => Ok(()),
467 }
468 }
469
470 fn into_recovery_error(self) -> Option<WorkerError> {
472 match self {
473 Self::Failure(error) => Some(error),
474 Self::CleanClose | Self::Drain => None,
475 }
476 }
477}
478
479impl std::fmt::Display for DropCause {
480 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
481 match self {
482 Self::Failure(error) => write!(formatter, "{error}"),
483 Self::CleanClose => write!(formatter, "server closed the worker stream cleanly"),
484 Self::Drain => write!(formatter, "server drained the worker stream"),
485 }
486 }
487}
488
489struct SharedShutdown<'a, S> {
498 inner: Pin<&'a mut S>,
499 fired: bool,
500}
501
502impl<'a, S> SharedShutdown<'a, S>
503where
504 S: Future<Output = ()> + Send,
505{
506 const fn new(inner: Pin<&'a mut S>) -> Self {
507 Self {
508 inner,
509 fired: false,
510 }
511 }
512
513 const fn fired(&self) -> bool {
515 self.fired
516 }
517
518 fn wait(&mut self) -> impl Future<Output = ()> + Send {
520 std::future::poll_fn(|context| {
521 if self.fired {
522 return Poll::Ready(());
523 }
524 match self.inner.as_mut().poll(context) {
525 Poll::Ready(()) => {
526 self.fired = true;
527 Poll::Ready(())
528 }
529 Poll::Pending => Poll::Pending,
530 }
531 })
532 }
533}
534
535pub async fn run_worker_with_session<S>(worker: Worker, session: S) -> Result<S, WorkerError>
541where
542 S: WorkerSession,
543{
544 worker.run_with_session(session).await
545}
546
547#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
549#[error("worker must register at least one activity handler")]
550pub struct EmptyActivitySet;
551
552fn _assert_live_session_type() {
553 let _ = std::mem::size_of::<GrpcWorkerSession>();
554 let _ = std::mem::size_of::<NoShutdown>();
555 let _ = serve_activity_tasks::<GrpcWorkerSession, ActivityRegistry>;
556}
557
558#[cfg(test)]
559mod tests {
560 use std::collections::BTreeSet;
561 use std::sync::Arc;
562 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
563 use std::time::Duration;
564
565 use aion_core::{ActivityError, ActivityId, ContentType, Payload, WorkflowId};
566 use aion_proto::{ProtoActivityId, ProtoActivityTask, ProtoPayload, ProtoWorkflowId};
567 use async_trait::async_trait;
568 use futures::StreamExt as _;
569 use futures::stream;
570 use serde::{Deserialize, Serialize};
571 use tokio::sync::{Notify, mpsc};
572
573 use super::{Worker, WorkerBuilder};
574 use crate::config::{ReconnectConfig, WorkerConfig};
575 use crate::context::ActivityContext;
576 use crate::error::WorkerError;
577 use crate::protocol::{
578 WorkerSession, WorkerSessionEvent, WorkerTaskStream, validate_activity_handlers,
579 };
580
581 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
582 struct TestInput {
583 value: i32,
584 }
585
586 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
587 struct TestOutput {
588 value: i32,
589 }
590
591 struct ChannelSession {
592 receiver: Option<mpsc::Receiver<Result<WorkerSessionEvent, WorkerError>>>,
593 reports: Vec<RecordedReport>,
594 registered: Vec<String>,
595 }
596
597 #[derive(Clone, Debug, PartialEq, Eq)]
598 enum RecordedReport {
599 Completed(WorkflowId, ActivityId, Payload),
600 Failed(WorkflowId, ActivityId, ActivityError),
601 }
602
603 #[async_trait]
604 impl WorkerSession for ChannelSession {
605 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
606 drop(config.clone());
607 Ok(())
608 }
609
610 async fn register(
611 &mut self,
612 activity_types: Vec<String>,
613 available_handlers: &BTreeSet<String>,
614 ) -> Result<(), WorkerError> {
615 validate_activity_handlers(&activity_types, available_handlers)?;
616 self.registered = activity_types;
617 Ok(())
618 }
619
620 fn receive_tasks(&mut self) -> WorkerTaskStream {
621 match self.receiver.take() {
622 Some(receiver) => Box::pin(tokio_stream::wrappers::ReceiverStream::new(receiver)),
623 None => Box::pin(stream::empty()),
624 }
625 }
626
627 async fn report_result(
628 &mut self,
629 workflow_id: WorkflowId,
630 activity_id: ActivityId,
631 result: Payload,
632 ) -> Result<(), WorkerError> {
633 self.reports
634 .push(RecordedReport::Completed(workflow_id, activity_id, result));
635 Ok(())
636 }
637
638 async fn report_failure(
639 &mut self,
640 workflow_id: WorkflowId,
641 activity_id: ActivityId,
642 failure: ActivityError,
643 ) -> Result<(), WorkerError> {
644 self.reports
645 .push(RecordedReport::Failed(workflow_id, activity_id, failure));
646 Ok(())
647 }
648
649 async fn send_heartbeat(
650 &mut self,
651 workflow_id: WorkflowId,
652 activity_id: ActivityId,
653 progress: Option<Payload>,
654 ) -> Result<(), WorkerError> {
655 drop((workflow_id, activity_id, progress));
656 Ok(())
657 }
658 }
659
660 struct HungReportSession {
663 log: mpsc::UnboundedSender<SessionLog>,
664 index: usize,
665 }
666
667 #[async_trait]
668 impl WorkerSession for HungReportSession {
669 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
670 drop(config.clone());
671 Ok(())
672 }
673
674 async fn register(
675 &mut self,
676 activity_types: Vec<String>,
677 available_handlers: &BTreeSet<String>,
678 ) -> Result<(), WorkerError> {
679 validate_activity_handlers(&activity_types, available_handlers)?;
680 self.log
681 .send(SessionLog::Registered(self.index, activity_types))
682 .map_err(WorkerError::decode)
683 }
684
685 fn receive_tasks(&mut self) -> WorkerTaskStream {
686 Box::pin(stream::pending())
687 }
688
689 async fn report_result(
690 &mut self,
691 _workflow_id: WorkflowId,
692 _activity_id: ActivityId,
693 _result: Payload,
694 ) -> Result<(), WorkerError> {
695 std::future::pending::<()>().await;
696 Ok(())
697 }
698
699 async fn report_failure(
700 &mut self,
701 _workflow_id: WorkflowId,
702 _activity_id: ActivityId,
703 _failure: ActivityError,
704 ) -> Result<(), WorkerError> {
705 std::future::pending::<()>().await;
706 Ok(())
707 }
708
709 async fn send_heartbeat(
710 &mut self,
711 _workflow_id: WorkflowId,
712 _activity_id: ActivityId,
713 _progress: Option<Payload>,
714 ) -> Result<(), WorkerError> {
715 Ok(())
716 }
717 }
718
719 enum SessionKind {
720 Scripted(ScriptedSession),
721 Hung(HungReportSession),
722 }
723
724 #[async_trait]
725 impl WorkerSession for SessionKind {
726 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
727 match self {
728 Self::Scripted(session) => session.handshake(config).await,
729 Self::Hung(session) => session.handshake(config).await,
730 }
731 }
732
733 async fn register(
734 &mut self,
735 activity_types: Vec<String>,
736 available_handlers: &BTreeSet<String>,
737 ) -> Result<(), WorkerError> {
738 match self {
739 Self::Scripted(session) => {
740 session.register(activity_types, available_handlers).await
741 }
742 Self::Hung(session) => session.register(activity_types, available_handlers).await,
743 }
744 }
745
746 fn receive_tasks(&mut self) -> WorkerTaskStream {
747 match self {
748 Self::Scripted(session) => session.receive_tasks(),
749 Self::Hung(session) => session.receive_tasks(),
750 }
751 }
752
753 async fn report_result(
754 &mut self,
755 workflow_id: WorkflowId,
756 activity_id: ActivityId,
757 result: Payload,
758 ) -> Result<(), WorkerError> {
759 match self {
760 Self::Scripted(session) => {
761 session
762 .report_result(workflow_id, activity_id, result)
763 .await
764 }
765 Self::Hung(session) => {
766 session
767 .report_result(workflow_id, activity_id, result)
768 .await
769 }
770 }
771 }
772
773 async fn report_failure(
774 &mut self,
775 workflow_id: WorkflowId,
776 activity_id: ActivityId,
777 failure: ActivityError,
778 ) -> Result<(), WorkerError> {
779 match self {
780 Self::Scripted(session) => {
781 session
782 .report_failure(workflow_id, activity_id, failure)
783 .await
784 }
785 Self::Hung(session) => {
786 session
787 .report_failure(workflow_id, activity_id, failure)
788 .await
789 }
790 }
791 }
792
793 async fn send_heartbeat(
794 &mut self,
795 workflow_id: WorkflowId,
796 activity_id: ActivityId,
797 progress: Option<Payload>,
798 ) -> Result<(), WorkerError> {
799 match self {
800 Self::Scripted(session) => {
801 session
802 .send_heartbeat(workflow_id, activity_id, progress)
803 .await
804 }
805 Self::Hung(session) => {
806 session
807 .send_heartbeat(workflow_id, activity_id, progress)
808 .await
809 }
810 }
811 }
812 }
813
814 struct DrainLatchSession {
818 events: Vec<Result<WorkerSessionEvent, WorkerError>>,
819 fail_id: ActivityId,
820 }
821
822 #[async_trait]
823 impl WorkerSession for DrainLatchSession {
824 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
825 drop(config.clone());
826 Ok(())
827 }
828
829 async fn register(
830 &mut self,
831 activity_types: Vec<String>,
832 available_handlers: &BTreeSet<String>,
833 ) -> Result<(), WorkerError> {
834 validate_activity_handlers(&activity_types, available_handlers)
835 }
836
837 fn receive_tasks(&mut self) -> WorkerTaskStream {
838 Box::pin(stream::iter(std::mem::take(&mut self.events)))
839 }
840
841 async fn report_result(
842 &mut self,
843 _workflow_id: WorkflowId,
844 activity_id: ActivityId,
845 _result: Payload,
846 ) -> Result<(), WorkerError> {
847 if activity_id == self.fail_id {
848 return Err(WorkerError::Transport {
849 source: tonic::Status::unavailable(
850 "stream broke abruptly after the drain frame",
851 ),
852 });
853 }
854 Ok(())
855 }
856
857 async fn report_failure(
858 &mut self,
859 _workflow_id: WorkflowId,
860 _activity_id: ActivityId,
861 _failure: ActivityError,
862 ) -> Result<(), WorkerError> {
863 Ok(())
864 }
865
866 async fn send_heartbeat(
867 &mut self,
868 _workflow_id: WorkflowId,
869 _activity_id: ActivityId,
870 _progress: Option<Payload>,
871 ) -> Result<(), WorkerError> {
872 Ok(())
873 }
874 }
875
876 enum LatchKind {
877 Latch(DrainLatchSession),
878 Deny(ScriptedSession),
879 }
880
881 #[async_trait]
882 impl WorkerSession for LatchKind {
883 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
884 match self {
885 Self::Latch(session) => session.handshake(config).await,
886 Self::Deny(session) => session.handshake(config).await,
887 }
888 }
889
890 async fn register(
891 &mut self,
892 activity_types: Vec<String>,
893 available_handlers: &BTreeSet<String>,
894 ) -> Result<(), WorkerError> {
895 match self {
896 Self::Latch(session) => session.register(activity_types, available_handlers).await,
897 Self::Deny(session) => session.register(activity_types, available_handlers).await,
898 }
899 }
900
901 fn receive_tasks(&mut self) -> WorkerTaskStream {
902 match self {
903 Self::Latch(session) => session.receive_tasks(),
904 Self::Deny(session) => session.receive_tasks(),
905 }
906 }
907
908 async fn report_result(
909 &mut self,
910 workflow_id: WorkflowId,
911 activity_id: ActivityId,
912 result: Payload,
913 ) -> Result<(), WorkerError> {
914 match self {
915 Self::Latch(session) => {
916 session
917 .report_result(workflow_id, activity_id, result)
918 .await
919 }
920 Self::Deny(session) => {
921 session
922 .report_result(workflow_id, activity_id, result)
923 .await
924 }
925 }
926 }
927
928 async fn report_failure(
929 &mut self,
930 workflow_id: WorkflowId,
931 activity_id: ActivityId,
932 failure: ActivityError,
933 ) -> Result<(), WorkerError> {
934 match self {
935 Self::Latch(session) => {
936 session
937 .report_failure(workflow_id, activity_id, failure)
938 .await
939 }
940 Self::Deny(session) => {
941 session
942 .report_failure(workflow_id, activity_id, failure)
943 .await
944 }
945 }
946 }
947
948 async fn send_heartbeat(
949 &mut self,
950 workflow_id: WorkflowId,
951 activity_id: ActivityId,
952 progress: Option<Payload>,
953 ) -> Result<(), WorkerError> {
954 match self {
955 Self::Latch(session) => {
956 session
957 .send_heartbeat(workflow_id, activity_id, progress)
958 .await
959 }
960 Self::Deny(session) => {
961 session
962 .send_heartbeat(workflow_id, activity_id, progress)
963 .await
964 }
965 }
966 }
967 }
968
969 #[test]
970 fn empty_worker_is_rejected() {
971 let error = WorkerBuilder::new(test_config()).build().err();
972
973 assert!(error.is_some_and(|error| error.to_string().contains("at least one activity")));
974 }
975
976 #[test]
977 fn worker_collects_two_activity_registration_names() -> Result<(), WorkerError> {
978 let worker = two_activity_worker()?;
979 let expected = [String::from("double"), String::from("increment")]
980 .into_iter()
981 .collect::<BTreeSet<_>>();
982
983 assert_eq!(worker.available_handlers(), &expected);
984 assert_eq!(
985 worker.activity_types(),
986 &[String::from("double"), String::from("increment")]
987 );
988 Ok(())
989 }
990
991 #[tokio::test]
992 async fn worker_registers_names_with_session() -> Result<(), WorkerError> {
993 let worker = two_activity_worker()?;
994 let session = worker
995 .run_with_session(ChannelSession {
996 receiver: None,
997 reports: Vec::new(),
998 registered: Vec::new(),
999 })
1000 .await?;
1001
1002 assert_eq!(
1003 session.registered,
1004 vec![String::from("double"), String::from("increment")]
1005 );
1006 Ok(())
1007 }
1008
1009 #[tokio::test]
1010 async fn shutdown_waits_for_slow_in_flight_activity() -> Result<(), WorkerError> {
1011 let workflow_id = WorkflowId::new_v4();
1012 let activity_id = ActivityId::from_sequence_position(7);
1013 let (sender, receiver) = mpsc::channel(2);
1014 sender
1015 .send(Ok(WorkerSessionEvent::Task(proto_task(
1016 workflow_id,
1017 activity_id.clone(),
1018 "slow",
1019 0,
1020 ))))
1021 .await
1022 .map_err(WorkerError::decode)?;
1023 let release = Arc::new(AtomicBool::new(false));
1024 let started = Arc::new(AtomicUsize::new(0));
1025 let worker = Worker::builder(test_config())
1026 .register_activity("slow", {
1027 let release = Arc::clone(&release);
1028 let started = Arc::clone(&started);
1029 move |input: TestInput, context: &ActivityContext| {
1030 let release = Arc::clone(&release);
1031 let started = Arc::clone(&started);
1032 Box::pin(async move {
1033 let _ = input;
1034 started.fetch_add(1, Ordering::SeqCst);
1035 context.cancelled().await;
1036 while !release.load(Ordering::SeqCst) {
1037 tokio::time::sleep(Duration::from_millis(1)).await;
1038 }
1039 Ok(TestOutput { value: 1 })
1040 })
1041 }
1042 })?
1043 .build()?;
1044 let (shutdown_sender, shutdown_receiver) = tokio::sync::oneshot::channel::<()>();
1045 let session = ChannelSession {
1046 receiver: Some(receiver),
1047 reports: Vec::new(),
1048 registered: Vec::new(),
1049 };
1050 let handle = tokio::spawn(async move {
1051 worker
1052 .run_with_session_until(session, async {
1053 let _ = shutdown_receiver.await;
1054 })
1055 .await
1056 });
1057
1058 wait_until_started(&started).await;
1059 shutdown_sender
1060 .send(())
1061 .map_err(|()| WorkerError::decode(SendFailed))?;
1062 tokio::time::sleep(Duration::from_millis(20)).await;
1063 assert!(!handle.is_finished());
1064 release.store(true, Ordering::SeqCst);
1065 drop(sender);
1066 let session = handle.await.map_err(WorkerError::decode)??;
1067
1068 assert_eq!(session.reports.len(), 1);
1069 assert!(matches!(
1070 &session.reports[0],
1071 RecordedReport::Completed(_, reported_id, _) if reported_id == &activity_id
1072 ));
1073 Ok(())
1074 }
1075
1076 fn two_activity_worker() -> Result<Worker, WorkerError> {
1077 two_activity_worker_with(test_config())
1078 }
1079
1080 fn two_activity_worker_with(config: WorkerConfig) -> Result<Worker, WorkerError> {
1081 Worker::builder(config)
1082 .register_activity("double", |input: TestInput, context| {
1083 Box::pin(async move {
1084 let _ = context;
1085 Ok(TestOutput {
1086 value: input.value * 2,
1087 })
1088 })
1089 })?
1090 .register_activity("increment", |input: TestInput, context| {
1091 Box::pin(async move {
1092 let _ = context;
1093 Ok(TestOutput {
1094 value: input.value + 1,
1095 })
1096 })
1097 })?
1098 .build()
1099 }
1100
1101 fn proto_task(
1102 workflow_id: WorkflowId,
1103 activity_id: ActivityId,
1104 activity_type: &str,
1105 value: i32,
1106 ) -> ProtoActivityTask {
1107 ProtoActivityTask {
1108 workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
1109 activity_id: Some(ProtoActivityId::from(activity_id)),
1110 activity_type: activity_type.to_owned(),
1111 input: Some(ProtoPayload::from(Payload::new(
1112 ContentType::Json,
1113 format!("{{\"value\":{value}}}").into_bytes(),
1114 ))),
1115 attempt: 1,
1116 }
1117 }
1118
1119 async fn wait_until_started(started: &AtomicUsize) {
1120 while started.load(Ordering::SeqCst) == 0 {
1121 tokio::time::sleep(Duration::from_millis(1)).await;
1122 }
1123 }
1124
1125 fn test_config() -> WorkerConfig {
1126 test_config_with(ReconnectConfig::new(
1127 Duration::from_millis(5),
1128 Duration::from_millis(20),
1129 3,
1130 ))
1131 }
1132
1133 fn test_config_with(reconnect: ReconnectConfig) -> WorkerConfig {
1134 WorkerConfig::new(
1135 "http://127.0.0.1:50051",
1136 "payments",
1137 "worker-a",
1138 1,
1139 reconnect,
1140 None,
1141 )
1142 }
1143
1144 fn slow_reconnect_config() -> WorkerConfig {
1145 test_config_with(ReconnectConfig::new(
1146 Duration::from_secs(5),
1147 Duration::from_secs(10),
1148 5,
1149 ))
1150 }
1151
1152 #[derive(Debug, thiserror::Error)]
1153 #[error("failed to send shutdown signal")]
1154 struct SendFailed;
1155
1156 #[derive(Debug, thiserror::Error)]
1157 #[error("expected the worker run to fail")]
1158 struct UnexpectedSuccess;
1159
1160 #[derive(Debug, thiserror::Error)]
1161 #[error("expected a completed activity report")]
1162 struct UnexpectedReportShape;
1163
1164 #[derive(Debug)]
1166 enum SessionLog {
1167 Registered(usize, Vec<String>),
1168 Reported(usize, RecordedReport),
1169 }
1170
1171 struct ScriptedSession {
1174 index: usize,
1175 log: mpsc::UnboundedSender<SessionLog>,
1176 events: Vec<Result<WorkerSessionEvent, WorkerError>>,
1177 fail_reports: bool,
1178 register_denial: Option<tonic::Status>,
1179 delay_stream: Option<Duration>,
1182 }
1183
1184 #[async_trait]
1185 impl WorkerSession for ScriptedSession {
1186 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
1187 drop(config.clone());
1188 Ok(())
1189 }
1190
1191 async fn register(
1192 &mut self,
1193 activity_types: Vec<String>,
1194 available_handlers: &BTreeSet<String>,
1195 ) -> Result<(), WorkerError> {
1196 validate_activity_handlers(&activity_types, available_handlers)?;
1197 if let Some(denial) = self.register_denial.take() {
1198 return Err(WorkerError::Registration {
1199 source: Box::new(denial),
1200 });
1201 }
1202 self.log
1203 .send(SessionLog::Registered(self.index, activity_types))
1204 .map_err(WorkerError::decode)
1205 }
1206
1207 fn receive_tasks(&mut self) -> WorkerTaskStream {
1208 let events = std::mem::take(&mut self.events);
1209 match self.delay_stream.take() {
1210 Some(delay) => Box::pin(
1211 stream::once(async move {
1212 tokio::time::sleep(delay).await;
1213 stream::iter(events)
1214 })
1215 .flatten(),
1216 ),
1217 None => Box::pin(stream::iter(events)),
1218 }
1219 }
1220
1221 async fn report_result(
1222 &mut self,
1223 workflow_id: WorkflowId,
1224 activity_id: ActivityId,
1225 result: Payload,
1226 ) -> Result<(), WorkerError> {
1227 if self.fail_reports {
1228 return Err(WorkerError::Transport {
1229 source: tonic::Status::unavailable("transport dropped before result ack"),
1230 });
1231 }
1232 self.log
1233 .send(SessionLog::Reported(
1234 self.index,
1235 RecordedReport::Completed(workflow_id, activity_id, result),
1236 ))
1237 .map_err(WorkerError::decode)
1238 }
1239
1240 async fn report_failure(
1241 &mut self,
1242 workflow_id: WorkflowId,
1243 activity_id: ActivityId,
1244 failure: ActivityError,
1245 ) -> Result<(), WorkerError> {
1246 if self.fail_reports {
1247 return Err(WorkerError::Transport {
1248 source: tonic::Status::unavailable("transport dropped before failure ack"),
1249 });
1250 }
1251 self.log
1252 .send(SessionLog::Reported(
1253 self.index,
1254 RecordedReport::Failed(workflow_id, activity_id, failure),
1255 ))
1256 .map_err(WorkerError::decode)
1257 }
1258
1259 async fn send_heartbeat(
1260 &mut self,
1261 workflow_id: WorkflowId,
1262 activity_id: ActivityId,
1263 progress: Option<Payload>,
1264 ) -> Result<(), WorkerError> {
1265 drop((workflow_id, activity_id, progress));
1266 Ok(())
1267 }
1268 }
1269
1270 #[tokio::test]
1271 async fn establishment_retries_transient_failures_until_attempts_exhausted()
1272 -> Result<(), WorkerError> {
1273 let worker = two_activity_worker()?;
1274 let attempts = Arc::new(AtomicUsize::new(0));
1275 let connect = {
1276 let attempts = Arc::clone(&attempts);
1277 move || {
1278 attempts.fetch_add(1, Ordering::SeqCst);
1279 async move {
1280 Err::<ScriptedSession, _>(WorkerError::Transport {
1281 source: tonic::Status::unavailable("engine unreachable"),
1282 })
1283 }
1284 }
1285 };
1286
1287 let result = worker
1288 .run_with_connector_until(connect, std::future::pending::<()>())
1289 .await;
1290
1291 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1292 let Err(error) = result else {
1293 return Err(WorkerError::decode(UnexpectedSuccess));
1294 };
1295 assert!(error.is_retryable());
1296 assert!(matches!(
1297 error.grpc_status().map(tonic::Status::code),
1298 Some(tonic::Code::Unavailable)
1299 ));
1300 Ok(())
1301 }
1302
1303 #[tokio::test]
1304 async fn establishment_denial_surfaces_after_one_attempt() -> Result<(), WorkerError> {
1305 let worker = two_activity_worker()?;
1306 let attempts = Arc::new(AtomicUsize::new(0));
1307 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1308 let connect = {
1309 let attempts = Arc::clone(&attempts);
1310 move || {
1311 attempts.fetch_add(1, Ordering::SeqCst);
1312 let log = log_sender.clone();
1313 async move {
1314 Ok(ScriptedSession {
1315 index: 1,
1316 log,
1317 events: Vec::new(),
1318 fail_reports: false,
1319 register_denial: Some(tonic::Status::permission_denied(
1320 "namespace `payments` is not granted to subject `worker-a`",
1321 )),
1322 delay_stream: None,
1323 })
1324 }
1325 }
1326 };
1327
1328 let result = worker
1329 .run_with_connector_until(connect, std::future::pending::<()>())
1330 .await;
1331
1332 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1333 let Err(error) = result else {
1334 return Err(WorkerError::decode(UnexpectedSuccess));
1335 };
1336 assert!(!error.is_retryable());
1337 assert!(matches!(
1338 error.grpc_status().map(tonic::Status::code),
1339 Some(tonic::Code::PermissionDenied)
1340 ));
1341 assert_eq!(
1342 error.grpc_status().map(tonic::Status::message),
1343 Some("namespace `payments` is not granted to subject `worker-a`")
1344 );
1345 drop(log_receiver);
1346 Ok(())
1347 }
1348
1349 #[tokio::test]
1350 async fn mid_run_drop_reconnects_re_registers_and_re_reports_unacked() -> Result<(), WorkerError>
1351 {
1352 let workflow_id = WorkflowId::new_v4();
1353 let activity_id = ActivityId::from_sequence_position(3);
1354 let worker = two_activity_worker()?;
1355 let attempts = Arc::new(AtomicUsize::new(0));
1356 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1357 let connect = {
1358 let attempts = Arc::clone(&attempts);
1359 let log_sender = log_sender.clone();
1360 let workflow_id = workflow_id.clone();
1361 let activity_id = activity_id.clone();
1362 move || {
1363 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1364 let log = log_sender.clone();
1365 let task = proto_task(workflow_id.clone(), activity_id.clone(), "double", 21);
1366 async move {
1367 if attempt == 1 {
1368 Ok(ScriptedSession {
1369 index: 1,
1370 log,
1371 events: vec![Ok(WorkerSessionEvent::Task(task))],
1372 fail_reports: true,
1373 register_denial: None,
1374 delay_stream: None,
1375 })
1376 } else if attempt == 2 {
1377 Ok(ScriptedSession {
1378 index: attempt,
1379 log,
1380 events: Vec::new(),
1381 fail_reports: false,
1382 register_denial: None,
1383 delay_stream: None,
1384 })
1385 } else {
1386 Ok(ScriptedSession {
1389 index: attempt,
1390 log,
1391 events: Vec::new(),
1392 fail_reports: false,
1393 register_denial: Some(tonic::Status::permission_denied(
1394 "namespace `payments` revoked for subject `worker-a`",
1395 )),
1396 delay_stream: None,
1397 })
1398 }
1399 }
1400 }
1401 };
1402
1403 let result = worker
1404 .run_with_connector_until(connect, std::future::pending::<()>())
1405 .await;
1406
1407 drop(log_sender);
1408 let mut registrations = Vec::new();
1409 let mut reports = Vec::new();
1410 while let Some(entry) = log_receiver.recv().await {
1411 match entry {
1412 SessionLog::Registered(index, types) => registrations.push((index, types)),
1413 SessionLog::Reported(index, report) => reports.push((index, report)),
1414 }
1415 }
1416 let Err(error) = result else {
1417 return Err(WorkerError::decode(UnexpectedSuccess));
1418 };
1419 assert!(!error.is_retryable());
1420 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1421 let expected_types = vec![String::from("double"), String::from("increment")];
1422 assert_eq!(
1423 registrations,
1424 vec![(1, expected_types.clone()), (2, expected_types)]
1425 );
1426 assert_eq!(reports.len(), 1);
1427 let (session_index, report) = &reports[0];
1428 assert_eq!(*session_index, 2);
1429 let RecordedReport::Completed(reported_workflow, reported_id, payload) = report else {
1430 return Err(WorkerError::decode(UnexpectedReportShape));
1431 };
1432 assert_eq!(reported_workflow, &workflow_id);
1433 assert_eq!(reported_id, &activity_id);
1434 let output: TestOutput =
1435 serde_json::from_slice(payload.bytes()).map_err(WorkerError::decode)?;
1436 assert_eq!(output.value, 42);
1437 Ok(())
1438 }
1439
1440 #[tokio::test]
1441 async fn mid_run_drop_re_reports_unacked_results_for_all_workflows() -> Result<(), WorkerError>
1442 {
1443 let first_workflow = WorkflowId::new_v4();
1444 let second_workflow = WorkflowId::new_v4();
1445 let activity_id = ActivityId::from_sequence_position(3);
1446 let worker = two_activity_worker()?;
1447 let attempts = Arc::new(AtomicUsize::new(0));
1448 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1449 let connect = {
1450 let attempts = Arc::clone(&attempts);
1451 let log_sender = log_sender.clone();
1452 let first_workflow = first_workflow.clone();
1453 let second_workflow = second_workflow.clone();
1454 let activity_id = activity_id.clone();
1455 move || {
1456 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1457 let log = log_sender.clone();
1458 let first_task =
1459 proto_task(first_workflow.clone(), activity_id.clone(), "double", 10);
1460 let second_task =
1461 proto_task(second_workflow.clone(), activity_id.clone(), "double", 20);
1462 async move {
1463 if attempt == 1 {
1464 Ok(ScriptedSession {
1465 index: 1,
1466 log,
1467 events: vec![
1468 Ok(WorkerSessionEvent::Task(first_task)),
1469 Ok(WorkerSessionEvent::Task(second_task)),
1470 ],
1471 fail_reports: true,
1472 register_denial: None,
1473 delay_stream: None,
1474 })
1475 } else if attempt == 2 {
1476 Ok(ScriptedSession {
1477 index: attempt,
1478 log,
1479 events: Vec::new(),
1480 fail_reports: false,
1481 register_denial: None,
1482 delay_stream: None,
1483 })
1484 } else {
1485 Ok(ScriptedSession {
1488 index: attempt,
1489 log,
1490 events: Vec::new(),
1491 fail_reports: false,
1492 register_denial: Some(tonic::Status::permission_denied(
1493 "namespace `payments` revoked for subject `worker-a`",
1494 )),
1495 delay_stream: None,
1496 })
1497 }
1498 }
1499 }
1500 };
1501
1502 let result = worker
1503 .run_with_connector_until(connect, std::future::pending::<()>())
1504 .await;
1505
1506 drop(log_sender);
1507 let mut reports = Vec::new();
1508 while let Some(entry) = log_receiver.recv().await {
1509 if let SessionLog::Reported(index, report) = entry {
1510 reports.push((index, report));
1511 }
1512 }
1513 let Err(error) = result else {
1514 return Err(WorkerError::decode(UnexpectedSuccess));
1515 };
1516 assert!(!error.is_retryable());
1517 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1518 assert_eq!(
1519 reports.len(),
1520 2,
1521 "both workflows' colliding sequence-position results must be re-reported"
1522 );
1523 let mut reported_workflows = Vec::new();
1524 for (session_index, report) in &reports {
1525 assert_eq!(*session_index, 2, "re-reports must land on the new session");
1526 let RecordedReport::Completed(reported_workflow, reported_id, _) = report else {
1527 return Err(WorkerError::decode(UnexpectedReportShape));
1528 };
1529 assert_eq!(reported_id, &activity_id);
1530 reported_workflows.push(reported_workflow.clone());
1531 }
1532 assert!(reported_workflows.contains(&first_workflow));
1533 assert!(reported_workflows.contains(&second_workflow));
1534 Ok(())
1535 }
1536
1537 #[tokio::test]
1538 async fn shutdown_during_recovery_establishment_returns_original_drop_error()
1539 -> Result<(), WorkerError> {
1540 let worker = two_activity_worker()?;
1541 let attempts = Arc::new(AtomicUsize::new(0));
1542 let notify = Arc::new(Notify::new());
1543 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1544 let connect = {
1545 let attempts = Arc::clone(&attempts);
1546 let notify = Arc::clone(¬ify);
1547 move || {
1548 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1549 let notify = Arc::clone(¬ify);
1550 let log = log_sender.clone();
1551 async move {
1552 if attempt == 1 {
1553 Ok(ScriptedSession {
1554 index: 1,
1555 log,
1556 events: vec![Err(WorkerError::Transport {
1557 source: tonic::Status::unavailable("stream reset by peer"),
1558 })],
1559 fail_reports: false,
1560 register_denial: None,
1561 delay_stream: None,
1562 })
1563 } else {
1564 notify.notify_one();
1568 std::future::pending::<()>().await;
1569 Err(WorkerError::Transport {
1570 source: tonic::Status::unavailable("unreachable"),
1571 })
1572 }
1573 }
1574 }
1575 };
1576 let shutdown = {
1577 let notify = Arc::clone(¬ify);
1578 async move {
1579 notify.notified().await;
1580 }
1581 };
1582
1583 let run = worker.run_with_connector_until(connect, shutdown);
1584 let result = tokio::time::timeout(Duration::from_secs(5), run)
1585 .await
1586 .map_err(WorkerError::decode)?;
1587
1588 assert_eq!(attempts.load(Ordering::SeqCst), 2);
1589 let Err(error) = result else {
1590 return Err(WorkerError::decode(UnexpectedSuccess));
1591 };
1592 assert!(matches!(
1593 error.grpc_status().map(tonic::Status::code),
1594 Some(tonic::Code::Unavailable)
1595 ));
1596 assert_eq!(
1597 error.grpc_status().map(tonic::Status::message),
1598 Some("stream reset by peer"),
1599 "shutdown during recovery establishment must surface the original drop error"
1600 );
1601 drop(log_receiver);
1602 Ok(())
1603 }
1604
1605 #[tokio::test(start_paused = true)]
1609 async fn mid_run_drop_budget_exhaustion_surfaces_last_drop_error() -> Result<(), WorkerError> {
1610 let worker = two_activity_worker()?;
1611 let attempts = Arc::new(AtomicUsize::new(0));
1612 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1613 let connect = {
1614 let attempts = Arc::clone(&attempts);
1615 move || {
1616 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1617 let log = log_sender.clone();
1618 async move {
1619 Ok(ScriptedSession {
1620 index: attempt,
1621 log,
1622 events: vec![Err(WorkerError::Transport {
1623 source: tonic::Status::unavailable("stream reset by peer"),
1624 })],
1625 fail_reports: false,
1626 register_denial: None,
1627 delay_stream: None,
1628 })
1629 }
1630 }
1631 };
1632
1633 let run = worker.run_with_connector_until(connect, std::future::pending::<()>());
1634 let result = tokio::time::timeout(Duration::from_secs(5), run)
1635 .await
1636 .map_err(WorkerError::decode)?;
1637
1638 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1641 let Err(error) = result else {
1642 return Err(WorkerError::decode(UnexpectedSuccess));
1643 };
1644 assert!(error.is_retryable());
1645 assert!(matches!(
1646 error.grpc_status().map(tonic::Status::code),
1647 Some(tonic::Code::Unavailable)
1648 ));
1649 assert_eq!(
1650 error.grpc_status().map(tonic::Status::message),
1651 Some("stream reset by peer")
1652 );
1653 drop(log_receiver);
1654 Ok(())
1655 }
1656
1657 #[tokio::test]
1658 async fn mid_run_denial_surfaces_without_reconnect() -> Result<(), WorkerError> {
1659 let worker = two_activity_worker()?;
1660 let attempts = Arc::new(AtomicUsize::new(0));
1661 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1662 let connect = {
1663 let attempts = Arc::clone(&attempts);
1664 move || {
1665 attempts.fetch_add(1, Ordering::SeqCst);
1666 let log = log_sender.clone();
1667 async move {
1668 Ok(ScriptedSession {
1669 index: 1,
1670 log,
1671 events: vec![Err(WorkerError::Transport {
1672 source: tonic::Status::permission_denied(
1673 "namespace `payments` revoked for subject `worker-a`",
1674 ),
1675 })],
1676 fail_reports: false,
1677 register_denial: None,
1678 delay_stream: None,
1679 })
1680 }
1681 }
1682 };
1683
1684 let result = worker
1685 .run_with_connector_until(connect, std::future::pending::<()>())
1686 .await;
1687
1688 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1689 let Err(error) = result else {
1690 return Err(WorkerError::decode(UnexpectedSuccess));
1691 };
1692 assert!(!error.is_retryable());
1693 assert!(matches!(
1694 error.grpc_status().map(tonic::Status::code),
1695 Some(tonic::Code::PermissionDenied)
1696 ));
1697 assert_eq!(
1698 error.grpc_status().map(tonic::Status::message),
1699 Some("namespace `payments` revoked for subject `worker-a`")
1700 );
1701 drop(log_receiver);
1702 Ok(())
1703 }
1704
1705 #[tokio::test]
1706 async fn shutdown_during_establishment_backoff_returns_promptly() -> Result<(), WorkerError> {
1707 let worker = two_activity_worker_with(slow_reconnect_config())?;
1708 let attempts = Arc::new(AtomicUsize::new(0));
1709 let notify = Arc::new(Notify::new());
1710 let connect = {
1711 let attempts = Arc::clone(&attempts);
1712 let notify = Arc::clone(¬ify);
1713 move || {
1714 attempts.fetch_add(1, Ordering::SeqCst);
1715 notify.notify_one();
1716 async move {
1717 Err::<ScriptedSession, _>(WorkerError::Transport {
1718 source: tonic::Status::unavailable("engine unreachable"),
1719 })
1720 }
1721 }
1722 };
1723 let shutdown = {
1724 let notify = Arc::clone(¬ify);
1725 async move {
1726 notify.notified().await;
1727 }
1728 };
1729
1730 let run = worker.run_with_connector_until(connect, shutdown);
1731 tokio::time::timeout(Duration::from_millis(500), run)
1732 .await
1733 .map_err(WorkerError::decode)??;
1734
1735 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1736 Ok(())
1737 }
1738
1739 #[tokio::test]
1740 async fn shutdown_during_mid_run_drop_backoff_returns_promptly() -> Result<(), WorkerError> {
1741 let worker = two_activity_worker_with(slow_reconnect_config())?;
1742 let attempts = Arc::new(AtomicUsize::new(0));
1743 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1744 let connect = {
1745 let attempts = Arc::clone(&attempts);
1746 move || {
1747 attempts.fetch_add(1, Ordering::SeqCst);
1748 let log = log_sender.clone();
1749 async move {
1750 Ok(ScriptedSession {
1751 index: 1,
1752 log,
1753 events: vec![Err(WorkerError::Transport {
1754 source: tonic::Status::unavailable("stream reset by peer"),
1755 })],
1756 fail_reports: false,
1757 register_denial: None,
1758 delay_stream: None,
1759 })
1760 }
1761 }
1762 };
1763 let shutdown = async {
1764 tokio::time::sleep(Duration::from_millis(50)).await;
1765 };
1766
1767 let run = worker.run_with_connector_until(connect, shutdown);
1768 let result = tokio::time::timeout(Duration::from_millis(500), run)
1769 .await
1770 .map_err(WorkerError::decode)?;
1771
1772 assert_eq!(attempts.load(Ordering::SeqCst), 1);
1773 let Err(error) = result else {
1774 return Err(WorkerError::decode(UnexpectedSuccess));
1775 };
1776 assert!(error.is_retryable());
1777 assert!(matches!(
1778 error.grpc_status().map(tonic::Status::code),
1779 Some(tonic::Code::Unavailable)
1780 ));
1781 drop(log_receiver);
1782 Ok(())
1783 }
1784
1785 #[tokio::test]
1786 async fn served_tasks_reset_drop_budget_across_cycles() -> Result<(), WorkerError> {
1787 let workflow_id = WorkflowId::new_v4();
1788 let activity_id = ActivityId::from_sequence_position(7);
1789 let worker = two_activity_worker_with(test_config_with(ReconnectConfig::new(
1792 Duration::from_millis(1),
1793 Duration::from_secs(3600),
1794 2,
1795 )))?;
1796 let attempts = Arc::new(AtomicUsize::new(0));
1797 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
1798 let connect = {
1799 let attempts = Arc::clone(&attempts);
1800 let log_sender = log_sender.clone();
1801 let workflow_id = workflow_id.clone();
1802 let activity_id = activity_id.clone();
1803 move || {
1804 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1805 let log = log_sender.clone();
1806 let task = proto_task(workflow_id.clone(), activity_id.clone(), "double", 21);
1807 async move {
1808 if attempt <= 4 {
1809 Ok(ScriptedSession {
1810 index: attempt,
1811 log,
1812 events: vec![
1813 Ok(WorkerSessionEvent::Task(task)),
1814 Err(WorkerError::Transport {
1815 source: tonic::Status::unavailable("stream reset by peer"),
1816 }),
1817 ],
1818 fail_reports: false,
1819 register_denial: None,
1820 delay_stream: None,
1821 })
1822 } else {
1823 Ok(ScriptedSession {
1824 index: attempt,
1825 log,
1826 events: Vec::new(),
1827 fail_reports: false,
1828 register_denial: Some(tonic::Status::permission_denied(
1829 "namespace `payments` revoked for subject `worker-a`",
1830 )),
1831 delay_stream: None,
1832 })
1833 }
1834 }
1835 }
1836 };
1837
1838 let run = worker.run_with_connector_until(connect, std::future::pending::<()>());
1839 let result = tokio::time::timeout(Duration::from_secs(5), run)
1840 .await
1841 .map_err(WorkerError::decode)?;
1842
1843 drop(log_sender);
1844 let mut registrations = 0_usize;
1845 while let Some(entry) = log_receiver.recv().await {
1846 if let SessionLog::Registered(..) = entry {
1847 registrations += 1;
1848 }
1849 }
1850 assert_eq!(attempts.load(Ordering::SeqCst), 5);
1855 assert_eq!(registrations, 4);
1856 let Err(error) = result else {
1857 return Err(WorkerError::decode(UnexpectedSuccess));
1858 };
1859 assert!(!error.is_retryable());
1860 assert!(matches!(
1861 error.grpc_status().map(tonic::Status::code),
1862 Some(tonic::Code::PermissionDenied)
1863 ));
1864 Ok(())
1865 }
1866
1867 #[tokio::test(start_paused = true)]
1868 async fn session_outliving_max_backoff_resets_drop_budget() -> Result<(), WorkerError> {
1869 let worker = two_activity_worker_with(test_config_with(ReconnectConfig::new(
1870 Duration::from_millis(5),
1871 Duration::from_millis(20),
1872 2,
1873 )))?;
1874 let attempts = Arc::new(AtomicUsize::new(0));
1875 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1876 let connect = {
1877 let attempts = Arc::clone(&attempts);
1878 move || {
1879 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1880 let log = log_sender.clone();
1881 async move {
1882 Ok(ScriptedSession {
1883 index: attempt,
1884 log,
1885 events: vec![Err(WorkerError::Transport {
1886 source: tonic::Status::unavailable("stream reset by peer"),
1887 })],
1888 fail_reports: false,
1889 register_denial: None,
1890 delay_stream: (attempt == 2).then_some(Duration::from_millis(30)),
1894 })
1895 }
1896 }
1897 };
1898
1899 let run = worker.run_with_connector_until(connect, std::future::pending::<()>());
1900 let result = tokio::time::timeout(Duration::from_secs(5), run)
1901 .await
1902 .map_err(WorkerError::decode)?;
1903
1904 assert_eq!(attempts.load(Ordering::SeqCst), 3);
1911 let Err(error) = result else {
1912 return Err(WorkerError::decode(UnexpectedSuccess));
1913 };
1914 assert!(error.is_retryable());
1915 assert!(matches!(
1916 error.grpc_status().map(tonic::Status::code),
1917 Some(tonic::Code::Unavailable)
1918 ));
1919 drop(log_receiver);
1920 Ok(())
1921 }
1922
1923 #[tokio::test(start_paused = true)]
1930 async fn post_drop_drain_time_does_not_reset_drop_budget() -> Result<(), WorkerError> {
1931 let workflow_id = WorkflowId::new_v4();
1932 let activity_id = ActivityId::from_sequence_position(9);
1933 let config = WorkerConfig::new(
1936 "http://127.0.0.1:50051",
1937 "payments",
1938 "worker-a",
1939 2,
1940 ReconnectConfig::new(Duration::from_millis(5), Duration::from_millis(20), 2),
1941 None,
1942 );
1943 let worker = Worker::builder(config)
1944 .register_activity("slow", |input: TestInput, context: &ActivityContext| {
1945 let _ = (input, context);
1946 Box::pin(async move {
1947 tokio::time::sleep(Duration::from_millis(60)).await;
1950 Ok(TestOutput { value: 1 })
1951 })
1952 })?
1953 .build()?;
1954 let attempts = Arc::new(AtomicUsize::new(0));
1955 let (log_sender, log_receiver) = mpsc::unbounded_channel();
1956 let connect = {
1957 let attempts = Arc::clone(&attempts);
1958 let workflow_id = workflow_id.clone();
1959 let activity_id = activity_id.clone();
1960 move || {
1961 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
1962 let log = log_sender.clone();
1963 let task = proto_task(workflow_id.clone(), activity_id.clone(), "slow", 1);
1964 async move {
1965 if attempt == 1 {
1966 Ok(ScriptedSession {
1970 index: 1,
1971 log,
1972 events: vec![Err(WorkerError::Transport {
1973 source: tonic::Status::unavailable("stream reset by peer"),
1974 })],
1975 fail_reports: false,
1976 register_denial: None,
1977 delay_stream: None,
1978 })
1979 } else {
1980 Ok(ScriptedSession {
1985 index: attempt,
1986 log,
1987 events: vec![
1988 Ok(WorkerSessionEvent::Task(task)),
1989 Err(WorkerError::Transport {
1990 source: tonic::Status::unavailable("stream reset by peer"),
1991 }),
1992 ],
1993 fail_reports: true,
1994 register_denial: None,
1995 delay_stream: None,
1996 })
1997 }
1998 }
1999 }
2000 };
2001
2002 let run = worker.run_with_connector_until(connect, std::future::pending::<()>());
2003 let result = tokio::time::timeout(Duration::from_secs(5), run)
2004 .await
2005 .map_err(WorkerError::decode)?;
2006
2007 assert_eq!(attempts.load(Ordering::SeqCst), 2);
2014 let Err(error) = result else {
2015 return Err(WorkerError::decode(UnexpectedSuccess));
2016 };
2017 assert!(error.is_retryable());
2018 assert!(matches!(
2019 error.grpc_status().map(tonic::Status::code),
2020 Some(tonic::Code::Unavailable)
2021 ));
2022 drop(log_receiver);
2023 Ok(())
2024 }
2025
2026 #[tokio::test]
2027 async fn clean_close_reconnects_re_registers_and_keeps_serving() -> Result<(), WorkerError> {
2028 let workflow_id = WorkflowId::new_v4();
2029 let first_activity = ActivityId::from_sequence_position(1);
2030 let second_activity = ActivityId::from_sequence_position(2);
2031 let worker = two_activity_worker()?;
2032 let attempts = Arc::new(AtomicUsize::new(0));
2033 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
2034 let connect = {
2035 let attempts = Arc::clone(&attempts);
2036 let log_sender = log_sender.clone();
2037 let workflow_id = workflow_id.clone();
2038 let first_activity = first_activity.clone();
2039 let second_activity = second_activity.clone();
2040 move || {
2041 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
2042 let log = log_sender.clone();
2043 let first_task =
2044 proto_task(workflow_id.clone(), first_activity.clone(), "double", 10);
2045 let second_task =
2046 proto_task(workflow_id.clone(), second_activity.clone(), "double", 20);
2047 async move {
2048 match attempt {
2049 1 => Ok(ScriptedSession {
2052 index: 1,
2053 log,
2054 events: vec![Ok(WorkerSessionEvent::Task(first_task))],
2055 fail_reports: false,
2056 register_denial: None,
2057 delay_stream: None,
2058 }),
2059 2 => Ok(ScriptedSession {
2060 index: 2,
2061 log,
2062 events: vec![Ok(WorkerSessionEvent::Task(second_task))],
2063 fail_reports: false,
2064 register_denial: None,
2065 delay_stream: None,
2066 }),
2067 _ => Ok(ScriptedSession {
2068 index: attempt,
2069 log,
2070 events: Vec::new(),
2071 fail_reports: false,
2072 register_denial: Some(tonic::Status::permission_denied(
2073 "namespace `payments` revoked for subject `worker-a`",
2074 )),
2075 delay_stream: None,
2076 }),
2077 }
2078 }
2079 }
2080 };
2081
2082 let run = worker.run_with_connector_until(connect, std::future::pending::<()>());
2083 let result = tokio::time::timeout(Duration::from_secs(5), run)
2084 .await
2085 .map_err(WorkerError::decode)?;
2086
2087 drop(log_sender);
2088 let mut registrations = Vec::new();
2089 let mut reports = Vec::new();
2090 while let Some(entry) = log_receiver.recv().await {
2091 match entry {
2092 SessionLog::Registered(index, types) => registrations.push((index, types)),
2093 SessionLog::Reported(index, report) => reports.push((index, report)),
2094 }
2095 }
2096 assert_eq!(attempts.load(Ordering::SeqCst), 3);
2100 let expected_types = vec![String::from("double"), String::from("increment")];
2101 assert_eq!(
2102 registrations,
2103 vec![(1, expected_types.clone()), (2, expected_types)]
2104 );
2105 assert_eq!(reports.len(), 3);
2106 assert!(matches!(
2107 &reports[0],
2108 (1, RecordedReport::Completed(_, id, _)) if id == &first_activity
2109 ));
2110 assert!(matches!(
2111 &reports[1],
2112 (2, RecordedReport::Completed(_, id, _)) if id == &first_activity
2113 ));
2114 assert!(matches!(
2115 &reports[2],
2116 (2, RecordedReport::Completed(_, id, _)) if id == &second_activity
2117 ));
2118 let Err(error) = result else {
2119 return Err(WorkerError::decode(UnexpectedSuccess));
2120 };
2121 assert!(!error.is_retryable());
2122 assert!(matches!(
2123 error.grpc_status().map(tonic::Status::code),
2124 Some(tonic::Code::PermissionDenied)
2125 ));
2126 Ok(())
2127 }
2128
2129 #[tokio::test(start_paused = true)]
2130 async fn clean_close_loop_exhausts_drop_budget_with_classified_error() -> Result<(), WorkerError>
2131 {
2132 let worker = two_activity_worker()?;
2133 let attempts = Arc::new(AtomicUsize::new(0));
2134 let (log_sender, log_receiver) = mpsc::unbounded_channel();
2135 let connect = {
2136 let attempts = Arc::clone(&attempts);
2137 move || {
2138 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
2139 let log = log_sender.clone();
2140 async move {
2141 Ok(ScriptedSession {
2142 index: attempt,
2143 log,
2144 events: Vec::new(),
2145 fail_reports: false,
2146 register_denial: None,
2147 delay_stream: None,
2148 })
2149 }
2150 }
2151 };
2152
2153 let run = worker.run_with_connector_until(connect, std::future::pending::<()>());
2154 let result = tokio::time::timeout(Duration::from_secs(5), run)
2155 .await
2156 .map_err(WorkerError::decode)?;
2157
2158 assert_eq!(attempts.load(Ordering::SeqCst), 3);
2163 let Err(error) = result else {
2164 return Err(WorkerError::decode(UnexpectedSuccess));
2165 };
2166 assert!(matches!(error, WorkerError::CleanCloseExhausted));
2167 assert!(error.to_string().contains("closed the stream cleanly"));
2168 drop(log_receiver);
2169 Ok(())
2170 }
2171
2172 #[tokio::test]
2173 async fn shutdown_during_clean_close_backoff_returns_ok_promptly() -> Result<(), WorkerError> {
2174 let worker = two_activity_worker_with(slow_reconnect_config())?;
2175 let attempts = Arc::new(AtomicUsize::new(0));
2176 let (log_sender, log_receiver) = mpsc::unbounded_channel();
2177 let connect = {
2178 let attempts = Arc::clone(&attempts);
2179 move || {
2180 attempts.fetch_add(1, Ordering::SeqCst);
2181 let log = log_sender.clone();
2182 async move {
2183 Ok(ScriptedSession {
2184 index: 1,
2185 log,
2186 events: Vec::new(),
2187 fail_reports: false,
2188 register_denial: None,
2189 delay_stream: None,
2190 })
2191 }
2192 }
2193 };
2194 let shutdown = async {
2195 tokio::time::sleep(Duration::from_millis(50)).await;
2196 };
2197
2198 let run = worker.run_with_connector_until(connect, shutdown);
2201 tokio::time::timeout(Duration::from_millis(500), run)
2202 .await
2203 .map_err(WorkerError::decode)??;
2204
2205 assert_eq!(attempts.load(Ordering::SeqCst), 1);
2206 drop(log_receiver);
2207 Ok(())
2208 }
2209
2210 #[tokio::test]
2214 async fn result_ack_clears_exactly_its_tracker_entry() -> Result<(), WorkerError> {
2215 use crate::protocol::reconnect::{PendingActivityReport, UnackedResultTracker};
2216 use crate::runtime::loop_::{SessionHealth, serve_activity_tasks_until};
2217
2218 let workflow_a = WorkflowId::new_v4();
2219 let workflow_b = WorkflowId::new_v4();
2220 let position = ActivityId::from_sequence_position(5);
2221 let mut tracker = UnackedResultTracker::new();
2222 for workflow in [&workflow_a, &workflow_b] {
2223 tracker.record(PendingActivityReport::Completed {
2224 workflow_id: workflow.clone(),
2225 activity_id: position.clone(),
2226 output: Payload::new(ContentType::Json, b"{\"value\":1}".to_vec()),
2227 });
2228 }
2229
2230 let worker = two_activity_worker()?;
2231 let mut session = ChannelSession {
2232 receiver: None,
2233 reports: Vec::new(),
2234 registered: Vec::new(),
2235 };
2236 let (sender, receiver) = mpsc::channel(4);
2237 sender
2238 .send(Ok(WorkerSessionEvent::ResultAck {
2239 workflow_id: workflow_a.clone(),
2240 activity_id: position.clone(),
2241 }))
2242 .await
2243 .map_err(WorkerError::decode)?;
2244 sender
2246 .send(Ok(WorkerSessionEvent::ResultAck {
2247 workflow_id: WorkflowId::new_v4(),
2248 activity_id: ActivityId::from_sequence_position(99),
2249 }))
2250 .await
2251 .map_err(WorkerError::decode)?;
2252 drop(sender);
2253 session.receiver = Some(receiver);
2254
2255 let mut health = SessionHealth::default();
2256 serve_activity_tasks_until(
2257 &test_config(),
2258 &mut session,
2259 Arc::new(crate::activity::ActivityRegistry::new()),
2260 &mut tracker,
2261 &mut health,
2262 std::future::pending(),
2263 )
2264 .await?;
2265
2266 assert_eq!(tracker.len(), 1, "exactly the acked entry must clear");
2267 assert!(tracker.get(&workflow_a, &position).is_none());
2268 assert!(tracker.get(&workflow_b, &position).is_some());
2269 drop(worker);
2270 Ok(())
2271 }
2272
2273 #[tokio::test]
2277 async fn acked_results_decay_out_of_the_reconnect_replay() -> Result<(), WorkerError> {
2278 use crate::protocol::re_report_unacked;
2279 use crate::protocol::reconnect::{PendingActivityReport, UnackedResultTracker};
2280 use crate::runtime::loop_::{SessionHealth, serve_activity_tasks_until};
2281
2282 let workflow_id = WorkflowId::new_v4();
2283 let acked_id = ActivityId::from_sequence_position(1);
2284 let unacked_id = ActivityId::from_sequence_position(2);
2285 let mut tracker = UnackedResultTracker::new();
2286 for id in [&acked_id, &unacked_id] {
2287 tracker.record(PendingActivityReport::Completed {
2288 workflow_id: workflow_id.clone(),
2289 activity_id: id.clone(),
2290 output: Payload::new(ContentType::Json, b"{\"value\":2}".to_vec()),
2291 });
2292 }
2293
2294 let mut session = ChannelSession {
2297 receiver: None,
2298 reports: Vec::new(),
2299 registered: Vec::new(),
2300 };
2301 let (sender, receiver) = mpsc::channel(2);
2302 sender
2303 .send(Ok(WorkerSessionEvent::ResultAck {
2304 workflow_id: workflow_id.clone(),
2305 activity_id: acked_id.clone(),
2306 }))
2307 .await
2308 .map_err(WorkerError::decode)?;
2309 drop(sender);
2310 session.receiver = Some(receiver);
2311 let mut health = SessionHealth::default();
2312 serve_activity_tasks_until(
2313 &test_config(),
2314 &mut session,
2315 Arc::new(crate::activity::ActivityRegistry::new()),
2316 &mut tracker,
2317 &mut health,
2318 std::future::pending(),
2319 )
2320 .await?;
2321
2322 let mut replay_session = ChannelSession {
2324 receiver: None,
2325 reports: Vec::new(),
2326 registered: Vec::new(),
2327 };
2328 re_report_unacked(&tracker, &mut replay_session).await?;
2329 assert_eq!(
2330 replay_session.reports.len(),
2331 1,
2332 "only the un-acked result may be re-reported"
2333 );
2334 assert!(matches!(
2335 &replay_session.reports[0],
2336 RecordedReport::Completed(_, id, _) if id == &unacked_id
2337 ));
2338
2339 let (sender, receiver) = mpsc::channel(2);
2342 sender
2343 .send(Ok(WorkerSessionEvent::ResultAck {
2344 workflow_id: workflow_id.clone(),
2345 activity_id: unacked_id.clone(),
2346 }))
2347 .await
2348 .map_err(WorkerError::decode)?;
2349 drop(sender);
2350 replay_session.receiver = Some(receiver);
2351 let mut health = SessionHealth::default();
2352 serve_activity_tasks_until(
2353 &test_config(),
2354 &mut replay_session,
2355 Arc::new(crate::activity::ActivityRegistry::new()),
2356 &mut tracker,
2357 &mut health,
2358 std::future::pending(),
2359 )
2360 .await?;
2361 assert!(tracker.is_empty(), "acks must drain the tracker");
2362
2363 let mut decayed_session = ChannelSession {
2364 receiver: None,
2365 reports: Vec::new(),
2366 registered: Vec::new(),
2367 };
2368 re_report_unacked(&tracker, &mut decayed_session).await?;
2369 assert!(
2370 decayed_session.reports.is_empty(),
2371 "steady-state replay must send nothing"
2372 );
2373 Ok(())
2374 }
2375
2376 #[tokio::test(start_paused = true)]
2379 async fn shutdown_interrupts_hung_unacked_replay_promptly() -> Result<(), WorkerError> {
2380 let workflow_id = WorkflowId::new_v4();
2383 let activity_id = ActivityId::from_sequence_position(3);
2384 let worker = two_activity_worker()?;
2385 let attempts = Arc::new(AtomicUsize::new(0));
2386 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
2387 let (registered_2_tx, registered_2_rx) = tokio::sync::oneshot::channel::<()>();
2388 let registered_2_tx = std::sync::Mutex::new(Some(registered_2_tx));
2389 let connect = {
2390 let log_sender = log_sender.clone();
2391 let workflow_id = workflow_id.clone();
2392 let activity_id = activity_id.clone();
2393 move |attempt_override: usize| {
2394 let log = log_sender.clone();
2395 let task = proto_task(workflow_id.clone(), activity_id.clone(), "double", 21);
2396 let notify = if attempt_override == 2 {
2397 registered_2_tx
2398 .lock()
2399 .ok()
2400 .and_then(|mut guard| guard.take())
2401 } else {
2402 None
2403 };
2404 async move {
2405 if attempt_override == 1 {
2406 Ok(SessionKind::Scripted(ScriptedSession {
2407 index: 1,
2408 log,
2409 events: vec![Ok(WorkerSessionEvent::Task(task))],
2410 fail_reports: true,
2411 register_denial: None,
2412 delay_stream: None,
2413 }))
2414 } else {
2415 if let Some(notify) = notify {
2416 let _ = notify.send(());
2417 }
2418 Ok(SessionKind::Hung(HungReportSession { index: 2, log }))
2419 }
2420 }
2421 }
2422 };
2423
2424 let attempts_for_connect = Arc::clone(&attempts);
2425 let run = worker.run_with_connector_until(
2426 move || {
2427 let attempt = attempts_for_connect.fetch_add(1, Ordering::SeqCst) + 1;
2428 connect(attempt)
2429 },
2430 async move {
2431 let _ = registered_2_rx.await;
2432 },
2433 );
2434
2435 tokio::time::timeout(Duration::from_secs(60), run)
2438 .await
2439 .map_err(WorkerError::decode)??;
2440
2441 drop(log_sender);
2442 let mut hung_session_reports = 0_usize;
2443 while let Some(entry) = log_receiver.recv().await {
2444 if let SessionLog::Reported(2, _) = entry {
2445 hung_session_reports += 1;
2446 }
2447 }
2448 assert_eq!(
2449 hung_session_reports, 0,
2450 "the hung replay must not have produced a report"
2451 );
2452 assert_eq!(attempts.load(Ordering::SeqCst), 2);
2453 Ok(())
2454 }
2455
2456 #[tokio::test(start_paused = true)]
2460 async fn drain_cycles_reconnect_without_consuming_drop_budget() -> Result<(), WorkerError> {
2461 let worker = two_activity_worker_with(test_config_with(ReconnectConfig::new(
2462 Duration::from_millis(5),
2463 Duration::from_millis(20),
2464 2,
2465 )))?;
2466 let attempts = Arc::new(AtomicUsize::new(0));
2467 let (log_sender, mut log_receiver) = mpsc::unbounded_channel();
2468 let connect = {
2469 let attempts = Arc::clone(&attempts);
2470 move || {
2471 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
2472 let log = log_sender.clone();
2473 async move {
2474 if attempt <= 3 {
2475 Ok(ScriptedSession {
2476 index: attempt,
2477 log,
2478 events: vec![Ok(WorkerSessionEvent::Drain)],
2479 fail_reports: false,
2480 register_denial: None,
2481 delay_stream: None,
2482 })
2483 } else {
2484 Ok(ScriptedSession {
2485 index: attempt,
2486 log,
2487 events: Vec::new(),
2488 fail_reports: false,
2489 register_denial: Some(tonic::Status::permission_denied(
2490 "namespace `payments` revoked for subject `worker-a`",
2491 )),
2492 delay_stream: None,
2493 })
2494 }
2495 }
2496 }
2497 };
2498
2499 let result = worker
2500 .run_with_connector_until(connect, std::future::pending::<()>())
2501 .await;
2502
2503 assert_eq!(attempts.load(Ordering::SeqCst), 4);
2507 let Err(error) = result else {
2508 return Err(WorkerError::decode(UnexpectedSuccess));
2509 };
2510 assert!(matches!(
2511 error.grpc_status().map(tonic::Status::code),
2512 Some(tonic::Code::PermissionDenied)
2513 ));
2514 let mut registrations = 0_usize;
2515 while let Some(entry) = log_receiver.recv().await {
2516 if matches!(entry, SessionLog::Registered(..)) {
2517 registrations += 1;
2518 }
2519 }
2520 assert_eq!(registrations, 3, "every drain cycle must re-register");
2521 Ok(())
2522 }
2523
2524 #[tokio::test(start_paused = true)]
2530 async fn drain_latch_keeps_abrupt_post_drain_failures_unbudgeted() -> Result<(), WorkerError> {
2531 let workflow_id = WorkflowId::new_v4();
2532 let worker = Worker::builder(test_config_with(ReconnectConfig::new(
2537 Duration::from_millis(5),
2538 Duration::from_millis(20),
2539 2,
2540 )))
2541 .register_activity("slow_double", |input: TestInput, context| {
2542 Box::pin(async move {
2543 let _ = context;
2544 tokio::time::sleep(Duration::from_millis(1)).await;
2545 Ok(TestOutput {
2546 value: input.value * 2,
2547 })
2548 })
2549 })?
2550 .build()?;
2551 let attempts = Arc::new(AtomicUsize::new(0));
2552 let (log_sender, log_receiver) = mpsc::unbounded_channel();
2553 let connect = {
2554 let attempts = Arc::clone(&attempts);
2555 let workflow_id = workflow_id.clone();
2556 move || {
2557 let attempt = attempts.fetch_add(1, Ordering::SeqCst) + 1;
2558 let log = log_sender.clone();
2559 let attempt_u64 = u64::try_from(attempt).unwrap_or(u64::MAX);
2560 let activity_id = ActivityId::from_sequence_position(attempt_u64);
2561 let task = proto_task(workflow_id.clone(), activity_id.clone(), "slow_double", 21);
2562 async move {
2563 if attempt <= 3 {
2564 Ok(LatchKind::Latch(DrainLatchSession {
2565 events: vec![
2566 Ok(WorkerSessionEvent::Task(task)),
2567 Ok(WorkerSessionEvent::Drain),
2568 ],
2569 fail_id: activity_id,
2570 }))
2571 } else {
2572 Ok(LatchKind::Deny(ScriptedSession {
2573 index: attempt,
2574 log,
2575 events: Vec::new(),
2576 fail_reports: false,
2577 register_denial: Some(tonic::Status::permission_denied(
2578 "namespace `payments` revoked for subject `worker-a`",
2579 )),
2580 delay_stream: None,
2581 }))
2582 }
2583 }
2584 }
2585 };
2586
2587 let result = worker
2588 .run_with_connector_until(connect, std::future::pending::<()>())
2589 .await;
2590
2591 assert_eq!(attempts.load(Ordering::SeqCst), 4);
2594 let Err(error) = result else {
2595 return Err(WorkerError::decode(UnexpectedSuccess));
2596 };
2597 assert!(matches!(
2598 error.grpc_status().map(tonic::Status::code),
2599 Some(tonic::Code::PermissionDenied)
2600 ));
2601 drop(log_receiver);
2602 Ok(())
2603 }
2604
2605 #[tokio::test]
2609 async fn shutdown_during_post_drain_backoff_returns_ok_promptly() -> Result<(), WorkerError> {
2610 let worker = two_activity_worker_with(test_config_with(ReconnectConfig::new(
2611 Duration::from_secs(5),
2612 Duration::from_secs(10),
2613 5,
2614 )))?;
2615 let attempts = Arc::new(AtomicUsize::new(0));
2616 let (log_sender, log_receiver) = mpsc::unbounded_channel();
2617 let connect = {
2618 let attempts = Arc::clone(&attempts);
2619 move || {
2620 attempts.fetch_add(1, Ordering::SeqCst);
2621 let log = log_sender.clone();
2622 async move {
2623 Ok(ScriptedSession {
2624 index: 1,
2625 log,
2626 events: vec![Ok(WorkerSessionEvent::Drain)],
2627 fail_reports: false,
2628 register_denial: None,
2629 delay_stream: None,
2630 })
2631 }
2632 }
2633 };
2634 let shutdown = async {
2635 tokio::time::sleep(Duration::from_millis(50)).await;
2636 };
2637
2638 let run = worker.run_with_connector_until(connect, shutdown);
2641 tokio::time::timeout(Duration::from_millis(500), run)
2642 .await
2643 .map_err(WorkerError::decode)??;
2644
2645 assert_eq!(attempts.load(Ordering::SeqCst), 1);
2646 drop(log_receiver);
2647 Ok(())
2648 }
2649}