Skip to main content

lellm_graph/
checkpoint.rs

1//! Checkpoint + ExecutionTrace — 从 lellm-runtime 合并。
2
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5
6use crate::delta::{DeltaOp, ReducerRegistry, StateDelta};
7use crate::ids::TraceId;
8use crate::state::State;
9
10// ─── CheckpointPolicy ──────────────────────────────────────────
11
12/// Checkpoint 触发时机。
13#[derive(Debug, Clone)]
14pub enum CheckpointTrigger {
15    BarrierResolved,
16    ExecutionCompleted,
17    HumanDecision,
18    Explicit,
19    Adaptive(ExecutionMetadata),
20}
21
22impl PartialEq for CheckpointTrigger {
23    fn eq(&self, other: &Self) -> bool {
24        matches!(
25            (self, other),
26            (Self::BarrierResolved, Self::BarrierResolved)
27                | (Self::ExecutionCompleted, Self::ExecutionCompleted)
28                | (Self::HumanDecision, Self::HumanDecision)
29                | (Self::Explicit, Self::Explicit)
30                | (Self::Adaptive(_), Self::Adaptive(_))
31        )
32    }
33}
34
35/// Checkpoint 策略。
36#[derive(Debug, Clone)]
37pub struct CheckpointPolicy {
38    pub triggers: Vec<CheckpointTrigger>,
39}
40
41impl Default for CheckpointPolicy {
42    fn default() -> Self {
43        Self::conservative()
44    }
45}
46
47impl CheckpointPolicy {
48    pub fn conservative() -> Self {
49        Self {
50            triggers: vec![
51                CheckpointTrigger::BarrierResolved,
52                CheckpointTrigger::ExecutionCompleted,
53                CheckpointTrigger::HumanDecision,
54            ],
55        }
56    }
57
58    pub fn minimal() -> Self {
59        Self {
60            triggers: vec![
61                CheckpointTrigger::BarrierResolved,
62                CheckpointTrigger::ExecutionCompleted,
63            ],
64        }
65    }
66
67    pub fn manual() -> Self {
68        Self {
69            triggers: vec![CheckpointTrigger::Explicit],
70        }
71    }
72
73    pub fn should_checkpoint_on_barrier(&self) -> bool {
74        self.triggers.contains(&CheckpointTrigger::BarrierResolved)
75    }
76
77    pub fn should_checkpoint_on_completion(&self) -> bool {
78        self.triggers
79            .contains(&CheckpointTrigger::ExecutionCompleted)
80    }
81
82    pub fn should_checkpoint_on_human_decision(&self) -> bool {
83        self.triggers.contains(&CheckpointTrigger::HumanDecision)
84    }
85
86    pub fn should_checkpoint_on_explicit(&self) -> bool {
87        self.triggers.contains(&CheckpointTrigger::Explicit)
88    }
89
90    pub fn has_adaptive_trigger(&self) -> bool {
91        self.triggers
92            .iter()
93            .any(|t| matches!(t, CheckpointTrigger::Adaptive(_)))
94    }
95}
96
97// ─── ExecutionMetadata ────────────────────────────────────────
98
99/// 节点执行元数据 — 用于 Adaptive Checkpoint 决策。
100#[derive(Debug, Clone, Default)]
101pub struct ExecutionMetadata {
102    pub duration_ms: u64,
103    pub token_cost: f64,
104    pub has_side_effects: bool,
105}
106
107impl ExecutionMetadata {
108    pub fn lightweight() -> Self {
109        Self {
110            duration_ms: 2,
111            token_cost: 0.0,
112            has_side_effects: false,
113        }
114    }
115
116    pub fn heavy() -> Self {
117        Self {
118            duration_ms: 90_000,
119            token_cost: 0.01,
120            has_side_effects: false,
121        }
122    }
123
124    pub fn with_side_effects() -> Self {
125        Self {
126            duration_ms: 0,
127            token_cost: 0.0,
128            has_side_effects: true,
129        }
130    }
131}
132
133/// Checkpoint 评分。
134#[derive(Debug, Clone)]
135pub struct CheckpointScore {
136    pub duration_weight: f64,
137    pub token_weight: f64,
138    pub side_effect_weight: f64,
139    pub threshold: f64,
140}
141
142impl Default for CheckpointScore {
143    fn default() -> Self {
144        Self {
145            duration_weight: 1.0,
146            token_weight: 1000.0,
147            side_effect_weight: 10000.0,
148            threshold: 100.0,
149        }
150    }
151}
152
153impl CheckpointScore {
154    pub fn calculate(&self, metadata: &ExecutionMetadata) -> f64 {
155        let mut score = self.duration_weight * metadata.duration_ms as f64;
156        score += self.token_weight * metadata.token_cost;
157        if metadata.has_side_effects {
158            score += self.side_effect_weight;
159        }
160        score
161    }
162
163    pub fn should_checkpoint(&self, metadata: &ExecutionMetadata) -> bool {
164        self.calculate(metadata) >= self.threshold
165    }
166}
167
168// ─── CheckpointStoreError ──────────────────────────────────────
169
170/// Checkpoint 存储操作错误。
171#[derive(Debug, thiserror::Error)]
172pub enum CheckpointStoreError {
173    #[error("storage error: {0}")]
174    Storage(String),
175    #[error("checkpoint not found: {0}")]
176    NotFound(CheckpointId),
177    #[error("corrupted checkpoint: {0}")]
178    Corrupted(String),
179}
180
181// ─── CheckpointStore trait ─────────────────────────────────────
182
183/// Checkpoint 存储后端 SPI。
184#[async_trait::async_trait]
185pub trait CheckpointStore: Send + Sync {
186    async fn save(&self, checkpoint: &Checkpoint) -> Result<(), CheckpointStoreError>;
187    async fn load(&self, id: &CheckpointId) -> Result<Option<Checkpoint>, CheckpointStoreError>;
188    async fn load_latest(
189        &self,
190        trace_id: &TraceId,
191    ) -> Result<Option<Checkpoint>, CheckpointStoreError>;
192    async fn list(&self, trace_id: &TraceId) -> Result<Vec<CheckpointId>, CheckpointStoreError>;
193    async fn delete(&self, id: &CheckpointId) -> Result<bool, CheckpointStoreError>;
194    async fn prune(&self, trace_id: &TraceId, keep: usize) -> Result<usize, CheckpointStoreError>;
195}
196
197// ─── Checkpoint ─────────────────────────────────────────────────
198
199/// Checkpoint ID。
200#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
201pub struct CheckpointId(pub uuid::Uuid);
202
203impl std::fmt::Display for CheckpointId {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        write!(f, "{}", self.0)
206    }
207}
208
209/// 执行游标。
210#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
211pub struct NodeId(pub String);
212
213impl std::fmt::Display for NodeId {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        write!(f, "{}", self.0)
216    }
217}
218
219/// Checkpoint — Materialized State + Execution Cursor。
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct Checkpoint {
222    pub checkpoint_id: CheckpointId,
223    pub parent_trace_id: TraceId,
224    pub graph_hash: String,
225    pub current_node: NodeId,
226    pub state: State,
227    pub created_at: String,
228    pub snapshot: Option<StateSnapshot>,
229}
230
231impl Checkpoint {
232    pub fn new(
233        parent_trace_id: TraceId,
234        graph_hash: impl Into<String>,
235        current_node: impl Into<String>,
236        state: State,
237    ) -> Self {
238        Self {
239            checkpoint_id: CheckpointId(uuid::Uuid::new_v4()),
240            parent_trace_id,
241            graph_hash: graph_hash.into(),
242            current_node: NodeId(current_node.into()),
243            state,
244            created_at: chrono_like_timestamp(),
245            snapshot: None,
246        }
247    }
248
249    pub fn with_snapshot(
250        parent_trace_id: TraceId,
251        graph_hash: impl Into<String>,
252        current_node: impl Into<String>,
253        current_state: State,
254        base_snapshot: State,
255        recent_deltas: Vec<StateDelta>,
256    ) -> Self {
257        Self {
258            checkpoint_id: CheckpointId(uuid::Uuid::new_v4()),
259            parent_trace_id,
260            graph_hash: graph_hash.into(),
261            current_node: NodeId(current_node.into()),
262            state: current_state,
263            created_at: chrono_like_timestamp(),
264            snapshot: Some(StateSnapshot {
265                base_snapshot,
266                recent_deltas,
267            }),
268        }
269    }
270
271    pub fn restore_state(
272        &self,
273        registry: &ReducerRegistry,
274    ) -> Result<State, crate::state::StateError> {
275        if let Some(snapshot) = &self.snapshot {
276            snapshot.restore(registry)
277        } else {
278            Ok(self.state.clone())
279        }
280    }
281
282    pub fn restore_state_simple(&self) -> State {
283        if let Some(snapshot) = &self.snapshot {
284            snapshot.restore_simple()
285        } else {
286            self.state.clone()
287        }
288    }
289}
290
291/// 增量 State 快照。
292#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct StateSnapshot {
294    pub base_snapshot: State,
295    pub recent_deltas: Vec<StateDelta>,
296}
297
298impl StateSnapshot {
299    pub fn restore(&self, registry: &ReducerRegistry) -> Result<State, crate::state::StateError> {
300        let mut state = self.base_snapshot.clone();
301        registry.merge_deltas(&mut state, &self.recent_deltas)?;
302        Ok(state)
303    }
304
305    pub fn restore_simple(&self) -> State {
306        let mut state = self.base_snapshot.clone();
307        for delta in &self.recent_deltas {
308            match delta.op {
309                DeltaOp::Put => {
310                    state.insert(delta.key.to_string(), delta.value.clone());
311                }
312                DeltaOp::Delete => {
313                    state.remove(delta.key.as_ref());
314                }
315            }
316        }
317        state
318    }
319
320    pub fn base_size_bytes(&self) -> usize {
321        serde_json::to_vec(&self.base_snapshot)
322            .map(|v| v.len())
323            .unwrap_or(0)
324    }
325
326    pub fn deltas_size_bytes(&self) -> usize {
327        serde_json::to_vec(&self.recent_deltas)
328            .map(|v| v.len())
329            .unwrap_or(0)
330    }
331
332    pub fn total_size_bytes(&self) -> usize {
333        self.base_size_bytes() + self.deltas_size_bytes()
334    }
335
336    pub fn compact(&mut self, threshold: usize) {
337        if self.recent_deltas.len() > threshold {
338            let restored = self.restore_simple();
339            self.base_snapshot = restored;
340            self.recent_deltas.clear();
341        }
342    }
343}
344
345// ─── IncrementalSnapshotState ─────────────────────────────────
346
347/// 增量快照运行时状态。
348#[derive(Debug, Clone, Default)]
349pub struct IncrementalSnapshotState {
350    pub base_state: Option<State>,
351    pub pending_deltas: Vec<StateDelta>,
352    pub compact_threshold: usize,
353}
354
355impl IncrementalSnapshotState {
356    pub fn new(compact_threshold: usize) -> Self {
357        Self {
358            base_state: None,
359            pending_deltas: Vec::new(),
360            compact_threshold,
361        }
362    }
363
364    pub fn record_delta(&mut self, delta: StateDelta) {
365        self.pending_deltas.push(delta);
366    }
367
368    pub fn record_deltas(&mut self, deltas: Vec<StateDelta>) {
369        self.pending_deltas.extend(deltas);
370    }
371
372    pub fn snapshot(&mut self, current_state: &State) -> (Option<State>, Vec<StateDelta>, State) {
373        let base = self.base_state.clone();
374        let deltas = std::mem::take(&mut self.pending_deltas);
375
376        if base.is_some() && deltas.len() > self.compact_threshold {
377            self.base_state = Some(current_state.clone());
378            self.pending_deltas.clear();
379            return (None, Vec::new(), current_state.clone());
380        }
381
382        (base, deltas, current_state.clone())
383    }
384
385    pub fn from_checkpoint(checkpoint: &Checkpoint) -> Self {
386        if let Some(snapshot) = &checkpoint.snapshot {
387            Self {
388                base_state: Some(snapshot.base_snapshot.clone()),
389                pending_deltas: snapshot.recent_deltas.clone(),
390                compact_threshold: 20,
391            }
392        } else {
393            Self {
394                base_state: Some(checkpoint.state.clone()),
395                pending_deltas: Vec::new(),
396                compact_threshold: 20,
397            }
398        }
399    }
400
401    pub fn clear_pending(&mut self) {
402        self.pending_deltas.clear();
403    }
404}
405
406/// 图变更校验模式。
407#[derive(Debug, Clone, Copy, PartialEq, Eq)]
408pub enum GraphHashMode {
409    Strict,
410    Force,
411}
412
413// ─── ExecutionTrace ─────────────────────────────────────────────
414
415/// 节点执行记录。
416#[derive(Debug, Clone, Serialize, Deserialize)]
417pub struct ExecutionEntry {
418    pub step: usize,
419    pub node_name: String,
420    pub start_time: String,
421    pub end_time: String,
422    pub success: bool,
423    pub error: Option<String>,
424}
425
426/// Barrier 决策记录。
427#[derive(Debug, Clone, Serialize, Deserialize)]
428pub struct BarrierDecisionRecord {
429    pub barrier_id: String,
430    pub node_id: String,
431    pub decision: Value,
432    pub decided_at: String,
433}
434
435/// ExecutionTrace — Delta 历史。
436#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct ExecutionTrace {
438    pub trace_id: TraceId,
439    pub initial_state: State,
440    pub entries: Vec<ExecutionEntry>,
441    pub deltas: Vec<StateDelta>,
442    pub barrier_decisions: Vec<BarrierDecisionRecord>,
443}
444
445impl ExecutionTrace {
446    pub fn new(initial_state: State) -> Self {
447        Self {
448            trace_id: TraceId::default(),
449            initial_state,
450            entries: Vec::new(),
451            deltas: Vec::new(),
452            barrier_decisions: Vec::new(),
453        }
454    }
455}
456
457/// 图执行最终结果。
458#[derive(Debug, Clone, Serialize, Deserialize)]
459pub struct GraphResult {
460    pub trace_id: TraceId,
461    pub state: State,
462    pub execution_log: Vec<ExecutionEntry>,
463    pub duration_ms: u128,
464}
465
466// ─── Helpers ────────────────────────────────────────────────────
467
468fn chrono_like_timestamp() -> String {
469    use std::time::{SystemTime, UNIX_EPOCH};
470    let dur = SystemTime::now()
471        .duration_since(UNIX_EPOCH)
472        .unwrap_or_default();
473    let secs = dur.as_secs();
474    format!(
475        "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z",
476        ((secs / 86400 / 365) + 1970) as u16,
477        ((secs / 86400 % 365) / 30 + 1) as u8,
478        (secs / 86400 % 30 + 1) as u8,
479        (secs % 86400 / 3600) as u8,
480        (secs % 3600 / 60) as u8,
481        (secs % 60) as u8
482    )
483}