use crate::channels::BaseChannel;
use crate::checkpoint::{BaseCheckpointSaver, Checkpoint, CheckpointMetadata, StateSnapshot};
use crate::config::Config;
use crate::errors::{Error, Result};
use crate::graph::START;
use crate::nodes::PregelNode;
use crate::state::State;
use crate::types::{StreamEvent, StreamMode};
use futures::stream::{Stream, StreamExt};
use std::collections::{HashMap, HashSet};
use std::pin::Pin;
use std::sync::Arc;
pub struct Pregel<S: State> {
nodes: HashMap<String, PregelNode<S>>,
channels: HashMap<String, Box<dyn BaseChannel>>,
checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
entry_point: String,
finish_points: HashSet<String>,
edges: HashMap<String, Vec<String>>,
current_step: usize,
recursion_limit: usize,
written_channels: HashSet<String>,
}
impl<S: State> Pregel<S> {
pub fn new(
nodes: HashMap<String, PregelNode<S>>,
channels: HashMap<String, Box<dyn BaseChannel>>,
checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
entry_point: String,
finish_points: HashSet<String>,
edges: HashMap<String, Vec<String>>,
) -> Self {
Self {
nodes,
channels,
checkpointer,
entry_point,
finish_points,
edges,
current_step: 0,
recursion_limit: 25,
written_channels: HashSet::new(),
}
}
pub fn with_recursion_limit(mut self, limit: usize) -> Self {
self.recursion_limit = limit;
self
}
pub async fn invoke(&mut self, input: S, config: Config) -> Result<S> {
self.recursion_limit = config.recursion_limit;
self.current_step = 0;
if let Some(checkpointer) = &self.checkpointer {
if let Some(tuple) = checkpointer.get_tuple(&config).await? {
self.restore_channels(&tuple.checkpoint)?;
self.current_step = tuple.metadata.step;
}
}
self.write_input_to_channels(&input)?;
loop {
if self.current_step >= self.recursion_limit {
return Err(Error::RecursionLimitError {
current: self.current_step,
limit: self.recursion_limit,
});
}
let triggered_nodes = self.find_triggered_nodes();
if triggered_nodes.is_empty() {
break; }
let mut tasks = Vec::new();
for node_name in &triggered_nodes {
if let Some(node) = self.nodes.get(node_name) {
let state = self.read_state_for_node(node)?;
let node_clone = node.clone();
let config_clone = config.clone();
let task = tokio::spawn(async move {
node_clone.bound.invoke(state, &config_clone).await
});
tasks.push((node_name.clone(), task));
}
}
let mut updates: HashMap<String, S> = HashMap::new();
for (node_name, task) in tasks {
match task.await {
Ok(Ok(result)) => {
updates.insert(node_name, result);
}
Ok(Err(e)) => return Err(e),
Err(e) => {
return Err(Error::execution(format!("Node execution panicked: {}", e)))
}
}
}
self.written_channels = self.apply_updates(updates)?;
if let Some(checkpointer) = &self.checkpointer {
let checkpoint = self.create_checkpoint(&config)?;
let metadata = CheckpointMetadata {
step: self.current_step,
source: "pregel".to_string(),
created_at: chrono::Utc::now(),
extra: HashMap::new(),
};
checkpointer.put(&checkpoint, &metadata, &config).await?;
}
self.current_step += 1;
}
self.get_final_state()
}
pub async fn stream(
&mut self,
input: S,
config: Config,
mode: StreamMode,
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + std::marker::Send>>> {
self.recursion_limit = config.recursion_limit;
self.current_step = 0;
let (tx, rx) = tokio::sync::mpsc::channel(100);
if let Some(checkpointer) = &self.checkpointer {
if let Some(tuple) = checkpointer.get_tuple(&config).await? {
self.restore_channels(&tuple.checkpoint)?;
self.current_step = tuple.metadata.step;
}
}
self.write_input_to_channels(&input)?;
let _nodes = self.nodes.clone();
let _channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
let _checkpointer = self.checkpointer.clone();
let _entry_point = self.entry_point.clone();
let recursion_limit = self.recursion_limit;
tokio::spawn(async move {
let mut step = 0;
loop {
if step >= recursion_limit {
let _ = tx.send(Err(Error::RecursionLimitError {
current: step,
limit: recursion_limit,
})).await;
break;
}
if matches!(mode, StreamMode::Values) {
let event = StreamEvent::Values {
ns: vec![],
data: serde_json::json!({"step": step}),
interrupts: vec![],
};
if tx.send(Ok(event)).await.is_err() {
break;
}
}
step += 1;
break; }
});
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
}
pub async fn get_state(&self, config: &Config) -> Result<Option<StateSnapshot<S>>> {
if let Some(checkpointer) = &self.checkpointer {
if let Some(tuple) = checkpointer.get_tuple(config).await? {
let state = self.state_from_checkpoint(&tuple.checkpoint)?;
return Ok(Some(StateSnapshot {
state,
checkpoint: tuple.checkpoint,
metadata: tuple.metadata,
config: tuple.config,
}));
}
}
Ok(None)
}
pub async fn get_state_history(
&self,
config: &Config,
limit: Option<usize>,
) -> Result<Vec<StateSnapshot<S>>> {
if let Some(checkpointer) = &self.checkpointer {
let tuples = checkpointer.list(config, limit).await?;
let mut snapshots = Vec::new();
for tuple in tuples {
let state = self.state_from_checkpoint(&tuple.checkpoint)?;
snapshots.push(StateSnapshot {
state,
checkpoint: tuple.checkpoint,
metadata: tuple.metadata,
config: tuple.config,
});
}
return Ok(snapshots);
}
Ok(Vec::new())
}
fn write_input_to_channels(&mut self, input: &S) -> Result<()> {
let value = input.to_value()?;
if let Some(channel) = self.channels.get_mut("__start__") {
channel.update(vec![value])?;
self.written_channels.insert("__start__".to_string());
}
Ok(())
}
fn find_triggered_nodes(&self) -> Vec<String> {
let mut triggered = Vec::new();
for (name, node) in &self.nodes {
if node.is_triggered(&self.written_channels.iter().cloned().collect::<Vec<_>>()) {
triggered.push(name.clone());
}
}
if triggered.is_empty() && self.current_step == 0 {
triggered.push(self.entry_point.clone());
}
triggered
}
fn read_state_for_node(&self, node: &PregelNode<S>) -> Result<S> {
let mut merged: Option<S> = None;
for ch_name in &node.channels {
if let Some(channel) = self.channels.get(ch_name) {
if let Some(value) = channel.get()? {
let piece = S::from_value(value)?;
merged = match merged {
None => Some(piece),
Some(mut m) => {
m.merge(piece)?;
Some(m)
}
};
}
}
}
if merged.is_none() && node.triggers.iter().any(|t| t == START) {
if let Some(channel) = self.channels.get(START) {
if let Some(value) = channel.get()? {
merged = Some(S::from_value(value)?);
}
}
}
merged.ok_or_else(|| {
Error::state(format!(
"Cannot construct state for node '{}' (input channels {:?})",
node.name, node.channels
))
})
}
fn apply_updates(&mut self, updates: HashMap<String, S>) -> Result<HashSet<String>> {
let mut next_triggers = HashSet::new();
for (node_name, state) in updates {
let value = state.to_value()?;
if let Some(node) = self.nodes.get(&node_name) {
for writer in &node.writers {
if let Some(channel) = self.channels.get_mut(&writer.channel) {
channel.update(vec![value.clone()])?;
next_triggers.insert(writer.channel.clone());
}
}
}
if let Some(targets) = self.edges.get(&node_name) {
for target in targets {
let input_ch = format!("{}_input", target);
if let Some(ch) = self.channels.get_mut(&input_ch) {
ch.update(vec![value.clone()])?;
}
}
}
}
Ok(next_triggers)
}
fn create_checkpoint(&self, config: &Config) -> Result<Checkpoint> {
let mut checkpoint = Checkpoint::new();
if let Some(thread_id) = &config.thread_id {
checkpoint.thread_id = Some(thread_id.clone());
}
for (name, channel) in &self.channels {
let channel_data = channel.checkpoint()?;
checkpoint.set_channel(name, channel_data);
}
Ok(checkpoint)
}
fn restore_channels(&mut self, checkpoint: &Checkpoint) -> Result<()> {
for (name, value) in &checkpoint.channel_values {
if let Some(channel) = self.channels.get_mut(name) {
channel.update(vec![value.clone()])?;
}
}
Ok(())
}
fn state_from_checkpoint(&self, checkpoint: &Checkpoint) -> Result<S> {
if let Some(value) = checkpoint.get_channel("__state__") {
return S::from_value(value.clone());
}
if let Some(value) = checkpoint.get_channel("__start__") {
return S::from_value(value.clone());
}
Err(Error::checkpoint("Cannot construct state from checkpoint"))
}
fn get_final_state(&self) -> Result<S> {
if let Some(channel) = self.channels.get(crate::graph::END) {
if let Some(value) = channel.get()? {
return S::from_value(value);
}
}
for fp in &self.finish_points {
let ch_name = format!("{}_output", fp);
if let Some(channel) = self.channels.get(&ch_name) {
if let Some(value) = channel.get()? {
return S::from_value(value);
}
}
}
if let Some(channel) = self.channels.get(START) {
if let Some(value) = channel.get()? {
return S::from_value(value);
}
}
Err(Error::state("Cannot determine final state"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::channels::{LastValue};
use crate::nodes::{PregelNode, ChannelWrite};
use crate::state::State as StateTrait;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
struct TestState {
count: i32,
}
impl StateTrait for TestState {
fn merge(&mut self, other: Self) -> Result<()> {
self.count += other.count;
Ok(())
}
}
#[tokio::test]
async fn test_pregel_basic() {
let increment_node = PregelNode::from_node(
"increment",
vec!["__start__".to_string()],
vec!["__start__".to_string()],
|mut state: TestState, _config: &Config| async move {
state.count += 1;
Ok(state)
},
vec![ChannelWrite::new("__end__")],
);
let mut nodes = HashMap::new();
nodes.insert("increment".to_string(), increment_node);
let mut channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
channels.insert("__start__".to_string(), Box::new(LastValue::<TestState>::new()));
channels.insert("__end__".to_string(), Box::new(LastValue::<TestState>::new()));
let mut pregel = Pregel::new(
nodes,
channels,
None,
"increment".to_string(),
HashSet::from(["increment".to_string()]),
HashMap::new(),
);
let input = TestState { count: 0 };
let result = pregel.invoke(input, Config::default()).await.unwrap();
assert_eq!(result.count, 1);
}
#[tokio::test]
async fn test_pregel_two_node_chain() {
let a = PregelNode::from_node(
"a",
vec!["a_input".to_string()],
vec![START.to_string()],
|mut state: TestState, _config: &Config| async move {
state.count += 1;
Ok(state)
},
vec![ChannelWrite::new("a_output")],
);
let b = PregelNode::from_node(
"b",
vec!["b_input".to_string()],
vec!["a_output".to_string()],
|mut state: TestState, _config: &Config| async move {
state.count *= 10;
Ok(state)
},
vec![ChannelWrite::new("b_output")],
);
let mut nodes = HashMap::new();
nodes.insert("a".to_string(), a);
nodes.insert("b".to_string(), b);
let mut channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
channels.insert(START.to_string(), Box::new(LastValue::<TestState>::new()));
channels.insert("a_input".to_string(), Box::new(LastValue::<TestState>::new()));
channels.insert("a_output".to_string(), Box::new(LastValue::<TestState>::new()));
channels.insert("b_input".to_string(), Box::new(LastValue::<TestState>::new()));
channels.insert("b_output".to_string(), Box::new(LastValue::<TestState>::new()));
channels.insert("__end__".to_string(), Box::new(LastValue::<TestState>::new()));
let mut edges = HashMap::new();
edges.insert("a".to_string(), vec!["b".to_string()]);
let mut pregel = Pregel::new(
nodes,
channels,
None,
"a".to_string(),
HashSet::from(["b".to_string()]),
edges,
);
let result = pregel
.invoke(TestState { count: 5 }, Config::default())
.await
.unwrap();
assert_eq!(result.count, 60); }
}