1use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13#[derive(Debug, Clone)]
32pub struct RemainingSteps {
33 current: Arc<RwLock<u32>>,
34 max: u32,
35}
36
37impl RemainingSteps {
38 pub fn new(max: u32) -> Self {
40 Self {
41 current: Arc::new(RwLock::new(max)),
42 max,
43 }
44 }
45
46 pub async fn current(&self) -> u32 {
48 *self.current.read().await
49 }
50
51 pub fn max(&self) -> u32 {
53 self.max
54 }
55
56 pub async fn decrement(&self) -> u32 {
58 let mut current = self.current.write().await;
59 if *current > 0 {
60 *current -= 1;
61 }
62 *current
63 }
64
65 pub async fn decrement_by(&self, amount: u32) -> u32 {
67 let mut current = self.current.write().await;
68 *current = current.saturating_sub(amount);
69 *current
70 }
71
72 pub async fn is_exhausted(&self) -> bool {
74 *self.current.read().await == 0
75 }
76
77 pub async fn has_at_least(&self, n: u32) -> bool {
79 *self.current.read().await >= n
80 }
81
82 pub async fn reset(&self) {
84 let mut current = self.current.write().await;
85 *current = self.max;
86 }
87
88 pub async fn set(&self, value: u32) {
90 let mut current = self.current.write().await;
91 *current = value.min(self.max);
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct GraphConfig {
98 pub max_steps: u32,
100
101 pub debug: bool,
103
104 pub checkpoint_enabled: bool,
106
107 pub checkpoint_interval: u32,
109
110 pub timeout_ms: u64,
112
113 pub max_parallelism: usize,
115
116 pub custom: HashMap<String, Value>,
118}
119
120impl Default for GraphConfig {
121 fn default() -> Self {
122 Self {
123 max_steps: 100,
124 debug: false,
125 checkpoint_enabled: false,
126 checkpoint_interval: 10,
127 timeout_ms: 0,
128 max_parallelism: 10,
129 custom: HashMap::new(),
130 }
131 }
132}
133
134impl GraphConfig {
135 pub fn new() -> Self {
137 Self::default()
138 }
139
140 pub fn with_max_steps(mut self, max_steps: u32) -> Self {
142 self.max_steps = max_steps;
143 self
144 }
145
146 pub fn with_debug(mut self, debug: bool) -> Self {
148 self.debug = debug;
149 self
150 }
151
152 pub fn with_checkpoints(mut self, enabled: bool, interval: u32) -> Self {
154 self.checkpoint_enabled = enabled;
155 self.checkpoint_interval = interval;
156 self
157 }
158
159 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
161 self.timeout_ms = timeout_ms;
162 self
163 }
164
165 pub fn with_max_parallelism(mut self, max: usize) -> Self {
167 self.max_parallelism = max;
168 self
169 }
170
171 pub fn with_custom(mut self, key: impl Into<String>, value: Value) -> Self {
173 self.custom.insert(key.into(), value);
174 self
175 }
176
177 pub fn remaining_steps(&self) -> RemainingSteps {
179 RemainingSteps::new(self.max_steps)
180 }
181}
182
183#[derive(Debug)]
188pub struct RuntimeContext {
189 pub execution_id: String,
191
192 pub graph_id: String,
194
195 pub current_node: Arc<RwLock<String>>,
197
198 pub remaining_steps: RemainingSteps,
200
201 pub config: GraphConfig,
203
204 pub metadata: HashMap<String, Value>,
206
207 pub parent_execution_id: Option<String>,
209
210 pub tags: Vec<String>,
212}
213
214impl RuntimeContext {
215 pub fn new(graph_id: impl Into<String>) -> Self {
217 Self {
218 execution_id: Uuid::new_v4().to_string(),
219 graph_id: graph_id.into(),
220 current_node: Arc::new(RwLock::new(String::new())),
221 remaining_steps: RemainingSteps::new(100),
222 config: GraphConfig::default(),
223 metadata: HashMap::new(),
224 parent_execution_id: None,
225 tags: Vec::new(),
226 }
227 }
228
229 pub fn with_config(graph_id: impl Into<String>, config: GraphConfig) -> Self {
231 let remaining_steps = config.remaining_steps();
232 Self {
233 execution_id: Uuid::new_v4().to_string(),
234 graph_id: graph_id.into(),
235 current_node: Arc::new(RwLock::new(String::new())),
236 remaining_steps,
237 config,
238 metadata: HashMap::new(),
239 parent_execution_id: None,
240 tags: Vec::new(),
241 }
242 }
243
244 pub fn for_sub_workflow(
246 graph_id: impl Into<String>,
247 parent_execution_id: impl Into<String>,
248 config: GraphConfig,
249 ) -> Self {
250 let remaining_steps = config.remaining_steps();
251 Self {
252 execution_id: Uuid::new_v4().to_string(),
253 graph_id: graph_id.into(),
254 current_node: Arc::new(RwLock::new(String::new())),
255 remaining_steps,
256 config,
257 metadata: HashMap::new(),
258 parent_execution_id: Some(parent_execution_id.into()),
259 tags: Vec::new(),
260 }
261 }
262
263 pub async fn current_node(&self) -> String {
265 self.current_node.read().await.clone()
266 }
267
268 pub async fn set_current_node(&self, node_id: impl Into<String>) {
270 let mut current = self.current_node.write().await;
271 *current = node_id.into();
272 }
273
274 pub async fn is_recursion_limit_reached(&self) -> bool {
276 self.remaining_steps.is_exhausted().await
277 }
278
279 pub async fn decrement_steps(&self) -> u32 {
281 self.remaining_steps.decrement().await
282 }
283
284 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
286 self.metadata.insert(key.into(), value);
287 self
288 }
289
290 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
292 self.tags.push(tag.into());
293 self
294 }
295
296 pub fn is_debug(&self) -> bool {
298 self.config.debug
299 }
300
301 pub fn is_sub_workflow(&self) -> bool {
303 self.parent_execution_id.is_some()
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[tokio::test]
312 async fn test_remaining_steps() {
313 let steps = RemainingSteps::new(10);
314
315 assert_eq!(steps.current().await, 10);
316 assert_eq!(steps.max(), 10);
317 assert!(!steps.is_exhausted().await);
318 assert!(steps.has_at_least(5).await);
319
320 steps.decrement().await;
321 assert_eq!(steps.current().await, 9);
322
323 steps.decrement_by(5).await;
324 assert_eq!(steps.current().await, 4);
325
326 steps.reset().await;
327 assert_eq!(steps.current().await, 10);
328 }
329
330 #[tokio::test]
331 async fn test_remaining_steps_exhausted() {
332 let steps = RemainingSteps::new(2);
333
334 assert!(!steps.is_exhausted().await);
335 steps.decrement().await;
336 assert!(!steps.is_exhausted().await);
337 steps.decrement().await;
338 assert!(steps.is_exhausted().await);
339
340 steps.decrement().await;
342 assert!(steps.is_exhausted().await);
343 }
344
345 #[test]
346 fn test_graph_config() {
347 let config = GraphConfig::new()
348 .with_max_steps(50)
349 .with_debug(true)
350 .with_checkpoints(true, 5)
351 .with_timeout(30000)
352 .with_max_parallelism(4);
353
354 assert_eq!(config.max_steps, 50);
355 assert!(config.debug);
356 assert!(config.checkpoint_enabled);
357 assert_eq!(config.checkpoint_interval, 5);
358 assert_eq!(config.timeout_ms, 30000);
359 assert_eq!(config.max_parallelism, 4);
360 }
361
362 #[tokio::test]
363 async fn test_runtime_context() {
364 let ctx = RuntimeContext::new("test_graph")
365 .with_metadata("key", serde_json::json!("value"))
366 .with_tag("test");
367
368 assert!(!ctx.execution_id.is_empty());
369 assert_eq!(ctx.graph_id, "test_graph");
370 assert!(ctx.current_node().await.is_empty());
371 assert!(!ctx.is_sub_workflow());
372
373 ctx.set_current_node("node_1").await;
374 assert_eq!(ctx.current_node().await, "node_1");
375 }
376
377 #[tokio::test]
378 async fn test_runtime_context_sub_workflow() {
379 let ctx = RuntimeContext::for_sub_workflow(
380 "sub_graph",
381 "parent-execution-123",
382 GraphConfig::default(),
383 );
384
385 assert!(ctx.is_sub_workflow());
386 assert_eq!(
387 ctx.parent_execution_id,
388 Some("parent-execution-123".to_string())
389 );
390 }
391}