Skip to main content

aion_worker/protocol/
task.rs

1//! `ActivityTask` decode and `TaskResult`/`TaskFailure` encode.
2
3use aion_core::{ActivityId, Payload, WorkflowId};
4use aion_proto::ProtoActivityTask;
5
6use crate::error::WorkerError;
7
8const WIRE_DEFAULT_ATTEMPT: u32 = 1;
9
10/// SDK-level activity task envelope decoded from the AW-owned worker proto.
11///
12/// The current worker wire shape does not carry an attempt field. The SDK keeps
13/// an attempt property for the worker-side API and reports the first-attempt
14/// value until AW adds an owned wire field.
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct ActivityTask {
17    /// Owning workflow id, required later when reporting this task's outcome.
18    pub workflow_id: WorkflowId,
19    /// Activity id correlating reports and heartbeats with this task.
20    pub activity_id: ActivityId,
21    /// Registered activity type name requested by the engine.
22    pub activity_type: String,
23    /// Attempt number surfaced to execution machinery.
24    pub attempt: u32,
25    /// Opaque activity input payload, preserving its content-type tag.
26    pub input: Payload,
27}
28
29impl TryFrom<ProtoActivityTask> for ActivityTask {
30    type Error = WorkerError;
31
32    fn try_from(value: ProtoActivityTask) -> Result<Self, Self::Error> {
33        let workflow_id = value
34            .workflow_id
35            .ok_or(MalformedActivityTask::MissingWorkflowId)
36            .and_then(|workflow_id| {
37                WorkflowId::try_from(workflow_id)
38                    .map_err(|source| MalformedActivityTask::InvalidWorkflowId { source })
39            })
40            .map_err(WorkerError::decode)?;
41        let activity_id = value
42            .activity_id
43            .ok_or(MalformedActivityTask::MissingActivityId)
44            .map(ActivityId::from)
45            .map_err(WorkerError::decode)?;
46        if value.activity_type.is_empty() {
47            return Err(WorkerError::decode(
48                MalformedActivityTask::MissingActivityType,
49            ));
50        }
51        let input = value
52            .input
53            .ok_or(MalformedActivityTask::MissingInput)
54            .and_then(|input| {
55                Payload::try_from(input)
56                    .map_err(|source| MalformedActivityTask::InvalidInput { source })
57            })
58            .map_err(WorkerError::decode)?;
59
60        Ok(Self {
61            workflow_id,
62            activity_id,
63            activity_type: value.activity_type,
64            attempt: WIRE_DEFAULT_ATTEMPT,
65            input,
66        })
67    }
68}
69
70#[derive(Debug, thiserror::Error)]
71enum MalformedActivityTask {
72    #[error("activity task workflow_id is missing")]
73    MissingWorkflowId,
74    #[error("activity task workflow_id is invalid: {source}")]
75    InvalidWorkflowId { source: aion_proto::WireError },
76    #[error("activity task activity_id is missing")]
77    MissingActivityId,
78    #[error("activity task activity_type is missing")]
79    MissingActivityType,
80    #[error("activity task input payload is missing")]
81    MissingInput,
82    #[error("activity task input payload is invalid: {source}")]
83    InvalidInput { source: aion_proto::WireError },
84}
85
86#[cfg(test)]
87mod tests {
88    use aion_core::{ActivityId, ContentType, Payload, WorkflowId};
89    use aion_proto::{ProtoActivityId, ProtoActivityTask, ProtoPayload, ProtoWorkflowId};
90    use serde_json::json;
91
92    use super::ActivityTask;
93    use crate::WorkerError;
94
95    #[test]
96    fn decodes_proto_activity_task_preserving_payload_content_type()
97    -> Result<(), Box<dyn std::error::Error>> {
98        let workflow_id = WorkflowId::new_v4();
99        let activity_id = ActivityId::from_sequence_position(42);
100        let input_value = json!({"amount": 1250, "currency": "USD"});
101        let input = Payload::from_json(&input_value)?;
102        let proto = ProtoActivityTask {
103            workflow_id: Some(ProtoWorkflowId::from(workflow_id.clone())),
104            activity_id: Some(ProtoActivityId::from(activity_id.clone())),
105            activity_type: String::from("charge-card"),
106            input: Some(ProtoPayload::from(input.clone())),
107        };
108
109        let task = ActivityTask::try_from(proto)?;
110
111        assert_eq!(task.workflow_id, workflow_id);
112        assert_eq!(task.activity_id, activity_id);
113        assert_eq!(task.activity_type, "charge-card");
114        assert_eq!(task.attempt, 1);
115        assert_eq!(task.input.content_type(), &ContentType::Json);
116        assert_eq!(task.input.bytes(), input.bytes());
117        assert_eq!(task.input.to_json()?, input_value);
118        Ok(())
119    }
120
121    #[test]
122    fn missing_required_field_maps_to_decode_error() {
123        let result = ActivityTask::try_from(ProtoActivityTask {
124            workflow_id: None,
125            activity_id: Some(ProtoActivityId::from(ActivityId::from_sequence_position(1))),
126            activity_type: String::from("charge-card"),
127            input: Some(ProtoPayload::from(Payload::new(
128                ContentType::Json,
129                b"{}".to_vec(),
130            ))),
131        });
132
133        assert!(matches!(result, Err(WorkerError::Decode { .. })));
134    }
135}