1use crate::{RGraphError, RGraphResult};
8use parking_lot::RwLock;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
18pub struct StatePath(pub String);
19
20impl StatePath {
21 pub fn new(path: impl Into<String>) -> Self {
23 Self(path.into())
24 }
25
26 pub fn nested(parent: impl Into<String>, child: impl Into<String>) -> Self {
28 Self(format!("{}.{}", parent.into(), child.into()))
29 }
30
31 pub fn as_str(&self) -> &str {
33 &self.0
34 }
35
36 pub fn components(&self) -> Vec<&str> {
38 self.0.split('.').collect()
39 }
40}
41
42impl From<String> for StatePath {
43 fn from(path: String) -> Self {
44 StatePath(path)
45 }
46}
47
48impl From<&str> for StatePath {
49 fn from(path: &str) -> Self {
50 StatePath(path.to_string())
51 }
52}
53
54#[derive(Debug, Clone, PartialEq)]
56#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
57pub enum StateValue {
58 String(String),
60 Integer(i64),
62 Float(f64),
64 Boolean(bool),
66 Array(Vec<StateValue>),
68 Object(HashMap<String, StateValue>),
70 Null,
72 Bytes(Vec<u8>),
74}
75
76impl StateValue {
77 pub fn as_string(&self) -> Option<&str> {
79 match self {
80 StateValue::String(s) => Some(s),
81 _ => None,
82 }
83 }
84
85 pub fn as_integer(&self) -> Option<i64> {
87 match self {
88 StateValue::Integer(i) => Some(*i),
89 StateValue::Float(f) => Some(*f as i64),
90 _ => None,
91 }
92 }
93
94 pub fn as_float(&self) -> Option<f64> {
96 match self {
97 StateValue::Float(f) => Some(*f),
98 StateValue::Integer(i) => Some(*i as f64),
99 _ => None,
100 }
101 }
102
103 pub fn as_boolean(&self) -> Option<bool> {
105 match self {
106 StateValue::Boolean(b) => Some(*b),
107 _ => None,
108 }
109 }
110
111 pub fn as_array(&self) -> Option<&Vec<StateValue>> {
113 match self {
114 StateValue::Array(arr) => Some(arr),
115 _ => None,
116 }
117 }
118
119 pub fn as_object(&self) -> Option<&HashMap<String, StateValue>> {
121 match self {
122 StateValue::Object(obj) => Some(obj),
123 _ => None,
124 }
125 }
126
127 pub fn is_null(&self) -> bool {
129 matches!(self, StateValue::Null)
130 }
131
132 pub fn type_name(&self) -> &'static str {
134 match self {
135 StateValue::String(_) => "string",
136 StateValue::Integer(_) => "integer",
137 StateValue::Float(_) => "float",
138 StateValue::Boolean(_) => "boolean",
139 StateValue::Array(_) => "array",
140 StateValue::Object(_) => "object",
141 StateValue::Null => "null",
142 StateValue::Bytes(_) => "bytes",
143 }
144 }
145}
146
147impl From<String> for StateValue {
149 fn from(s: String) -> Self {
150 StateValue::String(s)
151 }
152}
153
154impl From<&str> for StateValue {
155 fn from(s: &str) -> Self {
156 StateValue::String(s.to_string())
157 }
158}
159
160impl From<i64> for StateValue {
161 fn from(i: i64) -> Self {
162 StateValue::Integer(i)
163 }
164}
165
166impl From<i32> for StateValue {
167 fn from(i: i32) -> Self {
168 StateValue::Integer(i as i64)
169 }
170}
171
172impl From<f64> for StateValue {
173 fn from(f: f64) -> Self {
174 StateValue::Float(f)
175 }
176}
177
178impl From<f32> for StateValue {
179 fn from(f: f32) -> Self {
180 StateValue::Float(f as f64)
181 }
182}
183
184impl From<bool> for StateValue {
185 fn from(b: bool) -> Self {
186 StateValue::Boolean(b)
187 }
188}
189
190impl From<Vec<StateValue>> for StateValue {
191 fn from(arr: Vec<StateValue>) -> Self {
192 StateValue::Array(arr)
193 }
194}
195
196impl From<HashMap<String, StateValue>> for StateValue {
197 fn from(obj: HashMap<String, StateValue>) -> Self {
198 StateValue::Object(obj)
199 }
200}
201
202impl From<Vec<u8>> for StateValue {
203 fn from(bytes: Vec<u8>) -> Self {
204 StateValue::Bytes(bytes)
205 }
206}
207
208#[cfg(feature = "serde")]
209impl From<serde_json::Value> for StateValue {
210 fn from(value: serde_json::Value) -> Self {
211 match value {
212 serde_json::Value::String(s) => StateValue::String(s),
213 serde_json::Value::Number(n) => {
214 if let Some(i) = n.as_i64() {
215 StateValue::Integer(i)
216 } else if let Some(f) = n.as_f64() {
217 StateValue::Float(f)
218 } else {
219 StateValue::Null
220 }
221 }
222 serde_json::Value::Bool(b) => StateValue::Boolean(b),
223 serde_json::Value::Array(arr) => {
224 StateValue::Array(arr.into_iter().map(StateValue::from).collect())
225 }
226 serde_json::Value::Object(obj) => StateValue::Object(
227 obj.into_iter()
228 .map(|(k, v)| (k, StateValue::from(v)))
229 .collect(),
230 ),
231 serde_json::Value::Null => StateValue::Null,
232 }
233 }
234}
235
236#[cfg(feature = "serde")]
237impl From<StateValue> for serde_json::Value {
238 fn from(value: StateValue) -> Self {
239 match value {
240 StateValue::String(s) => serde_json::Value::String(s),
241 StateValue::Integer(i) => serde_json::Value::Number(i.into()),
242 StateValue::Float(f) => serde_json::Value::Number(
243 serde_json::Number::from_f64(f).unwrap_or(serde_json::Number::from(0)),
244 ),
245 StateValue::Boolean(b) => serde_json::Value::Bool(b),
246 StateValue::Array(arr) => {
247 serde_json::Value::Array(arr.into_iter().map(serde_json::Value::from).collect())
248 }
249 StateValue::Object(obj) => serde_json::Value::Object(
250 obj.into_iter()
251 .map(|(k, v)| (k, serde_json::Value::from(v)))
252 .collect(),
253 ),
254 StateValue::Null => serde_json::Value::Null,
255 StateValue::Bytes(_) => serde_json::Value::Null, }
257 }
258}
259
260#[derive(Debug, Clone)]
262#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
263pub struct GraphState {
264 #[cfg_attr(feature = "serde", serde(skip, default = "default_data"))]
266 data: Arc<RwLock<HashMap<String, StateValue>>>,
267 #[cfg_attr(feature = "serde", serde(skip, default = "default_metadata"))]
269 metadata: Arc<RwLock<HashMap<String, StateValue>>>,
270 #[cfg_attr(feature = "serde", serde(skip, default = "default_execution_log"))]
272 execution_log: Arc<RwLock<Vec<StateHistoryEntry>>>,
273}
274
275#[derive(Debug, Clone)]
277#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
278pub struct StateHistoryEntry {
279 pub timestamp: chrono::DateTime<chrono::Utc>,
280 pub node_id: String,
281 pub operation: StateOperation,
282 pub key: String,
283 pub old_value: Option<StateValue>,
284 pub new_value: Option<StateValue>,
285}
286
287#[derive(Debug, Clone)]
289#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
290pub enum StateOperation {
291 Set,
292 Get,
293 Remove,
294 Clear,
295}
296
297impl GraphState {
298 pub fn new() -> Self {
300 Self {
301 data: Arc::new(RwLock::new(HashMap::new())),
302 metadata: Arc::new(RwLock::new(HashMap::new())),
303 execution_log: Arc::new(RwLock::new(Vec::new())),
304 }
305 }
306
307 pub fn with_data(data: HashMap<String, StateValue>) -> Self {
309 Self {
310 data: Arc::new(RwLock::new(data)),
311 metadata: Arc::new(RwLock::new(HashMap::new())),
312 execution_log: Arc::new(RwLock::new(Vec::new())),
313 }
314 }
315
316 pub fn set(&self, key: impl Into<String>, value: impl Into<StateValue>) -> &Self {
318 let key = key.into();
319 let value = value.into();
320
321 self.log_operation(
323 "system",
324 StateOperation::Set,
325 &key,
326 None,
327 Some(value.clone()),
328 );
329
330 let mut data = self.data.write();
332 data.insert(key, value);
333
334 self
335 }
336
337 pub fn set_with_context(
339 &self,
340 node_id: &str,
341 key: impl Into<String>,
342 value: impl Into<StateValue>,
343 ) -> &Self {
344 let key = key.into();
345 let value = value.into();
346
347 let old_value = self.data.read().get(&key).cloned();
349
350 self.log_operation(
352 node_id,
353 StateOperation::Set,
354 &key,
355 old_value,
356 Some(value.clone()),
357 );
358
359 let mut data = self.data.write();
361 data.insert(key, value);
362
363 self
364 }
365
366 pub fn get(&self, key: &str) -> RGraphResult<StateValue> {
368 let path = StatePath::new(key);
369 self.get_by_path(&path)
370 }
371
372 pub fn get_by_path(&self, path: &StatePath) -> RGraphResult<StateValue> {
374 let components = path.components();
375 let data = self.data.read();
376
377 if components.len() == 1 {
378 data.get(components[0])
380 .cloned()
381 .ok_or_else(|| RGraphError::state(format!("Key '{}' not found", components[0])))
382 } else {
383 let mut current_value = data
385 .get(components[0])
386 .ok_or_else(|| RGraphError::state(format!("Key '{}' not found", components[0])))?;
387
388 for component in &components[1..] {
389 match current_value {
390 StateValue::Object(ref obj) => {
391 current_value = obj.get(*component).ok_or_else(|| {
392 RGraphError::state(format!("Nested key '{}' not found", component))
393 })?;
394 }
395 _ => {
396 return Err(RGraphError::state(format!(
397 "Cannot access '{}' on non-object value",
398 component
399 )))
400 }
401 }
402 }
403
404 Ok(current_value.clone())
405 }
406 }
407
408 pub fn get_typed<T>(&self, key: &str) -> RGraphResult<T>
410 where
411 T: TryFrom<StateValue>,
412 T::Error: std::fmt::Display,
413 {
414 let value = self.get(key)?;
415 T::try_from(value).map_err(|e| RGraphError::state(e.to_string()))
416 }
417
418 pub fn contains_key(&self, key: &str) -> bool {
420 self.data.read().contains_key(key)
421 }
422
423 pub fn remove(&self, key: &str) -> Option<StateValue> {
425 let mut data = self.data.write();
426 let old_value = data.remove(key);
427
428 self.log_operation(
430 "system",
431 StateOperation::Remove,
432 key,
433 old_value.clone(),
434 None,
435 );
436
437 old_value
438 }
439
440 pub fn clear(&self) {
442 let mut data = self.data.write();
443 data.clear();
444
445 self.log_operation("system", StateOperation::Clear, "all", None, None);
447 }
448
449 pub fn keys(&self) -> Vec<String> {
451 self.data.read().keys().cloned().collect()
452 }
453
454 pub fn len(&self) -> usize {
456 self.data.read().len()
457 }
458
459 pub fn is_empty(&self) -> bool {
461 self.data.read().is_empty()
462 }
463
464 pub fn merge(&self, other: &GraphState) {
466 let other_data = other.data.read();
467 let mut data = self.data.write();
468
469 for (key, value) in other_data.iter() {
470 data.insert(key.clone(), value.clone());
471 }
472 }
473
474 pub fn snapshot(&self) -> HashMap<String, StateValue> {
476 self.data.read().clone()
477 }
478
479 pub fn set_metadata(&self, key: impl Into<String>, value: impl Into<StateValue>) {
481 let mut metadata = self.metadata.write();
482 metadata.insert(key.into(), value.into());
483 }
484
485 pub fn get_metadata(&self, key: &str) -> Option<StateValue> {
487 self.metadata.read().get(key).cloned()
488 }
489
490 pub fn execution_history(&self) -> Vec<StateHistoryEntry> {
492 self.execution_log.read().clone()
493 }
494
495 pub fn with_input(self, key: impl Into<String>, value: impl Into<StateValue>) -> Self {
497 self.set(key, value);
498 self
499 }
500
501 pub fn get_output<T>(&self, key: &str) -> RGraphResult<T>
503 where
504 T: TryFrom<StateValue>,
505 T::Error: std::fmt::Display,
506 {
507 self.get_typed(key)
508 }
509
510 fn log_operation(
512 &self,
513 node_id: &str,
514 operation: StateOperation,
515 key: &str,
516 old_value: Option<StateValue>,
517 new_value: Option<StateValue>,
518 ) {
519 let entry = StateHistoryEntry {
520 timestamp: chrono::Utc::now(),
521 node_id: node_id.to_string(),
522 operation,
523 key: key.to_string(),
524 old_value,
525 new_value,
526 };
527
528 self.execution_log.write().push(entry);
529 }
530}
531
532impl Default for GraphState {
533 fn default() -> Self {
534 Self::new()
535 }
536}
537
538impl TryFrom<StateValue> for String {
540 type Error = RGraphError;
541
542 fn try_from(value: StateValue) -> Result<Self, Self::Error> {
543 match value {
544 StateValue::String(s) => Ok(s),
545 _ => Err(RGraphError::state(format!(
546 "Cannot convert {} to String",
547 value.type_name()
548 ))),
549 }
550 }
551}
552
553impl TryFrom<StateValue> for i64 {
554 type Error = RGraphError;
555
556 fn try_from(value: StateValue) -> Result<Self, Self::Error> {
557 match value {
558 StateValue::Integer(i) => Ok(i),
559 StateValue::Float(f) => Ok(f as i64),
560 _ => Err(RGraphError::state(format!(
561 "Cannot convert {} to i64",
562 value.type_name()
563 ))),
564 }
565 }
566}
567
568impl TryFrom<StateValue> for f64 {
569 type Error = RGraphError;
570
571 fn try_from(value: StateValue) -> Result<Self, Self::Error> {
572 match value {
573 StateValue::Float(f) => Ok(f),
574 StateValue::Integer(i) => Ok(i as f64),
575 _ => Err(RGraphError::state(format!(
576 "Cannot convert {} to f64",
577 value.type_name()
578 ))),
579 }
580 }
581}
582
583impl TryFrom<StateValue> for bool {
584 type Error = RGraphError;
585
586 fn try_from(value: StateValue) -> Result<Self, Self::Error> {
587 match value {
588 StateValue::Boolean(b) => Ok(b),
589 _ => Err(RGraphError::state(format!(
590 "Cannot convert {} to bool",
591 value.type_name()
592 ))),
593 }
594 }
595}
596
597impl TryFrom<StateValue> for Vec<StateValue> {
598 type Error = RGraphError;
599
600 fn try_from(value: StateValue) -> Result<Self, Self::Error> {
601 match value {
602 StateValue::Array(arr) => Ok(arr),
603 _ => Err(RGraphError::state(format!(
604 "Cannot convert {} to Vec<StateValue>",
605 value.type_name()
606 ))),
607 }
608 }
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614
615 #[test]
616 fn test_state_value_conversions() {
617 let string_val: StateValue = "hello".into();
618 assert_eq!(string_val.as_string(), Some("hello"));
619
620 let int_val: StateValue = 42i64.into();
621 assert_eq!(int_val.as_integer(), Some(42));
622
623 let float_val: StateValue = 3.14f64.into();
624 assert_eq!(float_val.as_float(), Some(3.14));
625
626 let bool_val: StateValue = true.into();
627 assert_eq!(bool_val.as_boolean(), Some(true));
628 }
629
630 #[test]
631 fn test_graph_state_basic_operations() {
632 let state = GraphState::new();
633
634 state.set("key1", "value1");
636 assert_eq!(
637 state.get("key1").unwrap(),
638 StateValue::String("value1".to_string())
639 );
640
641 assert!(state.contains_key("key1"));
643 assert!(!state.contains_key("nonexistent"));
644
645 let removed = state.remove("key1");
647 assert_eq!(removed, Some(StateValue::String("value1".to_string())));
648 assert!(!state.contains_key("key1"));
649 }
650
651 #[test]
652 fn test_state_path() {
653 let path = StatePath::new("parent.child.grandchild");
654 let components = path.components();
655 assert_eq!(components, vec!["parent", "child", "grandchild"]);
656
657 let nested_path = StatePath::nested("parent", "child");
658 assert_eq!(nested_path.as_str(), "parent.child");
659 }
660
661 #[test]
662 fn test_state_with_input() {
663 let state = GraphState::new()
664 .with_input("name", "Alice")
665 .with_input("age", 30);
666
667 assert_eq!(state.get("name").unwrap().as_string(), Some("Alice"));
668 assert_eq!(state.get("age").unwrap().as_integer(), Some(30));
669 }
670
671 #[test]
672 fn test_state_merge() {
673 let state1 = GraphState::new();
674 state1.set("key1", "value1");
675
676 let state2 = GraphState::new();
677 state2.set("key2", "value2");
678
679 state1.merge(&state2);
680
681 assert!(state1.contains_key("key1"));
682 assert!(state1.contains_key("key2"));
683 }
684
685 #[test]
686 fn test_execution_history() {
687 let state = GraphState::new();
688 state.set_with_context("node1", "key1", "value1");
689 state.set_with_context("node2", "key2", "value2");
690
691 let history = state.execution_history();
692 assert_eq!(history.len(), 2);
693 assert_eq!(history[0].node_id, "node1");
694 assert_eq!(history[1].node_id, "node2");
695 }
696}
697
698#[cfg(feature = "serde")]
700fn default_data() -> Arc<RwLock<HashMap<String, StateValue>>> {
701 Arc::new(RwLock::new(HashMap::new()))
702}
703
704#[cfg(feature = "serde")]
705fn default_metadata() -> Arc<RwLock<HashMap<String, StateValue>>> {
706 Arc::new(RwLock::new(HashMap::new()))
707}
708
709#[cfg(feature = "serde")]
710fn default_execution_log() -> Arc<RwLock<Vec<StateHistoryEntry>>> {
711 Arc::new(RwLock::new(Vec::new()))
712}