Skip to main content

aion_worker/
activity.rs

1//! `Activity` trait, `ActivityFailure`, and typed registration.
2
3use std::any::Any;
4use std::collections::{BTreeMap, BTreeSet};
5use std::future::Future;
6use std::panic::AssertUnwindSafe;
7use std::pin::Pin;
8
9use aion_core::{ActivityError, ActivityErrorKind, Payload};
10use async_trait::async_trait;
11use futures::FutureExt;
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use tracing::error;
15
16use crate::context::ActivityContext;
17use crate::error::{MissingActivityHandler, WorkerError};
18use crate::protocol::ActivityTask;
19use crate::runtime::loop_::{ActivityDispatcher, DispatchOutcome};
20
21/// Explicit retryability classification for an activity failure.
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub enum Classification {
24    /// The engine may retry the activity according to policy.
25    Retryable,
26    /// The activity failure is permanent and must not be retried.
27    Terminal,
28}
29
30/// Handler-returned failure with explicit retryability classification.
31#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
32#[error("{message}")]
33pub struct ActivityFailure {
34    classification: Classification,
35    message: String,
36    detail: Option<Payload>,
37}
38
39impl ActivityFailure {
40    /// Creates a retryable activity failure.
41    #[must_use]
42    pub fn retryable(message: impl Into<String>) -> Self {
43        Self::new(Classification::Retryable, message, None)
44    }
45
46    /// Creates a terminal activity failure.
47    #[must_use]
48    pub fn terminal(message: impl Into<String>) -> Self {
49        Self::new(Classification::Terminal, message, None)
50    }
51
52    /// Attaches opaque structured detail to this failure.
53    #[must_use]
54    pub fn with_detail(mut self, detail: Payload) -> Self {
55        self.detail = Some(detail);
56        self
57    }
58
59    /// Returns the explicit retryability classification.
60    #[must_use]
61    pub const fn classification(&self) -> &Classification {
62        &self.classification
63    }
64
65    /// Returns the human-readable failure message.
66    #[must_use]
67    pub fn message(&self) -> &str {
68        &self.message
69    }
70
71    /// Returns the optional structured failure detail.
72    #[must_use]
73    pub const fn detail(&self) -> Option<&Payload> {
74        self.detail.as_ref()
75    }
76
77    fn new(
78        classification: Classification,
79        message: impl Into<String>,
80        detail: Option<Payload>,
81    ) -> Self {
82        Self {
83            classification,
84            message: message.into(),
85            detail,
86        }
87    }
88}
89
90impl From<Classification> for ActivityErrorKind {
91    fn from(value: Classification) -> Self {
92        match value {
93            Classification::Retryable => Self::Retryable,
94            Classification::Terminal => Self::Terminal,
95        }
96    }
97}
98
99impl From<ActivityFailure> for ActivityError {
100    fn from(value: ActivityFailure) -> Self {
101        Self {
102            kind: ActivityErrorKind::from(value.classification),
103            message: value.message,
104            details: value.detail,
105        }
106    }
107}
108
109/// Boxed future returned by a typed activity handler.
110pub type HandlerFuture<'context, Output> =
111    Pin<Box<dyn Future<Output = Result<Output, ActivityFailure>> + Send + 'context>>;
112
113type BoxedHandler<Input, Output> = Box<
114    dyn for<'context> Fn(Input, &'context ActivityContext) -> HandlerFuture<'context, Output>
115        + Send
116        + Sync,
117>;
118
119/// Registry of typed activity handlers keyed by activity-type name.
120#[derive(Default)]
121pub struct ActivityRegistry {
122    handlers: BTreeMap<String, Box<dyn ErasedActivityHandler>>,
123}
124
125impl ActivityRegistry {
126    /// Creates an empty activity registry.
127    #[must_use]
128    pub fn new() -> Self {
129        Self::default()
130    }
131
132    /// Registers one typed activity handler under an activity-type name.
133    ///
134    /// # Errors
135    ///
136    /// Returns [`WorkerError::Registration`] when the name is already registered.
137    pub fn register_activity<Input, Output, Handler>(
138        mut self,
139        activity_type: impl Into<String>,
140        handler: Handler,
141    ) -> Result<Self, WorkerError>
142    where
143        Input: Serialize + DeserializeOwned + Send + Sync + 'static,
144        Output: Serialize + Send + Sync + 'static,
145        Handler: for<'context> Fn(Input, &'context ActivityContext) -> HandlerFuture<'context, Output>
146            + Send
147            + Sync
148            + 'static,
149    {
150        let activity_type = activity_type.into();
151        if self.handlers.contains_key(&activity_type) {
152            return Err(WorkerError::registration(DuplicateActivityType {
153                activity_type,
154            }));
155        }
156        self.handlers
157            .insert(activity_type, Box::new(TypedHandler::new(handler)));
158        Ok(self)
159    }
160
161    /// Returns true when no activity handlers have been registered.
162    #[must_use]
163    pub fn is_empty(&self) -> bool {
164        self.handlers.is_empty()
165    }
166
167    /// Returns the registered activity-type names in deterministic order.
168    #[must_use]
169    pub fn activity_types(&self) -> BTreeSet<String> {
170        self.handlers.keys().cloned().collect()
171    }
172}
173
174#[async_trait]
175impl ActivityDispatcher for ActivityRegistry {
176    async fn dispatch(
177        &self,
178        task: ActivityTask,
179        context: ActivityContext,
180    ) -> Result<DispatchOutcome, WorkerError> {
181        let Some(handler) = self.handlers.get(&task.activity_type) else {
182            return Err(WorkerError::registration(MissingActivityHandler {
183                activity_type: task.activity_type,
184            }));
185        };
186        handler.dispatch(task, context).await
187    }
188
189    fn activity_types(&self) -> BTreeSet<String> {
190        self.activity_types()
191    }
192}
193
194/// Backwards-compatible name for the typed activity registry used by the runtime.
195pub type TypedActivityDispatcher = ActivityRegistry;
196
197/// Decodes a payload into a typed value using the payload content-type tag.
198///
199/// # Errors
200///
201/// Returns [`WorkerError::Decode`] when the payload tag or bytes cannot produce
202/// the requested type.
203pub fn decode_payload<T>(payload: &Payload) -> Result<T, WorkerError>
204where
205    T: DeserializeOwned,
206{
207    let value = payload.to_json().map_err(WorkerError::decode)?;
208    serde_json::from_value(value).map_err(WorkerError::decode)
209}
210
211/// Encodes a typed value into the baseline JSON payload codec.
212///
213/// # Errors
214///
215/// Returns [`WorkerError::Encode`] when the value cannot be serialized.
216pub fn encode_payload<T>(value: &T) -> Result<Payload, WorkerError>
217where
218    T: Serialize,
219{
220    let value = serde_json::to_value(value).map_err(WorkerError::encode)?;
221    Payload::from_json(&value).map_err(WorkerError::encode)
222}
223
224#[async_trait]
225trait ErasedActivityHandler: Send + Sync {
226    async fn dispatch(
227        &self,
228        task: ActivityTask,
229        context: ActivityContext,
230    ) -> Result<DispatchOutcome, WorkerError>;
231}
232
233struct TypedHandler<Input, Output> {
234    handler: BoxedHandler<Input, Output>,
235}
236
237impl<Input, Output> TypedHandler<Input, Output> {
238    fn new(
239        handler: impl for<'context> Fn(
240            Input,
241            &'context ActivityContext,
242        ) -> HandlerFuture<'context, Output>
243        + Send
244        + Sync
245        + 'static,
246    ) -> Self {
247        Self {
248            handler: Box::new(handler),
249        }
250    }
251}
252
253#[async_trait]
254impl<Input, Output> ErasedActivityHandler for TypedHandler<Input, Output>
255where
256    Input: DeserializeOwned + Send + Sync + 'static,
257    Output: Serialize + Send + Sync + 'static,
258{
259    async fn dispatch(
260        &self,
261        task: ActivityTask,
262        context: ActivityContext,
263    ) -> Result<DispatchOutcome, WorkerError> {
264        let input = match decode_payload::<Input>(&task.input) {
265            Ok(input) => input,
266            Err(error) => {
267                error!(
268                    activity_type = %task.activity_type,
269                    activity_id = task.activity_id.sequence_position(),
270                    attempt = task.attempt,
271                    error = %error,
272                    "failed to decode activity input; reporting terminal activity failure"
273                );
274                let failure =
275                    ActivityFailure::terminal(format!("failed to decode activity input: {error}"));
276                return Ok(DispatchOutcome::Failed {
277                    failure: ActivityError::from(failure),
278                });
279            }
280        };
281        let handler_future =
282            match std::panic::catch_unwind(AssertUnwindSafe(|| (self.handler)(input, &context))) {
283                Ok(handler_future) => handler_future,
284                Err(panic) => return Ok(panic_failure(&task, &panic)),
285            };
286        let handler_result = AssertUnwindSafe(handler_future).catch_unwind().await;
287        let outcome = match handler_result {
288            Ok(Ok(output)) => DispatchOutcome::Completed {
289                output: encode_payload(&output)?,
290            },
291            Ok(Err(failure)) => DispatchOutcome::Failed {
292                failure: ActivityError::from(failure),
293            },
294            Err(panic) => panic_failure(&task, &panic),
295        };
296        Ok(outcome)
297    }
298}
299
300fn panic_failure(task: &ActivityTask, panic: &Box<dyn Any + Send>) -> DispatchOutcome {
301    let message = panic_message(panic);
302    error!(
303        activity_type = %task.activity_type,
304        activity_id = task.activity_id.sequence_position(),
305        attempt = task.attempt,
306        panic = %message,
307        "activity handler panicked; reporting retryable activity failure"
308    );
309    DispatchOutcome::Failed {
310        failure: ActivityError::from(ActivityFailure::retryable(format!(
311            "activity handler panicked: {message}"
312        ))),
313    }
314}
315
316fn panic_message(panic: &Box<dyn Any + Send>) -> String {
317    if let Some(message) = panic.downcast_ref::<&str>() {
318        return (*message).to_owned();
319    }
320    if let Some(message) = panic.downcast_ref::<String>() {
321        return message.clone();
322    }
323    String::from("unknown panic payload")
324}
325
326/// Error returned when an activity type is registered more than once.
327#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
328#[error("activity type `{activity_type}` already has a registered handler")]
329pub struct DuplicateActivityType {
330    /// Duplicate activity type name.
331    pub activity_type: String,
332}
333
334#[cfg(test)]
335mod tests {
336    use aion_core::{ActivityError, ActivityId, ContentType, WorkflowId};
337    use aion_proto::{
338        ProtoActivityError, ProtoActivityErrorKind, ProtoActivityId, ProtoActivityTask,
339        ProtoPayload, ProtoWorkflowId,
340    };
341    use serde::{Deserialize, Serialize};
342
343    use super::{ActivityFailure, ActivityRegistry, decode_payload, encode_payload};
344    use crate::WorkerError;
345    use crate::runtime::{ActivityDispatcher, DispatchOutcome};
346
347    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
348    struct TestInput {
349        value: i32,
350    }
351
352    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
353    struct TestOutput {
354        doubled: i32,
355    }
356
357    #[test]
358    fn retryable_and_terminal_failures_map_to_distinct_wire_classifications() {
359        let retryable = ActivityFailure::retryable("temporary outage");
360        let terminal = ActivityFailure::terminal("invalid request");
361
362        let retryable_core = ActivityError::from(retryable);
363        let terminal_core = ActivityError::from(terminal);
364        let retryable_wire = ProtoActivityError::from(retryable_core);
365        let terminal_wire = ProtoActivityError::from(terminal_core);
366
367        assert_eq!(
368            retryable_wire.kind,
369            ProtoActivityErrorKind::Retryable as i32
370        );
371        assert_eq!(terminal_wire.kind, ProtoActivityErrorKind::Terminal as i32);
372    }
373
374    #[tokio::test]
375    async fn typed_activity_round_trips_through_registry() -> Result<(), WorkerError> {
376        let registry =
377            ActivityRegistry::new().register_activity("double", |input: TestInput, context| {
378                Box::pin(async move {
379                    assert_eq!(context.attempt(), 1);
380                    Ok(TestOutput {
381                        doubled: input.value * 2,
382                    })
383                })
384            })?;
385        let task = proto_task("double", &TestInput { value: 21 })?;
386        let (context, cancellation) = crate::ActivityContext::for_workflow(
387            Some(WorkflowId::new_v4()),
388            ActivityId::from_sequence_position(99),
389            1,
390            None,
391        );
392        drop(cancellation);
393
394        let outcome = registry.dispatch(task.try_into()?, context).await?;
395
396        let DispatchOutcome::Completed { output } = outcome else {
397            return Err(WorkerError::decode(UnexpectedFailure));
398        };
399        assert_eq!(output.content_type(), &ContentType::Json);
400        let decoded: TestOutput = decode_payload(&output)?;
401        assert_eq!(decoded, TestOutput { doubled: 42 });
402        Ok(())
403    }
404
405    #[test]
406    fn duplicate_activity_registration_is_rejected() -> Result<(), WorkerError> {
407        let registry =
408            ActivityRegistry::new().register_activity("double", |input: TestInput, context| {
409                Box::pin(async move {
410                    let _ = context;
411                    Ok(TestOutput {
412                        doubled: input.value * 2,
413                    })
414                })
415            })?;
416
417        let error = registry
418            .register_activity("double", |input: TestInput, context| {
419                Box::pin(async move {
420                    let _ = context;
421                    Ok(TestOutput {
422                        doubled: input.value,
423                    })
424                })
425            })
426            .err()
427            .ok_or_else(|| WorkerError::decode(UnexpectedFailure))?;
428
429        assert!(
430            error
431                .to_string()
432                .contains("already has a registered handler")
433        );
434        Ok(())
435    }
436
437    fn proto_task(
438        activity_type: &str,
439        input: &TestInput,
440    ) -> Result<ProtoActivityTask, WorkerError> {
441        Ok(ProtoActivityTask {
442            workflow_id: Some(ProtoWorkflowId::from(WorkflowId::new_v4())),
443            activity_id: Some(ProtoActivityId::from(ActivityId::from_sequence_position(1))),
444            activity_type: activity_type.to_owned(),
445            input: Some(ProtoPayload::from(encode_payload(&input)?)),
446        })
447    }
448
449    #[derive(Debug, thiserror::Error)]
450    #[error("expected completed activity outcome")]
451    struct UnexpectedFailure;
452}