Skip to main content

agent_sdk_store_file/
checkpoint.rs

1use std::path::PathBuf;
2
3use agent_sdk_core::{
4    AgentError, CheckpointPrunePolicy, CheckpointPruneReport, CheckpointSaveOutcome,
5    CheckpointStore, JournalCursor, RunCheckpoint, RunId,
6};
7
8use crate::util::{read_json, remove_file_if_exists, root_join, safe_segment, write_json};
9
10#[derive(Clone, Debug)]
11/// Filesystem-backed checkpoint accelerator store.
12pub struct FileCheckpointStore {
13    root: PathBuf,
14}
15
16impl FileCheckpointStore {
17    /// Creates a checkpoint store rooted under the provided directory.
18    pub fn new(root: impl Into<PathBuf>) -> Self {
19        Self { root: root.into() }
20    }
21
22    fn checkpoint_path(&self, checkpoint: &RunCheckpoint) -> PathBuf {
23        root_join(
24            &self.root,
25            &[
26                "runs".to_string(),
27                safe_segment(checkpoint.run_id.as_str()),
28                "checkpoints".to_string(),
29                format!(
30                    "{:020}-{}.json",
31                    checkpoint.covers_journal_seq,
32                    safe_segment(&checkpoint.checkpoint_id)
33                ),
34            ],
35        )
36    }
37
38    fn checkpoint_dir(&self, run_id: &RunId) -> PathBuf {
39        root_join(
40            &self.root,
41            &[
42                "runs".to_string(),
43                safe_segment(run_id.as_str()),
44                "checkpoints".to_string(),
45            ],
46        )
47    }
48
49    fn list(&self, run_id: &RunId) -> Result<Vec<(PathBuf, RunCheckpoint)>, AgentError> {
50        let dir = self.checkpoint_dir(run_id);
51        if !dir.exists() {
52            return Ok(Vec::new());
53        }
54        let mut entries = Vec::new();
55        for entry in std::fs::read_dir(dir).map_err(|error| {
56            AgentError::new(
57                agent_sdk_core::AgentErrorKind::RecoveryRepairNeeded,
58                agent_sdk_core::RetryClassification::Retryable,
59                error.to_string(),
60            )
61        })? {
62            let path = entry.map_err(|error| {
63                AgentError::new(
64                    agent_sdk_core::AgentErrorKind::RecoveryRepairNeeded,
65                    agent_sdk_core::RetryClassification::Retryable,
66                    error.to_string(),
67                )
68            })?;
69            let path = path.path();
70            if path.extension().and_then(|ext| ext.to_str()) != Some("json") {
71                continue;
72            }
73            if let Some(checkpoint) = read_json::<RunCheckpoint>(&path)? {
74                entries.push((path, checkpoint));
75            }
76        }
77        entries.sort_by_key(|(_, checkpoint)| {
78            (
79                checkpoint.covers_journal_seq,
80                checkpoint.checkpoint_seq,
81                checkpoint.created_at_millis,
82            )
83        });
84        Ok(entries)
85    }
86}
87
88impl CheckpointStore for FileCheckpointStore {
89    fn save(
90        &self,
91        checkpoint: RunCheckpoint,
92        latest_journal_seq: u64,
93    ) -> Result<CheckpointSaveOutcome, AgentError> {
94        checkpoint.validate_against_latest_seq(latest_journal_seq)?;
95        let checkpoint_ref = checkpoint.checkpoint_id.clone();
96        let covers_journal_seq = checkpoint.covers_journal_seq;
97        let terminal_checkpoint = checkpoint.pending_side_effects.is_empty()
98            && checkpoint.pending_approvals.is_empty()
99            && checkpoint.loop_state == "terminal";
100        write_json(&self.checkpoint_path(&checkpoint), &checkpoint)?;
101        Ok(CheckpointSaveOutcome {
102            checkpoint_ref,
103            covers_journal_seq,
104            terminal_checkpoint,
105        })
106    }
107
108    fn load_latest(&self, run_id: &RunId) -> Result<Option<RunCheckpoint>, AgentError> {
109        Ok(self
110            .list(run_id)?
111            .into_iter()
112            .map(|(_, checkpoint)| checkpoint)
113            .max_by_key(|checkpoint| {
114                (
115                    checkpoint.covers_journal_seq,
116                    checkpoint.checkpoint_seq,
117                    checkpoint.created_at_millis,
118                )
119            }))
120    }
121
122    fn load_at_or_before(
123        &self,
124        run_id: &RunId,
125        cursor: &JournalCursor,
126    ) -> Result<Option<RunCheckpoint>, AgentError> {
127        let cursor_seq = cursor
128            .as_str()
129            .strip_prefix("journal.")
130            .unwrap_or(cursor.as_str())
131            .parse::<u64>()
132            .unwrap_or(0);
133        Ok(self
134            .list(run_id)?
135            .into_iter()
136            .map(|(_, checkpoint)| checkpoint)
137            .filter(|checkpoint| checkpoint.covers_journal_seq <= cursor_seq)
138            .max_by_key(|checkpoint| {
139                (
140                    checkpoint.covers_journal_seq,
141                    checkpoint.checkpoint_seq,
142                    checkpoint.created_at_millis,
143                )
144            }))
145    }
146
147    fn prune(
148        &self,
149        run_id: &RunId,
150        policy: CheckpointPrunePolicy,
151    ) -> Result<CheckpointPruneReport, AgentError> {
152        let entries = self.list(run_id)?;
153        let terminal_to_preserve = policy.preserve_latest_terminal.then(|| {
154            entries
155                .iter()
156                .filter(|(_, checkpoint)| {
157                    checkpoint.pending_side_effects.is_empty()
158                        && checkpoint.pending_approvals.is_empty()
159                        && checkpoint.loop_state == "terminal"
160                })
161                .max_by_key(|(_, checkpoint)| {
162                    (
163                        checkpoint.covers_journal_seq,
164                        checkpoint.checkpoint_seq,
165                        checkpoint.created_at_millis,
166                    )
167                })
168                .map(|(_, checkpoint)| checkpoint.checkpoint_id.clone())
169        });
170        let terminal_to_preserve = terminal_to_preserve.flatten();
171
172        let mut pruned_count = 0;
173        let mut retained_count = 0;
174        for (path, checkpoint) in entries {
175            let preserve_terminal = terminal_to_preserve
176                .as_ref()
177                .is_some_and(|id| id == &checkpoint.checkpoint_id);
178            if checkpoint.covers_journal_seq < policy.prune_covered_before && !preserve_terminal {
179                remove_file_if_exists(&path)?;
180                pruned_count += 1;
181            } else {
182                retained_count += 1;
183            }
184        }
185
186        Ok(CheckpointPruneReport {
187            run_id: run_id.clone(),
188            pruned_count,
189            retained_count,
190            preserved_terminal_checkpoint: terminal_to_preserve,
191        })
192    }
193}