use std::collections::{HashMap, HashSet};
use serde_json::Value as JsonValue;
use crate::channels::Channel;
use crate::constants::{RESUME, NULL_TASK_ID, CONFIG_KEY_SCRATCHPAD, RESERVED};
use crate::types::PregelScratchpad;
use super::{PregelNode, PregelExecutableTask, ChannelVersions, TriggerToNodes};
fn as_f64(v: &JsonValue) -> Option<f64> {
match v {
JsonValue::Number(n) => n.as_f64(),
JsonValue::String(s) => s.parse::<f64>().ok(),
_ => None,
}
}
fn version_gt(a: &JsonValue, b: &JsonValue) -> bool {
if let (Some(an), Some(bn)) = (as_f64(a), as_f64(b)) {
return an > bn;
}
let a_str = match a {
JsonValue::String(s) => s.as_str(),
JsonValue::Number(n) => return n.to_string().as_str() > b.to_string().as_str(),
_ => return false,
};
let b_str = match b {
JsonValue::String(s) => s.as_str(),
_ => return false,
};
a_str > b_str
}
pub fn prepare_next_tasks(
nodes: &HashMap<String, PregelNode>,
channels: &HashMap<String, Box<dyn Channel>>,
config: &langgraph_checkpoint::config::RunnableConfig,
step: u64,
versions_seen: &mut HashMap<String, HashMap<String, JsonValue>>,
trigger_to_nodes: &TriggerToNodes,
updated_channels: Option<&HashSet<String>>,
checkpoint_id: &str,
pending_writes: &[(String, String, JsonValue)],
channel_versions: &ChannelVersions,
) -> Vec<PregelExecutableTask> {
let mut tasks = Vec::new();
let null_version = JsonValue::String("".to_string());
let current_versions = channel_versions;
let candidates: Vec<String> = if let Some(updated) = updated_channels {
let mut candidate_set = HashSet::new();
for chan in updated {
if let Some(node_names) = trigger_to_nodes.get(chan) {
candidate_set.extend(node_names.iter().cloned());
}
}
candidate_set.into_iter().collect()
} else {
nodes.keys().cloned().collect()
};
let null_resume: Option<&JsonValue> = pending_writes.iter().find_map(|(tid, chan, val)| {
if tid == NULL_TASK_ID && chan == RESUME {
Some(val)
} else {
None
}
});
for name in candidates {
let node = match nodes.get(&name) {
Some(n) => n,
None => continue,
};
let should_trigger = if let Some(seen) = versions_seen.get(&name) {
node.triggers.iter().any(|chan| {
let chan_available = channels.get(chan).is_some_and(|c| c.is_available());
let chan_version = current_versions.get(chan).unwrap_or(&null_version);
let last_seen = seen.get(chan).unwrap_or(&null_version);
chan_available && version_gt(chan_version, last_seen)
})
} else {
node.triggers.iter().any(|chan| {
channels.get(chan).is_some_and(|c| c.is_available())
})
};
if !should_trigger {
continue;
}
let input = gather_input(node, channels);
let task_id = format!("{}:{:04}:PULL:{}", checkpoint_id, step, name);
let task_resume: Vec<JsonValue> = pending_writes
.iter()
.filter(|(tid, chan, _)| tid == &task_id && chan == RESUME)
.map(|(_, _, val)| val.clone())
.collect();
let scratchpad = create_scratchpad(
null_resume,
&task_resume,
step,
);
let mut task_config = config.clone();
let configurable = task_config
.entry("configurable".to_string())
.or_insert_with(|| JsonValue::Object(serde_json::Map::new()));
if let Some(conf_obj) = configurable.as_object_mut() {
conf_obj.insert(
CONFIG_KEY_SCRATCHPAD.to_string(),
serde_json::to_value(&scratchpad).unwrap_or_default(),
);
}
tasks.push(PregelExecutableTask {
name: name.clone(),
input,
proc: node.bound.clone(),
writes: Vec::new(),
config: task_config,
triggers: node.triggers.clone(),
id: task_id,
});
}
tasks
}
fn gather_input(
node: &PregelNode,
channels: &HashMap<String, Box<dyn Channel>>,
) -> JsonValue {
let mut map = serde_json::Map::new();
for ch in &node.channels {
if let Some(channel) = channels.get(ch) {
if let Ok(val) = channel.get() {
map.insert(ch.clone(), val);
}
}
}
JsonValue::Object(map)
}
fn create_scratchpad(
null_resume: Option<&JsonValue>,
task_resume: &[JsonValue],
step: u64,
) -> PregelScratchpad {
let mut resume_values = task_resume.to_vec();
if resume_values.is_empty() {
if let Some(null_val) = null_resume {
resume_values.push(null_val.clone());
}
}
PregelScratchpad {
step,
interrupt_counter: 0,
resume: resume_values,
is_resuming: null_resume.is_some() || !task_resume.is_empty(),
}
}
pub fn apply_writes(
channels: &mut HashMap<String, Box<dyn Channel>>,
tasks: &[PregelExecutableTask],
versions_seen: &mut HashMap<String, HashMap<String, JsonValue>>,
channel_versions: &mut ChannelVersions,
trigger_to_nodes: &TriggerToNodes,
get_next_version: impl Fn(Option<&JsonValue>) -> JsonValue,
) -> HashSet<String> {
let mut updated = HashSet::new();
let bump_step = tasks.iter().any(|t| !t.triggers.is_empty());
for task in tasks {
let seen = versions_seen.entry(task.name.clone()).or_default();
for trigger in &task.triggers {
if let Some(ver) = channel_versions.get(trigger) {
seen.insert(trigger.clone(), ver.clone());
}
}
}
let max_version = channel_versions.values().max_by(|a, b| {
version_gt_partial(a, b)
}).cloned();
let next_version = get_next_version(max_version.as_ref());
let trigger_channels: HashSet<String> = tasks
.iter()
.flat_map(|t| t.triggers.iter().cloned())
.collect();
for chan in &trigger_channels {
if RESERVED.contains(&chan.as_str()) {
continue;
}
if let Some(ch) = channels.get(chan.as_str()) {
if ch.consume() {
channel_versions.insert(chan.clone(), next_version.clone());
}
}
}
let mut writes_by_channel: HashMap<String, Vec<JsonValue>> = HashMap::new();
for task in tasks {
for (chan, val) in &task.writes {
if RESERVED.contains(&chan.as_str()) {
continue;
}
writes_by_channel
.entry(chan.clone())
.or_default()
.push(val.clone());
}
}
for (chan, vals) in &writes_by_channel {
if let Some(ch) = channels.get(chan.as_str()) {
if ch.update(vals).unwrap_or(false) {
channel_versions.insert(chan.clone(), next_version.clone());
if ch.is_available() {
updated.insert(chan.clone());
}
}
}
}
if bump_step {
for (chan, ch) in channels.iter() {
if ch.is_available() && !updated.contains(chan) {
if ch.update(&[]).unwrap_or(false) {
channel_versions.insert(chan.clone(), next_version.clone());
if ch.is_available() {
updated.insert(chan.clone());
}
}
}
}
}
if bump_step && !updated.iter().any(|u| trigger_to_nodes.contains_key(u)) {
for (chan, ch) in channels.iter() {
if ch.finish() {
channel_versions.insert(chan.clone(), next_version.clone());
if ch.is_available() {
updated.insert(chan.clone());
}
}
}
}
updated
}
fn version_gt_partial(a: &JsonValue, b: &JsonValue) -> std::cmp::Ordering {
if let (Some(an), Some(bn)) = (as_f64(a), as_f64(b)) {
return an.partial_cmp(&bn).unwrap_or(std::cmp::Ordering::Equal);
}
let a_str = match a {
JsonValue::String(s) => s.as_str(),
JsonValue::Number(n) => return n.to_string().cmp(&b.to_string()),
_ => return std::cmp::Ordering::Equal,
};
let b_str = match b {
JsonValue::String(s) => s.as_str(),
JsonValue::Number(n) => return a_str.cmp(&n.to_string()),
_ => return std::cmp::Ordering::Equal,
};
a_str.cmp(b_str)
}
pub fn should_interrupt(
interrupt_nodes: &HashSet<String>,
task_names: &[String],
) -> bool {
task_names.iter().any(|n| interrupt_nodes.contains(n))
}