1use std::collections::BTreeSet;
4use std::pin::Pin;
5
6use aion_core::{ActivityError, ActivityId, Payload, WorkflowId};
7use aion_proto::{
8 ProtoActivityId, ProtoActivityResult, ProtoActivityTask, ProtoHeartbeat, ProtoPayload,
9 ProtoWorkflowId, proto_activity_result,
10};
11use async_trait::async_trait;
12use futures::{Stream, StreamExt};
13use tokio::sync::mpsc;
14use tokio_stream::wrappers::ReceiverStream;
15use tonic::{Request, metadata::MetadataValue, transport::Channel};
16
17use crate::config::WorkerConfig;
18use crate::error::{MissingActivityHandler, WorkerError};
19
20type GeneratedClient = aion_proto::generated::worker_protocol_client::WorkerProtocolClient<Channel>;
21
22pub type WorkerTaskStream =
24 Pin<Box<dyn Stream<Item = Result<WorkerSessionEvent, WorkerError>> + Send>>;
25
26#[derive(Clone, Debug, PartialEq, Eq)]
28pub enum WorkerSessionEvent {
29 Task(ProtoActivityTask),
31 Drain,
33 Cancel {
40 workflow_id: WorkflowId,
42 activity_id: ActivityId,
44 },
45}
46
47#[async_trait]
55pub trait WorkerSession: Send {
56 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError>;
63
64 async fn register(
73 &mut self,
74 activity_types: Vec<String>,
75 available_handlers: &BTreeSet<String>,
76 ) -> Result<(), WorkerError>;
77
78 fn receive_tasks(&mut self) -> WorkerTaskStream;
80
81 async fn report_result(
83 &mut self,
84 workflow_id: WorkflowId,
85 activity_id: ActivityId,
86 result: Payload,
87 ) -> Result<(), WorkerError>;
88
89 async fn report_failure(
91 &mut self,
92 workflow_id: WorkflowId,
93 activity_id: ActivityId,
94 failure: ActivityError,
95 ) -> Result<(), WorkerError>;
96
97 async fn send_heartbeat(
99 &mut self,
100 workflow_id: WorkflowId,
101 activity_id: ActivityId,
102 progress: Option<Payload>,
103 ) -> Result<(), WorkerError>;
104}
105
106pub fn validate_activity_handlers(
112 activity_types: &[String],
113 available_handlers: &BTreeSet<String>,
114) -> Result<(), WorkerError> {
115 if let Some(activity_type) = activity_types
116 .iter()
117 .find(|activity_type| !available_handlers.contains(*activity_type))
118 {
119 return Err(WorkerError::registration(MissingActivityHandler {
120 activity_type: activity_type.clone(),
121 }));
122 }
123
124 Ok(())
125}
126
127pub struct GrpcWorkerSession {
129 config: WorkerConfig,
130 activity_types: Vec<String>,
131 client: Option<GeneratedClient>,
132 sender: Option<mpsc::Sender<aion_proto::generated::WorkerToServer>>,
133 receiver: Option<tonic::codec::Streaming<aion_proto::generated::ServerToWorker>>,
134}
135
136impl GrpcWorkerSession {
137 pub async fn connect(config: WorkerConfig) -> Result<Self, WorkerError> {
147 let client = GeneratedClient::connect(config.endpoint.clone())
148 .await
149 .map_err(|source| WorkerError::Connect { source })?;
150
151 Ok(Self {
152 config,
153 activity_types: Vec::new(),
154 client: Some(client),
155 sender: None,
156 receiver: None,
157 })
158 }
159
160 #[must_use]
162 pub fn from_channel(config: WorkerConfig, channel: Channel) -> Self {
163 Self {
164 config,
165 activity_types: Vec::new(),
166 client: Some(GeneratedClient::new(channel)),
167 sender: None,
168 receiver: None,
169 }
170 }
171
172 async fn open_registered_stream(
184 &mut self,
185 register: aion_proto::generated::RegisterWorker,
186 ) -> Result<(), WorkerError> {
187 let client = self.client.as_mut().ok_or_else(|| {
188 WorkerError::registration(SessionStateError {
189 message: String::from("worker session has not completed its handshake"),
190 })
191 })?;
192 let (sender, outbound) = mpsc::channel(16);
193 sender
194 .try_send(aion_proto::generated::WorkerToServer {
195 message: Some(aion_proto::generated::worker_to_server::Message::Register(
196 register,
197 )),
198 })
199 .map_err(|_| {
200 WorkerError::registration(SessionStateError {
201 message: String::from(
202 "could not queue RegisterWorker as the first stream frame",
203 ),
204 })
205 })?;
206 let mut request = Request::new(ReceiverStream::new(outbound));
207 apply_auth_metadata(request.metadata_mut(), &self.config)?;
208 let response = client
209 .stream_worker(request)
210 .await
211 .map_err(registration_denial_error)?;
212
213 self.sender = Some(sender);
214 self.receiver = Some(response.into_inner());
215 Ok(())
216 }
217
218 async fn send_to_server(
219 &self,
220 message: aion_proto::generated::worker_to_server::Message,
221 ) -> Result<(), WorkerError> {
222 let sender = self.sender.as_ref().ok_or_else(|| {
223 WorkerError::registration(SessionStateError {
224 message: String::from("worker stream has not been opened"),
225 })
226 })?;
227 sender
228 .send(aion_proto::generated::WorkerToServer {
229 message: Some(message),
230 })
231 .await
232 .map_err(|source| WorkerError::Transport {
233 source: tonic::Status::unavailable(format!("worker stream send failed: {source}")),
234 })
235 }
236}
237
238fn registration_denial_error(status: tonic::Status) -> WorkerError {
247 if status.code() == tonic::Code::Unauthenticated {
248 WorkerError::Handshake { source: status }
249 } else {
250 WorkerError::Registration {
251 source: Box::new(status),
252 }
253 }
254}
255
256fn apply_auth_metadata(
257 metadata: &mut tonic::metadata::MetadataMap,
258 config: &WorkerConfig,
259) -> Result<(), WorkerError> {
260 let namespace =
261 MetadataValue::try_from(config.namespace.as_str()).map_err(|_| WorkerError::Handshake {
262 source: tonic::Status::invalid_argument("worker namespace is not valid gRPC metadata"),
263 })?;
264 let subject =
265 MetadataValue::try_from(config.subject.as_str()).map_err(|_| WorkerError::Handshake {
266 source: tonic::Status::invalid_argument("worker subject is not valid gRPC metadata"),
267 })?;
268 metadata.insert("x-aion-namespaces", namespace);
269 metadata.insert("x-aion-subject", subject);
270 Ok(())
271}
272
273#[async_trait]
274impl WorkerSession for GrpcWorkerSession {
275 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
276 self.config = config.clone();
277 if self.client.is_none() {
278 self.client = Some(
279 GeneratedClient::connect(self.config.endpoint.clone())
280 .await
281 .map_err(|source| WorkerError::Connect { source })?,
282 );
283 }
284 Ok(())
285 }
286
287 async fn register(
288 &mut self,
289 activity_types: Vec<String>,
290 available_handlers: &BTreeSet<String>,
291 ) -> Result<(), WorkerError> {
292 validate_activity_handlers(&activity_types, available_handlers)?;
293 self.activity_types.clone_from(&activity_types);
294
295 let register = aion_proto::generated::RegisterWorker {
296 namespace: self.config.task_queue.clone(),
297 activity_types,
298 };
299 self.open_registered_stream(register).await
300 }
301
302 fn receive_tasks(&mut self) -> WorkerTaskStream {
303 match self.receiver.take() {
304 Some(receiver) => Box::pin(receiver.filter_map(|message| async move {
305 Some(match message {
306 Ok(server_message) => decode_server_message(server_message),
307 Err(source) => Err(WorkerError::Transport { source }),
308 })
309 })),
310 None => Box::pin(futures::stream::iter([Err(WorkerError::Transport {
311 source: tonic::Status::failed_precondition(
312 "worker receive stream has not been opened",
313 ),
314 })])),
315 }
316 }
317
318 async fn report_result(
319 &mut self,
320 workflow_id: WorkflowId,
321 activity_id: ActivityId,
322 result: Payload,
323 ) -> Result<(), WorkerError> {
324 let result = ProtoActivityResult {
325 workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
326 activity_id: Some(ProtoActivityId::from(activity_id)),
327 outcome: Some(proto_activity_result::Outcome::Result(ProtoPayload::from(
328 result,
329 ))),
330 };
331 self.send_to_server(aion_proto::generated::worker_to_server::Message::Result(
332 generated_activity_result(result),
333 ))
334 .await
335 }
336
337 async fn report_failure(
338 &mut self,
339 workflow_id: WorkflowId,
340 activity_id: ActivityId,
341 failure: ActivityError,
342 ) -> Result<(), WorkerError> {
343 let result = ProtoActivityResult {
344 workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
345 activity_id: Some(ProtoActivityId::from(activity_id)),
346 outcome: Some(proto_activity_result::Outcome::Error(failure.into())),
347 };
348 self.send_to_server(aion_proto::generated::worker_to_server::Message::Result(
349 generated_activity_result(result),
350 ))
351 .await
352 }
353
354 async fn send_heartbeat(
355 &mut self,
356 workflow_id: WorkflowId,
357 activity_id: ActivityId,
358 progress: Option<Payload>,
359 ) -> Result<(), WorkerError> {
360 let heartbeat = ProtoHeartbeat {
361 workflow_id: Some(ProtoWorkflowId::from(workflow_id)),
362 activity_id: Some(ProtoActivityId::from(activity_id)),
363 progress: progress.map(ProtoPayload::from),
364 };
365 self.send_to_server(aion_proto::generated::worker_to_server::Message::Heartbeat(
366 generated_heartbeat(heartbeat),
367 ))
368 .await
369 }
370}
371
372fn decode_server_message(
373 message: aion_proto::generated::ServerToWorker,
374) -> Result<WorkerSessionEvent, WorkerError> {
375 match message.message {
376 Some(aion_proto::generated::server_to_worker::Message::Task(task)) => {
377 Ok(WorkerSessionEvent::Task(proto_task(task)))
378 }
379 Some(aion_proto::generated::server_to_worker::Message::Drain(_)) => {
380 Ok(WorkerSessionEvent::Drain)
381 }
382 None => Err(WorkerError::decode(SessionStateError {
383 message: String::from("server-to-worker message was empty"),
384 })),
385 }
386}
387
388fn generated_activity_result(value: ProtoActivityResult) -> aion_proto::generated::ActivityResult {
389 aion_proto::generated::ActivityResult {
390 workflow_id: value.workflow_id.map(generated_workflow_id),
391 activity_id: value.activity_id.map(generated_activity_id),
392 outcome: value.outcome.map(|outcome| match outcome {
393 proto_activity_result::Outcome::Result(result) => {
394 aion_proto::generated::activity_result::Outcome::Result(generated_payload(result))
395 }
396 proto_activity_result::Outcome::Error(error) => {
397 aion_proto::generated::activity_result::Outcome::Error(generated_error(error))
398 }
399 }),
400 }
401}
402
403fn generated_heartbeat(value: ProtoHeartbeat) -> aion_proto::generated::Heartbeat {
404 aion_proto::generated::Heartbeat {
405 workflow_id: value.workflow_id.map(generated_workflow_id),
406 activity_id: value.activity_id.map(generated_activity_id),
407 progress: value.progress.map(generated_payload),
408 }
409}
410
411fn proto_task(value: aion_proto::generated::ActivityTask) -> ProtoActivityTask {
412 ProtoActivityTask {
413 workflow_id: value.workflow_id.map(proto_workflow_id),
414 activity_id: value.activity_id.map(proto_activity_id),
415 activity_type: value.activity_type,
416 input: value.input.map(proto_payload),
417 }
418}
419
420fn generated_payload(value: ProtoPayload) -> aion_proto::generated::Payload {
421 aion_proto::generated::Payload {
422 content_type: value.content_type,
423 bytes: value.bytes,
424 }
425}
426
427fn proto_payload(value: aion_proto::generated::Payload) -> ProtoPayload {
428 ProtoPayload {
429 content_type: value.content_type,
430 bytes: value.bytes,
431 }
432}
433
434fn generated_workflow_id(value: ProtoWorkflowId) -> aion_proto::generated::WorkflowId {
435 aion_proto::generated::WorkflowId { uuid: value.uuid }
436}
437
438fn proto_workflow_id(value: aion_proto::generated::WorkflowId) -> ProtoWorkflowId {
439 ProtoWorkflowId { uuid: value.uuid }
440}
441
442fn generated_activity_id(value: ProtoActivityId) -> aion_proto::generated::ActivityId {
443 aion_proto::generated::ActivityId {
444 sequence_position: value.sequence_position,
445 }
446}
447
448fn proto_activity_id(value: aion_proto::generated::ActivityId) -> ProtoActivityId {
449 ProtoActivityId {
450 sequence_position: value.sequence_position,
451 }
452}
453
454fn generated_error(value: aion_proto::ProtoActivityError) -> aion_proto::generated::ActivityError {
455 aion_proto::generated::ActivityError {
456 kind: value.kind,
457 message: value.message,
458 details: value.details.map(generated_payload),
459 }
460}
461
462#[derive(thiserror::Error, Debug)]
463#[error("{message}")]
464struct SessionStateError {
465 message: String,
466}
467
468#[cfg(test)]
469mod tests {
470 use std::collections::BTreeSet;
471
472 use aion_proto::ProtoActivityTask;
473 use async_trait::async_trait;
474 use futures::{StreamExt, stream};
475
476 use super::{
477 WorkerSession, WorkerSessionEvent, WorkerTaskStream, apply_auth_metadata,
478 validate_activity_handlers,
479 };
480 use crate::error::WorkerError;
481 use crate::{ReconnectConfig, WorkerConfig};
482
483 #[derive(Default)]
484 struct FakeSession {
485 handshakes: Vec<(String, String)>,
486 registrations: Vec<Vec<String>>,
487 }
488
489 #[async_trait]
490 impl WorkerSession for FakeSession {
491 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
492 self.handshakes
493 .push((config.task_queue.clone(), config.identity.clone()));
494 Ok(())
495 }
496
497 async fn register(
498 &mut self,
499 activity_types: Vec<String>,
500 available_handlers: &BTreeSet<String>,
501 ) -> Result<(), WorkerError> {
502 validate_activity_handlers(&activity_types, available_handlers)?;
503 self.registrations.push(activity_types);
504 Ok(())
505 }
506
507 fn receive_tasks(&mut self) -> WorkerTaskStream {
508 Box::pin(stream::iter([Ok(WorkerSessionEvent::Task(
509 ProtoActivityTask {
510 workflow_id: None,
511 activity_id: None,
512 activity_type: String::from("charge-card"),
513 input: None,
514 },
515 ))]))
516 }
517
518 async fn report_result(
519 &mut self,
520 workflow_id: aion_core::WorkflowId,
521 activity_id: aion_core::ActivityId,
522 result: aion_core::Payload,
523 ) -> Result<(), WorkerError> {
524 drop((workflow_id, activity_id, result));
525 Ok(())
526 }
527
528 async fn report_failure(
529 &mut self,
530 workflow_id: aion_core::WorkflowId,
531 activity_id: aion_core::ActivityId,
532 failure: aion_core::ActivityError,
533 ) -> Result<(), WorkerError> {
534 drop((workflow_id, activity_id, failure));
535 Ok(())
536 }
537
538 async fn send_heartbeat(
539 &mut self,
540 workflow_id: aion_core::WorkflowId,
541 activity_id: aion_core::ActivityId,
542 progress: Option<aion_core::Payload>,
543 ) -> Result<(), WorkerError> {
544 drop((workflow_id, activity_id, progress));
545 Ok(())
546 }
547 }
548
549 #[test]
550 fn apply_auth_metadata_sets_worker_authorization_headers() -> Result<(), WorkerError> {
551 let config = WorkerConfig::builder()
552 .endpoint("http://127.0.0.1:50051")
553 .task_queue("payments")
554 .identity("worker-a")
555 .max_concurrency(4)
556 .reconnect_initial_backoff(std::time::Duration::from_millis(5))
557 .reconnect_max_backoff(std::time::Duration::from_millis(20))
558 .reconnect_max_attempts(3)
559 .namespace("payments")
560 .subject("worker-a")
561 .build()
562 .map_err(WorkerError::registration)?;
563 let mut metadata = tonic::metadata::MetadataMap::new();
564
565 apply_auth_metadata(&mut metadata, &config)?;
566
567 assert_eq!(
568 metadata
569 .get("x-aion-namespaces")
570 .and_then(|value| value.to_str().ok()),
571 Some("payments")
572 );
573 assert_eq!(
574 metadata
575 .get("x-aion-subject")
576 .and_then(|value| value.to_str().ok()),
577 Some("worker-a")
578 );
579 Ok(())
580 }
581
582 #[tokio::test]
583 async fn fake_session_records_handshake_and_registration() -> Result<(), WorkerError> {
584 let config = WorkerConfig::new(
585 "http://127.0.0.1:50051",
586 "payments",
587 "worker-a",
588 4,
589 ReconnectConfig::new(
590 std::time::Duration::from_millis(5),
591 std::time::Duration::from_millis(20),
592 3,
593 ),
594 None,
595 );
596 let activity_types = vec![String::from("charge-card"), String::from("send-email")];
597 let handlers = activity_types.iter().cloned().collect::<BTreeSet<_>>();
598 let mut session = FakeSession::default();
599
600 session.handshake(&config).await?;
601 session.register(activity_types.clone(), &handlers).await?;
602 let received = session.receive_tasks().next().await;
603
604 assert_eq!(
605 session.handshakes,
606 vec![(String::from("payments"), String::from("worker-a"))]
607 );
608 assert_eq!(session.registrations, vec![activity_types]);
609 assert!(received.is_some());
610
611 Ok(())
612 }
613
614 #[test]
615 fn registration_rejects_activity_without_handler() {
616 let activity_types = vec![String::from("charge-card"), String::from("send-email")];
617 let handlers = [String::from("charge-card")]
618 .into_iter()
619 .collect::<BTreeSet<_>>();
620
621 let result = validate_activity_handlers(&activity_types, &handlers);
622 assert!(result.is_err());
623 let error = match result {
624 Ok(()) => return,
625 Err(error) => error,
626 };
627
628 assert_eq!(
629 error.to_string(),
630 "worker registration failed: activity type `send-email` has no registered handler"
631 );
632 }
633}