Skip to main content

aion_proto/
worker.rs

1//! Worker protocol serde/prost wire types.
2
3use crate::{ProtoActivityId, ProtoPayload, ProtoWorkflowId, WireError};
4
5/// Proto representation of `ActivityErrorKind`. Zero is invalid on decode.
6#[derive(
7    Clone,
8    Copy,
9    Debug,
10    PartialEq,
11    Eq,
12    Hash,
13    serde::Serialize,
14    serde::Deserialize,
15    prost::Enumeration,
16)]
17#[repr(i32)]
18pub enum ProtoActivityErrorKind {
19    /// Missing/invalid kind.
20    Unspecified = 0,
21    /// Activity failure may be retried by the engine.
22    Retryable = 1,
23    /// Activity failure is terminal.
24    Terminal = 2,
25}
26
27/// Proto representation of `ActivityError`.
28#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
29pub struct ProtoActivityError {
30    /// Explicit retryability classification.
31    #[prost(enumeration = "ProtoActivityErrorKind", tag = "1")]
32    pub kind: i32,
33    /// Human-readable error message.
34    #[prost(string, tag = "2")]
35    pub message: String,
36    /// Optional structured failure details.
37    #[prost(message, optional, tag = "3")]
38    pub details: Option<ProtoPayload>,
39}
40
41/// Worker registration advertisement.
42#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
43pub struct ProtoRegisterWorker {
44    /// Namespace that scopes this worker stream.
45    #[prost(string, tag = "1")]
46    pub namespace: String,
47    /// Activity types implemented by the worker, preserving wire order.
48    #[prost(string, repeated, tag = "2")]
49    pub activity_types: Vec<String>,
50}
51
52/// Activity invocation pushed to a worker.
53#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
54pub struct ProtoActivityTask {
55    /// Owning workflow id.
56    #[prost(message, optional, tag = "1")]
57    pub workflow_id: Option<ProtoWorkflowId>,
58    /// Correlating activity id.
59    #[prost(message, optional, tag = "2")]
60    pub activity_id: Option<ProtoActivityId>,
61    /// Activity type name.
62    #[prost(string, tag = "3")]
63    pub activity_type: String,
64    /// Serialized activity input.
65    #[prost(message, optional, tag = "4")]
66    pub input: Option<ProtoPayload>,
67    /// One-based delivery attempt stamped by the dispatching engine seam.
68    /// Zero is malformed: consumers reject a task whose attempt is 0 (the
69    /// proto3 default means the producer failed to stamp it).
70    #[prost(uint32, tag = "5")]
71    pub attempt: u32,
72}
73
74/// Server-initiated drain: the server is going away (restart, deploy,
75/// rebalance). The worker finishes already-assigned work, stops expecting
76/// new tasks, and reconnects after the schedule's initial backoff. A drain
77/// frame re-classifies the session's eventual stream end (clean or abrupt)
78/// as a drain-class drop that consumes no drop budget — distinct from denial
79/// (gRPC error status, terminal) and from an unannounced close (budgeted
80/// retryable drop).
81#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
82pub struct ProtoDrainRequest {}
83
84/// Positive registration acknowledgement — always the first frame on the
85/// response stream. There is no negative counterpart: a denied or invalid
86/// registration fails the RPC with a gRPC error status.
87#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
88pub struct ProtoRegisterAck {
89    /// Server-assigned stream identifier for this registration.
90    #[prost(uint64, tag = "1")]
91    pub worker_id: u64,
92    /// The namespace the registration was authorized against.
93    #[prost(string, tag = "2")]
94    pub namespace: String,
95    /// Operator-configured liveness window on this server, in milliseconds.
96    #[prost(uint64, tag = "3")]
97    pub heartbeat_window_ms: u64,
98}
99
100/// Per-result acknowledgement: the server has consumed the identified
101/// `ActivityResult` frame and the worker may stop re-reporting it. Not a
102/// durability receipt — the durable truth is the workflow's event history.
103#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
104pub struct ProtoResultAck {
105    /// Owning workflow id.
106    #[prost(message, optional, tag = "1")]
107    pub workflow_id: Option<ProtoWorkflowId>,
108    /// Correlating activity id.
109    #[prost(message, optional, tag = "2")]
110    pub activity_id: Option<ProtoActivityId>,
111}
112
113/// Activity result or failure reported by a worker.
114#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
115pub struct ProtoActivityResult {
116    /// Owning workflow id.
117    #[prost(message, optional, tag = "1")]
118    pub workflow_id: Option<ProtoWorkflowId>,
119    /// Correlating activity id.
120    #[prost(message, optional, tag = "2")]
121    pub activity_id: Option<ProtoActivityId>,
122    /// Successful result payload or explicit activity error.
123    #[prost(oneof = "proto_activity_result::Outcome", tags = "3, 4")]
124    pub outcome: Option<proto_activity_result::Outcome>,
125}
126
127/// Types nested under [`ProtoActivityResult`].
128pub mod proto_activity_result {
129    use super::{ProtoActivityError, ProtoPayload};
130
131    /// Proto oneof for activity success or failure.
132    #[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Oneof)]
133    pub enum Outcome {
134        /// Successful activity output.
135        #[prost(message, tag = "3")]
136        Result(ProtoPayload),
137        /// Activity failure preserving retryability classification.
138        #[prost(message, tag = "4")]
139        Error(ProtoActivityError),
140    }
141}
142
143/// Worker heartbeat for an in-flight activity.
144#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
145pub struct ProtoHeartbeat {
146    /// Owning workflow id.
147    #[prost(message, optional, tag = "1")]
148    pub workflow_id: Option<ProtoWorkflowId>,
149    /// Correlating activity id.
150    #[prost(message, optional, tag = "2")]
151    pub activity_id: Option<ProtoActivityId>,
152    /// Optional opaque progress payload.
153    #[prost(message, optional, tag = "3")]
154    pub progress: Option<ProtoPayload>,
155}
156
157impl From<aion_core::ActivityErrorKind> for ProtoActivityErrorKind {
158    fn from(value: aion_core::ActivityErrorKind) -> Self {
159        match value {
160            aion_core::ActivityErrorKind::Retryable => Self::Retryable,
161            aion_core::ActivityErrorKind::Terminal => Self::Terminal,
162        }
163    }
164}
165
166impl TryFrom<ProtoActivityErrorKind> for aion_core::ActivityErrorKind {
167    type Error = WireError;
168
169    fn try_from(value: ProtoActivityErrorKind) -> Result<Self, Self::Error> {
170        match value {
171            ProtoActivityErrorKind::Unspecified => {
172                Err(WireError::backend("activity error kind is missing"))
173            }
174            ProtoActivityErrorKind::Retryable => Ok(Self::Retryable),
175            ProtoActivityErrorKind::Terminal => Ok(Self::Terminal),
176        }
177    }
178}
179
180impl From<aion_core::ActivityError> for ProtoActivityError {
181    fn from(value: aion_core::ActivityError) -> Self {
182        Self {
183            kind: ProtoActivityErrorKind::from(value.kind) as i32,
184            message: value.message,
185            details: value.details.map(ProtoPayload::from),
186        }
187    }
188}
189
190impl TryFrom<ProtoActivityError> for aion_core::ActivityError {
191    type Error = WireError;
192
193    fn try_from(value: ProtoActivityError) -> Result<Self, Self::Error> {
194        let kind = ProtoActivityErrorKind::try_from(value.kind)
195            .map_err(|_| WireError::backend("activity error kind is unknown"))?;
196        Ok(Self {
197            kind: aion_core::ActivityErrorKind::try_from(kind)?,
198            message: value.message,
199            details: value
200                .details
201                .map(aion_core::Payload::try_from)
202                .transpose()?,
203        })
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use prost::Message;
210    use serde_json::json;
211
212    use super::{
213        ProtoActivityError, ProtoActivityErrorKind, ProtoActivityResult, ProtoActivityTask,
214        ProtoDrainRequest, ProtoHeartbeat, ProtoRegisterAck, ProtoRegisterWorker, ProtoResultAck,
215        proto_activity_result,
216    };
217    use crate::{ProtoActivityId, ProtoPayload, ProtoWorkflowId, WireError};
218
219    fn workflow_id() -> aion_core::WorkflowId {
220        aion_core::WorkflowId::new(uuid::Uuid::nil())
221    }
222
223    #[test]
224    fn activity_error_round_trips_preserving_classification() -> Result<(), WireError> {
225        let core = aion_core::ActivityError {
226            kind: aion_core::ActivityErrorKind::Retryable,
227            message: String::from("connection reset"),
228            details: Some(
229                aion_core::Payload::from_json(&json!({"retry_after_ms": 500}))
230                    .map_err(|_| WireError::backend("test payload could not be created"))?,
231            ),
232        };
233
234        let proto = ProtoActivityError::from(core.clone());
235        assert_eq!(aion_core::ActivityError::try_from(proto.clone())?, core);
236        assert!(aion_core::ActivityError::try_from(proto)?.is_retryable());
237
238        let terminal = ProtoActivityError {
239            kind: ProtoActivityErrorKind::Terminal as i32,
240            message: String::from("invalid request"),
241            details: None,
242        };
243        assert!(!aion_core::ActivityError::try_from(terminal)?.is_retryable());
244
245        Ok(())
246    }
247
248    #[test]
249    fn worker_registration_round_trips_through_serde_and_proto()
250    -> Result<(), Box<dyn std::error::Error>> {
251        let registration = ProtoRegisterWorker {
252            namespace: String::from("tenant-a"),
253            activity_types: vec![String::from("charge-card"), String::from("send-email")],
254        };
255
256        assert_json_and_proto_round_trip(&registration)
257    }
258
259    #[test]
260    fn activity_task_round_trips_through_serde_and_proto() -> Result<(), Box<dyn std::error::Error>>
261    {
262        let task = ProtoActivityTask {
263            workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
264            activity_id: Some(ProtoActivityId::from(
265                aion_core::ActivityId::from_sequence_position(7),
266            )),
267            activity_type: String::from("charge-card"),
268            input: Some(ProtoPayload::from(aion_core::Payload::from_json(
269                &json!({"amount": 42}),
270            )?)),
271            attempt: 3,
272        };
273
274        assert_json_and_proto_round_trip(&task)
275    }
276
277    #[test]
278    fn drain_request_round_trips_through_serde_and_proto() -> Result<(), Box<dyn std::error::Error>>
279    {
280        assert_json_and_proto_round_trip(&ProtoDrainRequest {})
281    }
282
283    #[test]
284    fn register_ack_round_trips_through_serde_and_proto() -> Result<(), Box<dyn std::error::Error>>
285    {
286        let ack = ProtoRegisterAck {
287            worker_id: 7,
288            namespace: String::from("tenant-a"),
289            heartbeat_window_ms: 30_000,
290        };
291
292        assert_json_and_proto_round_trip(&ack)
293    }
294
295    #[test]
296    fn result_ack_round_trips_through_serde_and_proto() -> Result<(), Box<dyn std::error::Error>> {
297        let ack = ProtoResultAck {
298            workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
299            activity_id: Some(ProtoActivityId::from(
300                aion_core::ActivityId::from_sequence_position(11),
301            )),
302        };
303
304        assert_json_and_proto_round_trip(&ack)
305    }
306
307    #[cfg(feature = "generated")]
308    #[test]
309    fn server_to_worker_ack_arms_pin_oneof_tags_three_and_four()
310    -> Result<(), Box<dyn std::error::Error>> {
311        // Pins the new ServerToWorker oneof arms to wire tags 3 (register_ack)
312        // and 4 (result_ack): field key = (tag << 3) | 2 (length-delimited).
313        let register_ack = crate::generated::ServerToWorker {
314            message: Some(crate::generated::server_to_worker::Message::RegisterAck(
315                crate::generated::RegisterAck {
316                    worker_id: 1,
317                    namespace: String::from("tenant-a"),
318                    heartbeat_window_ms: 1_000,
319                },
320            )),
321        };
322        let mut bytes = Vec::new();
323        register_ack.encode(&mut bytes)?;
324        assert_eq!(bytes.first(), Some(&0x1A));
325        assert_eq!(
326            crate::generated::ServerToWorker::decode(bytes.as_slice())?,
327            register_ack
328        );
329
330        let result_ack = crate::generated::ServerToWorker {
331            message: Some(crate::generated::server_to_worker::Message::ResultAck(
332                crate::generated::ResultAck {
333                    workflow_id: None,
334                    activity_id: None,
335                },
336            )),
337        };
338        let mut bytes = Vec::new();
339        result_ack.encode(&mut bytes)?;
340        assert_eq!(bytes.first(), Some(&0x22));
341        assert_eq!(
342            crate::generated::ServerToWorker::decode(bytes.as_slice())?,
343            result_ack
344        );
345        Ok(())
346    }
347
348    #[test]
349    fn activity_task_attempt_uses_wire_tag_five() -> Result<(), Box<dyn std::error::Error>> {
350        // Pins the attempt field to proto tag 5 (field key 0x28 = tag 5,
351        // varint wire type) so the hand-written SDK stubs cannot drift.
352        let task = ProtoActivityTask {
353            workflow_id: None,
354            activity_id: None,
355            activity_type: String::new(),
356            input: None,
357            attempt: 9,
358        };
359        let mut bytes = Vec::new();
360        task.encode(&mut bytes)?;
361        assert_eq!(bytes, vec![0x28, 9]);
362        Ok(())
363    }
364
365    #[test]
366    fn activity_success_result_round_trips_through_serde_and_proto()
367    -> Result<(), Box<dyn std::error::Error>> {
368        let result = ProtoActivityResult {
369            workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
370            activity_id: Some(ProtoActivityId::from(
371                aion_core::ActivityId::from_sequence_position(8),
372            )),
373            outcome: Some(proto_activity_result::Outcome::Result(ProtoPayload::from(
374                aion_core::Payload::from_json(&json!({"authorization": "ok"}))?,
375            ))),
376        };
377
378        assert_json_and_proto_round_trip(&result)
379    }
380
381    #[test]
382    fn activity_error_result_round_trips_through_serde_and_proto()
383    -> Result<(), Box<dyn std::error::Error>> {
384        let result = ProtoActivityResult {
385            workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
386            activity_id: Some(ProtoActivityId::from(
387                aion_core::ActivityId::from_sequence_position(9),
388            )),
389            outcome: Some(proto_activity_result::Outcome::Error(
390                ProtoActivityError::from(aion_core::ActivityError {
391                    kind: aion_core::ActivityErrorKind::Terminal,
392                    message: String::from("card declined"),
393                    details: Some(aion_core::Payload::from_json(&json!({"code": "declined"}))?),
394                }),
395            )),
396        };
397
398        assert_json_and_proto_round_trip(&result)
399    }
400
401    #[test]
402    fn heartbeat_round_trips_through_serde_and_proto() -> Result<(), Box<dyn std::error::Error>> {
403        let heartbeat = ProtoHeartbeat {
404            workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
405            activity_id: Some(ProtoActivityId::from(
406                aion_core::ActivityId::from_sequence_position(10),
407            )),
408            progress: Some(ProtoPayload::from(aion_core::Payload::from_json(
409                &json!({"percent": 50}),
410            )?)),
411        };
412
413        assert_json_and_proto_round_trip(&heartbeat)
414    }
415
416    fn assert_json_and_proto_round_trip<T>(value: &T) -> Result<(), Box<dyn std::error::Error>>
417    where
418        T: Message
419            + Default
420            + serde::Serialize
421            + serde::de::DeserializeOwned
422            + PartialEq
423            + std::fmt::Debug,
424    {
425        assert_eq!(
426            serde_json::from_str::<T>(&serde_json::to_string(value)?)?,
427            *value
428        );
429        assert_eq!(prost_round_trip(value)?, *value);
430        Ok(())
431    }
432
433    fn prost_round_trip<T>(value: &T) -> Result<T, Box<dyn std::error::Error>>
434    where
435        T: Message + Default,
436    {
437        let mut bytes = Vec::new();
438        value.encode(&mut bytes)?;
439        Ok(T::decode(bytes.as_slice())?)
440    }
441}