use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use langgraph_checkpoint::config::RunnableConfig;
use crate::config;
use crate::constants::{CONFIG_KEY_SCRATCHPAD, CONFIG_KEY_CHECKPOINT_NS};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum Durability {
Sync,
Async,
Exit,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "lowercase")]
pub enum StreamMode {
Values,
Updates,
Checkpoints,
Tasks,
Debug,
Messages,
Custom,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Interrupt {
pub value: JsonValue,
pub id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Send {
pub node: String,
pub arg: JsonValue,
}
impl Send {
pub fn new(node: impl Into<String>, arg: JsonValue) -> Self {
Self {
node: node.into(),
arg,
}
}
}
impl std::hash::Hash for Send {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.node.hash(state);
self.arg.to_string().hash(state);
}
}
impl PartialEq for Send {
fn eq(&self, other: &Self) -> bool {
self.node == other.node && self.arg == other.arg
}
}
impl Eq for Send {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Command {
#[serde(skip_serializing_if = "Option::is_none")]
pub graph: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub update: Option<JsonValue>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resume: Option<JsonValue>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub goto: Vec<CommandGoto>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum CommandGoto {
Node(String),
Send(Send),
}
impl Command {
pub const PARENT: &'static str = "__parent__";
pub fn new() -> Self {
Self {
graph: None,
update: None,
resume: None,
goto: Vec::new(),
}
}
pub fn resume(value: JsonValue) -> Self {
Self {
resume: Some(value),
..Self::new()
}
}
pub fn goto(node: impl Into<String>) -> Self {
Self {
goto: vec![CommandGoto::Node(node.into())],
..Self::new()
}
}
}
impl Default for Command {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub initial_interval: f64,
pub backoff_factor: f64,
pub max_interval: f64,
pub max_attempts: usize,
pub jitter: bool,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
initial_interval: 0.5,
backoff_factor: 2.0,
max_interval: 128.0,
max_attempts: 3,
jitter: true,
}
}
}
#[derive(Debug, Clone)]
#[derive(Default)]
pub struct CachePolicy {
pub ttl: Option<i64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Overwrite {
pub value: JsonValue,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PregelTask {
pub id: String,
pub name: String,
pub path: Vec<String>,
#[serde(skip)]
pub error: Option<String>,
pub interrupts: Vec<Interrupt>,
#[serde(skip)]
pub result: Option<JsonValue>,
}
#[derive(Debug)]
pub struct PregelExecutableTask {
pub name: String,
pub input: JsonValue,
pub writes: Vec<(String, JsonValue)>,
pub config: RunnableConfig,
pub triggers: Vec<String>,
pub retry_policy: Vec<RetryPolicy>,
pub id: String,
pub path: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct StateSnapshot {
pub values: JsonValue,
pub next: Vec<String>,
pub config: RunnableConfig,
pub metadata: Option<langgraph_checkpoint::CheckpointMetadata>,
pub created_at: Option<String>,
pub parent_config: Option<RunnableConfig>,
pub tasks: Vec<PregelTask>,
pub interrupts: Vec<Interrupt>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PregelScratchpad {
pub step: u64,
pub interrupt_counter: u64,
pub resume: Vec<JsonValue>,
pub is_resuming: bool,
}
impl PregelScratchpad {
pub fn new(step: u64) -> Self {
Self {
step,
interrupt_counter: 0,
resume: Vec::new(),
is_resuming: false,
}
}
pub fn next_interrupt_id(&mut self) -> u64 {
let id = self.interrupt_counter;
self.interrupt_counter += 1;
id
}
}
#[derive(Debug, Clone)]
pub struct GraphInterrupt {
pub interrupts: Vec<Interrupt>,
}
impl std::fmt::Display for GraphInterrupt {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "GraphInterrupt: {} interrupt(s)", self.interrupts.len())
}
}
impl std::error::Error for GraphInterrupt {}
#[derive(Debug, Clone)]
pub struct InterruptError(pub GraphInterrupt);
impl std::fmt::Display for InterruptError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl std::error::Error for InterruptError {}
impl From<GraphInterrupt> for InterruptError {
fn from(interrupt: GraphInterrupt) -> Self {
InterruptError(interrupt)
}
}
impl From<InterruptError> for GraphInterrupt {
fn from(e: InterruptError) -> GraphInterrupt {
e.0
}
}
pub fn interrupt(value: JsonValue) -> Result<JsonValue, InterruptError> {
let config = config::get_config();
if let Some(configurable) = config.get("configurable") {
if let Some(scratchpad_val) = configurable.get(CONFIG_KEY_SCRATCHPAD) {
if let Ok(mut scratchpad) = serde_json::from_value::<PregelScratchpad>(scratchpad_val.clone()) {
let idx = scratchpad.next_interrupt_id() as usize;
if idx < scratchpad.resume.len() {
let resume_value = scratchpad.resume[idx].clone();
return Ok(resume_value);
}
}
}
}
Err(InterruptError(GraphInterrupt {
interrupts: vec![Interrupt {
value,
id: uuid_from_config(&config),
}],
}))
}
fn uuid_from_config(config: &HashMap<String, JsonValue>) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let checkpoint_ns = config
.get("configurable")
.and_then(|c| c.get(CONFIG_KEY_CHECKPOINT_NS))
.and_then(|v| v.as_str())
.unwrap_or("");
let mut hasher = DefaultHasher::new();
checkpoint_ns.hash(&mut hasher);
let hash = hasher.finish();
format!("{:016x}", hash)
}