1use std::any::Any;
4use std::collections::{BTreeMap, BTreeSet};
5use std::future::Future;
6use std::panic::AssertUnwindSafe;
7use std::pin::Pin;
8
9use aion_core::{ActivityError, ActivityErrorKind, Payload};
10use async_trait::async_trait;
11use futures::FutureExt;
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use tracing::error;
15
16use crate::context::ActivityContext;
17use crate::error::{MissingActivityHandler, WorkerError};
18use crate::protocol::ActivityTask;
19use crate::runtime::loop_::{ActivityDispatcher, DispatchOutcome};
20
21#[derive(Clone, Debug, PartialEq, Eq)]
23pub enum Classification {
24 Retryable,
26 Terminal,
28}
29
30#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
32#[error("{message}")]
33pub struct ActivityFailure {
34 classification: Classification,
35 message: String,
36 detail: Option<Payload>,
37}
38
39impl ActivityFailure {
40 #[must_use]
42 pub fn retryable(message: impl Into<String>) -> Self {
43 Self::new(Classification::Retryable, message, None)
44 }
45
46 #[must_use]
48 pub fn terminal(message: impl Into<String>) -> Self {
49 Self::new(Classification::Terminal, message, None)
50 }
51
52 #[must_use]
54 pub fn with_detail(mut self, detail: Payload) -> Self {
55 self.detail = Some(detail);
56 self
57 }
58
59 #[must_use]
61 pub const fn classification(&self) -> &Classification {
62 &self.classification
63 }
64
65 #[must_use]
67 pub fn message(&self) -> &str {
68 &self.message
69 }
70
71 #[must_use]
73 pub const fn detail(&self) -> Option<&Payload> {
74 self.detail.as_ref()
75 }
76
77 fn new(
78 classification: Classification,
79 message: impl Into<String>,
80 detail: Option<Payload>,
81 ) -> Self {
82 Self {
83 classification,
84 message: message.into(),
85 detail,
86 }
87 }
88}
89
90impl From<Classification> for ActivityErrorKind {
91 fn from(value: Classification) -> Self {
92 match value {
93 Classification::Retryable => Self::Retryable,
94 Classification::Terminal => Self::Terminal,
95 }
96 }
97}
98
99impl From<ActivityFailure> for ActivityError {
100 fn from(value: ActivityFailure) -> Self {
101 Self {
102 kind: ActivityErrorKind::from(value.classification),
103 message: value.message,
104 details: value.detail,
105 }
106 }
107}
108
109pub type HandlerFuture<'context, Output> =
111 Pin<Box<dyn Future<Output = Result<Output, ActivityFailure>> + Send + 'context>>;
112
113type BoxedHandler<Input, Output> = Box<
114 dyn for<'context> Fn(Input, &'context ActivityContext) -> HandlerFuture<'context, Output>
115 + Send
116 + Sync,
117>;
118
119#[derive(Default)]
121pub struct ActivityRegistry {
122 handlers: BTreeMap<String, Box<dyn ErasedActivityHandler>>,
123}
124
125impl ActivityRegistry {
126 #[must_use]
128 pub fn new() -> Self {
129 Self::default()
130 }
131
132 pub fn register_activity<Input, Output, Handler>(
138 mut self,
139 activity_type: impl Into<String>,
140 handler: Handler,
141 ) -> Result<Self, WorkerError>
142 where
143 Input: Serialize + DeserializeOwned + Send + Sync + 'static,
144 Output: Serialize + Send + Sync + 'static,
145 Handler: for<'context> Fn(Input, &'context ActivityContext) -> HandlerFuture<'context, Output>
146 + Send
147 + Sync
148 + 'static,
149 {
150 let activity_type = activity_type.into();
151 if self.handlers.contains_key(&activity_type) {
152 return Err(WorkerError::registration(DuplicateActivityType {
153 activity_type,
154 }));
155 }
156 self.handlers
157 .insert(activity_type, Box::new(TypedHandler::new(handler)));
158 Ok(self)
159 }
160
161 #[must_use]
163 pub fn is_empty(&self) -> bool {
164 self.handlers.is_empty()
165 }
166
167 #[must_use]
169 pub fn activity_types(&self) -> BTreeSet<String> {
170 self.handlers.keys().cloned().collect()
171 }
172}
173
174#[async_trait]
175impl ActivityDispatcher for ActivityRegistry {
176 async fn dispatch(
177 &self,
178 task: ActivityTask,
179 context: ActivityContext,
180 ) -> Result<DispatchOutcome, WorkerError> {
181 let Some(handler) = self.handlers.get(&task.activity_type) else {
182 return Err(WorkerError::registration(MissingActivityHandler {
183 activity_type: task.activity_type,
184 }));
185 };
186 handler.dispatch(task, context).await
187 }
188
189 fn activity_types(&self) -> BTreeSet<String> {
190 self.activity_types()
191 }
192}
193
194pub type TypedActivityDispatcher = ActivityRegistry;
196
197pub fn decode_payload<T>(payload: &Payload) -> Result<T, WorkerError>
204where
205 T: DeserializeOwned,
206{
207 let value = payload.to_json().map_err(WorkerError::decode)?;
208 serde_json::from_value(value).map_err(WorkerError::decode)
209}
210
211pub fn encode_payload<T>(value: &T) -> Result<Payload, WorkerError>
217where
218 T: Serialize,
219{
220 let value = serde_json::to_value(value).map_err(WorkerError::encode)?;
221 Payload::from_json(&value).map_err(WorkerError::encode)
222}
223
224#[async_trait]
225trait ErasedActivityHandler: Send + Sync {
226 async fn dispatch(
227 &self,
228 task: ActivityTask,
229 context: ActivityContext,
230 ) -> Result<DispatchOutcome, WorkerError>;
231}
232
233struct TypedHandler<Input, Output> {
234 handler: BoxedHandler<Input, Output>,
235}
236
237impl<Input, Output> TypedHandler<Input, Output> {
238 fn new(
239 handler: impl for<'context> Fn(
240 Input,
241 &'context ActivityContext,
242 ) -> HandlerFuture<'context, Output>
243 + Send
244 + Sync
245 + 'static,
246 ) -> Self {
247 Self {
248 handler: Box::new(handler),
249 }
250 }
251}
252
253#[async_trait]
254impl<Input, Output> ErasedActivityHandler for TypedHandler<Input, Output>
255where
256 Input: DeserializeOwned + Send + Sync + 'static,
257 Output: Serialize + Send + Sync + 'static,
258{
259 async fn dispatch(
260 &self,
261 task: ActivityTask,
262 context: ActivityContext,
263 ) -> Result<DispatchOutcome, WorkerError> {
264 let input = match decode_payload::<Input>(&task.input) {
265 Ok(input) => input,
266 Err(error) => {
267 error!(
268 activity_type = %task.activity_type,
269 activity_id = task.activity_id.sequence_position(),
270 attempt = task.attempt,
271 error = %error,
272 "failed to decode activity input; reporting terminal activity failure"
273 );
274 let failure =
275 ActivityFailure::terminal(format!("failed to decode activity input: {error}"));
276 return Ok(DispatchOutcome::Failed {
277 failure: ActivityError::from(failure),
278 });
279 }
280 };
281 let handler_future =
282 match std::panic::catch_unwind(AssertUnwindSafe(|| (self.handler)(input, &context))) {
283 Ok(handler_future) => handler_future,
284 Err(panic) => return Ok(panic_failure(&task, &panic)),
285 };
286 let handler_result = AssertUnwindSafe(handler_future).catch_unwind().await;
287 let outcome = match handler_result {
288 Ok(Ok(output)) => DispatchOutcome::Completed {
289 output: encode_payload(&output)?,
290 },
291 Ok(Err(failure)) => DispatchOutcome::Failed {
292 failure: ActivityError::from(failure),
293 },
294 Err(panic) => panic_failure(&task, &panic),
295 };
296 Ok(outcome)
297 }
298}
299
300fn panic_failure(task: &ActivityTask, panic: &Box<dyn Any + Send>) -> DispatchOutcome {
301 let message = panic_message(panic);
302 error!(
303 activity_type = %task.activity_type,
304 activity_id = task.activity_id.sequence_position(),
305 attempt = task.attempt,
306 panic = %message,
307 "activity handler panicked; reporting retryable activity failure"
308 );
309 DispatchOutcome::Failed {
310 failure: ActivityError::from(ActivityFailure::retryable(format!(
311 "activity handler panicked: {message}"
312 ))),
313 }
314}
315
316fn panic_message(panic: &Box<dyn Any + Send>) -> String {
317 if let Some(message) = panic.downcast_ref::<&str>() {
318 return (*message).to_owned();
319 }
320 if let Some(message) = panic.downcast_ref::<String>() {
321 return message.clone();
322 }
323 String::from("unknown panic payload")
324}
325
326#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
328#[error("activity type `{activity_type}` already has a registered handler")]
329pub struct DuplicateActivityType {
330 pub activity_type: String,
332}
333
334#[cfg(test)]
335mod tests {
336 use aion_core::{ActivityError, ActivityId, ContentType, WorkflowId};
337 use aion_proto::{
338 ProtoActivityError, ProtoActivityErrorKind, ProtoActivityId, ProtoActivityTask,
339 ProtoPayload, ProtoWorkflowId,
340 };
341 use serde::{Deserialize, Serialize};
342
343 use super::{ActivityFailure, ActivityRegistry, decode_payload, encode_payload};
344 use crate::WorkerError;
345 use crate::runtime::{ActivityDispatcher, DispatchOutcome};
346
347 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
348 struct TestInput {
349 value: i32,
350 }
351
352 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
353 struct TestOutput {
354 doubled: i32,
355 }
356
357 #[test]
358 fn retryable_and_terminal_failures_map_to_distinct_wire_classifications() {
359 let retryable = ActivityFailure::retryable("temporary outage");
360 let terminal = ActivityFailure::terminal("invalid request");
361
362 let retryable_core = ActivityError::from(retryable);
363 let terminal_core = ActivityError::from(terminal);
364 let retryable_wire = ProtoActivityError::from(retryable_core);
365 let terminal_wire = ProtoActivityError::from(terminal_core);
366
367 assert_eq!(
368 retryable_wire.kind,
369 ProtoActivityErrorKind::Retryable as i32
370 );
371 assert_eq!(terminal_wire.kind, ProtoActivityErrorKind::Terminal as i32);
372 }
373
374 #[tokio::test]
375 async fn typed_activity_round_trips_through_registry() -> Result<(), WorkerError> {
376 let registry =
377 ActivityRegistry::new().register_activity("double", |input: TestInput, context| {
378 Box::pin(async move {
379 assert_eq!(context.attempt(), 1);
380 Ok(TestOutput {
381 doubled: input.value * 2,
382 })
383 })
384 })?;
385 let task = proto_task("double", &TestInput { value: 21 })?;
386 let (context, cancellation) = crate::ActivityContext::for_workflow(
387 Some(WorkflowId::new_v4()),
388 ActivityId::from_sequence_position(99),
389 1,
390 None,
391 );
392 drop(cancellation);
393
394 let outcome = registry.dispatch(task.try_into()?, context).await?;
395
396 let DispatchOutcome::Completed { output } = outcome else {
397 return Err(WorkerError::decode(UnexpectedFailure));
398 };
399 assert_eq!(output.content_type(), &ContentType::Json);
400 let decoded: TestOutput = decode_payload(&output)?;
401 assert_eq!(decoded, TestOutput { doubled: 42 });
402 Ok(())
403 }
404
405 #[test]
406 fn duplicate_activity_registration_is_rejected() -> Result<(), WorkerError> {
407 let registry =
408 ActivityRegistry::new().register_activity("double", |input: TestInput, context| {
409 Box::pin(async move {
410 let _ = context;
411 Ok(TestOutput {
412 doubled: input.value * 2,
413 })
414 })
415 })?;
416
417 let error = registry
418 .register_activity("double", |input: TestInput, context| {
419 Box::pin(async move {
420 let _ = context;
421 Ok(TestOutput {
422 doubled: input.value,
423 })
424 })
425 })
426 .err()
427 .ok_or_else(|| WorkerError::decode(UnexpectedFailure))?;
428
429 assert!(
430 error
431 .to_string()
432 .contains("already has a registered handler")
433 );
434 Ok(())
435 }
436
437 fn proto_task(
438 activity_type: &str,
439 input: &TestInput,
440 ) -> Result<ProtoActivityTask, WorkerError> {
441 Ok(ProtoActivityTask {
442 workflow_id: Some(ProtoWorkflowId::from(WorkflowId::new_v4())),
443 activity_id: Some(ProtoActivityId::from(ActivityId::from_sequence_position(1))),
444 activity_type: activity_type.to_owned(),
445 input: Some(ProtoPayload::from(encode_payload(&input)?)),
446 })
447 }
448
449 #[derive(Debug, thiserror::Error)]
450 #[error("expected completed activity outcome")]
451 struct UnexpectedFailure;
452}