use crate::config::Config;
use crate::errors::Result;
use crate::state::State;
use async_trait::async_trait;
use std::fmt::Debug;
use std::future::Future;
use std::sync::Arc;
#[async_trait]
pub trait Node<S: State>: Send + Sync {
async fn invoke(&self, state: S, config: &Config) -> Result<S>;
}
#[async_trait]
impl<S, F, Fut> Node<S> for F
where
S: State,
F: Fn(S, &Config) -> Fut + Send + Sync,
Fut: Future<Output = Result<S>> + Send,
{
async fn invoke(&self, state: S, config: &Config) -> Result<S> {
self(state, config).await
}
}
pub type NodeBox<S> = Box<dyn Node<S>>;
pub type NodeArc<S> = Arc<dyn Node<S>>;
#[derive(Clone)]
pub struct PregelNode<S: State> {
pub name: String,
pub channels: Vec<String>,
pub triggers: Vec<String>,
pub bound: NodeArc<S>,
pub writers: Vec<ChannelWrite>,
}
impl<S: State> Debug for PregelNode<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PregelNode")
.field("name", &self.name)
.field("channels", &self.channels)
.field("triggers", &self.triggers)
.field("bound", &"<node>")
.field("writers", &self.writers)
.finish()
}
}
impl<S: State> PregelNode<S> {
pub fn new(
name: impl Into<String>,
channels: Vec<String>,
triggers: Vec<String>,
bound: NodeArc<S>,
writers: Vec<ChannelWrite>,
) -> Self {
Self {
name: name.into(),
channels,
triggers,
bound,
writers,
}
}
pub fn from_node(
name: impl Into<String>,
channels: Vec<String>,
triggers: Vec<String>,
bound: impl Node<S> + 'static,
writers: Vec<ChannelWrite>,
) -> Self {
Self {
name: name.into(),
channels,
triggers,
bound: Arc::new(bound),
writers,
}
}
pub fn is_triggered(&self, written_channels: &[String]) -> bool {
self.triggers.iter().any(|t| written_channels.contains(t))
}
}
#[derive(Debug, Clone)]
pub struct ChannelWrite {
pub channel: String,
pub skip_none: bool,
pub mapper: Option<String>,
}
impl ChannelWrite {
pub fn new(channel: impl Into<String>) -> Self {
Self {
channel: channel.into(),
skip_none: true,
mapper: None,
}
}
pub fn with_skip_none(mut self, skip: bool) -> Self {
self.skip_none = skip;
self
}
}
pub fn node_fn<S, F, Fut>(f: F) -> impl Node<S>
where
S: State,
F: Fn(S, &Config) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<S>> + Send + 'static,
{
f
}
pub fn simple_node<S, F, Fut>(f: F) -> impl Node<S>
where
S: State,
F: Fn(S) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<S>> + Send + 'static,
{
move |state: S, _config: &Config| f(state)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::State as StateTrait;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize)]
struct TestState {
count: i32,
}
impl StateTrait for TestState {
fn merge(&mut self, other: Self) -> Result<()> {
self.count += other.count;
Ok(())
}
}
#[tokio::test]
async fn test_node_from_closure() {
let node = |mut state: TestState, _config: &Config| async move {
state.count += 1;
Ok(state)
};
let state = TestState { count: 0 };
let result = node.invoke(state, &Config::default()).await.unwrap();
assert_eq!(result.count, 1);
}
#[tokio::test]
async fn test_simple_node() {
let node = simple_node(|mut state: TestState| async move {
state.count += 10;
Ok(state)
});
let state = TestState { count: 5 };
let result = node.invoke(state, &Config::default()).await.unwrap();
assert_eq!(result.count, 15);
}
struct CustomNode;
#[async_trait]
impl Node<TestState> for CustomNode {
async fn invoke(&self, mut state: TestState, _config: &Config) -> Result<TestState> {
state.count *= 2;
Ok(state)
}
}
#[tokio::test]
async fn test_custom_node() {
let node = CustomNode;
let state = TestState { count: 5 };
let result = node.invoke(state, &Config::default()).await.unwrap();
assert_eq!(result.count, 10);
}
#[test]
fn test_pregel_node_is_triggered() {
let node = PregelNode::from_node(
"test",
vec!["in".to_string()],
vec!["trigger_a".to_string(), "trigger_b".to_string()],
|state: TestState, _: &Config| async move { Ok(state) },
vec![],
);
assert!(node.is_triggered(&["trigger_a".to_string()]));
assert!(node.is_triggered(&["trigger_b".to_string()]));
assert!(node.is_triggered(&["trigger_a".to_string(), "other".to_string()]));
assert!(!node.is_triggered(&["other".to_string()]));
assert!(!node.is_triggered(&[]));
}
#[test]
fn test_channel_write() {
let write = ChannelWrite::new("output").with_skip_none(false);
assert_eq!(write.channel, "output");
assert!(!write.skip_none);
}
}