use crate::context::Context;
use crate::error::TaskError;
use crate::storage::Storage;
use crate::tasks::r#trait::{Task, TaskOutput};
use async_trait::async_trait;
use futures::future::BoxFuture;
use futures::FutureExt;
use std::future::Future;
use std::sync::Arc;
type StepFn = Box<
dyn Fn(Arc<Context>, serde_json::Value) -> BoxFuture<'static, Result<StepResult, TaskError>>
+ Send
+ Sync,
>;
#[derive(Debug, Clone)]
pub struct StepResult {
pub state: serde_json::Value,
pub output: TaskOutput,
}
impl StepResult {
pub fn new(state: serde_json::Value, output: TaskOutput) -> Self {
StepResult { state, output }
}
}
pub struct StatefulTask {
id: String,
deps: Vec<String>,
priority: u8,
storage: Arc<dyn Storage>,
step: StepFn,
}
impl StatefulTask {
pub fn new<F, Fut>(id: impl Into<String>, storage: Arc<dyn Storage>, step: F) -> Self
where
F: Fn(Arc<Context>, serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<StepResult, TaskError>> + Send + 'static,
{
StatefulTask {
id: id.into(),
deps: Vec::new(),
priority: 0,
storage,
step: Box::new(move |ctx, state| step(ctx, state).boxed()),
}
}
pub fn with_deps<I, S>(mut self, deps: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.deps = deps.into_iter().map(Into::into).collect();
self
}
pub fn with_priority(mut self, priority: u8) -> Self {
self.priority = priority;
self
}
fn state_key(&self) -> String {
format!("state:{}", self.id)
}
}
#[async_trait]
impl Task for StatefulTask {
fn id(&self) -> &str {
&self.id
}
fn dependencies(&self) -> Vec<String> {
self.deps.clone()
}
fn priority(&self) -> u8 {
self.priority
}
async fn execute(&self, ctx: Arc<Context>) -> Result<TaskOutput, TaskError> {
let key = self.state_key();
let prev = self
.storage
.load(&key)
.await
.map_err(|e| TaskError::execution(format!("loading checkpoint: {e}")))?
.unwrap_or(serde_json::Value::Null);
let result = (self.step)(ctx, prev).await?;
self.storage
.save(&key, &result.state)
.await
.map_err(|e| TaskError::execution(format!("saving checkpoint: {e}")))?;
Ok(result.output)
}
}