Skip to main content

aion_worker/protocol/
session.rs

1//! `WorkerSession` trait and gRPC-backed implementation.
2
3use std::collections::BTreeSet;
4use std::pin::Pin;
5
6use aion_core::{ActivityError, ActivityId, Payload, WorkflowId};
7use aion_proto::{
8    ProtoActivityId, ProtoActivityResult, ProtoActivityTask, ProtoHeartbeat, ProtoPayload,
9    ProtoWorkflowId, proto_activity_result,
10};
11use async_trait::async_trait;
12use futures::{Stream, StreamExt};
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15use tonic::{Request, metadata::MetadataValue, transport::Channel};
16
17use crate::config::WorkerConfig;
18use crate::error::{MissingActivityHandler, WorkerError};
19
20type GeneratedClient = aion_proto::generated::worker_protocol_client::WorkerProtocolClient<Channel>;
21
22/// Boxed receive stream returned by worker sessions.
23pub type WorkerTaskStream =
24    Pin<Box<dyn Stream<Item = Result<WorkerSessionEvent, WorkerError>> + Send>>;
25
26/// Event pushed by the worker session receive stream.
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum WorkerSessionEvent {
29    /// A new activity task to execute.
30    Task(ProtoActivityTask),
31    /// Server-initiated drain: the server is going away (restart, deploy,
32    /// rebalance). The worker finishes in-flight work, reports what it can,
33    /// stops expecting new tasks, and reconnects after the schedule's initial
34    /// backoff. A drain frame latches for the session: the eventual stream
35    /// end — clean or abrupt — is drain-class and consumes no drop budget.
36    Drain,
37    /// The server consumed the identified `ActivityResult` frame; the worker
38    /// may stop re-reporting it. Clears the matching unacked-tracker entry.
39    ResultAck {
40        /// Workflow owning the acknowledged result.
41        workflow_id: WorkflowId,
42        /// Activity whose result was acknowledged.
43        activity_id: ActivityId,
44    },
45    /// Cooperative cancellation for an in-flight activity.
46    ///
47    /// The current AW worker proto in this worktree does not yet carry this
48    /// frame, but fake sessions can emit it and the runtime handles it without
49    /// forcing task termination. When AW lands the wire variant,
50    /// `decode_server_message` should map it to this event.
51    Cancel {
52        /// Workflow owning the activity.
53        workflow_id: WorkflowId,
54        /// Activity to mark cancelled.
55        activity_id: ActivityId,
56    },
57}
58
59/// Transport abstraction for the AW-owned worker protocol.
60///
61/// The current `aion-proto` worker endpoint is `WorkerProtocol::StreamWorker`,
62/// a single bidirectional gRPC stream. These methods intentionally present the
63/// worker conversation as handshake/register/receive/report/heartbeat phases so
64/// execution machinery can be tested against fakes and never touches generated
65/// stubs directly. If AW changes the wire shape, this trait adapts in this module.
66#[async_trait]
67pub trait WorkerSession: Send {
68    /// Performs the worker handshake for the configured task queue and identity.
69    ///
70    /// Maps to transport/channel establishment for AW's `StreamWorker` RPC.
71    /// AW currently names the task-queue scope `namespace` and has no identity
72    /// field, so identity is retained at this SDK boundary until the wire adds
73    /// a corresponding shape.
74    async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError>;
75
76    /// Registers activity-type names implemented by this worker.
77    ///
78    /// Maps to opening AW's `StreamWorker` RPC with `RegisterWorker` queued as
79    /// the mandatory first frame and then awaiting the server's `RegisterAck`
80    /// — the guaranteed first frame on the response stream. Registration
81    /// succeeds only when the ack arrives; a denial fails the RPC with a gRPC
82    /// error status (`PermissionDenied` / `Unauthenticated`), and an ack that
83    /// does not arrive within the reconnect policy's `max_backoff` is a
84    /// retryable registration failure. The caller supplies
85    /// `available_handlers` so registration can be rejected before serving if
86    /// any requested name lacks a handler.
87    async fn register(
88        &mut self,
89        activity_types: Vec<String>,
90        available_handlers: &BTreeSet<String>,
91    ) -> Result<(), WorkerError>;
92
93    /// Opens the receive side of AW's `StreamWorker` RPC and yields pushed tasks.
94    fn receive_tasks(&mut self) -> WorkerTaskStream;
95
96    /// Reports successful activity output via `WorkerToServer.result`.
97    async fn report_result(
98        &mut self,
99        workflow_id: WorkflowId,
100        activity_id: ActivityId,
101        result: Payload,
102    ) -> Result<(), WorkerError>;
103
104    /// Reports explicit activity failure via `WorkerToServer.result`.
105    async fn report_failure(
106        &mut self,
107        workflow_id: WorkflowId,
108        activity_id: ActivityId,
109        failure: ActivityError,
110    ) -> Result<(), WorkerError>;
111
112    /// Sends cooperative progress via `WorkerToServer.heartbeat`.
113    async fn send_heartbeat(
114        &mut self,
115        workflow_id: WorkflowId,
116        activity_id: ActivityId,
117        progress: Option<Payload>,
118    ) -> Result<(), WorkerError>;
119}
120
121/// Validates that every requested activity type has a registered handler.
122///
123/// # Errors
124///
125/// Returns [`WorkerError::Registration`] for the first missing handler name.
126pub fn validate_activity_handlers(
127    activity_types: &[String],
128    available_handlers: &BTreeSet<String>,
129) -> Result<(), WorkerError> {
130    if let Some(activity_type) = activity_types
131        .iter()
132        .find(|activity_type| !available_handlers.contains(*activity_type))
133    {
134        return Err(WorkerError::registration(MissingActivityHandler {
135            activity_type: activity_type.clone(),
136        }));
137    }
138
139    Ok(())
140}
141
142/// Server-assigned registration facts carried by the `RegisterAck` frame.
143#[derive(Clone, Debug, PartialEq, Eq)]
144pub struct RegisteredSessionInfo {
145    /// Server-assigned stream identifier, for correlating worker logs with
146    /// server logs (`worker_id=3 lost`).
147    pub worker_id: u64,
148    /// The namespace the registration was authorized against.
149    pub namespace: String,
150    /// The server's operator-configured liveness window: an in-flight
151    /// activity must heartbeat at least this often or be declared lost.
152    pub heartbeat_window: std::time::Duration,
153}
154
155/// gRPC-backed [`WorkerSession`] using `aion-proto` generated tonic stubs.
156pub struct GrpcWorkerSession {
157    config: WorkerConfig,
158    activity_types: Vec<String>,
159    client: Option<GeneratedClient>,
160    sender: Option<mpsc::Sender<aion_proto::generated::WorkerToServer>>,
161    receiver: Option<tonic::codec::Streaming<aion_proto::generated::ServerToWorker>>,
162    registered_info: Option<RegisteredSessionInfo>,
163}
164
165impl GrpcWorkerSession {
166    /// Connects to the configured worker endpoint.
167    ///
168    /// Opaque credentials are accepted by [`WorkerConfig`] but the current AW
169    /// worker proto does not define a credential metadata convention, so no
170    /// authentication scheme is interpreted here.
171    ///
172    /// # Errors
173    ///
174    /// Returns [`WorkerError::Connect`] if tonic cannot create the channel.
175    pub async fn connect(config: WorkerConfig) -> Result<Self, WorkerError> {
176        let client = GeneratedClient::connect(config.endpoint.clone())
177            .await
178            .map_err(|source| WorkerError::Connect { source })?;
179
180        Ok(Self {
181            config,
182            activity_types: Vec::new(),
183            client: Some(client),
184            sender: None,
185            receiver: None,
186            registered_info: None,
187        })
188    }
189
190    /// Creates a session from an existing tonic channel.
191    #[must_use]
192    pub fn from_channel(config: WorkerConfig, channel: Channel) -> Self {
193        Self {
194            config,
195            activity_types: Vec::new(),
196            client: Some(GeneratedClient::new(channel)),
197            sender: None,
198            receiver: None,
199            registered_info: None,
200        }
201    }
202
203    /// Server-assigned registration facts from the `RegisterAck`, available
204    /// once [`WorkerSession::register`] has succeeded.
205    #[must_use]
206    pub const fn registered_info(&self) -> Option<&RegisteredSessionInfo> {
207        self.registered_info.as_ref()
208    }
209
210    /// Opens AW's `StreamWorker` RPC with `RegisterWorker` queued as the first
211    /// outbound frame and awaits the server's `RegisterAck`.
212    ///
213    /// The server reads `RegisterWorker` from the inbound stream *before* it
214    /// returns its response stream (and therefore before tonic receives
215    /// response headers), so the frame must already be queued when the RPC is
216    /// issued. Awaiting `stream_worker` before sending `RegisterWorker`
217    /// deadlocks: the client waits for headers the server withholds until it
218    /// has read the registration.
219    ///
220    /// Registration succeeds only when the server's `RegisterAck` — its
221    /// guaranteed first response frame — arrives. The ack wait is bounded by
222    /// the reconnect policy's `max_backoff` (the operator's own definition of
223    /// the longest tolerable pause); a timeout, a non-ack first frame, or a
224    /// stream that ends before the ack is a retryable registration failure.
225    /// Denials surface as the RPC's gRPC error status exactly as before.
226    async fn open_registered_stream(
227        &mut self,
228        register: aion_proto::generated::RegisterWorker,
229    ) -> Result<(), WorkerError> {
230        let client = self.client.as_mut().ok_or_else(|| {
231            WorkerError::registration(SessionStateError {
232                message: String::from("worker session has not completed its handshake"),
233            })
234        })?;
235        let (sender, outbound) = mpsc::channel(16);
236        sender
237            .try_send(aion_proto::generated::WorkerToServer {
238                message: Some(aion_proto::generated::worker_to_server::Message::Register(
239                    register,
240                )),
241            })
242            .map_err(|_| {
243                WorkerError::registration(SessionStateError {
244                    message: String::from(
245                        "could not queue RegisterWorker as the first stream frame",
246                    ),
247                })
248            })?;
249        let mut request = Request::new(ReceiverStream::new(outbound));
250        apply_auth_metadata(request.metadata_mut(), &self.config)?;
251        let response = client
252            .stream_worker(request)
253            .await
254            .map_err(registration_denial_error)?;
255        let mut receiver = response.into_inner();
256
257        let first = tokio::time::timeout(self.config.reconnect.max_backoff, receiver.message())
258            .await
259            .map_err(|_| {
260                WorkerError::registration(SessionStateError {
261                    message: format!(
262                        "server did not acknowledge registration within {:?}",
263                        self.config.reconnect.max_backoff
264                    ),
265                })
266            })?
267            .map_err(registration_denial_error)?;
268        let ack = match first.and_then(|frame| frame.message) {
269            Some(aion_proto::generated::server_to_worker::Message::RegisterAck(ack)) => ack,
270            Some(_) => {
271                return Err(WorkerError::decode(SessionStateError {
272                    message: String::from(
273                        "protocol violation: server sent a non-RegisterAck frame before \
274                         acknowledging registration",
275                    ),
276                }));
277            }
278            None => {
279                return Err(WorkerError::registration(SessionStateError {
280                    message: String::from(
281                        "server ended the stream before acknowledging registration",
282                    ),
283                }));
284            }
285        };
286
287        self.registered_info = Some(RegisteredSessionInfo {
288            worker_id: ack.worker_id,
289            namespace: ack.namespace,
290            heartbeat_window: std::time::Duration::from_millis(ack.heartbeat_window_ms),
291        });
292        self.sender = Some(sender);
293        self.receiver = Some(receiver);
294        Ok(())
295    }
296
297    /// Sends one frame with a per-send deadline of the reconnect policy's
298    /// `max_backoff`: a send that outlives the operator's longest tolerable
299    /// pause is, by that same definition, a dead session and surfaces as a
300    /// retryable transport error instead of hanging the worker forever.
301    async fn send_to_server(
302        &self,
303        message: aion_proto::generated::worker_to_server::Message,
304    ) -> Result<(), WorkerError> {
305        let sender = self.sender.as_ref().ok_or_else(|| {
306            WorkerError::registration(SessionStateError {
307                message: String::from("worker stream has not been opened"),
308            })
309        })?;
310        let send = sender.send(aion_proto::generated::WorkerToServer {
311            message: Some(message),
312        });
313        tokio::time::timeout(self.config.reconnect.max_backoff, send)
314            .await
315            .map_err(|_| WorkerError::Transport {
316                source: tonic::Status::unavailable(format!(
317                    "worker stream send did not complete within {:?}",
318                    self.config.reconnect.max_backoff
319                )),
320            })?
321            .map_err(|source| WorkerError::Transport {
322                source: tonic::Status::unavailable(format!("worker stream send failed: {source}")),
323            })
324    }
325}
326
327/// Maps the `StreamWorker` RPC's rejection status to the worker error taxonomy.
328///
329/// The server validates stream metadata (credentials) and the `RegisterWorker`
330/// frame before returning response headers, so both failure classes surface
331/// from the same await: `Unauthenticated` is a credential/handshake rejection,
332/// everything else is a registration outcome (`PermissionDenied` for an
333/// ungranted namespace, `Unavailable` for transient transport faults). Both
334/// shapes preserve the status for `WorkerError::grpc_status` / `is_retryable`.
335fn registration_denial_error(status: tonic::Status) -> WorkerError {
336    if status.code() == tonic::Code::Unauthenticated {
337        WorkerError::Handshake { source: status }
338    } else {
339        WorkerError::Registration {
340            source: Box::new(status),
341        }
342    }
343}
344
345fn apply_auth_metadata(
346    metadata: &mut tonic::metadata::MetadataMap,
347    config: &WorkerConfig,
348) -> Result<(), WorkerError> {
349    let namespace =
350        MetadataValue::try_from(config.namespace.as_str()).map_err(|_| WorkerError::Handshake {
351            source: tonic::Status::invalid_argument("worker namespace is not valid gRPC metadata"),
352        })?;
353    let subject =
354        MetadataValue::try_from(config.subject.as_str()).map_err(|_| WorkerError::Handshake {
355            source: tonic::Status::invalid_argument("worker subject is not valid gRPC metadata"),
356        })?;
357    metadata.insert("x-aion-namespaces", namespace);
358    metadata.insert("x-aion-subject", subject);
359    Ok(())
360}
361
362#[async_trait]
363impl WorkerSession for GrpcWorkerSession {
364    async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
365        self.config = config.clone();
366        if self.client.is_none() {
367            self.client = Some(
368                GeneratedClient::connect(self.config.endpoint.clone())
369                    .await
370                    .map_err(|source| WorkerError::Connect { source })?,
371            );
372        }
373        Ok(())
374    }
375
376    async fn register(
377        &mut self,
378        activity_types: Vec<String>,
379        available_handlers: &BTreeSet<String>,
380    ) -> Result<(), WorkerError> {
381        validate_activity_handlers(&activity_types, available_handlers)?;
382        self.activity_types.clone_from(&activity_types);
383
384        let register = aion_proto::generated::RegisterWorker {
385            namespace: self.config.task_queue.clone(),
386            activity_types,
387        };
388        self.open_registered_stream(register).await
389    }
390
391    fn receive_tasks(&mut self) -> WorkerTaskStream {
392        match self.receiver.take() {
393            Some(receiver) => Box::pin(receiver.filter_map(|message| async move {
394                Some(match message {
395                    Ok(server_message) => decode_server_message(server_message),
396                    Err(source) => Err(WorkerError::Transport { source }),
397                })
398            })),
399            None => Box::pin(futures::stream::iter([Err(WorkerError::Transport {
400                source: tonic::Status::failed_precondition(
401                    "worker receive stream has not been opened",
402                ),
403            })])),
404        }
405    }
406
407    async fn report_result(
408        &mut self,
409        workflow_id: WorkflowId,
410        activity_id: ActivityId,
411        result: Payload,
412    ) -> Result<(), WorkerError> {
413        let result = ProtoActivityResult {
414            workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
415            activity_id: Some(ProtoActivityId::from(activity_id)),
416            outcome: Some(proto_activity_result::Outcome::Result(ProtoPayload::from(
417                result,
418            ))),
419        };
420        self.send_to_server(aion_proto::generated::worker_to_server::Message::Result(
421            generated_activity_result(result),
422        ))
423        .await
424    }
425
426    async fn report_failure(
427        &mut self,
428        workflow_id: WorkflowId,
429        activity_id: ActivityId,
430        failure: ActivityError,
431    ) -> Result<(), WorkerError> {
432        let result = ProtoActivityResult {
433            workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
434            activity_id: Some(ProtoActivityId::from(activity_id)),
435            outcome: Some(proto_activity_result::Outcome::Error(failure.into())),
436        };
437        self.send_to_server(aion_proto::generated::worker_to_server::Message::Result(
438            generated_activity_result(result),
439        ))
440        .await
441    }
442
443    async fn send_heartbeat(
444        &mut self,
445        workflow_id: WorkflowId,
446        activity_id: ActivityId,
447        progress: Option<Payload>,
448    ) -> Result<(), WorkerError> {
449        let heartbeat = ProtoHeartbeat {
450            workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
451            activity_id: Some(ProtoActivityId::from(activity_id)),
452            progress: progress.map(ProtoPayload::from),
453        };
454        self.send_to_server(aion_proto::generated::worker_to_server::Message::Heartbeat(
455            generated_heartbeat(heartbeat),
456        ))
457        .await
458    }
459}
460
461fn decode_server_message(
462    message: aion_proto::generated::ServerToWorker,
463) -> Result<WorkerSessionEvent, WorkerError> {
464    match message.message {
465        Some(aion_proto::generated::server_to_worker::Message::Task(task)) => {
466            Ok(WorkerSessionEvent::Task(proto_task(task)))
467        }
468        Some(aion_proto::generated::server_to_worker::Message::Drain(_)) => {
469            Ok(WorkerSessionEvent::Drain)
470        }
471        Some(aion_proto::generated::server_to_worker::Message::ResultAck(ack)) => {
472            decode_result_ack(ack)
473        }
474        Some(aion_proto::generated::server_to_worker::Message::RegisterAck(_)) => {
475            // The ack is consumed inside `open_registered_stream`; a second
476            // one mid-stream is a server ordering bug that must surface.
477            Err(WorkerError::decode(SessionStateError {
478                message: String::from(
479                    "protocol violation: RegisterAck received after registration completed",
480                ),
481            }))
482        }
483        None => Err(WorkerError::decode(SessionStateError {
484            message: String::from("server-to-worker message was empty"),
485        })),
486    }
487}
488
489fn decode_result_ack(
490    ack: aion_proto::generated::ResultAck,
491) -> Result<WorkerSessionEvent, WorkerError> {
492    let workflow_id = ack
493        .workflow_id
494        .ok_or_else(|| {
495            WorkerError::decode(SessionStateError {
496                message: String::from("result ack workflow_id is missing"),
497            })
498        })
499        .and_then(|id| {
500            WorkflowId::try_from(ProtoWorkflowId { uuid: id.uuid }).map_err(|source| {
501                WorkerError::decode(SessionStateError {
502                    message: format!("result ack workflow_id is invalid: {source}"),
503                })
504            })
505        })?;
506    let activity_id = ack
507        .activity_id
508        .map(|id| ActivityId::from_sequence_position(id.sequence_position))
509        .ok_or_else(|| {
510            WorkerError::decode(SessionStateError {
511                message: String::from("result ack activity_id is missing"),
512            })
513        })?;
514    Ok(WorkerSessionEvent::ResultAck {
515        workflow_id,
516        activity_id,
517    })
518}
519
520fn generated_activity_result(value: ProtoActivityResult) -> aion_proto::generated::ActivityResult {
521    aion_proto::generated::ActivityResult {
522        workflow_id: value.workflow_id.map(generated_workflow_id),
523        activity_id: value.activity_id.map(generated_activity_id),
524        outcome: value.outcome.map(|outcome| match outcome {
525            proto_activity_result::Outcome::Result(result) => {
526                aion_proto::generated::activity_result::Outcome::Result(generated_payload(result))
527            }
528            proto_activity_result::Outcome::Error(error) => {
529                aion_proto::generated::activity_result::Outcome::Error(generated_error(error))
530            }
531        }),
532    }
533}
534
535fn generated_heartbeat(value: ProtoHeartbeat) -> aion_proto::generated::Heartbeat {
536    aion_proto::generated::Heartbeat {
537        workflow_id: value.workflow_id.map(generated_workflow_id),
538        activity_id: value.activity_id.map(generated_activity_id),
539        progress: value.progress.map(generated_payload),
540    }
541}
542
543fn proto_task(value: aion_proto::generated::ActivityTask) -> ProtoActivityTask {
544    ProtoActivityTask {
545        workflow_id: value.workflow_id.map(proto_workflow_id),
546        activity_id: value.activity_id.map(proto_activity_id),
547        activity_type: value.activity_type,
548        input: value.input.map(proto_payload),
549        attempt: value.attempt,
550    }
551}
552
553fn generated_payload(value: ProtoPayload) -> aion_proto::generated::Payload {
554    aion_proto::generated::Payload {
555        content_type: value.content_type,
556        bytes: value.bytes,
557    }
558}
559
560fn proto_payload(value: aion_proto::generated::Payload) -> ProtoPayload {
561    ProtoPayload {
562        content_type: value.content_type,
563        bytes: value.bytes,
564    }
565}
566
567fn generated_workflow_id(value: ProtoWorkflowId) -> aion_proto::generated::WorkflowId {
568    aion_proto::generated::WorkflowId { uuid: value.uuid }
569}
570
571fn proto_workflow_id(value: aion_proto::generated::WorkflowId) -> ProtoWorkflowId {
572    ProtoWorkflowId { uuid: value.uuid }
573}
574
575fn generated_activity_id(value: ProtoActivityId) -> aion_proto::generated::ActivityId {
576    aion_proto::generated::ActivityId {
577        sequence_position: value.sequence_position,
578    }
579}
580
581fn proto_activity_id(value: aion_proto::generated::ActivityId) -> ProtoActivityId {
582    ProtoActivityId {
583        sequence_position: value.sequence_position,
584    }
585}
586
587fn generated_error(value: aion_proto::ProtoActivityError) -> aion_proto::generated::ActivityError {
588    aion_proto::generated::ActivityError {
589        kind: value.kind,
590        message: value.message,
591        details: value.details.map(generated_payload),
592    }
593}
594
595#[derive(thiserror::Error, Debug)]
596#[error("{message}")]
597struct SessionStateError {
598    message: String,
599}
600
601#[cfg(test)]
602mod tests {
603    use std::collections::BTreeSet;
604
605    use aion_proto::ProtoActivityTask;
606    use async_trait::async_trait;
607    use futures::{StreamExt, stream};
608
609    use super::{
610        WorkerSession, WorkerSessionEvent, WorkerTaskStream, apply_auth_metadata,
611        validate_activity_handlers,
612    };
613    use crate::error::WorkerError;
614    use crate::{ReconnectConfig, WorkerConfig};
615
616    #[derive(Default)]
617    struct FakeSession {
618        handshakes: Vec<(String, String)>,
619        registrations: Vec<Vec<String>>,
620    }
621
622    #[async_trait]
623    impl WorkerSession for FakeSession {
624        async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
625            self.handshakes
626                .push((config.task_queue.clone(), config.identity.clone()));
627            Ok(())
628        }
629
630        async fn register(
631            &mut self,
632            activity_types: Vec<String>,
633            available_handlers: &BTreeSet<String>,
634        ) -> Result<(), WorkerError> {
635            validate_activity_handlers(&activity_types, available_handlers)?;
636            self.registrations.push(activity_types);
637            Ok(())
638        }
639
640        fn receive_tasks(&mut self) -> WorkerTaskStream {
641            Box::pin(stream::iter([Ok(WorkerSessionEvent::Task(
642                ProtoActivityTask {
643                    workflow_id: None,
644                    activity_id: None,
645                    activity_type: String::from("charge-card"),
646                    input: None,
647                    attempt: 1,
648                },
649            ))]))
650        }
651
652        async fn report_result(
653            &mut self,
654            workflow_id: aion_core::WorkflowId,
655            activity_id: aion_core::ActivityId,
656            result: aion_core::Payload,
657        ) -> Result<(), WorkerError> {
658            drop((workflow_id, activity_id, result));
659            Ok(())
660        }
661
662        async fn report_failure(
663            &mut self,
664            workflow_id: aion_core::WorkflowId,
665            activity_id: aion_core::ActivityId,
666            failure: aion_core::ActivityError,
667        ) -> Result<(), WorkerError> {
668            drop((workflow_id, activity_id, failure));
669            Ok(())
670        }
671
672        async fn send_heartbeat(
673            &mut self,
674            workflow_id: aion_core::WorkflowId,
675            activity_id: aion_core::ActivityId,
676            progress: Option<aion_core::Payload>,
677        ) -> Result<(), WorkerError> {
678            drop((workflow_id, activity_id, progress));
679            Ok(())
680        }
681    }
682
683    #[test]
684    fn apply_auth_metadata_sets_worker_authorization_headers() -> Result<(), WorkerError> {
685        let config = WorkerConfig::builder()
686            .endpoint("http://127.0.0.1:50051")
687            .task_queue("payments")
688            .identity("worker-a")
689            .max_concurrency(4)
690            .reconnect_initial_backoff(std::time::Duration::from_millis(5))
691            .reconnect_max_backoff(std::time::Duration::from_millis(20))
692            .reconnect_max_attempts(3)
693            .namespace("payments")
694            .subject("worker-a")
695            .build()
696            .map_err(WorkerError::registration)?;
697        let mut metadata = tonic::metadata::MetadataMap::new();
698
699        apply_auth_metadata(&mut metadata, &config)?;
700
701        assert_eq!(
702            metadata
703                .get("x-aion-namespaces")
704                .and_then(|value| value.to_str().ok()),
705            Some("payments")
706        );
707        assert_eq!(
708            metadata
709                .get("x-aion-subject")
710                .and_then(|value| value.to_str().ok()),
711            Some("worker-a")
712        );
713        Ok(())
714    }
715
716    #[tokio::test]
717    async fn fake_session_records_handshake_and_registration() -> Result<(), WorkerError> {
718        let config = WorkerConfig::new(
719            "http://127.0.0.1:50051",
720            "payments",
721            "worker-a",
722            4,
723            ReconnectConfig::new(
724                std::time::Duration::from_millis(5),
725                std::time::Duration::from_millis(20),
726                3,
727            ),
728            None,
729        );
730        let activity_types = vec![String::from("charge-card"), String::from("send-email")];
731        let handlers = activity_types.iter().cloned().collect::<BTreeSet<_>>();
732        let mut session = FakeSession::default();
733
734        session.handshake(&config).await?;
735        session.register(activity_types.clone(), &handlers).await?;
736        let received = session.receive_tasks().next().await;
737
738        assert_eq!(
739            session.handshakes,
740            vec![(String::from("payments"), String::from("worker-a"))]
741        );
742        assert_eq!(session.registrations, vec![activity_types]);
743        assert!(received.is_some());
744
745        Ok(())
746    }
747
748    /// Brief test 16: a report send that never completes (server stopped
749    /// reading; outbound channel full) times out retryably at the reconnect
750    /// policy's `max_backoff` on a paused clock — the worker never hangs.
751    #[tokio::test(start_paused = true)]
752    async fn report_send_times_out_retryably_at_max_backoff() -> Result<(), WorkerError> {
753        let config = WorkerConfig::new(
754            "http://127.0.0.1:50051",
755            "payments",
756            "worker-a",
757            1,
758            ReconnectConfig::new(
759                std::time::Duration::from_millis(5),
760                std::time::Duration::from_millis(20),
761                3,
762            ),
763            None,
764        );
765        let (sender, receiver) = tokio::sync::mpsc::channel(1);
766        // Fill the channel so the next send blocks forever, modelling a
767        // server that stopped draining its receive side.
768        sender
769            .try_send(aion_proto::generated::WorkerToServer { message: None })
770            .map_err(WorkerError::decode)?;
771        let mut session = super::GrpcWorkerSession {
772            config,
773            activity_types: Vec::new(),
774            client: None,
775            sender: Some(sender),
776            receiver: None,
777            registered_info: None,
778        };
779
780        let result = session
781            .report_result(
782                aion_core::WorkflowId::new_v4(),
783                aion_core::ActivityId::from_sequence_position(1),
784                aion_core::Payload::new(aion_core::ContentType::Json, b"{}".to_vec()),
785            )
786            .await;
787
788        let Err(error) = result else {
789            return Err(WorkerError::Transport {
790                source: tonic::Status::internal("a hung send must time out, not hang"),
791            });
792        };
793        assert!(
794            matches!(error, WorkerError::Transport { .. }),
795            "send deadline elapse must be a retryable transport error: {error}"
796        );
797        assert!(error.is_retryable());
798        assert!(
799            error.to_string().contains("did not complete"),
800            "the error must name the deadline: {error}"
801        );
802        drop(receiver);
803        Ok(())
804    }
805
806    #[test]
807    fn registration_rejects_activity_without_handler() {
808        let activity_types = vec![String::from("charge-card"), String::from("send-email")];
809        let handlers = [String::from("charge-card")]
810            .into_iter()
811            .collect::<BTreeSet<_>>();
812
813        let result = validate_activity_handlers(&activity_types, &handlers);
814        assert!(result.is_err());
815        let error = match result {
816            Ok(()) => return,
817            Err(error) => error,
818        };
819
820        assert_eq!(
821            error.to_string(),
822            "worker registration failed: activity type `send-email` has no registered handler"
823        );
824    }
825}