Skip to main content

aion_worker/protocol/
session.rs

1//! `WorkerSession` trait and gRPC-backed implementation.
2
3use std::collections::BTreeSet;
4use std::pin::Pin;
5
6use aion_core::{ActivityError, ActivityId, Payload, WorkflowId};
7use aion_proto::{
8    ProtoActivityId, ProtoActivityResult, ProtoActivityTask, ProtoHeartbeat, ProtoPayload,
9    ProtoWorkflowId, proto_activity_result,
10};
11use async_trait::async_trait;
12use futures::{Stream, StreamExt};
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15use tonic::{Request, metadata::MetadataValue, transport::Channel};
16
17use crate::config::WorkerConfig;
18use crate::error::{MissingActivityHandler, WorkerError};
19
20type GeneratedClient = aion_proto::generated::worker_protocol_client::WorkerProtocolClient<Channel>;
21
22/// Boxed receive stream returned by worker sessions.
23pub type WorkerTaskStream =
24    Pin<Box<dyn Stream<Item = Result<WorkerSessionEvent, WorkerError>> + Send>>;
25
26/// Event pushed by the worker session receive stream.
27#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum WorkerSessionEvent {
29    /// A new activity task to execute.
30    Task(ProtoActivityTask),
31    /// Server is draining and will not dispatch more activity tasks on this stream.
32    Drain,
33    /// Cooperative cancellation for an in-flight activity.
34    ///
35    /// The current AW worker proto in this worktree does not yet carry this
36    /// frame, but fake sessions can emit it and the runtime handles it without
37    /// forcing task termination. When AW lands the wire variant,
38    /// `decode_server_message` should map it to this event.
39    Cancel {
40        /// Workflow owning the activity.
41        workflow_id: WorkflowId,
42        /// Activity to mark cancelled.
43        activity_id: ActivityId,
44    },
45}
46
47/// Transport abstraction for the AW-owned worker protocol.
48///
49/// The current `aion-proto` worker endpoint is `WorkerProtocol::StreamWorker`,
50/// a single bidirectional gRPC stream. These methods intentionally present the
51/// worker conversation as handshake/register/receive/report/heartbeat phases so
52/// execution machinery can be tested against fakes and never touches generated
53/// stubs directly. If AW changes the wire shape, this trait adapts in this module.
54#[async_trait]
55pub trait WorkerSession: Send {
56    /// Performs the worker handshake for the configured task queue and identity.
57    ///
58    /// Maps to transport/channel establishment for AW's `StreamWorker` RPC.
59    /// AW currently names the task-queue scope `namespace` and has no identity
60    /// field, so identity is retained at this SDK boundary until the wire adds
61    /// a corresponding shape.
62    async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError>;
63
64    /// Registers activity-type names implemented by this worker.
65    ///
66    /// Maps to opening AW's `StreamWorker` RPC with `RegisterWorker` queued as
67    /// the mandatory first frame: the server reads that frame before it sends
68    /// response headers, and there is no registration-ack frame in the wire
69    /// protocol — header receipt is the registration outcome. The caller
70    /// supplies `available_handlers` so registration can be rejected before
71    /// serving if any requested name lacks a handler.
72    async fn register(
73        &mut self,
74        activity_types: Vec<String>,
75        available_handlers: &BTreeSet<String>,
76    ) -> Result<(), WorkerError>;
77
78    /// Opens the receive side of AW's `StreamWorker` RPC and yields pushed tasks.
79    fn receive_tasks(&mut self) -> WorkerTaskStream;
80
81    /// Reports successful activity output via `WorkerToServer.result`.
82    async fn report_result(
83        &mut self,
84        workflow_id: WorkflowId,
85        activity_id: ActivityId,
86        result: Payload,
87    ) -> Result<(), WorkerError>;
88
89    /// Reports explicit activity failure via `WorkerToServer.result`.
90    async fn report_failure(
91        &mut self,
92        workflow_id: WorkflowId,
93        activity_id: ActivityId,
94        failure: ActivityError,
95    ) -> Result<(), WorkerError>;
96
97    /// Sends cooperative progress via `WorkerToServer.heartbeat`.
98    async fn send_heartbeat(
99        &mut self,
100        workflow_id: WorkflowId,
101        activity_id: ActivityId,
102        progress: Option<Payload>,
103    ) -> Result<(), WorkerError>;
104}
105
106/// Validates that every requested activity type has a registered handler.
107///
108/// # Errors
109///
110/// Returns [`WorkerError::Registration`] for the first missing handler name.
111pub fn validate_activity_handlers(
112    activity_types: &[String],
113    available_handlers: &BTreeSet<String>,
114) -> Result<(), WorkerError> {
115    if let Some(activity_type) = activity_types
116        .iter()
117        .find(|activity_type| !available_handlers.contains(*activity_type))
118    {
119        return Err(WorkerError::registration(MissingActivityHandler {
120            activity_type: activity_type.clone(),
121        }));
122    }
123
124    Ok(())
125}
126
127/// gRPC-backed [`WorkerSession`] using `aion-proto` generated tonic stubs.
128pub struct GrpcWorkerSession {
129    config: WorkerConfig,
130    activity_types: Vec<String>,
131    client: Option<GeneratedClient>,
132    sender: Option<mpsc::Sender<aion_proto::generated::WorkerToServer>>,
133    receiver: Option<tonic::codec::Streaming<aion_proto::generated::ServerToWorker>>,
134}
135
136impl GrpcWorkerSession {
137    /// Connects to the configured worker endpoint.
138    ///
139    /// Opaque credentials are accepted by [`WorkerConfig`] but the current AW
140    /// worker proto does not define a credential metadata convention, so no
141    /// authentication scheme is interpreted here.
142    ///
143    /// # Errors
144    ///
145    /// Returns [`WorkerError::Connect`] if tonic cannot create the channel.
146    pub async fn connect(config: WorkerConfig) -> Result<Self, WorkerError> {
147        let client = GeneratedClient::connect(config.endpoint.clone())
148            .await
149            .map_err(|source| WorkerError::Connect { source })?;
150
151        Ok(Self {
152            config,
153            activity_types: Vec::new(),
154            client: Some(client),
155            sender: None,
156            receiver: None,
157        })
158    }
159
160    /// Creates a session from an existing tonic channel.
161    #[must_use]
162    pub fn from_channel(config: WorkerConfig, channel: Channel) -> Self {
163        Self {
164            config,
165            activity_types: Vec::new(),
166            client: Some(GeneratedClient::new(channel)),
167            sender: None,
168            receiver: None,
169        }
170    }
171
172    /// Opens AW's `StreamWorker` RPC with `RegisterWorker` queued as the first
173    /// outbound frame.
174    ///
175    /// The server reads `RegisterWorker` from the inbound stream *before* it
176    /// returns its response stream (and therefore before tonic receives
177    /// response headers), so the frame must already be queued when the RPC is
178    /// issued. Awaiting `stream_worker` before sending `RegisterWorker`
179    /// deadlocks: the client waits for headers the server withholds until it
180    /// has read the registration. There is no registration-ack frame in the
181    /// wire protocol — receiving response headers *is* the successful
182    /// registration outcome, and a denial surfaces as the RPC's error status.
183    async fn open_registered_stream(
184        &mut self,
185        register: aion_proto::generated::RegisterWorker,
186    ) -> Result<(), WorkerError> {
187        let client = self.client.as_mut().ok_or_else(|| {
188            WorkerError::registration(SessionStateError {
189                message: String::from("worker session has not completed its handshake"),
190            })
191        })?;
192        let (sender, outbound) = mpsc::channel(16);
193        sender
194            .try_send(aion_proto::generated::WorkerToServer {
195                message: Some(aion_proto::generated::worker_to_server::Message::Register(
196                    register,
197                )),
198            })
199            .map_err(|_| {
200                WorkerError::registration(SessionStateError {
201                    message: String::from(
202                        "could not queue RegisterWorker as the first stream frame",
203                    ),
204                })
205            })?;
206        let mut request = Request::new(ReceiverStream::new(outbound));
207        apply_auth_metadata(request.metadata_mut(), &self.config)?;
208        let response = client
209            .stream_worker(request)
210            .await
211            .map_err(registration_denial_error)?;
212
213        self.sender = Some(sender);
214        self.receiver = Some(response.into_inner());
215        Ok(())
216    }
217
218    async fn send_to_server(
219        &self,
220        message: aion_proto::generated::worker_to_server::Message,
221    ) -> Result<(), WorkerError> {
222        let sender = self.sender.as_ref().ok_or_else(|| {
223            WorkerError::registration(SessionStateError {
224                message: String::from("worker stream has not been opened"),
225            })
226        })?;
227        sender
228            .send(aion_proto::generated::WorkerToServer {
229                message: Some(message),
230            })
231            .await
232            .map_err(|source| WorkerError::Transport {
233                source: tonic::Status::unavailable(format!("worker stream send failed: {source}")),
234            })
235    }
236}
237
238/// Maps the `StreamWorker` RPC's rejection status to the worker error taxonomy.
239///
240/// The server validates stream metadata (credentials) and the `RegisterWorker`
241/// frame before returning response headers, so both failure classes surface
242/// from the same await: `Unauthenticated` is a credential/handshake rejection,
243/// everything else is a registration outcome (`PermissionDenied` for an
244/// ungranted namespace, `Unavailable` for transient transport faults). Both
245/// shapes preserve the status for `WorkerError::grpc_status` / `is_retryable`.
246fn registration_denial_error(status: tonic::Status) -> WorkerError {
247    if status.code() == tonic::Code::Unauthenticated {
248        WorkerError::Handshake { source: status }
249    } else {
250        WorkerError::Registration {
251            source: Box::new(status),
252        }
253    }
254}
255
256fn apply_auth_metadata(
257    metadata: &mut tonic::metadata::MetadataMap,
258    config: &WorkerConfig,
259) -> Result<(), WorkerError> {
260    let namespace =
261        MetadataValue::try_from(config.namespace.as_str()).map_err(|_| WorkerError::Handshake {
262            source: tonic::Status::invalid_argument("worker namespace is not valid gRPC metadata"),
263        })?;
264    let subject =
265        MetadataValue::try_from(config.subject.as_str()).map_err(|_| WorkerError::Handshake {
266            source: tonic::Status::invalid_argument("worker subject is not valid gRPC metadata"),
267        })?;
268    metadata.insert("x-aion-namespaces", namespace);
269    metadata.insert("x-aion-subject", subject);
270    Ok(())
271}
272
273#[async_trait]
274impl WorkerSession for GrpcWorkerSession {
275    async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
276        self.config = config.clone();
277        if self.client.is_none() {
278            self.client = Some(
279                GeneratedClient::connect(self.config.endpoint.clone())
280                    .await
281                    .map_err(|source| WorkerError::Connect { source })?,
282            );
283        }
284        Ok(())
285    }
286
287    async fn register(
288        &mut self,
289        activity_types: Vec<String>,
290        available_handlers: &BTreeSet<String>,
291    ) -> Result<(), WorkerError> {
292        validate_activity_handlers(&activity_types, available_handlers)?;
293        self.activity_types.clone_from(&activity_types);
294
295        let register = aion_proto::generated::RegisterWorker {
296            namespace: self.config.task_queue.clone(),
297            activity_types,
298        };
299        self.open_registered_stream(register).await
300    }
301
302    fn receive_tasks(&mut self) -> WorkerTaskStream {
303        match self.receiver.take() {
304            Some(receiver) => Box::pin(receiver.filter_map(|message| async move {
305                Some(match message {
306                    Ok(server_message) => decode_server_message(server_message),
307                    Err(source) => Err(WorkerError::Transport { source }),
308                })
309            })),
310            None => Box::pin(futures::stream::iter([Err(WorkerError::Transport {
311                source: tonic::Status::failed_precondition(
312                    "worker receive stream has not been opened",
313                ),
314            })])),
315        }
316    }
317
318    async fn report_result(
319        &mut self,
320        workflow_id: WorkflowId,
321        activity_id: ActivityId,
322        result: Payload,
323    ) -> Result<(), WorkerError> {
324        let result = ProtoActivityResult {
325            workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
326            activity_id: Some(ProtoActivityId::from(activity_id)),
327            outcome: Some(proto_activity_result::Outcome::Result(ProtoPayload::from(
328                result,
329            ))),
330        };
331        self.send_to_server(aion_proto::generated::worker_to_server::Message::Result(
332            generated_activity_result(result),
333        ))
334        .await
335    }
336
337    async fn report_failure(
338        &mut self,
339        workflow_id: WorkflowId,
340        activity_id: ActivityId,
341        failure: ActivityError,
342    ) -> Result<(), WorkerError> {
343        let result = ProtoActivityResult {
344            workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
345            activity_id: Some(ProtoActivityId::from(activity_id)),
346            outcome: Some(proto_activity_result::Outcome::Error(failure.into())),
347        };
348        self.send_to_server(aion_proto::generated::worker_to_server::Message::Result(
349            generated_activity_result(result),
350        ))
351        .await
352    }
353
354    async fn send_heartbeat(
355        &mut self,
356        workflow_id: WorkflowId,
357        activity_id: ActivityId,
358        progress: Option<Payload>,
359    ) -> Result<(), WorkerError> {
360        let heartbeat = ProtoHeartbeat {
361            workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
362            activity_id: Some(ProtoActivityId::from(activity_id)),
363            progress: progress.map(ProtoPayload::from),
364        };
365        self.send_to_server(aion_proto::generated::worker_to_server::Message::Heartbeat(
366            generated_heartbeat(heartbeat),
367        ))
368        .await
369    }
370}
371
372fn decode_server_message(
373    message: aion_proto::generated::ServerToWorker,
374) -> Result<WorkerSessionEvent, WorkerError> {
375    match message.message {
376        Some(aion_proto::generated::server_to_worker::Message::Task(task)) => {
377            Ok(WorkerSessionEvent::Task(proto_task(task)))
378        }
379        Some(aion_proto::generated::server_to_worker::Message::Drain(_)) => {
380            Ok(WorkerSessionEvent::Drain)
381        }
382        None => Err(WorkerError::decode(SessionStateError {
383            message: String::from("server-to-worker message was empty"),
384        })),
385    }
386}
387
388fn generated_activity_result(value: ProtoActivityResult) -> aion_proto::generated::ActivityResult {
389    aion_proto::generated::ActivityResult {
390        workflow_id: value.workflow_id.map(generated_workflow_id),
391        activity_id: value.activity_id.map(generated_activity_id),
392        outcome: value.outcome.map(|outcome| match outcome {
393            proto_activity_result::Outcome::Result(result) => {
394                aion_proto::generated::activity_result::Outcome::Result(generated_payload(result))
395            }
396            proto_activity_result::Outcome::Error(error) => {
397                aion_proto::generated::activity_result::Outcome::Error(generated_error(error))
398            }
399        }),
400    }
401}
402
403fn generated_heartbeat(value: ProtoHeartbeat) -> aion_proto::generated::Heartbeat {
404    aion_proto::generated::Heartbeat {
405        workflow_id: value.workflow_id.map(generated_workflow_id),
406        activity_id: value.activity_id.map(generated_activity_id),
407        progress: value.progress.map(generated_payload),
408    }
409}
410
411fn proto_task(value: aion_proto::generated::ActivityTask) -> ProtoActivityTask {
412    ProtoActivityTask {
413        workflow_id: value.workflow_id.map(proto_workflow_id),
414        activity_id: value.activity_id.map(proto_activity_id),
415        activity_type: value.activity_type,
416        input: value.input.map(proto_payload),
417    }
418}
419
420fn generated_payload(value: ProtoPayload) -> aion_proto::generated::Payload {
421    aion_proto::generated::Payload {
422        content_type: value.content_type,
423        bytes: value.bytes,
424    }
425}
426
427fn proto_payload(value: aion_proto::generated::Payload) -> ProtoPayload {
428    ProtoPayload {
429        content_type: value.content_type,
430        bytes: value.bytes,
431    }
432}
433
434fn generated_workflow_id(value: ProtoWorkflowId) -> aion_proto::generated::WorkflowId {
435    aion_proto::generated::WorkflowId { uuid: value.uuid }
436}
437
438fn proto_workflow_id(value: aion_proto::generated::WorkflowId) -> ProtoWorkflowId {
439    ProtoWorkflowId { uuid: value.uuid }
440}
441
442fn generated_activity_id(value: ProtoActivityId) -> aion_proto::generated::ActivityId {
443    aion_proto::generated::ActivityId {
444        sequence_position: value.sequence_position,
445    }
446}
447
448fn proto_activity_id(value: aion_proto::generated::ActivityId) -> ProtoActivityId {
449    ProtoActivityId {
450        sequence_position: value.sequence_position,
451    }
452}
453
454fn generated_error(value: aion_proto::ProtoActivityError) -> aion_proto::generated::ActivityError {
455    aion_proto::generated::ActivityError {
456        kind: value.kind,
457        message: value.message,
458        details: value.details.map(generated_payload),
459    }
460}
461
462#[derive(thiserror::Error, Debug)]
463#[error("{message}")]
464struct SessionStateError {
465    message: String,
466}
467
468#[cfg(test)]
469mod tests {
470    use std::collections::BTreeSet;
471
472    use aion_proto::ProtoActivityTask;
473    use async_trait::async_trait;
474    use futures::{StreamExt, stream};
475
476    use super::{
477        WorkerSession, WorkerSessionEvent, WorkerTaskStream, apply_auth_metadata,
478        validate_activity_handlers,
479    };
480    use crate::error::WorkerError;
481    use crate::{ReconnectConfig, WorkerConfig};
482
483    #[derive(Default)]
484    struct FakeSession {
485        handshakes: Vec<(String, String)>,
486        registrations: Vec<Vec<String>>,
487    }
488
489    #[async_trait]
490    impl WorkerSession for FakeSession {
491        async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
492            self.handshakes
493                .push((config.task_queue.clone(), config.identity.clone()));
494            Ok(())
495        }
496
497        async fn register(
498            &mut self,
499            activity_types: Vec<String>,
500            available_handlers: &BTreeSet<String>,
501        ) -> Result<(), WorkerError> {
502            validate_activity_handlers(&activity_types, available_handlers)?;
503            self.registrations.push(activity_types);
504            Ok(())
505        }
506
507        fn receive_tasks(&mut self) -> WorkerTaskStream {
508            Box::pin(stream::iter([Ok(WorkerSessionEvent::Task(
509                ProtoActivityTask {
510                    workflow_id: None,
511                    activity_id: None,
512                    activity_type: String::from("charge-card"),
513                    input: None,
514                },
515            ))]))
516        }
517
518        async fn report_result(
519            &mut self,
520            workflow_id: aion_core::WorkflowId,
521            activity_id: aion_core::ActivityId,
522            result: aion_core::Payload,
523        ) -> Result<(), WorkerError> {
524            drop((workflow_id, activity_id, result));
525            Ok(())
526        }
527
528        async fn report_failure(
529            &mut self,
530            workflow_id: aion_core::WorkflowId,
531            activity_id: aion_core::ActivityId,
532            failure: aion_core::ActivityError,
533        ) -> Result<(), WorkerError> {
534            drop((workflow_id, activity_id, failure));
535            Ok(())
536        }
537
538        async fn send_heartbeat(
539            &mut self,
540            workflow_id: aion_core::WorkflowId,
541            activity_id: aion_core::ActivityId,
542            progress: Option<aion_core::Payload>,
543        ) -> Result<(), WorkerError> {
544            drop((workflow_id, activity_id, progress));
545            Ok(())
546        }
547    }
548
549    #[test]
550    fn apply_auth_metadata_sets_worker_authorization_headers() -> Result<(), WorkerError> {
551        let config = WorkerConfig::builder()
552            .endpoint("http://127.0.0.1:50051")
553            .task_queue("payments")
554            .identity("worker-a")
555            .max_concurrency(4)
556            .reconnect_initial_backoff(std::time::Duration::from_millis(5))
557            .reconnect_max_backoff(std::time::Duration::from_millis(20))
558            .reconnect_max_attempts(3)
559            .namespace("payments")
560            .subject("worker-a")
561            .build()
562            .map_err(WorkerError::registration)?;
563        let mut metadata = tonic::metadata::MetadataMap::new();
564
565        apply_auth_metadata(&mut metadata, &config)?;
566
567        assert_eq!(
568            metadata
569                .get("x-aion-namespaces")
570                .and_then(|value| value.to_str().ok()),
571            Some("payments")
572        );
573        assert_eq!(
574            metadata
575                .get("x-aion-subject")
576                .and_then(|value| value.to_str().ok()),
577            Some("worker-a")
578        );
579        Ok(())
580    }
581
582    #[tokio::test]
583    async fn fake_session_records_handshake_and_registration() -> Result<(), WorkerError> {
584        let config = WorkerConfig::new(
585            "http://127.0.0.1:50051",
586            "payments",
587            "worker-a",
588            4,
589            ReconnectConfig::new(
590                std::time::Duration::from_millis(5),
591                std::time::Duration::from_millis(20),
592                3,
593            ),
594            None,
595        );
596        let activity_types = vec![String::from("charge-card"), String::from("send-email")];
597        let handlers = activity_types.iter().cloned().collect::<BTreeSet<_>>();
598        let mut session = FakeSession::default();
599
600        session.handshake(&config).await?;
601        session.register(activity_types.clone(), &handlers).await?;
602        let received = session.receive_tasks().next().await;
603
604        assert_eq!(
605            session.handshakes,
606            vec![(String::from("payments"), String::from("worker-a"))]
607        );
608        assert_eq!(session.registrations, vec![activity_types]);
609        assert!(received.is_some());
610
611        Ok(())
612    }
613
614    #[test]
615    fn registration_rejects_activity_without_handler() {
616        let activity_types = vec![String::from("charge-card"), String::from("send-email")];
617        let handlers = [String::from("charge-card")]
618            .into_iter()
619            .collect::<BTreeSet<_>>();
620
621        let result = validate_activity_handlers(&activity_types, &handlers);
622        assert!(result.is_err());
623        let error = match result {
624            Ok(()) => return,
625            Err(error) => error,
626        };
627
628        assert_eq!(
629            error.to_string(),
630            "worker registration failed: activity type `send-email` has no registered handler"
631        );
632    }
633}