content_extractor_rl/
checkpoint.rs1use crate::Result;
7use serde::{Deserialize, Serialize};
8use std::path::{Path, PathBuf};
9use std::fs;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Checkpoint {
14 pub episode: usize,
15 pub step_count: usize,
16 pub avg_reward: f32,
17 pub avg_quality: f32,
18 pub best_quality: f32,
19 pub epsilon: f32,
20 pub timestamp: String,
21 pub model_path: PathBuf,
22 pub optimizer_state: Option<PathBuf>,
23}
24
25impl Checkpoint {
26 pub fn new(
28 episode: usize,
29 step_count: usize,
30 avg_reward: f32,
31 avg_quality: f32,
32 best_quality: f32,
33 epsilon: f32,
34 model_path: PathBuf,
35 ) -> Self {
36 Self {
37 episode,
38 step_count,
39 avg_reward,
40 avg_quality,
41 best_quality,
42 epsilon,
43 timestamp: chrono::Utc::now().to_rfc3339(),
44 model_path,
45 optimizer_state: None,
46 }
47 }
48
49 pub fn save(&self, path: &Path) -> Result<()> {
51 let json = serde_json::to_string_pretty(self)?;
52 fs::write(path, json)?;
53 Ok(())
54 }
55
56 pub fn load(path: &Path) -> Result<Self> {
58 let json = fs::read_to_string(path)?;
59 let checkpoint = serde_json::from_str(&json)?;
60 Ok(checkpoint)
61 }
62}
63
64pub struct CheckpointManager {
66 checkpoints_dir: PathBuf,
67 max_checkpoints: usize,
68}
69
70impl CheckpointManager {
71 pub fn new(checkpoints_dir: PathBuf, max_checkpoints: usize) -> Result<Self> {
73 fs::create_dir_all(&checkpoints_dir)?;
74 Ok(Self {
75 checkpoints_dir,
76 max_checkpoints,
77 })
78 }
79
80 pub fn save_checkpoint(&self, checkpoint: &Checkpoint) -> Result<()> {
82 let checkpoint_file = self.checkpoints_dir.join(format!(
83 "checkpoint_ep{}.json",
84 checkpoint.episode
85 ));
86
87 checkpoint.save(&checkpoint_file)?;
88
89 self.cleanup_old_checkpoints()?;
91
92 Ok(())
93 }
94
95 pub fn load_latest(&self) -> Result<Option<Checkpoint>> {
97 let mut checkpoints = self.list_checkpoints()?;
98
99 if checkpoints.is_empty() {
100 return Ok(None);
101 }
102
103 checkpoints.sort_by_key(|c| c.episode);
104 let latest = checkpoints.last().unwrap();
105
106 Ok(Some(latest.clone()))
107 }
108
109 pub fn load_best(&self) -> Result<Option<Checkpoint>> {
111 let checkpoints = self.list_checkpoints()?;
112
113 if checkpoints.is_empty() {
114 return Ok(None);
115 }
116
117 let best = checkpoints.iter()
118 .max_by(|a, b| a.best_quality.partial_cmp(&b.best_quality).unwrap())
119 .cloned();
120
121 Ok(best)
122 }
123
124 pub fn list_checkpoints(&self) -> Result<Vec<Checkpoint>> {
126 let mut checkpoints = Vec::new();
127
128 for entry in fs::read_dir(&self.checkpoints_dir)? {
129 let entry = entry?;
130 let path = entry.path();
131
132 if path.extension().and_then(|s| s.to_str()) == Some("json") {
133 if let Ok(checkpoint) = Checkpoint::load(&path) {
134 checkpoints.push(checkpoint);
135 }
136 }
137 }
138
139 Ok(checkpoints)
140 }
141
142 fn cleanup_old_checkpoints(&self) -> Result<()> {
144 let mut checkpoints = self.list_checkpoints()?;
145
146 if checkpoints.len() <= self.max_checkpoints {
147 return Ok(());
148 }
149
150 checkpoints.sort_by_key(|c| c.episode);
152
153 let to_remove = checkpoints.len() - self.max_checkpoints;
155
156 for checkpoint in checkpoints.iter().take(to_remove) {
157 let checkpoint_file = self.checkpoints_dir.join(format!(
158 "checkpoint_ep{}.json",
159 checkpoint.episode
160 ));
161
162 if checkpoint_file.exists() {
163 fs::remove_file(checkpoint_file)?;
164 }
165
166 if checkpoint.model_path.exists() {
168 fs::remove_file(&checkpoint.model_path)?;
169 }
170 }
171
172 Ok(())
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use tempfile::TempDir;
180
181 #[test]
182 fn test_checkpoint_save_load() {
183 let temp_dir = TempDir::new().unwrap();
184 let checkpoint_path = temp_dir.path().join("checkpoint.json");
185
186 let checkpoint = Checkpoint::new(
187 100,
188 5000,
189 0.5,
190 0.7,
191 0.8,
192 0.1,
193 PathBuf::from("model.onnx"),
194 );
195
196 checkpoint.save(&checkpoint_path).unwrap();
197 let loaded = Checkpoint::load(&checkpoint_path).unwrap();
198
199 assert_eq!(loaded.episode, 100);
200 assert_eq!(loaded.step_count, 5000);
201 }
202
203 #[test]
204 fn test_checkpoint_manager() {
205 let temp_dir = TempDir::new().unwrap();
206 let manager = CheckpointManager::new(temp_dir.path().to_path_buf(), 3).unwrap();
207
208 for i in 0..5 {
210 let checkpoint = Checkpoint::new(
211 i * 100,
212 i * 1000,
213 0.5,
214 0.7,
215 0.8,
216 0.1,
217 PathBuf::from(format!("model_{}.onnx", i)),
218 );
219 manager.save_checkpoint(&checkpoint).unwrap();
220 }
221
222 let checkpoints = manager.list_checkpoints().unwrap();
224 assert!(checkpoints.len() <= 3);
225 }
226}