1use std::fmt::{Display, Formatter};
2use std::path::{Path, PathBuf};
3
4use burn::module::Module;
5use burn::prelude::*;
6use burn::record::DefaultRecorder;
7use serde::de::DeserializeOwned;
8use serde::{Deserialize, Serialize};
9
10use crate::config::DdlConfig;
11use crate::training::{
12 TrainingComparisonReport, TrainingOutcome, TrainingReport, TrainingSweepOutcome,
13};
14use crate::variant::{ModelInstance, ModelVariant};
15
16const MANIFEST_FILE: &str = "manifest.json";
17const WEIGHTS_STEM: &str = "model";
18const TRAINING_REPORT_FILE: &str = "training_report.json";
19const TRAINING_COMPARISON_REPORT_FILE: &str = "training_comparison.json";
20const OPTIMIZER_STATE_FILE: &str = "optimizer_state.bin";
21const BEST_VALIDATION_DIR: &str = "best_validation";
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct CheckpointManifest {
25 pub variant: ModelVariant,
26 pub config: DdlConfig,
27 pub num_params: usize,
28 pub weights_stem: String,
29}
30
31impl CheckpointManifest {
32 pub fn new(variant: ModelVariant, config: DdlConfig, num_params: usize) -> Self {
33 Self {
34 variant,
35 config,
36 num_params,
37 weights_stem: WEIGHTS_STEM.to_string(),
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
43pub struct LoadedTrainingArtifact<B: Backend> {
44 pub manifest: CheckpointManifest,
45 pub report: TrainingReport,
46 pub model: ModelInstance<B>,
47 pub best_validation_model: Option<ModelInstance<B>>,
48 pub optimizer_state: Option<Vec<u8>>,
49}
50
51#[derive(Debug)]
52pub enum CheckpointError {
53 Io(std::io::Error),
54 Serde(serde_json::Error),
55 Recorder(burn::record::RecorderError),
56 ArtifactMismatch(String),
57 MissingBestValidationArtifact(PathBuf),
58 ModelVariantMismatch {
59 variant: ModelVariant,
60 actual_model: &'static str,
61 },
62}
63
64impl Display for CheckpointError {
65 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
66 match self {
67 Self::Io(error) => write!(f, "checkpoint I/O failed: {error}"),
68 Self::Serde(error) => write!(f, "checkpoint manifest serialization failed: {error}"),
69 Self::Recorder(error) => write!(f, "checkpoint weight serialization failed: {error}"),
70 Self::ArtifactMismatch(message) => {
71 write!(f, "training artifact is inconsistent: {message}")
72 }
73 Self::MissingBestValidationArtifact(path) => write!(
74 f,
75 "training artifact is missing the best-validation checkpoint at {}",
76 path.display()
77 ),
78 Self::ModelVariantMismatch {
79 variant,
80 actual_model,
81 } => write!(
82 f,
83 "model kind {actual_model} does not match checkpoint variant {}",
84 variant.slug()
85 ),
86 }
87 }
88}
89
90impl std::error::Error for CheckpointError {
91 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
92 match self {
93 Self::Io(error) => Some(error),
94 Self::Serde(error) => Some(error),
95 Self::Recorder(error) => Some(error),
96 Self::ArtifactMismatch(_) => None,
97 Self::MissingBestValidationArtifact(_) => None,
98 Self::ModelVariantMismatch { .. } => None,
99 }
100 }
101}
102
103impl From<std::io::Error> for CheckpointError {
104 fn from(error: std::io::Error) -> Self {
105 Self::Io(error)
106 }
107}
108
109impl From<serde_json::Error> for CheckpointError {
110 fn from(error: serde_json::Error) -> Self {
111 Self::Serde(error)
112 }
113}
114
115impl From<burn::record::RecorderError> for CheckpointError {
116 fn from(error: burn::record::RecorderError) -> Self {
117 Self::Recorder(error)
118 }
119}
120
121pub fn save_checkpoint<B: Backend>(
122 model: &ModelInstance<B>,
123 variant: ModelVariant,
124 config: &DdlConfig,
125 directory: impl AsRef<Path>,
126) -> Result<CheckpointManifest, CheckpointError> {
127 ensure_variant_matches_model(model, variant)?;
128
129 let directory = directory.as_ref();
130 let resolved_config = variant.resolve_config(config);
131 let manifest = CheckpointManifest::new(variant, resolved_config, model.num_params());
132 let recorder = DefaultRecorder::new();
133
134 std::fs::create_dir_all(directory)?;
135 match model {
136 ModelInstance::Baseline(model) => {
137 model
138 .as_ref()
139 .clone()
140 .save_file(weights_path(directory, &manifest.weights_stem), &recorder)?;
141 }
142 ModelInstance::Ddl(model) => {
143 model
144 .as_ref()
145 .clone()
146 .save_file(weights_path(directory, &manifest.weights_stem), &recorder)?;
147 }
148 }
149
150 let manifest_json = serde_json::to_string_pretty(&manifest)?;
151 std::fs::write(manifest_path(directory), manifest_json)?;
152
153 Ok(manifest)
154}
155
156pub fn load_checkpoint<B: Backend>(
157 directory: impl AsRef<Path>,
158 device: &B::Device,
159) -> Result<(CheckpointManifest, ModelInstance<B>), CheckpointError> {
160 let directory = directory.as_ref();
161 let manifest = load_manifest(directory)?;
162 let recorder = DefaultRecorder::new();
163 let model = match manifest.variant.init_model(&manifest.config, device) {
164 ModelInstance::Baseline(model) => {
165 ModelInstance::Baseline(Box::new(model.as_ref().clone().load_file(
166 weights_path(directory, &manifest.weights_stem),
167 &recorder,
168 device,
169 )?))
170 }
171 ModelInstance::Ddl(model) => {
172 ModelInstance::Ddl(Box::new(model.as_ref().clone().load_file(
173 weights_path(directory, &manifest.weights_stem),
174 &recorder,
175 device,
176 )?))
177 }
178 };
179
180 Ok((manifest, model))
181}
182
183pub fn load_manifest(directory: impl AsRef<Path>) -> Result<CheckpointManifest, CheckpointError> {
184 let manifest = std::fs::read_to_string(manifest_path(directory.as_ref()))?;
185 Ok(serde_json::from_str(&manifest)?)
186}
187
188pub fn load_training_artifact<B: Backend>(
189 directory: impl AsRef<Path>,
190 device: &B::Device,
191) -> Result<LoadedTrainingArtifact<B>, CheckpointError> {
192 let directory = directory.as_ref();
193 let report = load_training_report(directory)?;
194 let (manifest, model) = load_checkpoint::<B>(directory, device)?;
195
196 if report.variant != manifest.variant {
197 return Err(CheckpointError::ArtifactMismatch(format!(
198 "report variant {} does not match manifest variant {}",
199 report.variant.slug(),
200 manifest.variant.slug()
201 )));
202 }
203 if report.config != manifest.config {
204 return Err(CheckpointError::ArtifactMismatch(
205 "report config does not match manifest config".to_string(),
206 ));
207 }
208 if report.num_params != manifest.num_params {
209 return Err(CheckpointError::ArtifactMismatch(format!(
210 "report num_params {} does not match manifest {}",
211 report.num_params, manifest.num_params
212 )));
213 }
214
215 let best_validation_directory = best_validation_path(directory);
216 let best_validation_model = if report.best_validation.is_none() {
217 None
218 } else if best_validation_directory.join(MANIFEST_FILE).is_file() {
219 Some(load_checkpoint::<B>(&best_validation_directory, device)?.1)
220 } else if report.best_validation_step == Some(report.steps_completed) {
221 Some(model.clone())
222 } else {
223 return Err(CheckpointError::MissingBestValidationArtifact(
224 best_validation_directory,
225 ));
226 };
227 let optimizer_state = optimizer_state_path(directory)
228 .is_file()
229 .then(|| std::fs::read(optimizer_state_path(directory)))
230 .transpose()?;
231
232 Ok(LoadedTrainingArtifact {
233 manifest,
234 report,
235 model,
236 best_validation_model,
237 optimizer_state,
238 })
239}
240
241pub fn save_training_report(
242 report: &TrainingReport,
243 directory: impl AsRef<Path>,
244) -> Result<(), CheckpointError> {
245 let directory = directory.as_ref();
246 std::fs::create_dir_all(directory)?;
247 write_json(training_report_path(directory), report)
248}
249
250pub fn load_training_report(
251 directory: impl AsRef<Path>,
252) -> Result<TrainingReport, CheckpointError> {
253 read_json(training_report_path(directory.as_ref()))
254}
255
256pub fn save_training_comparison_report(
257 report: &TrainingComparisonReport,
258 directory: impl AsRef<Path>,
259) -> Result<(), CheckpointError> {
260 let directory = directory.as_ref();
261 std::fs::create_dir_all(directory)?;
262 write_json(training_comparison_report_path(directory), report)
263}
264
265pub fn load_training_comparison_report(
266 directory: impl AsRef<Path>,
267) -> Result<TrainingComparisonReport, CheckpointError> {
268 read_json(training_comparison_report_path(directory.as_ref()))
269}
270
271pub fn save_training_artifact<B: Backend>(
272 outcome: &TrainingOutcome<B>,
273 directory: impl AsRef<Path>,
274) -> Result<CheckpointManifest, CheckpointError> {
275 let directory = directory.as_ref();
276 let manifest = save_checkpoint(
277 &outcome.model,
278 outcome.report.variant,
279 &outcome.report.config,
280 directory,
281 )?;
282 save_training_report(&outcome.report, directory)?;
283 std::fs::write(optimizer_state_path(directory), &outcome.optimizer_state)?;
284
285 let best_validation_directory = best_validation_path(directory);
286 if best_validation_directory.exists() {
287 std::fs::remove_dir_all(&best_validation_directory)?;
288 }
289 if let Some(best_model) = outcome.best_validation_model.as_ref() {
290 save_checkpoint(
291 best_model,
292 outcome.report.variant,
293 &outcome.report.config,
294 &best_validation_directory,
295 )?;
296 save_training_report(&outcome.report, &best_validation_directory)?;
297 }
298
299 Ok(manifest)
300}
301
302pub fn save_training_sweep<B: Backend>(
303 sweep: &TrainingSweepOutcome<B>,
304 directory: impl AsRef<Path>,
305) -> Result<Vec<CheckpointManifest>, CheckpointError> {
306 let directory = directory.as_ref();
307 std::fs::create_dir_all(directory)?;
308 save_training_comparison_report(&sweep.report, directory)?;
309
310 sweep
311 .outcomes
312 .iter()
313 .map(|outcome| {
314 save_training_artifact(outcome, directory.join(outcome.report.variant.slug()))
315 })
316 .collect()
317}
318
319fn ensure_variant_matches_model<B: Backend>(
320 model: &ModelInstance<B>,
321 variant: ModelVariant,
322) -> Result<(), CheckpointError> {
323 match (model, variant.uses_ddl()) {
324 (ModelInstance::Baseline(_), false) | (ModelInstance::Ddl(_), true) => Ok(()),
325 (ModelInstance::Baseline(_), true) => Err(CheckpointError::ModelVariantMismatch {
326 variant,
327 actual_model: "baseline",
328 }),
329 (ModelInstance::Ddl(_), false) => Err(CheckpointError::ModelVariantMismatch {
330 variant,
331 actual_model: "ddl",
332 }),
333 }
334}
335
336fn manifest_path(directory: &Path) -> PathBuf {
337 directory.join(MANIFEST_FILE)
338}
339
340fn weights_path(directory: &Path, weights_stem: &str) -> PathBuf {
341 directory.join(weights_stem)
342}
343
344fn training_report_path(directory: &Path) -> PathBuf {
345 directory.join(TRAINING_REPORT_FILE)
346}
347
348fn training_comparison_report_path(directory: &Path) -> PathBuf {
349 directory.join(TRAINING_COMPARISON_REPORT_FILE)
350}
351
352fn optimizer_state_path(directory: &Path) -> PathBuf {
353 directory.join(OPTIMIZER_STATE_FILE)
354}
355
356fn best_validation_path(directory: &Path) -> PathBuf {
357 directory.join(BEST_VALIDATION_DIR)
358}
359
360fn write_json<T: Serialize>(path: PathBuf, value: &T) -> Result<(), CheckpointError> {
361 std::fs::write(path, serde_json::to_string_pretty(value)?)?;
362 Ok(())
363}
364
365fn read_json<T: DeserializeOwned>(path: PathBuf) -> Result<T, CheckpointError> {
366 let contents = std::fs::read_to_string(path)?;
367 Ok(serde_json::from_str(&contents)?)
368}