use crate::error::CanoError;
use crate::resource::Resources;
use crate::task::{TaskConfig, TaskResult};
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::borrow::Cow;
use std::fmt;
use std::future::Future;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StepOutcome<TCursor, TState> {
More(TCursor),
Done(TaskResult<TState>),
}
#[crate::task::stepped]
pub trait SteppedTask<TState, TResourceKey = Cow<'static, str>>: Send + Sync
where
TState: Clone + fmt::Debug + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
type Cursor: Serialize + DeserializeOwned + Send + Sync + 'static;
fn config(&self) -> TaskConfig {
TaskConfig::default()
}
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed(std::any::type_name::<Self>())
}
async fn step(
&self,
res: &Resources<TResourceKey>,
cursor: Option<Self::Cursor>,
) -> Result<StepOutcome<Self::Cursor, TState>, CanoError>;
}
pub async fn run_stepped<S, S2, K>(s: &S, res: &Resources<K>) -> Result<TaskResult<S2>, CanoError>
where
S: SteppedTask<S2, K> + ?Sized,
S2: Clone + fmt::Debug + Send + Sync + 'static,
K: Hash + Eq + Send + Sync + 'static,
{
let mut cursor: Option<S::Cursor> = None;
loop {
match s.step(res, cursor).await? {
StepOutcome::More(new_cursor) => {
cursor = Some(new_cursor);
}
StepOutcome::Done(result) => return Ok(result),
}
}
}
pub type DefaultStepCursor = Vec<u8>;
pub type DynSteppedTask<TState, TResourceKey = Cow<'static, str>> =
dyn SteppedTask<TState, TResourceKey, Cursor = Vec<u8>> + Send + Sync;
pub type SteppedTaskObject<TState, TResourceKey = Cow<'static, str>> =
std::sync::Arc<DynSteppedTask<TState, TResourceKey>>;
pub enum ErasedStep<TState> {
More(Vec<u8>),
Done(TaskResult<TState>),
}
pub type StepFuture<'a, TState> =
Pin<Box<dyn Future<Output = Result<ErasedStep<TState>, CanoError>> + Send + 'a>>;
pub trait ErasedSteppedTask<TState, TResourceKey>: Send + Sync
where
TState: Clone + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
fn name(&self) -> Cow<'static, str>;
fn config(&self) -> TaskConfig;
fn step<'a>(
&'a self,
res: &'a Resources<TResourceKey>,
cursor_bytes: Option<Vec<u8>>,
) -> StepFuture<'a, TState>;
}
pub(crate) struct SteppedAdapter<T>(pub Arc<T>);
impl<TState, TResourceKey, T> ErasedSteppedTask<TState, TResourceKey> for SteppedAdapter<T>
where
TState: Clone + fmt::Debug + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
T: SteppedTask<TState, TResourceKey> + 'static,
{
fn name(&self) -> Cow<'static, str> {
self.0.name()
}
fn config(&self) -> TaskConfig {
self.0.config()
}
fn step<'a>(
&'a self,
res: &'a Resources<TResourceKey>,
cursor_bytes: Option<Vec<u8>>,
) -> StepFuture<'a, TState> {
Box::pin(async move {
let cursor: Option<T::Cursor> = match cursor_bytes {
None => None,
Some(ref b) => Some(serde_json::from_slice(b).map_err(|e| {
CanoError::task_execution(format!(
"deserialize cursor for `{}`: {e}",
self.0.name()
))
})?),
};
match self.0.step(res, cursor).await? {
StepOutcome::More(c) => {
let blob = serde_json::to_vec(&c).map_err(|e| {
CanoError::task_execution(format!(
"serialize cursor for `{}`: {e}",
self.0.name()
))
})?;
Ok(ErasedStep::More(blob))
}
StepOutcome::Done(result) => Ok(ErasedStep::Done(result)),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::resource::Resources;
use crate::task;
use crate::task::Task;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum MyState {
Work,
Done,
Next,
}
struct ImmediateStepper;
#[task::stepped]
impl SteppedTask<MyState> for ImmediateStepper {
type Cursor = u32;
async fn step(
&self,
_res: &Resources,
_cursor: Option<u32>,
) -> Result<StepOutcome<u32, MyState>, CanoError> {
Ok(StepOutcome::Done(TaskResult::Single(MyState::Done)))
}
}
#[tokio::test]
async fn test_stepped_task_immediate_via_step() {
let stepper = ImmediateStepper;
let res = Resources::new();
let result = SteppedTask::step(&stepper, &res, None).await.unwrap();
assert_eq!(result, StepOutcome::Done(TaskResult::Single(MyState::Done)));
}
#[tokio::test]
async fn test_stepped_task_immediate_via_task_run() {
let stepper = ImmediateStepper;
let res = Resources::new();
let result = Task::run(&stepper, &res).await.unwrap();
assert_eq!(result, TaskResult::Single(MyState::Done));
}
struct CountingStepper {
target: u32,
}
impl CountingStepper {
fn new(target: u32) -> Self {
Self { target }
}
}
#[task::stepped]
impl SteppedTask<MyState> for CountingStepper {
type Cursor = u32;
async fn step(
&self,
_res: &Resources,
cursor: Option<u32>,
) -> Result<StepOutcome<u32, MyState>, CanoError> {
let n = cursor.unwrap_or(0) + 1;
if n >= self.target {
Ok(StepOutcome::Done(TaskResult::Single(MyState::Done)))
} else {
Ok(StepOutcome::More(n))
}
}
}
#[tokio::test]
async fn test_stepped_task_multiple_steps() {
let stepper = CountingStepper::new(5);
let res = Resources::new();
let result = Task::run(&stepper, &res).await.unwrap();
assert_eq!(result, TaskResult::Single(MyState::Done));
}
#[tokio::test]
async fn test_stepped_task_cursor_threading() {
let stepper = CountingStepper::new(3);
let res = Resources::new();
let r1 = SteppedTask::step(&stepper, &res, None).await.unwrap();
assert_eq!(r1, StepOutcome::More(1));
let r2 = SteppedTask::step(&stepper, &res, Some(1)).await.unwrap();
assert_eq!(r2, StepOutcome::More(2));
let r3 = SteppedTask::step(&stepper, &res, Some(2)).await.unwrap();
assert_eq!(r3, StepOutcome::Done(TaskResult::Single(MyState::Done)));
}
struct ErrorStepper;
#[task::stepped]
impl SteppedTask<MyState> for ErrorStepper {
type Cursor = u32;
async fn step(
&self,
_res: &Resources,
_cursor: Option<u32>,
) -> Result<StepOutcome<u32, MyState>, CanoError> {
Err(CanoError::task_execution("step failed"))
}
}
#[tokio::test]
async fn test_stepped_task_error_propagates() {
let stepper = ErrorStepper;
let res = Resources::new();
let err = Task::run(&stepper, &res).await.unwrap_err();
assert!(matches!(err, CanoError::TaskExecution(_)));
}
struct SplitStepper;
#[task::stepped]
impl SteppedTask<MyState> for SplitStepper {
type Cursor = u32;
async fn step(
&self,
_res: &Resources,
_cursor: Option<u32>,
) -> Result<StepOutcome<u32, MyState>, CanoError> {
Ok(StepOutcome::Done(TaskResult::Split(vec![
MyState::Work,
MyState::Next,
])))
}
}
#[tokio::test]
async fn test_stepped_task_split() {
let stepper = SplitStepper;
let res = Resources::new();
let result = Task::run(&stepper, &res).await.unwrap();
assert_eq!(
result,
TaskResult::Split(vec![MyState::Work, MyState::Next])
);
}
struct CustomStepper;
#[task::stepped]
impl SteppedTask<MyState> for CustomStepper {
type Cursor = u32;
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed("my-custom-stepper")
}
async fn step(
&self,
_res: &Resources,
_cursor: Option<u32>,
) -> Result<StepOutcome<u32, MyState>, CanoError> {
Ok(StepOutcome::Done(TaskResult::Single(MyState::Done)))
}
}
#[test]
fn test_stepped_task_config_override() {
let stepper = CustomStepper;
assert_eq!(
SteppedTask::<MyState>::config(&stepper)
.retry_mode
.max_attempts(),
1
);
}
#[test]
fn test_stepped_task_name_override() {
let stepper = CustomStepper;
assert_eq!(SteppedTask::<MyState>::name(&stepper), "my-custom-stepper");
}
#[test]
fn test_companion_task_forwards_config_and_name() {
let stepper = CustomStepper;
assert_eq!(Task::config(&stepper).retry_mode.max_attempts(), 1);
assert_eq!(Task::name(&stepper), "my-custom-stepper");
}
struct DefaultConfigStepper;
#[task::stepped]
impl SteppedTask<MyState> for DefaultConfigStepper {
type Cursor = u32;
async fn step(
&self,
_res: &Resources,
_cursor: Option<u32>,
) -> Result<StepOutcome<u32, MyState>, CanoError> {
Ok(StepOutcome::Done(TaskResult::Single(MyState::Done)))
}
}
#[test]
fn test_stepped_task_default_config_has_retries() {
let stepper = DefaultConfigStepper;
assert!(
SteppedTask::<MyState>::config(&stepper)
.retry_mode
.max_attempts()
> 1,
"SteppedTask default config must have retries"
);
}
#[test]
fn test_stepped_task_default_name_contains_type_name() {
let stepper = DefaultConfigStepper;
let name = SteppedTask::<MyState>::name(&stepper);
assert!(
name.contains("DefaultConfigStepper"),
"default name should contain the type name, got: {name}",
);
}
#[tokio::test]
async fn test_stepped_task_as_dyn_task() {
let stepper: Arc<dyn Task<MyState>> = Arc::new(ImmediateStepper);
let res = Resources::new();
let result = Task::run(stepper.as_ref(), &res).await.unwrap();
assert_eq!(result, TaskResult::Single(MyState::Done));
}
#[tokio::test]
async fn test_run_stepped_dyn_dispatch() {
let stepper: &dyn SteppedTask<MyState, Cursor = u32> = &ImmediateStepper;
let res = Resources::new();
let result = run_stepped(stepper, &res).await.unwrap();
assert_eq!(result, TaskResult::Single(MyState::Done));
}
#[test]
fn test_step_enum_more_clone_eq() {
let s1: StepOutcome<u32, MyState> = StepOutcome::More(42);
let s2 = s1.clone();
assert_eq!(s1, s2);
}
#[test]
fn test_step_enum_done_clone_eq() {
let s1: StepOutcome<u32, MyState> = StepOutcome::Done(TaskResult::Single(MyState::Done));
let s2 = s1.clone();
assert_eq!(s1, s2);
}
#[tokio::test]
async fn test_stepped_task_in_workflow() {
use crate::workflow::Workflow;
use cano_macros::task;
struct NextTask;
#[task]
impl Task<MyState> for NextTask {
async fn run_bare(&self) -> Result<TaskResult<MyState>, CanoError> {
Ok(TaskResult::Single(MyState::Done))
}
}
let stepper = CountingStepper::new(3);
let workflow = Workflow::bare()
.register(MyState::Work, stepper)
.register(MyState::Next, NextTask)
.add_exit_state(MyState::Done);
let result = workflow.orchestrate(MyState::Work).await.unwrap();
assert_eq!(result, MyState::Done);
}
struct TrackedStepper {
calls: Arc<AtomicU32>,
target: u32,
}
impl TrackedStepper {
fn new(target: u32) -> (Self, Arc<AtomicU32>) {
let calls = Arc::new(AtomicU32::new(0));
(
Self {
calls: Arc::clone(&calls),
target,
},
calls,
)
}
}
#[task::stepped]
impl SteppedTask<MyState> for TrackedStepper {
type Cursor = u32;
async fn step(
&self,
_res: &Resources,
cursor: Option<u32>,
) -> Result<StepOutcome<u32, MyState>, CanoError> {
let n = self.calls.fetch_add(1, Ordering::Relaxed);
let pos = cursor.unwrap_or(0) + 1;
if n + 1 >= self.target {
Ok(StepOutcome::Done(TaskResult::Single(MyState::Done)))
} else {
Ok(StepOutcome::More(pos))
}
}
}
#[tokio::test]
async fn test_stepped_task_step_count() {
let (stepper, calls) = TrackedStepper::new(4);
let res = Resources::new();
let result = Task::run(&stepper, &res).await.unwrap();
assert_eq!(result, TaskResult::Single(MyState::Done));
assert_eq!(calls.load(Ordering::Relaxed), 4);
}
}