Skip to main content

content_extractor_rl/
checkpoint.rs

1//! Model checkpoint management
2// ============================================================================
3// FILE: crates/content-extractor-rl/src/checkpoint.rs
4// ============================================================================
5
6use crate::Result;
7use serde::{Deserialize, Serialize};
8use std::path::{Path, PathBuf};
9use std::fs;
10
11/// Model checkpoint metadata
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Checkpoint {
14    pub episode: usize,
15    pub step_count: usize,
16    pub avg_reward: f32,
17    pub avg_quality: f32,
18    pub best_quality: f32,
19    pub epsilon: f32,
20    pub timestamp: String,
21    pub model_path: PathBuf,
22    pub optimizer_state: Option<PathBuf>,
23}
24
25impl Checkpoint {
26    /// Create new checkpoint
27    pub fn new(
28        episode: usize,
29        step_count: usize,
30        avg_reward: f32,
31        avg_quality: f32,
32        best_quality: f32,
33        epsilon: f32,
34        model_path: PathBuf,
35    ) -> Self {
36        Self {
37            episode,
38            step_count,
39            avg_reward,
40            avg_quality,
41            best_quality,
42            epsilon,
43            timestamp: chrono::Utc::now().to_rfc3339(),
44            model_path,
45            optimizer_state: None,
46        }
47    }
48
49    /// Save checkpoint to JSON
50    pub fn save(&self, path: &Path) -> Result<()> {
51        let json = serde_json::to_string_pretty(self)?;
52        fs::write(path, json)?;
53        Ok(())
54    }
55
56    /// Load checkpoint from JSON
57    pub fn load(path: &Path) -> Result<Self> {
58        let json = fs::read_to_string(path)?;
59        let checkpoint = serde_json::from_str(&json)?;
60        Ok(checkpoint)
61    }
62}
63
64/// Checkpoint manager
65pub struct CheckpointManager {
66    checkpoints_dir: PathBuf,
67    max_checkpoints: usize,
68}
69
70impl CheckpointManager {
71    /// Create new checkpoint manager
72    pub fn new(checkpoints_dir: PathBuf, max_checkpoints: usize) -> Result<Self> {
73        fs::create_dir_all(&checkpoints_dir)?;
74        Ok(Self {
75            checkpoints_dir,
76            max_checkpoints,
77        })
78    }
79
80    /// Save checkpoint
81    pub fn save_checkpoint(&self, checkpoint: &Checkpoint) -> Result<()> {
82        let checkpoint_file = self.checkpoints_dir.join(format!(
83            "checkpoint_ep{}.json",
84            checkpoint.episode
85        ));
86
87        checkpoint.save(&checkpoint_file)?;
88
89        // Clean up old checkpoints
90        self.cleanup_old_checkpoints()?;
91
92        Ok(())
93    }
94
95    /// Load latest checkpoint
96    pub fn load_latest(&self) -> Result<Option<Checkpoint>> {
97        let mut checkpoints = self.list_checkpoints()?;
98
99        if checkpoints.is_empty() {
100            return Ok(None);
101        }
102
103        checkpoints.sort_by_key(|c| c.episode);
104        let latest = checkpoints.last().unwrap();
105
106        Ok(Some(latest.clone()))
107    }
108
109    /// Load best checkpoint (by quality)
110    pub fn load_best(&self) -> Result<Option<Checkpoint>> {
111        let checkpoints = self.list_checkpoints()?;
112
113        if checkpoints.is_empty() {
114            return Ok(None);
115        }
116
117        let best = checkpoints.iter()
118            .max_by(|a, b| a.best_quality.partial_cmp(&b.best_quality).unwrap())
119            .cloned();
120
121        Ok(best)
122    }
123
124    /// List all checkpoints
125    pub fn list_checkpoints(&self) -> Result<Vec<Checkpoint>> {
126        let mut checkpoints = Vec::new();
127
128        for entry in fs::read_dir(&self.checkpoints_dir)? {
129            let entry = entry?;
130            let path = entry.path();
131
132            if path.extension().and_then(|s| s.to_str()) == Some("json") {
133                if let Ok(checkpoint) = Checkpoint::load(&path) {
134                    checkpoints.push(checkpoint);
135                }
136            }
137        }
138
139        Ok(checkpoints)
140    }
141
142    /// Clean up old checkpoints, keeping only the most recent ones
143    fn cleanup_old_checkpoints(&self) -> Result<()> {
144        let mut checkpoints = self.list_checkpoints()?;
145
146        if checkpoints.len() <= self.max_checkpoints {
147            return Ok(());
148        }
149
150        // Sort by episode
151        checkpoints.sort_by_key(|c| c.episode);
152
153        // Remove oldest checkpoints
154        let to_remove = checkpoints.len() - self.max_checkpoints;
155
156        for checkpoint in checkpoints.iter().take(to_remove) {
157            let checkpoint_file = self.checkpoints_dir.join(format!(
158                "checkpoint_ep{}.json",
159                checkpoint.episode
160            ));
161
162            if checkpoint_file.exists() {
163                fs::remove_file(checkpoint_file)?;
164            }
165
166            // Also remove model file if it exists
167            if checkpoint.model_path.exists() {
168                fs::remove_file(&checkpoint.model_path)?;
169            }
170        }
171
172        Ok(())
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use tempfile::TempDir;
180
181    #[test]
182    fn test_checkpoint_save_load() {
183        let temp_dir = TempDir::new().unwrap();
184        let checkpoint_path = temp_dir.path().join("checkpoint.json");
185
186        let checkpoint = Checkpoint::new(
187            100,
188            5000,
189            0.5,
190            0.7,
191            0.8,
192            0.1,
193            PathBuf::from("model.onnx"),
194        );
195
196        checkpoint.save(&checkpoint_path).unwrap();
197        let loaded = Checkpoint::load(&checkpoint_path).unwrap();
198
199        assert_eq!(loaded.episode, 100);
200        assert_eq!(loaded.step_count, 5000);
201    }
202
203    #[test]
204    fn test_checkpoint_manager() {
205        let temp_dir = TempDir::new().unwrap();
206        let manager = CheckpointManager::new(temp_dir.path().to_path_buf(), 3).unwrap();
207
208        // Save multiple checkpoints
209        for i in 0..5 {
210            let checkpoint = Checkpoint::new(
211                i * 100,
212                i * 1000,
213                0.5,
214                0.7,
215                0.8,
216                0.1,
217                PathBuf::from(format!("model_{}.onnx", i)),
218            );
219            manager.save_checkpoint(&checkpoint).unwrap();
220        }
221
222        // Should only keep 3 most recent
223        let checkpoints = manager.list_checkpoints().unwrap();
224        assert!(checkpoints.len() <= 3);
225    }
226}