Skip to main content

aion_worker/protocol/
task.rs

1//! `ActivityTask` decode and `TaskResult`/`TaskFailure` encode.
2
3use aion_core::{ActivityId, Payload, WorkflowId};
4use aion_proto::ProtoActivityTask;
5
6use crate::error::WorkerError;
7
8/// SDK-level activity task envelope decoded from the AW-owned worker proto.
9#[derive(Clone, Debug, PartialEq, Eq)]
10pub struct ActivityTask {
11    /// Owning workflow id, required later when reporting this task's outcome.
12    pub workflow_id: WorkflowId,
13    /// Activity id correlating reports and heartbeats with this task.
14    pub activity_id: ActivityId,
15    /// Registered activity type name requested by the engine.
16    pub activity_type: String,
17    /// One-based delivery attempt stamped by the dispatching engine seam and
18    /// read from the wire. Zero is malformed and rejected at decode.
19    pub attempt: u32,
20    /// Opaque activity input payload, preserving its content-type tag.
21    pub input: Payload,
22}
23
24impl TryFrom<ProtoActivityTask> for ActivityTask {
25    type Error = WorkerError;
26
27    fn try_from(value: ProtoActivityTask) -> Result<Self, Self::Error> {
28        let workflow_id = value
29            .workflow_id
30            .ok_or(MalformedActivityTask::MissingWorkflowId)
31            .and_then(|workflow_id| {
32                WorkflowId::try_from(workflow_id)
33                    .map_err(|source| MalformedActivityTask::InvalidWorkflowId { source })
34            })
35            .map_err(WorkerError::decode)?;
36        let activity_id = value
37            .activity_id
38            .ok_or(MalformedActivityTask::MissingActivityId)
39            .map(ActivityId::from)
40            .map_err(WorkerError::decode)?;
41        if value.activity_type.is_empty() {
42            return Err(WorkerError::decode(
43                MalformedActivityTask::MissingActivityType,
44            ));
45        }
46        let input = value
47            .input
48            .ok_or(MalformedActivityTask::MissingInput)
49            .and_then(|input| {
50                Payload::try_from(input)
51                    .map_err(|source| MalformedActivityTask::InvalidInput { source })
52            })
53            .map_err(WorkerError::decode)?;
54
55        if value.attempt == 0 {
56            // proto3 zero default = the producer failed to stamp the attempt.
57            return Err(WorkerError::decode(MalformedActivityTask::MissingAttempt));
58        }
59
60        Ok(Self {
61            workflow_id,
62            activity_id,
63            activity_type: value.activity_type,
64            attempt: value.attempt,
65            input,
66        })
67    }
68}
69
70#[derive(Debug, thiserror::Error)]
71enum MalformedActivityTask {
72    #[error("activity task workflow_id is missing")]
73    MissingWorkflowId,
74    #[error("activity task workflow_id is invalid: {source}")]
75    InvalidWorkflowId { source: aion_proto::WireError },
76    #[error("activity task activity_id is missing")]
77    MissingActivityId,
78    #[error("activity task activity_type is missing")]
79    MissingActivityType,
80    #[error("activity task input payload is missing")]
81    MissingInput,
82    #[error("activity task attempt is missing or zero (producer failed to stamp it)")]
83    MissingAttempt,
84    #[error("activity task input payload is invalid: {source}")]
85    InvalidInput { source: aion_proto::WireError },
86}
87
88#[cfg(test)]
89mod tests {
90    use aion_core::{ActivityId, ContentType, Payload, WorkflowId};
91    use aion_proto::{ProtoActivityId, ProtoActivityTask, ProtoPayload, ProtoWorkflowId};
92    use serde_json::json;
93
94    use super::ActivityTask;
95    use crate::WorkerError;
96
97    #[test]
98    fn decodes_proto_activity_task_preserving_payload_content_type()
99    -> Result<(), Box<dyn std::error::Error>> {
100        let workflow_id = WorkflowId::new_v4();
101        let activity_id = ActivityId::from_sequence_position(42);
102        let input_value = json!({"amount": 1250, "currency": "USD"});
103        let input = Payload::from_json(&input_value)?;
104        let proto = ProtoActivityTask {
105            workflow_id: Some(ProtoWorkflowId::from(workflow_id.clone())),
106            activity_id: Some(ProtoActivityId::from(activity_id.clone())),
107            activity_type: String::from("charge-card"),
108            input: Some(ProtoPayload::from(input.clone())),
109            attempt: 3,
110        };
111
112        let task = ActivityTask::try_from(proto)?;
113
114        assert_eq!(task.workflow_id, workflow_id);
115        assert_eq!(task.activity_id, activity_id);
116        assert_eq!(task.activity_type, "charge-card");
117        assert_eq!(task.attempt, 3, "attempt must be read from the wire");
118        assert_eq!(task.input.content_type(), &ContentType::Json);
119        assert_eq!(task.input.bytes(), input.bytes());
120        assert_eq!(task.input.to_json()?, input_value);
121        Ok(())
122    }
123
124    #[test]
125    fn missing_required_field_maps_to_decode_error() {
126        let result = ActivityTask::try_from(ProtoActivityTask {
127            workflow_id: None,
128            activity_id: Some(ProtoActivityId::from(ActivityId::from_sequence_position(1))),
129            activity_type: String::from("charge-card"),
130            input: Some(ProtoPayload::from(Payload::new(
131                ContentType::Json,
132                b"{}".to_vec(),
133            ))),
134            attempt: 1,
135        });
136
137        assert!(matches!(result, Err(WorkerError::Decode { .. })));
138    }
139
140    #[test]
141    fn zero_attempt_is_a_malformed_task() {
142        let result = ActivityTask::try_from(ProtoActivityTask {
143            workflow_id: Some(ProtoWorkflowId::from(WorkflowId::new_v4())),
144            activity_id: Some(ProtoActivityId::from(ActivityId::from_sequence_position(1))),
145            activity_type: String::from("charge-card"),
146            input: Some(ProtoPayload::from(Payload::new(
147                ContentType::Json,
148                b"{}".to_vec(),
149            ))),
150            attempt: 0,
151        });
152
153        let Err(error) = result else {
154            unreachable!("attempt 0 must be rejected as malformed");
155        };
156        assert!(matches!(error, WorkerError::Decode { .. }));
157        assert!(
158            error.to_string().contains("attempt"),
159            "error must name the attempt field: {error}"
160        );
161    }
162}