Skip to main content

mofa_kernel/workflow/
context.rs

1//! Runtime Context for Workflow Execution
2//!
3//! Provides runtime information and configuration for workflow execution,
4//! including recursion limit tracking and execution metadata.
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11use uuid::Uuid;
12
13/// Remaining steps tracker for recursion limit
14///
15/// Tracks and manages the remaining execution steps to prevent infinite loops.
16/// This is actively decremented during execution and can be checked by nodes.
17///
18/// # Example
19///
20/// ```rust,ignore
21/// let remaining = RemainingSteps::new(100);
22///
23/// // Check before proceeding
24/// if remaining.is_exhausted() {
25///     return Err(AgentError::RecursionLimitExceeded);
26/// }
27///
28/// // Decrement after each step
29/// remaining.decrement();
30/// ```
31#[derive(Debug, Clone)]
32pub struct RemainingSteps {
33    current: Arc<RwLock<u32>>,
34    max: u32,
35}
36
37impl RemainingSteps {
38    /// Create a new remaining steps tracker
39    pub fn new(max: u32) -> Self {
40        Self {
41            current: Arc::new(RwLock::new(max)),
42            max,
43        }
44    }
45
46    /// Get the current remaining steps
47    pub async fn current(&self) -> u32 {
48        *self.current.read().await
49    }
50
51    /// Get the maximum steps allowed
52    pub fn max(&self) -> u32 {
53        self.max
54    }
55
56    /// Decrement the remaining steps by one
57    pub async fn decrement(&self) -> u32 {
58        let mut current = self.current.write().await;
59        if *current > 0 {
60            *current -= 1;
61        }
62        *current
63    }
64
65    /// Decrement by a specific amount
66    pub async fn decrement_by(&self, amount: u32) -> u32 {
67        let mut current = self.current.write().await;
68        *current = current.saturating_sub(amount);
69        *current
70    }
71
72    /// Check if steps are exhausted
73    pub async fn is_exhausted(&self) -> bool {
74        *self.current.read().await == 0
75    }
76
77    /// Check if we have at least N steps remaining
78    pub async fn has_at_least(&self, n: u32) -> bool {
79        *self.current.read().await >= n
80    }
81
82    /// Reset to maximum
83    pub async fn reset(&self) {
84        let mut current = self.current.write().await;
85        *current = self.max;
86    }
87
88    /// Set to a specific value (cannot exceed max)
89    pub async fn set(&self, value: u32) {
90        let mut current = self.current.write().await;
91        *current = value.min(self.max);
92    }
93}
94
95/// Graph execution configuration
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct GraphConfig {
98    /// Maximum recursion depth
99    pub max_steps: u32,
100
101    /// Enable debug mode
102    pub debug: bool,
103
104    /// Enable checkpointing
105    pub checkpoint_enabled: bool,
106
107    /// Checkpoint interval (in steps)
108    pub checkpoint_interval: u32,
109
110    /// Timeout in milliseconds (0 = no timeout)
111    pub timeout_ms: u64,
112
113    /// Maximum parallel branches
114    pub max_parallelism: usize,
115
116    /// Custom configuration values
117    pub custom: HashMap<String, Value>,
118}
119
120impl Default for GraphConfig {
121    fn default() -> Self {
122        Self {
123            max_steps: 100,
124            debug: false,
125            checkpoint_enabled: false,
126            checkpoint_interval: 10,
127            timeout_ms: 0,
128            max_parallelism: 10,
129            custom: HashMap::new(),
130        }
131    }
132}
133
134impl GraphConfig {
135    /// Create a new config with default values
136    pub fn new() -> Self {
137        Self::default()
138    }
139
140    /// Set maximum recursion depth
141    pub fn with_max_steps(mut self, max_steps: u32) -> Self {
142        self.max_steps = max_steps;
143        self
144    }
145
146    /// Enable debug mode
147    pub fn with_debug(mut self, debug: bool) -> Self {
148        self.debug = debug;
149        self
150    }
151
152    /// Enable checkpointing
153    pub fn with_checkpoints(mut self, enabled: bool, interval: u32) -> Self {
154        self.checkpoint_enabled = enabled;
155        self.checkpoint_interval = interval;
156        self
157    }
158
159    /// Set timeout
160    pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
161        self.timeout_ms = timeout_ms;
162        self
163    }
164
165    /// Set maximum parallelism
166    pub fn with_max_parallelism(mut self, max: usize) -> Self {
167        self.max_parallelism = max;
168        self
169    }
170
171    /// Add a custom config value
172    pub fn with_custom(mut self, key: impl Into<String>, value: Value) -> Self {
173        self.custom.insert(key.into(), value);
174        self
175    }
176
177    /// Create RemainingSteps from this config
178    pub fn remaining_steps(&self) -> RemainingSteps {
179        RemainingSteps::new(self.max_steps)
180    }
181}
182
183/// Runtime context passed to node functions
184///
185/// Contains non-state information about the current execution,
186/// including execution ID, current node, remaining steps, and metadata.
187#[derive(Debug)]
188pub struct RuntimeContext {
189    /// Unique execution ID
190    pub execution_id: String,
191
192    /// Graph ID
193    pub graph_id: String,
194
195    /// Current node ID (updated during execution)
196    pub current_node: Arc<RwLock<String>>,
197
198    /// Remaining steps tracker
199    pub remaining_steps: RemainingSteps,
200
201    /// Graph configuration
202    pub config: GraphConfig,
203
204    /// Execution metadata
205    pub metadata: HashMap<String, Value>,
206
207    /// Parent execution ID (for sub-workflows)
208    pub parent_execution_id: Option<String>,
209
210    /// Execution tags
211    pub tags: Vec<String>,
212}
213
214impl RuntimeContext {
215    /// Create a new runtime context
216    pub fn new(graph_id: impl Into<String>) -> Self {
217        Self {
218            execution_id: Uuid::new_v4().to_string(),
219            graph_id: graph_id.into(),
220            current_node: Arc::new(RwLock::new(String::new())),
221            remaining_steps: RemainingSteps::new(100),
222            config: GraphConfig::default(),
223            metadata: HashMap::new(),
224            parent_execution_id: None,
225            tags: Vec::new(),
226        }
227    }
228
229    /// Create a context with a specific config
230    pub fn with_config(graph_id: impl Into<String>, config: GraphConfig) -> Self {
231        let remaining_steps = config.remaining_steps();
232        Self {
233            execution_id: Uuid::new_v4().to_string(),
234            graph_id: graph_id.into(),
235            current_node: Arc::new(RwLock::new(String::new())),
236            remaining_steps,
237            config,
238            metadata: HashMap::new(),
239            parent_execution_id: None,
240            tags: Vec::new(),
241        }
242    }
243
244    /// Create a context for a sub-workflow
245    pub fn for_sub_workflow(
246        graph_id: impl Into<String>,
247        parent_execution_id: impl Into<String>,
248        config: GraphConfig,
249    ) -> Self {
250        let remaining_steps = config.remaining_steps();
251        Self {
252            execution_id: Uuid::new_v4().to_string(),
253            graph_id: graph_id.into(),
254            current_node: Arc::new(RwLock::new(String::new())),
255            remaining_steps,
256            config,
257            metadata: HashMap::new(),
258            parent_execution_id: Some(parent_execution_id.into()),
259            tags: Vec::new(),
260        }
261    }
262
263    /// Get the current node ID
264    pub async fn current_node(&self) -> String {
265        self.current_node.read().await.clone()
266    }
267
268    /// Set the current node ID
269    pub async fn set_current_node(&self, node_id: impl Into<String>) {
270        let mut current = self.current_node.write().await;
271        *current = node_id.into();
272    }
273
274    /// Check if recursion limit is reached
275    pub async fn is_recursion_limit_reached(&self) -> bool {
276        self.remaining_steps.is_exhausted().await
277    }
278
279    /// Decrement remaining steps
280    pub async fn decrement_steps(&self) -> u32 {
281        self.remaining_steps.decrement().await
282    }
283
284    /// Add metadata
285    pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
286        self.metadata.insert(key.into(), value);
287        self
288    }
289
290    /// Add a tag
291    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
292        self.tags.push(tag.into());
293        self
294    }
295
296    /// Check if debug mode is enabled
297    pub fn is_debug(&self) -> bool {
298        self.config.debug
299    }
300
301    /// Check if this is a sub-workflow execution
302    pub fn is_sub_workflow(&self) -> bool {
303        self.parent_execution_id.is_some()
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[tokio::test]
312    async fn test_remaining_steps() {
313        let steps = RemainingSteps::new(10);
314
315        assert_eq!(steps.current().await, 10);
316        assert_eq!(steps.max(), 10);
317        assert!(!steps.is_exhausted().await);
318        assert!(steps.has_at_least(5).await);
319
320        steps.decrement().await;
321        assert_eq!(steps.current().await, 9);
322
323        steps.decrement_by(5).await;
324        assert_eq!(steps.current().await, 4);
325
326        steps.reset().await;
327        assert_eq!(steps.current().await, 10);
328    }
329
330    #[tokio::test]
331    async fn test_remaining_steps_exhausted() {
332        let steps = RemainingSteps::new(2);
333
334        assert!(!steps.is_exhausted().await);
335        steps.decrement().await;
336        assert!(!steps.is_exhausted().await);
337        steps.decrement().await;
338        assert!(steps.is_exhausted().await);
339
340        // Should stay at 0
341        steps.decrement().await;
342        assert!(steps.is_exhausted().await);
343    }
344
345    #[test]
346    fn test_graph_config() {
347        let config = GraphConfig::new()
348            .with_max_steps(50)
349            .with_debug(true)
350            .with_checkpoints(true, 5)
351            .with_timeout(30000)
352            .with_max_parallelism(4);
353
354        assert_eq!(config.max_steps, 50);
355        assert!(config.debug);
356        assert!(config.checkpoint_enabled);
357        assert_eq!(config.checkpoint_interval, 5);
358        assert_eq!(config.timeout_ms, 30000);
359        assert_eq!(config.max_parallelism, 4);
360    }
361
362    #[tokio::test]
363    async fn test_runtime_context() {
364        let ctx = RuntimeContext::new("test_graph")
365            .with_metadata("key", serde_json::json!("value"))
366            .with_tag("test");
367
368        assert!(!ctx.execution_id.is_empty());
369        assert_eq!(ctx.graph_id, "test_graph");
370        assert!(ctx.current_node().await.is_empty());
371        assert!(!ctx.is_sub_workflow());
372
373        ctx.set_current_node("node_1").await;
374        assert_eq!(ctx.current_node().await, "node_1");
375    }
376
377    #[tokio::test]
378    async fn test_runtime_context_sub_workflow() {
379        let ctx = RuntimeContext::for_sub_workflow(
380            "sub_graph",
381            "parent-execution-123",
382            GraphConfig::default(),
383        );
384
385        assert!(ctx.is_sub_workflow());
386        assert_eq!(ctx.parent_execution_id, Some("parent-execution-123".to_string()));
387    }
388}