1use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::delta::{DeltaOp, ReducerRegistry, StateDelta};
7use crate::ids::TraceId;
8use crate::state::State;
9
10#[derive(Debug, Clone)]
14pub enum CheckpointTrigger {
15 BarrierResolved,
16 ExecutionCompleted,
17 HumanDecision,
18 Explicit,
19 Adaptive(ExecutionMetadata),
20}
21
22impl PartialEq for CheckpointTrigger {
23 fn eq(&self, other: &Self) -> bool {
24 matches!(
25 (self, other),
26 (Self::BarrierResolved, Self::BarrierResolved)
27 | (Self::ExecutionCompleted, Self::ExecutionCompleted)
28 | (Self::HumanDecision, Self::HumanDecision)
29 | (Self::Explicit, Self::Explicit)
30 | (Self::Adaptive(_), Self::Adaptive(_))
31 )
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct CheckpointPolicy {
38 pub triggers: Vec<CheckpointTrigger>,
39}
40
41impl Default for CheckpointPolicy {
42 fn default() -> Self {
43 Self::conservative()
44 }
45}
46
47impl CheckpointPolicy {
48 pub fn conservative() -> Self {
49 Self {
50 triggers: vec![
51 CheckpointTrigger::BarrierResolved,
52 CheckpointTrigger::ExecutionCompleted,
53 CheckpointTrigger::HumanDecision,
54 ],
55 }
56 }
57
58 pub fn minimal() -> Self {
59 Self {
60 triggers: vec![
61 CheckpointTrigger::BarrierResolved,
62 CheckpointTrigger::ExecutionCompleted,
63 ],
64 }
65 }
66
67 pub fn manual() -> Self {
68 Self {
69 triggers: vec![CheckpointTrigger::Explicit],
70 }
71 }
72
73 pub fn should_checkpoint_on_barrier(&self) -> bool {
74 self.triggers.contains(&CheckpointTrigger::BarrierResolved)
75 }
76
77 pub fn should_checkpoint_on_completion(&self) -> bool {
78 self.triggers
79 .contains(&CheckpointTrigger::ExecutionCompleted)
80 }
81
82 pub fn should_checkpoint_on_human_decision(&self) -> bool {
83 self.triggers.contains(&CheckpointTrigger::HumanDecision)
84 }
85
86 pub fn should_checkpoint_on_explicit(&self) -> bool {
87 self.triggers.contains(&CheckpointTrigger::Explicit)
88 }
89
90 pub fn has_adaptive_trigger(&self) -> bool {
91 self.triggers
92 .iter()
93 .any(|t| matches!(t, CheckpointTrigger::Adaptive(_)))
94 }
95}
96
97#[derive(Debug, Clone, Default)]
101pub struct ExecutionMetadata {
102 pub duration_ms: u64,
103 pub token_cost: f64,
104 pub has_side_effects: bool,
105}
106
107impl ExecutionMetadata {
108 pub fn lightweight() -> Self {
109 Self {
110 duration_ms: 2,
111 token_cost: 0.0,
112 has_side_effects: false,
113 }
114 }
115
116 pub fn heavy() -> Self {
117 Self {
118 duration_ms: 90_000,
119 token_cost: 0.01,
120 has_side_effects: false,
121 }
122 }
123
124 pub fn with_side_effects() -> Self {
125 Self {
126 duration_ms: 0,
127 token_cost: 0.0,
128 has_side_effects: true,
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
135pub struct CheckpointScore {
136 pub duration_weight: f64,
137 pub token_weight: f64,
138 pub side_effect_weight: f64,
139 pub threshold: f64,
140}
141
142impl Default for CheckpointScore {
143 fn default() -> Self {
144 Self {
145 duration_weight: 1.0,
146 token_weight: 1000.0,
147 side_effect_weight: 10000.0,
148 threshold: 100.0,
149 }
150 }
151}
152
153impl CheckpointScore {
154 pub fn calculate(&self, metadata: &ExecutionMetadata) -> f64 {
155 let mut score = self.duration_weight * metadata.duration_ms as f64;
156 score += self.token_weight * metadata.token_cost;
157 if metadata.has_side_effects {
158 score += self.side_effect_weight;
159 }
160 score
161 }
162
163 pub fn should_checkpoint(&self, metadata: &ExecutionMetadata) -> bool {
164 self.calculate(metadata) >= self.threshold
165 }
166}
167
168#[derive(Debug, thiserror::Error)]
172pub enum CheckpointStoreError {
173 #[error("storage error: {0}")]
174 Storage(String),
175 #[error("checkpoint not found: {0}")]
176 NotFound(CheckpointId),
177 #[error("corrupted checkpoint: {0}")]
178 Corrupted(String),
179}
180
181#[async_trait::async_trait]
185pub trait CheckpointStore: Send + Sync {
186 async fn save(&self, checkpoint: &Checkpoint) -> Result<(), CheckpointStoreError>;
187 async fn load(&self, id: &CheckpointId) -> Result<Option<Checkpoint>, CheckpointStoreError>;
188 async fn load_latest(
189 &self,
190 trace_id: &TraceId,
191 ) -> Result<Option<Checkpoint>, CheckpointStoreError>;
192 async fn list(&self, trace_id: &TraceId) -> Result<Vec<CheckpointId>, CheckpointStoreError>;
193 async fn delete(&self, id: &CheckpointId) -> Result<bool, CheckpointStoreError>;
194 async fn prune(&self, trace_id: &TraceId, keep: usize) -> Result<usize, CheckpointStoreError>;
195}
196
197#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
201pub struct CheckpointId(pub uuid::Uuid);
202
203impl std::fmt::Display for CheckpointId {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 write!(f, "{}", self.0)
206 }
207}
208
209#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
211pub struct NodeId(pub String);
212
213impl std::fmt::Display for NodeId {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 write!(f, "{}", self.0)
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct Checkpoint {
222 pub checkpoint_id: CheckpointId,
223 pub parent_trace_id: TraceId,
224 pub graph_hash: String,
225 pub current_node: NodeId,
226 pub state: State,
227 pub created_at: String,
228 pub snapshot: Option<StateSnapshot>,
229}
230
231impl Checkpoint {
232 pub fn new(
233 parent_trace_id: TraceId,
234 graph_hash: impl Into<String>,
235 current_node: impl Into<String>,
236 state: State,
237 ) -> Self {
238 Self {
239 checkpoint_id: CheckpointId(uuid::Uuid::new_v4()),
240 parent_trace_id,
241 graph_hash: graph_hash.into(),
242 current_node: NodeId(current_node.into()),
243 state,
244 created_at: chrono_like_timestamp(),
245 snapshot: None,
246 }
247 }
248
249 pub fn with_snapshot(
250 parent_trace_id: TraceId,
251 graph_hash: impl Into<String>,
252 current_node: impl Into<String>,
253 current_state: State,
254 base_snapshot: State,
255 recent_deltas: Vec<StateDelta>,
256 ) -> Self {
257 Self {
258 checkpoint_id: CheckpointId(uuid::Uuid::new_v4()),
259 parent_trace_id,
260 graph_hash: graph_hash.into(),
261 current_node: NodeId(current_node.into()),
262 state: current_state,
263 created_at: chrono_like_timestamp(),
264 snapshot: Some(StateSnapshot {
265 base_snapshot,
266 recent_deltas,
267 }),
268 }
269 }
270
271 pub fn restore_state(
272 &self,
273 registry: &ReducerRegistry,
274 ) -> Result<State, crate::state::StateError> {
275 if let Some(snapshot) = &self.snapshot {
276 snapshot.restore(registry)
277 } else {
278 Ok(self.state.clone())
279 }
280 }
281
282 pub fn restore_state_simple(&self) -> State {
283 if let Some(snapshot) = &self.snapshot {
284 snapshot.restore_simple()
285 } else {
286 self.state.clone()
287 }
288 }
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct StateSnapshot {
294 pub base_snapshot: State,
295 pub recent_deltas: Vec<StateDelta>,
296}
297
298impl StateSnapshot {
299 pub fn restore(&self, registry: &ReducerRegistry) -> Result<State, crate::state::StateError> {
300 let mut state = self.base_snapshot.clone();
301 registry.merge_deltas(&mut state, &self.recent_deltas)?;
302 Ok(state)
303 }
304
305 pub fn restore_simple(&self) -> State {
306 let mut state = self.base_snapshot.clone();
307 for delta in &self.recent_deltas {
308 match delta.op {
309 DeltaOp::Put => {
310 state.insert(delta.key.to_string(), delta.value.clone());
311 }
312 DeltaOp::Delete => {
313 state.remove(delta.key.as_ref());
314 }
315 }
316 }
317 state
318 }
319
320 pub fn base_size_bytes(&self) -> usize {
321 serde_json::to_vec(&self.base_snapshot)
322 .map(|v| v.len())
323 .unwrap_or(0)
324 }
325
326 pub fn deltas_size_bytes(&self) -> usize {
327 serde_json::to_vec(&self.recent_deltas)
328 .map(|v| v.len())
329 .unwrap_or(0)
330 }
331
332 pub fn total_size_bytes(&self) -> usize {
333 self.base_size_bytes() + self.deltas_size_bytes()
334 }
335
336 pub fn compact(&mut self, threshold: usize) {
337 if self.recent_deltas.len() > threshold {
338 let restored = self.restore_simple();
339 self.base_snapshot = restored;
340 self.recent_deltas.clear();
341 }
342 }
343}
344
345#[derive(Debug, Clone, Default)]
349pub struct IncrementalSnapshotState {
350 pub base_state: Option<State>,
351 pub pending_deltas: Vec<StateDelta>,
352 pub compact_threshold: usize,
353}
354
355impl IncrementalSnapshotState {
356 pub fn new(compact_threshold: usize) -> Self {
357 Self {
358 base_state: None,
359 pending_deltas: Vec::new(),
360 compact_threshold,
361 }
362 }
363
364 pub fn record_delta(&mut self, delta: StateDelta) {
365 self.pending_deltas.push(delta);
366 }
367
368 pub fn record_deltas(&mut self, deltas: Vec<StateDelta>) {
369 self.pending_deltas.extend(deltas);
370 }
371
372 pub fn snapshot(&mut self, current_state: &State) -> (Option<State>, Vec<StateDelta>, State) {
373 let base = self.base_state.clone();
374 let deltas = std::mem::take(&mut self.pending_deltas);
375
376 if base.is_some() && deltas.len() > self.compact_threshold {
377 self.base_state = Some(current_state.clone());
378 self.pending_deltas.clear();
379 return (None, Vec::new(), current_state.clone());
380 }
381
382 (base, deltas, current_state.clone())
383 }
384
385 pub fn from_checkpoint(checkpoint: &Checkpoint) -> Self {
386 if let Some(snapshot) = &checkpoint.snapshot {
387 Self {
388 base_state: Some(snapshot.base_snapshot.clone()),
389 pending_deltas: snapshot.recent_deltas.clone(),
390 compact_threshold: 20,
391 }
392 } else {
393 Self {
394 base_state: Some(checkpoint.state.clone()),
395 pending_deltas: Vec::new(),
396 compact_threshold: 20,
397 }
398 }
399 }
400
401 pub fn clear_pending(&mut self) {
402 self.pending_deltas.clear();
403 }
404}
405
406#[derive(Debug, Clone, Copy, PartialEq, Eq)]
408pub enum GraphHashMode {
409 Strict,
410 Force,
411}
412
413#[derive(Debug, Clone, Serialize, Deserialize)]
417pub struct ExecutionEntry {
418 pub step: usize,
419 pub node_name: String,
420 pub start_time: String,
421 pub end_time: String,
422 pub success: bool,
423 pub error: Option<String>,
424}
425
426#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct BarrierDecisionRecord {
429 pub barrier_id: String,
430 pub node_id: String,
431 pub decision: Value,
432 pub decided_at: String,
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct ExecutionTrace {
438 pub trace_id: TraceId,
439 pub initial_state: State,
440 pub entries: Vec<ExecutionEntry>,
441 pub deltas: Vec<StateDelta>,
442 pub barrier_decisions: Vec<BarrierDecisionRecord>,
443}
444
445impl ExecutionTrace {
446 pub fn new(initial_state: State) -> Self {
447 Self {
448 trace_id: TraceId::default(),
449 initial_state,
450 entries: Vec::new(),
451 deltas: Vec::new(),
452 barrier_decisions: Vec::new(),
453 }
454 }
455}
456
457#[derive(Debug, Clone, Serialize, Deserialize)]
459pub struct GraphResult {
460 pub trace_id: TraceId,
461 pub state: State,
462 pub execution_log: Vec<ExecutionEntry>,
463 pub duration_ms: u128,
464}
465
466fn chrono_like_timestamp() -> String {
469 use std::time::{SystemTime, UNIX_EPOCH};
470 let dur = SystemTime::now()
471 .duration_since(UNIX_EPOCH)
472 .unwrap_or_default();
473 let secs = dur.as_secs();
474 format!(
475 "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
476 ((secs / 86400 / 365) + 1970) as u16,
477 ((secs / 86400 % 365) / 30 + 1) as u8,
478 (secs / 86400 % 30 + 1) as u8,
479 (secs % 86400 / 3600) as u8,
480 (secs % 3600 / 60) as u8,
481 (secs % 60) as u8
482 )
483}