1use std::collections::BTreeSet;
4use std::pin::Pin;
5
6use aion_core::{ActivityError, ActivityId, Payload, WorkflowId};
7use aion_proto::{
8 ProtoActivityId, ProtoActivityResult, ProtoActivityTask, ProtoHeartbeat, ProtoPayload,
9 ProtoWorkflowId, proto_activity_result,
10};
11use async_trait::async_trait;
12use futures::{Stream, StreamExt};
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15use tonic::{Request, metadata::MetadataValue, transport::Channel};
16
17use crate::config::WorkerConfig;
18use crate::error::{MissingActivityHandler, WorkerError};
19
20type GeneratedClient = aion_proto::generated::worker_protocol_client::WorkerProtocolClient<Channel>;
21
22pub type WorkerTaskStream =
24 Pin<Box<dyn Stream<Item = Result<WorkerSessionEvent, WorkerError>> + Send>>;
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum WorkerSessionEvent {
29 Task(ProtoActivityTask),
31 Drain,
37 ResultAck {
40 workflow_id: WorkflowId,
42 activity_id: ActivityId,
44 },
45 Cancel {
52 workflow_id: WorkflowId,
54 activity_id: ActivityId,
56 },
57}
58
59#[async_trait]
67pub trait WorkerSession: Send {
68 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError>;
75
76 async fn register(
88 &mut self,
89 activity_types: Vec<String>,
90 available_handlers: &BTreeSet<String>,
91 ) -> Result<(), WorkerError>;
92
93 fn receive_tasks(&mut self) -> WorkerTaskStream;
95
96 async fn report_result(
98 &mut self,
99 workflow_id: WorkflowId,
100 activity_id: ActivityId,
101 result: Payload,
102 ) -> Result<(), WorkerError>;
103
104 async fn report_failure(
106 &mut self,
107 workflow_id: WorkflowId,
108 activity_id: ActivityId,
109 failure: ActivityError,
110 ) -> Result<(), WorkerError>;
111
112 async fn send_heartbeat(
114 &mut self,
115 workflow_id: WorkflowId,
116 activity_id: ActivityId,
117 progress: Option<Payload>,
118 ) -> Result<(), WorkerError>;
119}
120
121pub fn validate_activity_handlers(
127 activity_types: &[String],
128 available_handlers: &BTreeSet<String>,
129) -> Result<(), WorkerError> {
130 if let Some(activity_type) = activity_types
131 .iter()
132 .find(|activity_type| !available_handlers.contains(*activity_type))
133 {
134 return Err(WorkerError::registration(MissingActivityHandler {
135 activity_type: activity_type.clone(),
136 }));
137 }
138
139 Ok(())
140}
141
142#[derive(Clone, Debug, PartialEq, Eq)]
144pub struct RegisteredSessionInfo {
145 pub worker_id: u64,
148 pub namespace: String,
150 pub heartbeat_window: std::time::Duration,
153}
154
155pub struct GrpcWorkerSession {
157 config: WorkerConfig,
158 activity_types: Vec<String>,
159 client: Option<GeneratedClient>,
160 sender: Option<mpsc::Sender<aion_proto::generated::WorkerToServer>>,
161 receiver: Option<tonic::codec::Streaming<aion_proto::generated::ServerToWorker>>,
162 registered_info: Option<RegisteredSessionInfo>,
163}
164
165impl GrpcWorkerSession {
166 pub async fn connect(config: WorkerConfig) -> Result<Self, WorkerError> {
176 let client = GeneratedClient::connect(config.endpoint.clone())
177 .await
178 .map_err(|source| WorkerError::Connect { source })?;
179
180 Ok(Self {
181 config,
182 activity_types: Vec::new(),
183 client: Some(client),
184 sender: None,
185 receiver: None,
186 registered_info: None,
187 })
188 }
189
190 #[must_use]
192 pub fn from_channel(config: WorkerConfig, channel: Channel) -> Self {
193 Self {
194 config,
195 activity_types: Vec::new(),
196 client: Some(GeneratedClient::new(channel)),
197 sender: None,
198 receiver: None,
199 registered_info: None,
200 }
201 }
202
203 #[must_use]
206 pub const fn registered_info(&self) -> Option<&RegisteredSessionInfo> {
207 self.registered_info.as_ref()
208 }
209
210 async fn open_registered_stream(
227 &mut self,
228 register: aion_proto::generated::RegisterWorker,
229 ) -> Result<(), WorkerError> {
230 let client = self.client.as_mut().ok_or_else(|| {
231 WorkerError::registration(SessionStateError {
232 message: String::from("worker session has not completed its handshake"),
233 })
234 })?;
235 let (sender, outbound) = mpsc::channel(16);
236 sender
237 .try_send(aion_proto::generated::WorkerToServer {
238 message: Some(aion_proto::generated::worker_to_server::Message::Register(
239 register,
240 )),
241 })
242 .map_err(|_| {
243 WorkerError::registration(SessionStateError {
244 message: String::from(
245 "could not queue RegisterWorker as the first stream frame",
246 ),
247 })
248 })?;
249 let mut request = Request::new(ReceiverStream::new(outbound));
250 apply_auth_metadata(request.metadata_mut(), &self.config)?;
251 let response = client
252 .stream_worker(request)
253 .await
254 .map_err(registration_denial_error)?;
255 let mut receiver = response.into_inner();
256
257 let first = tokio::time::timeout(self.config.reconnect.max_backoff, receiver.message())
258 .await
259 .map_err(|_| {
260 WorkerError::registration(SessionStateError {
261 message: format!(
262 "server did not acknowledge registration within {:?}",
263 self.config.reconnect.max_backoff
264 ),
265 })
266 })?
267 .map_err(registration_denial_error)?;
268 let ack = match first.and_then(|frame| frame.message) {
269 Some(aion_proto::generated::server_to_worker::Message::RegisterAck(ack)) => ack,
270 Some(_) => {
271 return Err(WorkerError::decode(SessionStateError {
272 message: String::from(
273 "protocol violation: server sent a non-RegisterAck frame before \
274 acknowledging registration",
275 ),
276 }));
277 }
278 None => {
279 return Err(WorkerError::registration(SessionStateError {
280 message: String::from(
281 "server ended the stream before acknowledging registration",
282 ),
283 }));
284 }
285 };
286
287 self.registered_info = Some(RegisteredSessionInfo {
288 worker_id: ack.worker_id,
289 namespace: ack.namespace,
290 heartbeat_window: std::time::Duration::from_millis(ack.heartbeat_window_ms),
291 });
292 self.sender = Some(sender);
293 self.receiver = Some(receiver);
294 Ok(())
295 }
296
297 async fn send_to_server(
302 &self,
303 message: aion_proto::generated::worker_to_server::Message,
304 ) -> Result<(), WorkerError> {
305 let sender = self.sender.as_ref().ok_or_else(|| {
306 WorkerError::registration(SessionStateError {
307 message: String::from("worker stream has not been opened"),
308 })
309 })?;
310 let send = sender.send(aion_proto::generated::WorkerToServer {
311 message: Some(message),
312 });
313 tokio::time::timeout(self.config.reconnect.max_backoff, send)
314 .await
315 .map_err(|_| WorkerError::Transport {
316 source: tonic::Status::unavailable(format!(
317 "worker stream send did not complete within {:?}",
318 self.config.reconnect.max_backoff
319 )),
320 })?
321 .map_err(|source| WorkerError::Transport {
322 source: tonic::Status::unavailable(format!("worker stream send failed: {source}")),
323 })
324 }
325}
326
327fn registration_denial_error(status: tonic::Status) -> WorkerError {
336 if status.code() == tonic::Code::Unauthenticated {
337 WorkerError::Handshake { source: status }
338 } else {
339 WorkerError::Registration {
340 source: Box::new(status),
341 }
342 }
343}
344
345fn apply_auth_metadata(
346 metadata: &mut tonic::metadata::MetadataMap,
347 config: &WorkerConfig,
348) -> Result<(), WorkerError> {
349 let namespace =
350 MetadataValue::try_from(config.namespace.as_str()).map_err(|_| WorkerError::Handshake {
351 source: tonic::Status::invalid_argument("worker namespace is not valid gRPC metadata"),
352 })?;
353 let subject =
354 MetadataValue::try_from(config.subject.as_str()).map_err(|_| WorkerError::Handshake {
355 source: tonic::Status::invalid_argument("worker subject is not valid gRPC metadata"),
356 })?;
357 metadata.insert("x-aion-namespaces", namespace);
358 metadata.insert("x-aion-subject", subject);
359 Ok(())
360}
361
362#[async_trait]
363impl WorkerSession for GrpcWorkerSession {
364 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
365 self.config = config.clone();
366 if self.client.is_none() {
367 self.client = Some(
368 GeneratedClient::connect(self.config.endpoint.clone())
369 .await
370 .map_err(|source| WorkerError::Connect { source })?,
371 );
372 }
373 Ok(())
374 }
375
376 async fn register(
377 &mut self,
378 activity_types: Vec<String>,
379 available_handlers: &BTreeSet<String>,
380 ) -> Result<(), WorkerError> {
381 validate_activity_handlers(&activity_types, available_handlers)?;
382 self.activity_types.clone_from(&activity_types);
383
384 let register = aion_proto::generated::RegisterWorker {
385 namespace: self.config.task_queue.clone(),
386 activity_types,
387 };
388 self.open_registered_stream(register).await
389 }
390
391 fn receive_tasks(&mut self) -> WorkerTaskStream {
392 match self.receiver.take() {
393 Some(receiver) => Box::pin(receiver.filter_map(|message| async move {
394 Some(match message {
395 Ok(server_message) => decode_server_message(server_message),
396 Err(source) => Err(WorkerError::Transport { source }),
397 })
398 })),
399 None => Box::pin(futures::stream::iter([Err(WorkerError::Transport {
400 source: tonic::Status::failed_precondition(
401 "worker receive stream has not been opened",
402 ),
403 })])),
404 }
405 }
406
407 async fn report_result(
408 &mut self,
409 workflow_id: WorkflowId,
410 activity_id: ActivityId,
411 result: Payload,
412 ) -> Result<(), WorkerError> {
413 let result = ProtoActivityResult {
414 workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
415 activity_id: Some(ProtoActivityId::from(activity_id)),
416 outcome: Some(proto_activity_result::Outcome::Result(ProtoPayload::from(
417 result,
418 ))),
419 };
420 self.send_to_server(aion_proto::generated::worker_to_server::Message::Result(
421 generated_activity_result(result),
422 ))
423 .await
424 }
425
426 async fn report_failure(
427 &mut self,
428 workflow_id: WorkflowId,
429 activity_id: ActivityId,
430 failure: ActivityError,
431 ) -> Result<(), WorkerError> {
432 let result = ProtoActivityResult {
433 workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
434 activity_id: Some(ProtoActivityId::from(activity_id)),
435 outcome: Some(proto_activity_result::Outcome::Error(failure.into())),
436 };
437 self.send_to_server(aion_proto::generated::worker_to_server::Message::Result(
438 generated_activity_result(result),
439 ))
440 .await
441 }
442
443 async fn send_heartbeat(
444 &mut self,
445 workflow_id: WorkflowId,
446 activity_id: ActivityId,
447 progress: Option<Payload>,
448 ) -> Result<(), WorkerError> {
449 let heartbeat = ProtoHeartbeat {
450 workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
451 activity_id: Some(ProtoActivityId::from(activity_id)),
452 progress: progress.map(ProtoPayload::from),
453 };
454 self.send_to_server(aion_proto::generated::worker_to_server::Message::Heartbeat(
455 generated_heartbeat(heartbeat),
456 ))
457 .await
458 }
459}
460
461fn decode_server_message(
462 message: aion_proto::generated::ServerToWorker,
463) -> Result<WorkerSessionEvent, WorkerError> {
464 match message.message {
465 Some(aion_proto::generated::server_to_worker::Message::Task(task)) => {
466 Ok(WorkerSessionEvent::Task(proto_task(task)))
467 }
468 Some(aion_proto::generated::server_to_worker::Message::Drain(_)) => {
469 Ok(WorkerSessionEvent::Drain)
470 }
471 Some(aion_proto::generated::server_to_worker::Message::ResultAck(ack)) => {
472 decode_result_ack(ack)
473 }
474 Some(aion_proto::generated::server_to_worker::Message::RegisterAck(_)) => {
475 Err(WorkerError::decode(SessionStateError {
478 message: String::from(
479 "protocol violation: RegisterAck received after registration completed",
480 ),
481 }))
482 }
483 None => Err(WorkerError::decode(SessionStateError {
484 message: String::from("server-to-worker message was empty"),
485 })),
486 }
487}
488
489fn decode_result_ack(
490 ack: aion_proto::generated::ResultAck,
491) -> Result<WorkerSessionEvent, WorkerError> {
492 let workflow_id = ack
493 .workflow_id
494 .ok_or_else(|| {
495 WorkerError::decode(SessionStateError {
496 message: String::from("result ack workflow_id is missing"),
497 })
498 })
499 .and_then(|id| {
500 WorkflowId::try_from(ProtoWorkflowId { uuid: id.uuid }).map_err(|source| {
501 WorkerError::decode(SessionStateError {
502 message: format!("result ack workflow_id is invalid: {source}"),
503 })
504 })
505 })?;
506 let activity_id = ack
507 .activity_id
508 .map(|id| ActivityId::from_sequence_position(id.sequence_position))
509 .ok_or_else(|| {
510 WorkerError::decode(SessionStateError {
511 message: String::from("result ack activity_id is missing"),
512 })
513 })?;
514 Ok(WorkerSessionEvent::ResultAck {
515 workflow_id,
516 activity_id,
517 })
518}
519
520fn generated_activity_result(value: ProtoActivityResult) -> aion_proto::generated::ActivityResult {
521 aion_proto::generated::ActivityResult {
522 workflow_id: value.workflow_id.map(generated_workflow_id),
523 activity_id: value.activity_id.map(generated_activity_id),
524 outcome: value.outcome.map(|outcome| match outcome {
525 proto_activity_result::Outcome::Result(result) => {
526 aion_proto::generated::activity_result::Outcome::Result(generated_payload(result))
527 }
528 proto_activity_result::Outcome::Error(error) => {
529 aion_proto::generated::activity_result::Outcome::Error(generated_error(error))
530 }
531 }),
532 }
533}
534
535fn generated_heartbeat(value: ProtoHeartbeat) -> aion_proto::generated::Heartbeat {
536 aion_proto::generated::Heartbeat {
537 workflow_id: value.workflow_id.map(generated_workflow_id),
538 activity_id: value.activity_id.map(generated_activity_id),
539 progress: value.progress.map(generated_payload),
540 }
541}
542
543fn proto_task(value: aion_proto::generated::ActivityTask) -> ProtoActivityTask {
544 ProtoActivityTask {
545 workflow_id: value.workflow_id.map(proto_workflow_id),
546 activity_id: value.activity_id.map(proto_activity_id),
547 activity_type: value.activity_type,
548 input: value.input.map(proto_payload),
549 attempt: value.attempt,
550 }
551}
552
553fn generated_payload(value: ProtoPayload) -> aion_proto::generated::Payload {
554 aion_proto::generated::Payload {
555 content_type: value.content_type,
556 bytes: value.bytes,
557 }
558}
559
560fn proto_payload(value: aion_proto::generated::Payload) -> ProtoPayload {
561 ProtoPayload {
562 content_type: value.content_type,
563 bytes: value.bytes,
564 }
565}
566
567fn generated_workflow_id(value: ProtoWorkflowId) -> aion_proto::generated::WorkflowId {
568 aion_proto::generated::WorkflowId { uuid: value.uuid }
569}
570
571fn proto_workflow_id(value: aion_proto::generated::WorkflowId) -> ProtoWorkflowId {
572 ProtoWorkflowId { uuid: value.uuid }
573}
574
575fn generated_activity_id(value: ProtoActivityId) -> aion_proto::generated::ActivityId {
576 aion_proto::generated::ActivityId {
577 sequence_position: value.sequence_position,
578 }
579}
580
581fn proto_activity_id(value: aion_proto::generated::ActivityId) -> ProtoActivityId {
582 ProtoActivityId {
583 sequence_position: value.sequence_position,
584 }
585}
586
587fn generated_error(value: aion_proto::ProtoActivityError) -> aion_proto::generated::ActivityError {
588 aion_proto::generated::ActivityError {
589 kind: value.kind,
590 message: value.message,
591 details: value.details.map(generated_payload),
592 }
593}
594
595#[derive(thiserror::Error, Debug)]
596#[error("{message}")]
597struct SessionStateError {
598 message: String,
599}
600
601#[cfg(test)]
602mod tests {
603 use std::collections::BTreeSet;
604
605 use aion_proto::ProtoActivityTask;
606 use async_trait::async_trait;
607 use futures::{StreamExt, stream};
608
609 use super::{
610 WorkerSession, WorkerSessionEvent, WorkerTaskStream, apply_auth_metadata,
611 validate_activity_handlers,
612 };
613 use crate::error::WorkerError;
614 use crate::{ReconnectConfig, WorkerConfig};
615
616 #[derive(Default)]
617 struct FakeSession {
618 handshakes: Vec<(String, String)>,
619 registrations: Vec<Vec<String>>,
620 }
621
622 #[async_trait]
623 impl WorkerSession for FakeSession {
624 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
625 self.handshakes
626 .push((config.task_queue.clone(), config.identity.clone()));
627 Ok(())
628 }
629
630 async fn register(
631 &mut self,
632 activity_types: Vec<String>,
633 available_handlers: &BTreeSet<String>,
634 ) -> Result<(), WorkerError> {
635 validate_activity_handlers(&activity_types, available_handlers)?;
636 self.registrations.push(activity_types);
637 Ok(())
638 }
639
640 fn receive_tasks(&mut self) -> WorkerTaskStream {
641 Box::pin(stream::iter([Ok(WorkerSessionEvent::Task(
642 ProtoActivityTask {
643 workflow_id: None,
644 activity_id: None,
645 activity_type: String::from("charge-card"),
646 input: None,
647 attempt: 1,
648 },
649 ))]))
650 }
651
652 async fn report_result(
653 &mut self,
654 workflow_id: aion_core::WorkflowId,
655 activity_id: aion_core::ActivityId,
656 result: aion_core::Payload,
657 ) -> Result<(), WorkerError> {
658 drop((workflow_id, activity_id, result));
659 Ok(())
660 }
661
662 async fn report_failure(
663 &mut self,
664 workflow_id: aion_core::WorkflowId,
665 activity_id: aion_core::ActivityId,
666 failure: aion_core::ActivityError,
667 ) -> Result<(), WorkerError> {
668 drop((workflow_id, activity_id, failure));
669 Ok(())
670 }
671
672 async fn send_heartbeat(
673 &mut self,
674 workflow_id: aion_core::WorkflowId,
675 activity_id: aion_core::ActivityId,
676 progress: Option<aion_core::Payload>,
677 ) -> Result<(), WorkerError> {
678 drop((workflow_id, activity_id, progress));
679 Ok(())
680 }
681 }
682
683 #[test]
684 fn apply_auth_metadata_sets_worker_authorization_headers() -> Result<(), WorkerError> {
685 let config = WorkerConfig::builder()
686 .endpoint("http://127.0.0.1:50051")
687 .task_queue("payments")
688 .identity("worker-a")
689 .max_concurrency(4)
690 .reconnect_initial_backoff(std::time::Duration::from_millis(5))
691 .reconnect_max_backoff(std::time::Duration::from_millis(20))
692 .reconnect_max_attempts(3)
693 .namespace("payments")
694 .subject("worker-a")
695 .build()
696 .map_err(WorkerError::registration)?;
697 let mut metadata = tonic::metadata::MetadataMap::new();
698
699 apply_auth_metadata(&mut metadata, &config)?;
700
701 assert_eq!(
702 metadata
703 .get("x-aion-namespaces")
704 .and_then(|value| value.to_str().ok()),
705 Some("payments")
706 );
707 assert_eq!(
708 metadata
709 .get("x-aion-subject")
710 .and_then(|value| value.to_str().ok()),
711 Some("worker-a")
712 );
713 Ok(())
714 }
715
716 #[tokio::test]
717 async fn fake_session_records_handshake_and_registration() -> Result<(), WorkerError> {
718 let config = WorkerConfig::new(
719 "http://127.0.0.1:50051",
720 "payments",
721 "worker-a",
722 4,
723 ReconnectConfig::new(
724 std::time::Duration::from_millis(5),
725 std::time::Duration::from_millis(20),
726 3,
727 ),
728 None,
729 );
730 let activity_types = vec![String::from("charge-card"), String::from("send-email")];
731 let handlers = activity_types.iter().cloned().collect::<BTreeSet<_>>();
732 let mut session = FakeSession::default();
733
734 session.handshake(&config).await?;
735 session.register(activity_types.clone(), &handlers).await?;
736 let received = session.receive_tasks().next().await;
737
738 assert_eq!(
739 session.handshakes,
740 vec![(String::from("payments"), String::from("worker-a"))]
741 );
742 assert_eq!(session.registrations, vec![activity_types]);
743 assert!(received.is_some());
744
745 Ok(())
746 }
747
748 #[tokio::test(start_paused = true)]
752 async fn report_send_times_out_retryably_at_max_backoff() -> Result<(), WorkerError> {
753 let config = WorkerConfig::new(
754 "http://127.0.0.1:50051",
755 "payments",
756 "worker-a",
757 1,
758 ReconnectConfig::new(
759 std::time::Duration::from_millis(5),
760 std::time::Duration::from_millis(20),
761 3,
762 ),
763 None,
764 );
765 let (sender, receiver) = tokio::sync::mpsc::channel(1);
766 sender
769 .try_send(aion_proto::generated::WorkerToServer { message: None })
770 .map_err(WorkerError::decode)?;
771 let mut session = super::GrpcWorkerSession {
772 config,
773 activity_types: Vec::new(),
774 client: None,
775 sender: Some(sender),
776 receiver: None,
777 registered_info: None,
778 };
779
780 let result = session
781 .report_result(
782 aion_core::WorkflowId::new_v4(),
783 aion_core::ActivityId::from_sequence_position(1),
784 aion_core::Payload::new(aion_core::ContentType::Json, b"{}".to_vec()),
785 )
786 .await;
787
788 let Err(error) = result else {
789 return Err(WorkerError::Transport {
790 source: tonic::Status::internal("a hung send must time out, not hang"),
791 });
792 };
793 assert!(
794 matches!(error, WorkerError::Transport { .. }),
795 "send deadline elapse must be a retryable transport error: {error}"
796 );
797 assert!(error.is_retryable());
798 assert!(
799 error.to_string().contains("did not complete"),
800 "the error must name the deadline: {error}"
801 );
802 drop(receiver);
803 Ok(())
804 }
805
806 #[test]
807 fn registration_rejects_activity_without_handler() {
808 let activity_types = vec![String::from("charge-card"), String::from("send-email")];
809 let handlers = [String::from("charge-card")]
810 .into_iter()
811 .collect::<BTreeSet<_>>();
812
813 let result = validate_activity_handlers(&activity_types, &handlers);
814 assert!(result.is_err());
815 let error = match result {
816 Ok(()) => return,
817 Err(error) => error,
818 };
819
820 assert_eq!(
821 error.to_string(),
822 "worker registration failed: activity type `send-email` has no registered handler"
823 );
824 }
825}