use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct RemainingSteps {
current: Arc<RwLock<u32>>,
max: u32,
}
impl RemainingSteps {
pub fn new(max: u32) -> Self {
Self {
current: Arc::new(RwLock::new(max)),
max,
}
}
pub async fn current(&self) -> u32 {
*self.current.read().await
}
pub fn max(&self) -> u32 {
self.max
}
pub async fn decrement(&self) -> u32 {
let mut current = self.current.write().await;
if *current > 0 {
*current -= 1;
}
*current
}
pub async fn decrement_by(&self, amount: u32) -> u32 {
let mut current = self.current.write().await;
*current = current.saturating_sub(amount);
*current
}
pub async fn is_exhausted(&self) -> bool {
*self.current.read().await == 0
}
pub async fn has_at_least(&self, n: u32) -> bool {
*self.current.read().await >= n
}
pub async fn reset(&self) {
let mut current = self.current.write().await;
*current = self.max;
}
pub async fn set(&self, value: u32) {
let mut current = self.current.write().await;
*current = value.min(self.max);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphConfig {
pub max_steps: u32,
pub debug: bool,
pub checkpoint_enabled: bool,
pub checkpoint_interval: u32,
pub timeout_ms: u64,
pub max_parallelism: usize,
pub custom: HashMap<String, Value>,
}
impl Default for GraphConfig {
fn default() -> Self {
Self {
max_steps: 100,
debug: false,
checkpoint_enabled: false,
checkpoint_interval: 10,
timeout_ms: 0,
max_parallelism: 10,
custom: HashMap::new(),
}
}
}
impl GraphConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_steps(mut self, max_steps: u32) -> Self {
self.max_steps = max_steps;
self
}
pub fn with_debug(mut self, debug: bool) -> Self {
self.debug = debug;
self
}
pub fn with_checkpoints(mut self, enabled: bool, interval: u32) -> Self {
self.checkpoint_enabled = enabled;
self.checkpoint_interval = interval;
self
}
pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
self.timeout_ms = timeout_ms;
self
}
pub fn with_max_parallelism(mut self, max: usize) -> Self {
self.max_parallelism = max;
self
}
pub fn with_custom(mut self, key: impl Into<String>, value: Value) -> Self {
self.custom.insert(key.into(), value);
self
}
pub fn remaining_steps(&self) -> RemainingSteps {
RemainingSteps::new(self.max_steps)
}
}
#[derive(Debug)]
pub struct RuntimeContext {
pub execution_id: String,
pub graph_id: String,
pub current_node: Arc<RwLock<String>>,
pub remaining_steps: RemainingSteps,
pub config: GraphConfig,
pub metadata: HashMap<String, Value>,
pub parent_execution_id: Option<String>,
pub tags: Vec<String>,
}
impl RuntimeContext {
pub fn new(graph_id: impl Into<String>) -> Self {
Self {
execution_id: Uuid::new_v4().to_string(),
graph_id: graph_id.into(),
current_node: Arc::new(RwLock::new(String::new())),
remaining_steps: RemainingSteps::new(100),
config: GraphConfig::default(),
metadata: HashMap::new(),
parent_execution_id: None,
tags: Vec::new(),
}
}
pub fn with_config(graph_id: impl Into<String>, config: GraphConfig) -> Self {
let remaining_steps = config.remaining_steps();
Self {
execution_id: Uuid::new_v4().to_string(),
graph_id: graph_id.into(),
current_node: Arc::new(RwLock::new(String::new())),
remaining_steps,
config,
metadata: HashMap::new(),
parent_execution_id: None,
tags: Vec::new(),
}
}
pub fn for_sub_workflow(
graph_id: impl Into<String>,
parent_execution_id: impl Into<String>,
config: GraphConfig,
) -> Self {
let remaining_steps = config.remaining_steps();
Self {
execution_id: Uuid::new_v4().to_string(),
graph_id: graph_id.into(),
current_node: Arc::new(RwLock::new(String::new())),
remaining_steps,
config,
metadata: HashMap::new(),
parent_execution_id: Some(parent_execution_id.into()),
tags: Vec::new(),
}
}
pub async fn current_node(&self) -> String {
self.current_node.read().await.clone()
}
pub async fn set_current_node(&self, node_id: impl Into<String>) {
let mut current = self.current_node.write().await;
*current = node_id.into();
}
pub async fn is_recursion_limit_reached(&self) -> bool {
self.remaining_steps.is_exhausted().await
}
pub async fn decrement_steps(&self) -> u32 {
self.remaining_steps.decrement().await
}
pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn is_debug(&self) -> bool {
self.config.debug
}
pub fn is_sub_workflow(&self) -> bool {
self.parent_execution_id.is_some()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_remaining_steps() {
let steps = RemainingSteps::new(10);
assert_eq!(steps.current().await, 10);
assert_eq!(steps.max(), 10);
assert!(!steps.is_exhausted().await);
assert!(steps.has_at_least(5).await);
steps.decrement().await;
assert_eq!(steps.current().await, 9);
steps.decrement_by(5).await;
assert_eq!(steps.current().await, 4);
steps.reset().await;
assert_eq!(steps.current().await, 10);
}
#[tokio::test]
async fn test_remaining_steps_exhausted() {
let steps = RemainingSteps::new(2);
assert!(!steps.is_exhausted().await);
steps.decrement().await;
assert!(!steps.is_exhausted().await);
steps.decrement().await;
assert!(steps.is_exhausted().await);
steps.decrement().await;
assert!(steps.is_exhausted().await);
}
#[test]
fn test_graph_config() {
let config = GraphConfig::new()
.with_max_steps(50)
.with_debug(true)
.with_checkpoints(true, 5)
.with_timeout(30000)
.with_max_parallelism(4);
assert_eq!(config.max_steps, 50);
assert!(config.debug);
assert!(config.checkpoint_enabled);
assert_eq!(config.checkpoint_interval, 5);
assert_eq!(config.timeout_ms, 30000);
assert_eq!(config.max_parallelism, 4);
}
#[tokio::test]
async fn test_runtime_context() {
let ctx = RuntimeContext::new("test_graph")
.with_metadata("key", serde_json::json!("value"))
.with_tag("test");
assert!(!ctx.execution_id.is_empty());
assert_eq!(ctx.graph_id, "test_graph");
assert!(ctx.current_node().await.is_empty());
assert!(!ctx.is_sub_workflow());
ctx.set_current_node("node_1").await;
assert_eq!(ctx.current_node().await, "node_1");
}
#[tokio::test]
async fn test_runtime_context_sub_workflow() {
let ctx = RuntimeContext::for_sub_workflow(
"sub_graph",
"parent-execution-123",
GraphConfig::default(),
);
assert!(ctx.is_sub_workflow());
assert_eq!(
ctx.parent_execution_id,
Some("parent-execution-123".to_string())
);
}
}