use crate::error::NodeError;
use crate::state::State;
use async_trait::async_trait;
use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
#[derive(Clone, Debug)]
pub enum NodeResult<S: State> {
Continue(S),
Interrupt {
state: S,
reason: String,
},
End(S),
}
impl<S: State> NodeResult<S> {
pub fn cont(state: S) -> Self {
Self::Continue(state)
}
pub fn interrupt(state: S, reason: impl Into<String>) -> Self {
Self::Interrupt {
state,
reason: reason.into(),
}
}
pub fn end(state: S) -> Self {
Self::End(state)
}
pub fn state(&self) -> &S {
match self {
Self::Continue(s) | Self::Interrupt { state: s, .. } | Self::End(s) => s,
}
}
pub fn into_state(self) -> S {
match self {
Self::Continue(s) | Self::Interrupt { state: s, .. } | Self::End(s) => s,
}
}
pub fn is_continue(&self) -> bool {
matches!(self, Self::Continue(_))
}
pub fn is_interrupt(&self) -> bool {
matches!(self, Self::Interrupt { .. })
}
pub fn is_end(&self) -> bool {
matches!(self, Self::End(_))
}
}
#[derive(Clone, Debug)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub multiplier: f64,
pub retry_on: RetryOn,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
multiplier: 2.0,
retry_on: RetryOn::All,
}
}
}
impl RetryConfig {
pub fn new(max_retries: u32) -> Self {
Self {
max_retries,
..Default::default()
}
}
pub fn with_initial_delay(mut self, delay: Duration) -> Self {
self.initial_delay = delay;
self
}
pub fn with_max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
let delay = self.initial_delay.as_secs_f64() * self.multiplier.powi(attempt as i32);
let delay = Duration::from_secs_f64(delay);
std::cmp::min(delay, self.max_delay)
}
}
#[derive(Clone, Debug)]
pub enum RetryOn {
All,
Specific(Vec<String>),
None,
}
#[async_trait]
pub trait Node<S: State>: Send + Sync {
fn id(&self) -> &str;
async fn run(&self, state: S) -> Result<NodeResult<S>, NodeError>;
fn retry_config(&self) -> Option<RetryConfig> {
None
}
fn timeout(&self) -> Option<Duration> {
None
}
fn description(&self) -> Option<&str> {
None
}
}
pub type BoxedNode<S> = Box<dyn Node<S>>;
pub struct FnNode<S, F, Fut>
where
S: State,
F: Fn(S) -> Fut + Send + Sync,
Fut: Future<Output = Result<NodeResult<S>, NodeError>> + Send,
{
id: String,
func: F,
retry_config: Option<RetryConfig>,
timeout: Option<Duration>,
description: Option<String>,
_phantom: std::marker::PhantomData<(S, Fut)>,
}
impl<S, F, Fut> FnNode<S, F, Fut>
where
S: State,
F: Fn(S) -> Fut + Send + Sync,
Fut: Future<Output = Result<NodeResult<S>, NodeError>> + Send,
{
pub fn new(id: impl Into<String>, func: F) -> Self {
Self {
id: id.into(),
func,
retry_config: None,
timeout: None,
description: None,
_phantom: std::marker::PhantomData,
}
}
pub fn with_retry(mut self, config: RetryConfig) -> Self {
self.retry_config = Some(config);
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
}
#[async_trait]
impl<S, F, Fut> Node<S> for FnNode<S, F, Fut>
where
S: State,
F: Fn(S) -> Fut + Send + Sync,
Fut: Future<Output = Result<NodeResult<S>, NodeError>> + Send,
{
fn id(&self) -> &str {
&self.id
}
async fn run(&self, state: S) -> Result<NodeResult<S>, NodeError> {
(self.func)(state).await
}
fn retry_config(&self) -> Option<RetryConfig> {
self.retry_config.clone()
}
fn timeout(&self) -> Option<Duration> {
self.timeout
}
fn description(&self) -> Option<&str> {
self.description.as_deref()
}
}
pub struct WrappedNode<S: State, Inner: Node<S>> {
inner: Inner,
before: Option<Arc<dyn Fn(&S) + Send + Sync>>,
after: Option<Arc<dyn Fn(&NodeResult<S>) + Send + Sync>>,
_phantom: std::marker::PhantomData<S>,
}
impl<S: State, Inner: Node<S>> WrappedNode<S, Inner> {
pub fn new(inner: Inner) -> Self {
Self {
inner,
before: None,
after: None,
_phantom: std::marker::PhantomData,
}
}
pub fn before(mut self, f: impl Fn(&S) + Send + Sync + 'static) -> Self {
self.before = Some(Arc::new(f));
self
}
pub fn after(mut self, f: impl Fn(&NodeResult<S>) + Send + Sync + 'static) -> Self {
self.after = Some(Arc::new(f));
self
}
}
#[async_trait]
impl<S: State, Inner: Node<S>> Node<S> for WrappedNode<S, Inner> {
fn id(&self) -> &str {
self.inner.id()
}
async fn run(&self, state: S) -> Result<NodeResult<S>, NodeError> {
if let Some(ref before) = self.before {
before(&state);
}
let result = self.inner.run(state).await?;
if let Some(ref after) = self.after {
after(&result);
}
Ok(result)
}
fn retry_config(&self) -> Option<RetryConfig> {
self.inner.retry_config()
}
fn timeout(&self) -> Option<Duration> {
self.inner.timeout()
}
fn description(&self) -> Option<&str> {
self.inner.description()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct TestState {
counter: u32,
}
impl State for TestState {
fn schema() -> serde_json::Value {
serde_json::json!({"type": "object"})
}
}
#[tokio::test]
async fn test_node_result() {
let state = TestState { counter: 42 };
let result = NodeResult::Continue(state.clone());
assert!(result.is_continue());
assert_eq!(result.state().counter, 42);
let result = NodeResult::End(state.clone());
assert!(result.is_end());
let result = NodeResult::Interrupt {
state,
reason: "Need input".to_string(),
};
assert!(result.is_interrupt());
}
#[tokio::test]
async fn test_retry_config() {
let config = RetryConfig::new(5)
.with_initial_delay(Duration::from_millis(50))
.with_max_delay(Duration::from_secs(5));
assert_eq!(config.max_retries, 5);
assert_eq!(config.initial_delay, Duration::from_millis(50));
let delay0 = config.delay_for_attempt(0);
let delay1 = config.delay_for_attempt(1);
assert!(delay1 > delay0);
}
}