1use crate::{ProtoActivityId, ProtoPayload, ProtoWorkflowId, WireError};
4
5#[derive(
7 Clone,
8 Copy,
9 Debug,
10 PartialEq,
11 Eq,
12 Hash,
13 serde::Serialize,
14 serde::Deserialize,
15 prost::Enumeration,
16)]
17#[repr(i32)]
18pub enum ProtoActivityErrorKind {
19 Unspecified = 0,
21 Retryable = 1,
23 Terminal = 2,
25}
26
27#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
29pub struct ProtoActivityError {
30 #[prost(enumeration = "ProtoActivityErrorKind", tag = "1")]
32 pub kind: i32,
33 #[prost(string, tag = "2")]
35 pub message: String,
36 #[prost(message, optional, tag = "3")]
38 pub details: Option<ProtoPayload>,
39}
40
41#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
43pub struct ProtoRegisterWorker {
44 #[prost(string, tag = "1")]
46 pub namespace: String,
47 #[prost(string, repeated, tag = "2")]
49 pub activity_types: Vec<String>,
50}
51
52#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
54pub struct ProtoActivityTask {
55 #[prost(message, optional, tag = "1")]
57 pub workflow_id: Option<ProtoWorkflowId>,
58 #[prost(message, optional, tag = "2")]
60 pub activity_id: Option<ProtoActivityId>,
61 #[prost(string, tag = "3")]
63 pub activity_type: String,
64 #[prost(message, optional, tag = "4")]
66 pub input: Option<ProtoPayload>,
67 #[prost(uint32, tag = "5")]
71 pub attempt: u32,
72}
73
74#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
82pub struct ProtoDrainRequest {}
83
84#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
88pub struct ProtoRegisterAck {
89 #[prost(uint64, tag = "1")]
91 pub worker_id: u64,
92 #[prost(string, tag = "2")]
94 pub namespace: String,
95 #[prost(uint64, tag = "3")]
97 pub heartbeat_window_ms: u64,
98}
99
100#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
104pub struct ProtoResultAck {
105 #[prost(message, optional, tag = "1")]
107 pub workflow_id: Option<ProtoWorkflowId>,
108 #[prost(message, optional, tag = "2")]
110 pub activity_id: Option<ProtoActivityId>,
111}
112
113#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
115pub struct ProtoActivityResult {
116 #[prost(message, optional, tag = "1")]
118 pub workflow_id: Option<ProtoWorkflowId>,
119 #[prost(message, optional, tag = "2")]
121 pub activity_id: Option<ProtoActivityId>,
122 #[prost(oneof = "proto_activity_result::Outcome", tags = "3, 4")]
124 pub outcome: Option<proto_activity_result::Outcome>,
125}
126
127pub mod proto_activity_result {
129 use super::{ProtoActivityError, ProtoPayload};
130
131 #[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Oneof)]
133 pub enum Outcome {
134 #[prost(message, tag = "3")]
136 Result(ProtoPayload),
137 #[prost(message, tag = "4")]
139 Error(ProtoActivityError),
140 }
141}
142
143#[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, prost::Message)]
145pub struct ProtoHeartbeat {
146 #[prost(message, optional, tag = "1")]
148 pub workflow_id: Option<ProtoWorkflowId>,
149 #[prost(message, optional, tag = "2")]
151 pub activity_id: Option<ProtoActivityId>,
152 #[prost(message, optional, tag = "3")]
154 pub progress: Option<ProtoPayload>,
155}
156
157impl From<aion_core::ActivityErrorKind> for ProtoActivityErrorKind {
158 fn from(value: aion_core::ActivityErrorKind) -> Self {
159 match value {
160 aion_core::ActivityErrorKind::Retryable => Self::Retryable,
161 aion_core::ActivityErrorKind::Terminal => Self::Terminal,
162 }
163 }
164}
165
166impl TryFrom<ProtoActivityErrorKind> for aion_core::ActivityErrorKind {
167 type Error = WireError;
168
169 fn try_from(value: ProtoActivityErrorKind) -> Result<Self, Self::Error> {
170 match value {
171 ProtoActivityErrorKind::Unspecified => {
172 Err(WireError::backend("activity error kind is missing"))
173 }
174 ProtoActivityErrorKind::Retryable => Ok(Self::Retryable),
175 ProtoActivityErrorKind::Terminal => Ok(Self::Terminal),
176 }
177 }
178}
179
180impl From<aion_core::ActivityError> for ProtoActivityError {
181 fn from(value: aion_core::ActivityError) -> Self {
182 Self {
183 kind: ProtoActivityErrorKind::from(value.kind) as i32,
184 message: value.message,
185 details: value.details.map(ProtoPayload::from),
186 }
187 }
188}
189
190impl TryFrom<ProtoActivityError> for aion_core::ActivityError {
191 type Error = WireError;
192
193 fn try_from(value: ProtoActivityError) -> Result<Self, Self::Error> {
194 let kind = ProtoActivityErrorKind::try_from(value.kind)
195 .map_err(|_| WireError::backend("activity error kind is unknown"))?;
196 Ok(Self {
197 kind: aion_core::ActivityErrorKind::try_from(kind)?,
198 message: value.message,
199 details: value
200 .details
201 .map(aion_core::Payload::try_from)
202 .transpose()?,
203 })
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use prost::Message;
210 use serde_json::json;
211
212 use super::{
213 ProtoActivityError, ProtoActivityErrorKind, ProtoActivityResult, ProtoActivityTask,
214 ProtoDrainRequest, ProtoHeartbeat, ProtoRegisterAck, ProtoRegisterWorker, ProtoResultAck,
215 proto_activity_result,
216 };
217 use crate::{ProtoActivityId, ProtoPayload, ProtoWorkflowId, WireError};
218
219 fn workflow_id() -> aion_core::WorkflowId {
220 aion_core::WorkflowId::new(uuid::Uuid::nil())
221 }
222
223 #[test]
224 fn activity_error_round_trips_preserving_classification() -> Result<(), WireError> {
225 let core = aion_core::ActivityError {
226 kind: aion_core::ActivityErrorKind::Retryable,
227 message: String::from("connection reset"),
228 details: Some(
229 aion_core::Payload::from_json(&json!({"retry_after_ms": 500}))
230 .map_err(|_| WireError::backend("test payload could not be created"))?,
231 ),
232 };
233
234 let proto = ProtoActivityError::from(core.clone());
235 assert_eq!(aion_core::ActivityError::try_from(proto.clone())?, core);
236 assert!(aion_core::ActivityError::try_from(proto)?.is_retryable());
237
238 let terminal = ProtoActivityError {
239 kind: ProtoActivityErrorKind::Terminal as i32,
240 message: String::from("invalid request"),
241 details: None,
242 };
243 assert!(!aion_core::ActivityError::try_from(terminal)?.is_retryable());
244
245 Ok(())
246 }
247
248 #[test]
249 fn worker_registration_round_trips_through_serde_and_proto()
250 -> Result<(), Box<dyn std::error::Error>> {
251 let registration = ProtoRegisterWorker {
252 namespace: String::from("tenant-a"),
253 activity_types: vec![String::from("charge-card"), String::from("send-email")],
254 };
255
256 assert_json_and_proto_round_trip(®istration)
257 }
258
259 #[test]
260 fn activity_task_round_trips_through_serde_and_proto() -> Result<(), Box<dyn std::error::Error>>
261 {
262 let task = ProtoActivityTask {
263 workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
264 activity_id: Some(ProtoActivityId::from(
265 aion_core::ActivityId::from_sequence_position(7),
266 )),
267 activity_type: String::from("charge-card"),
268 input: Some(ProtoPayload::from(aion_core::Payload::from_json(
269 &json!({"amount": 42}),
270 )?)),
271 attempt: 3,
272 };
273
274 assert_json_and_proto_round_trip(&task)
275 }
276
277 #[test]
278 fn drain_request_round_trips_through_serde_and_proto() -> Result<(), Box<dyn std::error::Error>>
279 {
280 assert_json_and_proto_round_trip(&ProtoDrainRequest {})
281 }
282
283 #[test]
284 fn register_ack_round_trips_through_serde_and_proto() -> Result<(), Box<dyn std::error::Error>>
285 {
286 let ack = ProtoRegisterAck {
287 worker_id: 7,
288 namespace: String::from("tenant-a"),
289 heartbeat_window_ms: 30_000,
290 };
291
292 assert_json_and_proto_round_trip(&ack)
293 }
294
295 #[test]
296 fn result_ack_round_trips_through_serde_and_proto() -> Result<(), Box<dyn std::error::Error>> {
297 let ack = ProtoResultAck {
298 workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
299 activity_id: Some(ProtoActivityId::from(
300 aion_core::ActivityId::from_sequence_position(11),
301 )),
302 };
303
304 assert_json_and_proto_round_trip(&ack)
305 }
306
307 #[cfg(feature = "generated")]
308 #[test]
309 fn server_to_worker_ack_arms_pin_oneof_tags_three_and_four()
310 -> Result<(), Box<dyn std::error::Error>> {
311 let register_ack = crate::generated::ServerToWorker {
314 message: Some(crate::generated::server_to_worker::Message::RegisterAck(
315 crate::generated::RegisterAck {
316 worker_id: 1,
317 namespace: String::from("tenant-a"),
318 heartbeat_window_ms: 1_000,
319 },
320 )),
321 };
322 let mut bytes = Vec::new();
323 register_ack.encode(&mut bytes)?;
324 assert_eq!(bytes.first(), Some(&0x1A));
325 assert_eq!(
326 crate::generated::ServerToWorker::decode(bytes.as_slice())?,
327 register_ack
328 );
329
330 let result_ack = crate::generated::ServerToWorker {
331 message: Some(crate::generated::server_to_worker::Message::ResultAck(
332 crate::generated::ResultAck {
333 workflow_id: None,
334 activity_id: None,
335 },
336 )),
337 };
338 let mut bytes = Vec::new();
339 result_ack.encode(&mut bytes)?;
340 assert_eq!(bytes.first(), Some(&0x22));
341 assert_eq!(
342 crate::generated::ServerToWorker::decode(bytes.as_slice())?,
343 result_ack
344 );
345 Ok(())
346 }
347
348 #[test]
349 fn activity_task_attempt_uses_wire_tag_five() -> Result<(), Box<dyn std::error::Error>> {
350 let task = ProtoActivityTask {
353 workflow_id: None,
354 activity_id: None,
355 activity_type: String::new(),
356 input: None,
357 attempt: 9,
358 };
359 let mut bytes = Vec::new();
360 task.encode(&mut bytes)?;
361 assert_eq!(bytes, vec![0x28, 9]);
362 Ok(())
363 }
364
365 #[test]
366 fn activity_success_result_round_trips_through_serde_and_proto()
367 -> Result<(), Box<dyn std::error::Error>> {
368 let result = ProtoActivityResult {
369 workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
370 activity_id: Some(ProtoActivityId::from(
371 aion_core::ActivityId::from_sequence_position(8),
372 )),
373 outcome: Some(proto_activity_result::Outcome::Result(ProtoPayload::from(
374 aion_core::Payload::from_json(&json!({"authorization": "ok"}))?,
375 ))),
376 };
377
378 assert_json_and_proto_round_trip(&result)
379 }
380
381 #[test]
382 fn activity_error_result_round_trips_through_serde_and_proto()
383 -> Result<(), Box<dyn std::error::Error>> {
384 let result = ProtoActivityResult {
385 workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
386 activity_id: Some(ProtoActivityId::from(
387 aion_core::ActivityId::from_sequence_position(9),
388 )),
389 outcome: Some(proto_activity_result::Outcome::Error(
390 ProtoActivityError::from(aion_core::ActivityError {
391 kind: aion_core::ActivityErrorKind::Terminal,
392 message: String::from("card declined"),
393 details: Some(aion_core::Payload::from_json(&json!({"code": "declined"}))?),
394 }),
395 )),
396 };
397
398 assert_json_and_proto_round_trip(&result)
399 }
400
401 #[test]
402 fn heartbeat_round_trips_through_serde_and_proto() -> Result<(), Box<dyn std::error::Error>> {
403 let heartbeat = ProtoHeartbeat {
404 workflow_id: Some(ProtoWorkflowId::from(workflow_id())),
405 activity_id: Some(ProtoActivityId::from(
406 aion_core::ActivityId::from_sequence_position(10),
407 )),
408 progress: Some(ProtoPayload::from(aion_core::Payload::from_json(
409 &json!({"percent": 50}),
410 )?)),
411 };
412
413 assert_json_and_proto_round_trip(&heartbeat)
414 }
415
416 fn assert_json_and_proto_round_trip<T>(value: &T) -> Result<(), Box<dyn std::error::Error>>
417 where
418 T: Message
419 + Default
420 + serde::Serialize
421 + serde::de::DeserializeOwned
422 + PartialEq
423 + std::fmt::Debug,
424 {
425 assert_eq!(
426 serde_json::from_str::<T>(&serde_json::to_string(value)?)?,
427 *value
428 );
429 assert_eq!(prost_round_trip(value)?, *value);
430 Ok(())
431 }
432
433 fn prost_round_trip<T>(value: &T) -> Result<T, Box<dyn std::error::Error>>
434 where
435 T: Message + Default,
436 {
437 let mut bytes = Vec::new();
438 value.encode(&mut bytes)?;
439 Ok(T::decode(bytes.as_slice())?)
440 }
441}