rrag_graph/
state.rs

1//! # Graph State Management
2//!
3//! This module provides the state management system for RGraph workflows.
4//! The state flows through the graph execution, accumulating results and
5//! providing context for decision-making.
6
7use crate::{RGraphError, RGraphResult};
8use parking_lot::RwLock;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15/// Path to a value in the graph state (supports nested access)
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
18pub struct StatePath(pub String);
19
20impl StatePath {
21    /// Create a new state path
22    pub fn new(path: impl Into<String>) -> Self {
23        Self(path.into())
24    }
25
26    /// Create a nested path
27    pub fn nested(parent: impl Into<String>, child: impl Into<String>) -> Self {
28        Self(format!("{}.{}", parent.into(), child.into()))
29    }
30
31    /// Get the path string
32    pub fn as_str(&self) -> &str {
33        &self.0
34    }
35
36    /// Split path into components
37    pub fn components(&self) -> Vec<&str> {
38        self.0.split('.').collect()
39    }
40}
41
42impl From<String> for StatePath {
43    fn from(path: String) -> Self {
44        StatePath(path)
45    }
46}
47
48impl From<&str> for StatePath {
49    fn from(path: &str) -> Self {
50        StatePath(path.to_string())
51    }
52}
53
54/// Values that can be stored in the graph state
55#[derive(Debug, Clone, PartialEq)]
56#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
57pub enum StateValue {
58    /// String value
59    String(String),
60    /// Integer value
61    Integer(i64),
62    /// Float value
63    Float(f64),
64    /// Boolean value
65    Boolean(bool),
66    /// Array of values
67    Array(Vec<StateValue>),
68    /// Object/Map of values
69    Object(HashMap<String, StateValue>),
70    /// Null value
71    Null,
72    /// Binary data
73    Bytes(Vec<u8>),
74}
75
76impl StateValue {
77    /// Convert to string if possible
78    pub fn as_string(&self) -> Option<&str> {
79        match self {
80            StateValue::String(s) => Some(s),
81            _ => None,
82        }
83    }
84
85    /// Convert to integer if possible
86    pub fn as_integer(&self) -> Option<i64> {
87        match self {
88            StateValue::Integer(i) => Some(*i),
89            StateValue::Float(f) => Some(*f as i64),
90            _ => None,
91        }
92    }
93
94    /// Convert to float if possible
95    pub fn as_float(&self) -> Option<f64> {
96        match self {
97            StateValue::Float(f) => Some(*f),
98            StateValue::Integer(i) => Some(*i as f64),
99            _ => None,
100        }
101    }
102
103    /// Convert to boolean if possible
104    pub fn as_boolean(&self) -> Option<bool> {
105        match self {
106            StateValue::Boolean(b) => Some(*b),
107            _ => None,
108        }
109    }
110
111    /// Convert to array if possible
112    pub fn as_array(&self) -> Option<&Vec<StateValue>> {
113        match self {
114            StateValue::Array(arr) => Some(arr),
115            _ => None,
116        }
117    }
118
119    /// Convert to object if possible
120    pub fn as_object(&self) -> Option<&HashMap<String, StateValue>> {
121        match self {
122            StateValue::Object(obj) => Some(obj),
123            _ => None,
124        }
125    }
126
127    /// Check if the value is null
128    pub fn is_null(&self) -> bool {
129        matches!(self, StateValue::Null)
130    }
131
132    /// Get the type name of the value
133    pub fn type_name(&self) -> &'static str {
134        match self {
135            StateValue::String(_) => "string",
136            StateValue::Integer(_) => "integer",
137            StateValue::Float(_) => "float",
138            StateValue::Boolean(_) => "boolean",
139            StateValue::Array(_) => "array",
140            StateValue::Object(_) => "object",
141            StateValue::Null => "null",
142            StateValue::Bytes(_) => "bytes",
143        }
144    }
145}
146
147// Convenient conversions
148impl From<String> for StateValue {
149    fn from(s: String) -> Self {
150        StateValue::String(s)
151    }
152}
153
154impl From<&str> for StateValue {
155    fn from(s: &str) -> Self {
156        StateValue::String(s.to_string())
157    }
158}
159
160impl From<i64> for StateValue {
161    fn from(i: i64) -> Self {
162        StateValue::Integer(i)
163    }
164}
165
166impl From<i32> for StateValue {
167    fn from(i: i32) -> Self {
168        StateValue::Integer(i as i64)
169    }
170}
171
172impl From<f64> for StateValue {
173    fn from(f: f64) -> Self {
174        StateValue::Float(f)
175    }
176}
177
178impl From<f32> for StateValue {
179    fn from(f: f32) -> Self {
180        StateValue::Float(f as f64)
181    }
182}
183
184impl From<bool> for StateValue {
185    fn from(b: bool) -> Self {
186        StateValue::Boolean(b)
187    }
188}
189
190impl From<Vec<StateValue>> for StateValue {
191    fn from(arr: Vec<StateValue>) -> Self {
192        StateValue::Array(arr)
193    }
194}
195
196impl From<HashMap<String, StateValue>> for StateValue {
197    fn from(obj: HashMap<String, StateValue>) -> Self {
198        StateValue::Object(obj)
199    }
200}
201
202impl From<Vec<u8>> for StateValue {
203    fn from(bytes: Vec<u8>) -> Self {
204        StateValue::Bytes(bytes)
205    }
206}
207
208#[cfg(feature = "serde")]
209impl From<serde_json::Value> for StateValue {
210    fn from(value: serde_json::Value) -> Self {
211        match value {
212            serde_json::Value::String(s) => StateValue::String(s),
213            serde_json::Value::Number(n) => {
214                if let Some(i) = n.as_i64() {
215                    StateValue::Integer(i)
216                } else if let Some(f) = n.as_f64() {
217                    StateValue::Float(f)
218                } else {
219                    StateValue::Null
220                }
221            }
222            serde_json::Value::Bool(b) => StateValue::Boolean(b),
223            serde_json::Value::Array(arr) => {
224                StateValue::Array(arr.into_iter().map(StateValue::from).collect())
225            }
226            serde_json::Value::Object(obj) => StateValue::Object(
227                obj.into_iter()
228                    .map(|(k, v)| (k, StateValue::from(v)))
229                    .collect(),
230            ),
231            serde_json::Value::Null => StateValue::Null,
232        }
233    }
234}
235
236#[cfg(feature = "serde")]
237impl From<StateValue> for serde_json::Value {
238    fn from(value: StateValue) -> Self {
239        match value {
240            StateValue::String(s) => serde_json::Value::String(s),
241            StateValue::Integer(i) => serde_json::Value::Number(i.into()),
242            StateValue::Float(f) => serde_json::Value::Number(
243                serde_json::Number::from_f64(f).unwrap_or(serde_json::Number::from(0)),
244            ),
245            StateValue::Boolean(b) => serde_json::Value::Bool(b),
246            StateValue::Array(arr) => {
247                serde_json::Value::Array(arr.into_iter().map(serde_json::Value::from).collect())
248            }
249            StateValue::Object(obj) => serde_json::Value::Object(
250                obj.into_iter()
251                    .map(|(k, v)| (k, serde_json::Value::from(v)))
252                    .collect(),
253            ),
254            StateValue::Null => serde_json::Value::Null,
255            StateValue::Bytes(_) => serde_json::Value::Null, // Can't represent bytes in JSON
256        }
257    }
258}
259
260/// The shared state that flows through the graph execution
261#[derive(Debug, Clone)]
262#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
263pub struct GraphState {
264    /// The state data
265    #[cfg_attr(feature = "serde", serde(skip, default = "default_data"))]
266    data: Arc<RwLock<HashMap<String, StateValue>>>,
267    /// Metadata about the state
268    #[cfg_attr(feature = "serde", serde(skip, default = "default_metadata"))]
269    metadata: Arc<RwLock<HashMap<String, StateValue>>>,
270    /// Execution history
271    #[cfg_attr(feature = "serde", serde(skip, default = "default_execution_log"))]
272    execution_log: Arc<RwLock<Vec<StateHistoryEntry>>>,
273}
274
275/// Entry in the state execution history
276#[derive(Debug, Clone)]
277#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
278pub struct StateHistoryEntry {
279    pub timestamp: chrono::DateTime<chrono::Utc>,
280    pub node_id: String,
281    pub operation: StateOperation,
282    pub key: String,
283    pub old_value: Option<StateValue>,
284    pub new_value: Option<StateValue>,
285}
286
287/// Types of state operations
288#[derive(Debug, Clone)]
289#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
290pub enum StateOperation {
291    Set,
292    Get,
293    Remove,
294    Clear,
295}
296
297impl GraphState {
298    /// Create a new empty graph state
299    pub fn new() -> Self {
300        Self {
301            data: Arc::new(RwLock::new(HashMap::new())),
302            metadata: Arc::new(RwLock::new(HashMap::new())),
303            execution_log: Arc::new(RwLock::new(Vec::new())),
304        }
305    }
306
307    /// Create a new graph state with initial data
308    pub fn with_data(data: HashMap<String, StateValue>) -> Self {
309        Self {
310            data: Arc::new(RwLock::new(data)),
311            metadata: Arc::new(RwLock::new(HashMap::new())),
312            execution_log: Arc::new(RwLock::new(Vec::new())),
313        }
314    }
315
316    /// Set a value in the state
317    pub fn set(&self, key: impl Into<String>, value: impl Into<StateValue>) -> &Self {
318        let key = key.into();
319        let value = value.into();
320
321        // Log the operation
322        self.log_operation(
323            "system",
324            StateOperation::Set,
325            &key,
326            None,
327            Some(value.clone()),
328        );
329
330        // Set the value
331        let mut data = self.data.write();
332        data.insert(key, value);
333
334        self
335    }
336
337    /// Set a value in the state with node context
338    pub fn set_with_context(
339        &self,
340        node_id: &str,
341        key: impl Into<String>,
342        value: impl Into<StateValue>,
343    ) -> &Self {
344        let key = key.into();
345        let value = value.into();
346
347        // Get old value for logging
348        let old_value = self.data.read().get(&key).cloned();
349
350        // Log the operation
351        self.log_operation(
352            node_id,
353            StateOperation::Set,
354            &key,
355            old_value,
356            Some(value.clone()),
357        );
358
359        // Set the value
360        let mut data = self.data.write();
361        data.insert(key, value);
362
363        self
364    }
365
366    /// Get a value from the state
367    pub fn get(&self, key: &str) -> RGraphResult<StateValue> {
368        let path = StatePath::new(key);
369        self.get_by_path(&path)
370    }
371
372    /// Get a value by path (supports nested access)
373    pub fn get_by_path(&self, path: &StatePath) -> RGraphResult<StateValue> {
374        let components = path.components();
375        let data = self.data.read();
376
377        if components.len() == 1 {
378            // Simple key access
379            data.get(components[0])
380                .cloned()
381                .ok_or_else(|| RGraphError::state(format!("Key '{}' not found", components[0])))
382        } else {
383            // Nested access
384            let mut current_value = data
385                .get(components[0])
386                .ok_or_else(|| RGraphError::state(format!("Key '{}' not found", components[0])))?;
387
388            for component in &components[1..] {
389                match current_value {
390                    StateValue::Object(ref obj) => {
391                        current_value = obj.get(*component).ok_or_else(|| {
392                            RGraphError::state(format!("Nested key '{}' not found", component))
393                        })?;
394                    }
395                    _ => {
396                        return Err(RGraphError::state(format!(
397                            "Cannot access '{}' on non-object value",
398                            component
399                        )))
400                    }
401                }
402            }
403
404            Ok(current_value.clone())
405        }
406    }
407
408    /// Get a typed value from the state
409    pub fn get_typed<T>(&self, key: &str) -> RGraphResult<T>
410    where
411        T: TryFrom<StateValue>,
412        T::Error: std::fmt::Display,
413    {
414        let value = self.get(key)?;
415        T::try_from(value).map_err(|e| RGraphError::state(e.to_string()))
416    }
417
418    /// Check if a key exists in the state
419    pub fn contains_key(&self, key: &str) -> bool {
420        self.data.read().contains_key(key)
421    }
422
423    /// Remove a value from the state
424    pub fn remove(&self, key: &str) -> Option<StateValue> {
425        let mut data = self.data.write();
426        let old_value = data.remove(key);
427
428        // Log the operation
429        self.log_operation(
430            "system",
431            StateOperation::Remove,
432            key,
433            old_value.clone(),
434            None,
435        );
436
437        old_value
438    }
439
440    /// Clear all data from the state
441    pub fn clear(&self) {
442        let mut data = self.data.write();
443        data.clear();
444
445        // Log the operation
446        self.log_operation("system", StateOperation::Clear, "all", None, None);
447    }
448
449    /// Get all keys in the state
450    pub fn keys(&self) -> Vec<String> {
451        self.data.read().keys().cloned().collect()
452    }
453
454    /// Get the number of items in the state
455    pub fn len(&self) -> usize {
456        self.data.read().len()
457    }
458
459    /// Check if the state is empty
460    pub fn is_empty(&self) -> bool {
461        self.data.read().is_empty()
462    }
463
464    /// Merge another state into this one
465    pub fn merge(&self, other: &GraphState) {
466        let other_data = other.data.read();
467        let mut data = self.data.write();
468
469        for (key, value) in other_data.iter() {
470            data.insert(key.clone(), value.clone());
471        }
472    }
473
474    /// Create a snapshot of the current state
475    pub fn snapshot(&self) -> HashMap<String, StateValue> {
476        self.data.read().clone()
477    }
478
479    /// Set metadata
480    pub fn set_metadata(&self, key: impl Into<String>, value: impl Into<StateValue>) {
481        let mut metadata = self.metadata.write();
482        metadata.insert(key.into(), value.into());
483    }
484
485    /// Get metadata
486    pub fn get_metadata(&self, key: &str) -> Option<StateValue> {
487        self.metadata.read().get(key).cloned()
488    }
489
490    /// Get execution history
491    pub fn execution_history(&self) -> Vec<StateHistoryEntry> {
492        self.execution_log.read().clone()
493    }
494
495    /// Convenience method to add input data
496    pub fn with_input(self, key: impl Into<String>, value: impl Into<StateValue>) -> Self {
497        self.set(key, value);
498        self
499    }
500
501    /// Get output data as a specific type
502    pub fn get_output<T>(&self, key: &str) -> RGraphResult<T>
503    where
504        T: TryFrom<StateValue>,
505        T::Error: std::fmt::Display,
506    {
507        self.get_typed(key)
508    }
509
510    /// Log a state operation
511    fn log_operation(
512        &self,
513        node_id: &str,
514        operation: StateOperation,
515        key: &str,
516        old_value: Option<StateValue>,
517        new_value: Option<StateValue>,
518    ) {
519        let entry = StateHistoryEntry {
520            timestamp: chrono::Utc::now(),
521            node_id: node_id.to_string(),
522            operation,
523            key: key.to_string(),
524            old_value,
525            new_value,
526        };
527
528        self.execution_log.write().push(entry);
529    }
530}
531
532impl Default for GraphState {
533    fn default() -> Self {
534        Self::new()
535    }
536}
537
538// Implement TryFrom for common types from StateValue
539impl TryFrom<StateValue> for String {
540    type Error = RGraphError;
541
542    fn try_from(value: StateValue) -> Result<Self, Self::Error> {
543        match value {
544            StateValue::String(s) => Ok(s),
545            _ => Err(RGraphError::state(format!(
546                "Cannot convert {} to String",
547                value.type_name()
548            ))),
549        }
550    }
551}
552
553impl TryFrom<StateValue> for i64 {
554    type Error = RGraphError;
555
556    fn try_from(value: StateValue) -> Result<Self, Self::Error> {
557        match value {
558            StateValue::Integer(i) => Ok(i),
559            StateValue::Float(f) => Ok(f as i64),
560            _ => Err(RGraphError::state(format!(
561                "Cannot convert {} to i64",
562                value.type_name()
563            ))),
564        }
565    }
566}
567
568impl TryFrom<StateValue> for f64 {
569    type Error = RGraphError;
570
571    fn try_from(value: StateValue) -> Result<Self, Self::Error> {
572        match value {
573            StateValue::Float(f) => Ok(f),
574            StateValue::Integer(i) => Ok(i as f64),
575            _ => Err(RGraphError::state(format!(
576                "Cannot convert {} to f64",
577                value.type_name()
578            ))),
579        }
580    }
581}
582
583impl TryFrom<StateValue> for bool {
584    type Error = RGraphError;
585
586    fn try_from(value: StateValue) -> Result<Self, Self::Error> {
587        match value {
588            StateValue::Boolean(b) => Ok(b),
589            _ => Err(RGraphError::state(format!(
590                "Cannot convert {} to bool",
591                value.type_name()
592            ))),
593        }
594    }
595}
596
597impl TryFrom<StateValue> for Vec<StateValue> {
598    type Error = RGraphError;
599
600    fn try_from(value: StateValue) -> Result<Self, Self::Error> {
601        match value {
602            StateValue::Array(arr) => Ok(arr),
603            _ => Err(RGraphError::state(format!(
604                "Cannot convert {} to Vec<StateValue>",
605                value.type_name()
606            ))),
607        }
608    }
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614
615    #[test]
616    fn test_state_value_conversions() {
617        let string_val: StateValue = "hello".into();
618        assert_eq!(string_val.as_string(), Some("hello"));
619
620        let int_val: StateValue = 42i64.into();
621        assert_eq!(int_val.as_integer(), Some(42));
622
623        let float_val: StateValue = 3.14f64.into();
624        assert_eq!(float_val.as_float(), Some(3.14));
625
626        let bool_val: StateValue = true.into();
627        assert_eq!(bool_val.as_boolean(), Some(true));
628    }
629
630    #[test]
631    fn test_graph_state_basic_operations() {
632        let state = GraphState::new();
633
634        // Test set and get
635        state.set("key1", "value1");
636        assert_eq!(
637            state.get("key1").unwrap(),
638            StateValue::String("value1".to_string())
639        );
640
641        // Test contains_key
642        assert!(state.contains_key("key1"));
643        assert!(!state.contains_key("nonexistent"));
644
645        // Test remove
646        let removed = state.remove("key1");
647        assert_eq!(removed, Some(StateValue::String("value1".to_string())));
648        assert!(!state.contains_key("key1"));
649    }
650
651    #[test]
652    fn test_state_path() {
653        let path = StatePath::new("parent.child.grandchild");
654        let components = path.components();
655        assert_eq!(components, vec!["parent", "child", "grandchild"]);
656
657        let nested_path = StatePath::nested("parent", "child");
658        assert_eq!(nested_path.as_str(), "parent.child");
659    }
660
661    #[test]
662    fn test_state_with_input() {
663        let state = GraphState::new()
664            .with_input("name", "Alice")
665            .with_input("age", 30);
666
667        assert_eq!(state.get("name").unwrap().as_string(), Some("Alice"));
668        assert_eq!(state.get("age").unwrap().as_integer(), Some(30));
669    }
670
671    #[test]
672    fn test_state_merge() {
673        let state1 = GraphState::new();
674        state1.set("key1", "value1");
675
676        let state2 = GraphState::new();
677        state2.set("key2", "value2");
678
679        state1.merge(&state2);
680
681        assert!(state1.contains_key("key1"));
682        assert!(state1.contains_key("key2"));
683    }
684
685    #[test]
686    fn test_execution_history() {
687        let state = GraphState::new();
688        state.set_with_context("node1", "key1", "value1");
689        state.set_with_context("node2", "key2", "value2");
690
691        let history = state.execution_history();
692        assert_eq!(history.len(), 2);
693        assert_eq!(history[0].node_id, "node1");
694        assert_eq!(history[1].node_id, "node2");
695    }
696}
697
698// Default functions for serde skipped fields
699#[cfg(feature = "serde")]
700fn default_data() -> Arc<RwLock<HashMap<String, StateValue>>> {
701    Arc::new(RwLock::new(HashMap::new()))
702}
703
704#[cfg(feature = "serde")]
705fn default_metadata() -> Arc<RwLock<HashMap<String, StateValue>>> {
706    Arc::new(RwLock::new(HashMap::new()))
707}
708
709#[cfg(feature = "serde")]
710fn default_execution_log() -> Arc<RwLock<Vec<StateHistoryEntry>>> {
711    Arc::new(RwLock::new(Vec::new()))
712}