1use serde::{de::DeserializeOwned, Serialize};
4use std::fmt::Debug;
5
6pub trait GraphState: Clone + Send + Sync + Debug + 'static {}
11
12impl<T> GraphState for T where T: Clone + Send + Sync + Debug + 'static {}
14
15#[derive(Debug, Clone)]
17pub struct GraphRunContext<State, Deps = ()> {
18 pub state: State,
20 pub deps: Deps,
22 pub step: u32,
24 pub run_id: String,
26 pub max_steps: u32,
28}
29
30impl<State, Deps> GraphRunContext<State, Deps> {
31 pub fn new(state: State, deps: Deps, run_id: impl Into<String>) -> Self {
33 Self {
34 state,
35 deps,
36 step: 0,
37 run_id: run_id.into(),
38 max_steps: 100,
39 }
40 }
41
42 pub fn with_max_steps(mut self, max: u32) -> Self {
44 self.max_steps = max;
45 self
46 }
47
48 pub fn increment_step(&mut self) {
50 self.step += 1;
51 }
52
53 pub fn is_max_steps_reached(&self) -> bool {
55 self.step >= self.max_steps
56 }
57}
58
59impl<State: Default, Deps: Default> Default for GraphRunContext<State, Deps> {
60 fn default() -> Self {
61 Self {
62 state: State::default(),
63 deps: Deps::default(),
64 step: 0,
65 run_id: generate_run_id(),
66 max_steps: 100,
67 }
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct GraphRunResult<State, End = ()> {
74 pub result: End,
76 pub state: State,
78 pub steps: u32,
80 pub history: Vec<String>,
82 pub run_id: String,
84}
85
86impl<State, End> GraphRunResult<State, End> {
87 pub fn new(result: End, state: State, steps: u32, run_id: impl Into<String>) -> Self {
89 Self {
90 result,
91 state,
92 steps,
93 history: Vec::new(),
94 run_id: run_id.into(),
95 }
96 }
97
98 pub fn with_history(mut self, history: Vec<String>) -> Self {
100 self.history = history;
101 self
102 }
103}
104
105pub fn generate_run_id() -> String {
107 use std::time::SystemTime;
108 let timestamp = SystemTime::now()
109 .duration_since(SystemTime::UNIX_EPOCH)
110 .unwrap_or_default()
111 .as_nanos();
112 format!("run-{:x}", timestamp)
113}
114
115pub trait PersistableState: GraphState + Serialize + DeserializeOwned {}
117
118impl<T> PersistableState for T where T: GraphState + Serialize + DeserializeOwned {}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[derive(Debug, Clone, Default)]
125 struct TestState {
126 value: i32,
127 }
128
129 #[test]
130 fn test_graph_state_trait() {
131 let state = TestState { value: 42 };
132 let cloned = state.clone();
133 assert_eq!(cloned.value, 42);
134 }
135
136 #[test]
137 fn test_run_context() {
138 let mut ctx = GraphRunContext::new(TestState { value: 0 }, (), "test-run");
139
140 assert_eq!(ctx.step, 0);
141 ctx.increment_step();
142 assert_eq!(ctx.step, 1);
143 }
144
145 #[test]
146 fn test_max_steps() {
147 let ctx = GraphRunContext::new(TestState::default(), (), "test").with_max_steps(5);
148
149 assert_eq!(ctx.max_steps, 5);
150 }
151
152 #[test]
153 fn test_generate_run_id() {
154 let id1 = generate_run_id();
155 let id2 = generate_run_id();
156 assert!(id1.starts_with("run-"));
157 assert!(!id1.is_empty());
159 assert!(!id2.is_empty());
160 }
161}