Skip to main content

deep_delta_learning/
checkpoint.rs

1use std::fmt::{Display, Formatter};
2use std::path::{Path, PathBuf};
3
4use burn::module::Module;
5use burn::prelude::*;
6use burn::record::DefaultRecorder;
7use serde::de::DeserializeOwned;
8use serde::{Deserialize, Serialize};
9
10use crate::config::DdlConfig;
11use crate::training::{
12    TrainingComparisonReport, TrainingOutcome, TrainingReport, TrainingSweepOutcome,
13};
14use crate::variant::{ModelInstance, ModelVariant};
15
16const MANIFEST_FILE: &str = "manifest.json";
17const WEIGHTS_STEM: &str = "model";
18const TRAINING_REPORT_FILE: &str = "training_report.json";
19const TRAINING_COMPARISON_REPORT_FILE: &str = "training_comparison.json";
20const OPTIMIZER_STATE_FILE: &str = "optimizer_state.bin";
21const BEST_VALIDATION_DIR: &str = "best_validation";
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct CheckpointManifest {
25    pub variant: ModelVariant,
26    pub config: DdlConfig,
27    pub num_params: usize,
28    pub weights_stem: String,
29}
30
31impl CheckpointManifest {
32    pub fn new(variant: ModelVariant, config: DdlConfig, num_params: usize) -> Self {
33        Self {
34            variant,
35            config,
36            num_params,
37            weights_stem: WEIGHTS_STEM.to_string(),
38        }
39    }
40}
41
42#[derive(Debug, Clone)]
43pub struct LoadedTrainingArtifact<B: Backend> {
44    pub manifest: CheckpointManifest,
45    pub report: TrainingReport,
46    pub model: ModelInstance<B>,
47    pub best_validation_model: Option<ModelInstance<B>>,
48    pub optimizer_state: Option<Vec<u8>>,
49}
50
51#[derive(Debug)]
52pub enum CheckpointError {
53    Io(std::io::Error),
54    Serde(serde_json::Error),
55    Recorder(burn::record::RecorderError),
56    ArtifactMismatch(String),
57    MissingBestValidationArtifact(PathBuf),
58    ModelVariantMismatch {
59        variant: ModelVariant,
60        actual_model: &'static str,
61    },
62}
63
64impl Display for CheckpointError {
65    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
66        match self {
67            Self::Io(error) => write!(f, "checkpoint I/O failed: {error}"),
68            Self::Serde(error) => write!(f, "checkpoint manifest serialization failed: {error}"),
69            Self::Recorder(error) => write!(f, "checkpoint weight serialization failed: {error}"),
70            Self::ArtifactMismatch(message) => {
71                write!(f, "training artifact is inconsistent: {message}")
72            }
73            Self::MissingBestValidationArtifact(path) => write!(
74                f,
75                "training artifact is missing the best-validation checkpoint at {}",
76                path.display()
77            ),
78            Self::ModelVariantMismatch {
79                variant,
80                actual_model,
81            } => write!(
82                f,
83                "model kind {actual_model} does not match checkpoint variant {}",
84                variant.slug()
85            ),
86        }
87    }
88}
89
90impl std::error::Error for CheckpointError {
91    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
92        match self {
93            Self::Io(error) => Some(error),
94            Self::Serde(error) => Some(error),
95            Self::Recorder(error) => Some(error),
96            Self::ArtifactMismatch(_) => None,
97            Self::MissingBestValidationArtifact(_) => None,
98            Self::ModelVariantMismatch { .. } => None,
99        }
100    }
101}
102
103impl From<std::io::Error> for CheckpointError {
104    fn from(error: std::io::Error) -> Self {
105        Self::Io(error)
106    }
107}
108
109impl From<serde_json::Error> for CheckpointError {
110    fn from(error: serde_json::Error) -> Self {
111        Self::Serde(error)
112    }
113}
114
115impl From<burn::record::RecorderError> for CheckpointError {
116    fn from(error: burn::record::RecorderError) -> Self {
117        Self::Recorder(error)
118    }
119}
120
121pub fn save_checkpoint<B: Backend>(
122    model: &ModelInstance<B>,
123    variant: ModelVariant,
124    config: &DdlConfig,
125    directory: impl AsRef<Path>,
126) -> Result<CheckpointManifest, CheckpointError> {
127    ensure_variant_matches_model(model, variant)?;
128
129    let directory = directory.as_ref();
130    let resolved_config = variant.resolve_config(config);
131    let manifest = CheckpointManifest::new(variant, resolved_config, model.num_params());
132    let recorder = DefaultRecorder::new();
133
134    std::fs::create_dir_all(directory)?;
135    match model {
136        ModelInstance::Baseline(model) => {
137            model
138                .as_ref()
139                .clone()
140                .save_file(weights_path(directory, &manifest.weights_stem), &recorder)?;
141        }
142        ModelInstance::Ddl(model) => {
143            model
144                .as_ref()
145                .clone()
146                .save_file(weights_path(directory, &manifest.weights_stem), &recorder)?;
147        }
148    }
149
150    let manifest_json = serde_json::to_string_pretty(&manifest)?;
151    std::fs::write(manifest_path(directory), manifest_json)?;
152
153    Ok(manifest)
154}
155
156pub fn load_checkpoint<B: Backend>(
157    directory: impl AsRef<Path>,
158    device: &B::Device,
159) -> Result<(CheckpointManifest, ModelInstance<B>), CheckpointError> {
160    let directory = directory.as_ref();
161    let manifest = load_manifest(directory)?;
162    let recorder = DefaultRecorder::new();
163    let model = match manifest.variant.init_model(&manifest.config, device) {
164        ModelInstance::Baseline(model) => {
165            ModelInstance::Baseline(Box::new(model.as_ref().clone().load_file(
166                weights_path(directory, &manifest.weights_stem),
167                &recorder,
168                device,
169            )?))
170        }
171        ModelInstance::Ddl(model) => {
172            ModelInstance::Ddl(Box::new(model.as_ref().clone().load_file(
173                weights_path(directory, &manifest.weights_stem),
174                &recorder,
175                device,
176            )?))
177        }
178    };
179
180    Ok((manifest, model))
181}
182
183pub fn load_manifest(directory: impl AsRef<Path>) -> Result<CheckpointManifest, CheckpointError> {
184    let manifest = std::fs::read_to_string(manifest_path(directory.as_ref()))?;
185    Ok(serde_json::from_str(&manifest)?)
186}
187
188pub fn load_training_artifact<B: Backend>(
189    directory: impl AsRef<Path>,
190    device: &B::Device,
191) -> Result<LoadedTrainingArtifact<B>, CheckpointError> {
192    let directory = directory.as_ref();
193    let report = load_training_report(directory)?;
194    let (manifest, model) = load_checkpoint::<B>(directory, device)?;
195
196    if report.variant != manifest.variant {
197        return Err(CheckpointError::ArtifactMismatch(format!(
198            "report variant {} does not match manifest variant {}",
199            report.variant.slug(),
200            manifest.variant.slug()
201        )));
202    }
203    if report.config != manifest.config {
204        return Err(CheckpointError::ArtifactMismatch(
205            "report config does not match manifest config".to_string(),
206        ));
207    }
208    if report.num_params != manifest.num_params {
209        return Err(CheckpointError::ArtifactMismatch(format!(
210            "report num_params {} does not match manifest {}",
211            report.num_params, manifest.num_params
212        )));
213    }
214
215    let best_validation_directory = best_validation_path(directory);
216    let best_validation_model = if report.best_validation.is_none() {
217        None
218    } else if best_validation_directory.join(MANIFEST_FILE).is_file() {
219        Some(load_checkpoint::<B>(&best_validation_directory, device)?.1)
220    } else if report.best_validation_step == Some(report.steps_completed) {
221        Some(model.clone())
222    } else {
223        return Err(CheckpointError::MissingBestValidationArtifact(
224            best_validation_directory,
225        ));
226    };
227    let optimizer_state = optimizer_state_path(directory)
228        .is_file()
229        .then(|| std::fs::read(optimizer_state_path(directory)))
230        .transpose()?;
231
232    Ok(LoadedTrainingArtifact {
233        manifest,
234        report,
235        model,
236        best_validation_model,
237        optimizer_state,
238    })
239}
240
241pub fn save_training_report(
242    report: &TrainingReport,
243    directory: impl AsRef<Path>,
244) -> Result<(), CheckpointError> {
245    let directory = directory.as_ref();
246    std::fs::create_dir_all(directory)?;
247    write_json(training_report_path(directory), report)
248}
249
250pub fn load_training_report(
251    directory: impl AsRef<Path>,
252) -> Result<TrainingReport, CheckpointError> {
253    read_json(training_report_path(directory.as_ref()))
254}
255
256pub fn save_training_comparison_report(
257    report: &TrainingComparisonReport,
258    directory: impl AsRef<Path>,
259) -> Result<(), CheckpointError> {
260    let directory = directory.as_ref();
261    std::fs::create_dir_all(directory)?;
262    write_json(training_comparison_report_path(directory), report)
263}
264
265pub fn load_training_comparison_report(
266    directory: impl AsRef<Path>,
267) -> Result<TrainingComparisonReport, CheckpointError> {
268    read_json(training_comparison_report_path(directory.as_ref()))
269}
270
271pub fn save_training_artifact<B: Backend>(
272    outcome: &TrainingOutcome<B>,
273    directory: impl AsRef<Path>,
274) -> Result<CheckpointManifest, CheckpointError> {
275    let directory = directory.as_ref();
276    let manifest = save_checkpoint(
277        &outcome.model,
278        outcome.report.variant,
279        &outcome.report.config,
280        directory,
281    )?;
282    save_training_report(&outcome.report, directory)?;
283    std::fs::write(optimizer_state_path(directory), &outcome.optimizer_state)?;
284
285    let best_validation_directory = best_validation_path(directory);
286    if best_validation_directory.exists() {
287        std::fs::remove_dir_all(&best_validation_directory)?;
288    }
289    if let Some(best_model) = outcome.best_validation_model.as_ref() {
290        save_checkpoint(
291            best_model,
292            outcome.report.variant,
293            &outcome.report.config,
294            &best_validation_directory,
295        )?;
296        save_training_report(&outcome.report, &best_validation_directory)?;
297    }
298
299    Ok(manifest)
300}
301
302pub fn save_training_sweep<B: Backend>(
303    sweep: &TrainingSweepOutcome<B>,
304    directory: impl AsRef<Path>,
305) -> Result<Vec<CheckpointManifest>, CheckpointError> {
306    let directory = directory.as_ref();
307    std::fs::create_dir_all(directory)?;
308    save_training_comparison_report(&sweep.report, directory)?;
309
310    sweep
311        .outcomes
312        .iter()
313        .map(|outcome| {
314            save_training_artifact(outcome, directory.join(outcome.report.variant.slug()))
315        })
316        .collect()
317}
318
319fn ensure_variant_matches_model<B: Backend>(
320    model: &ModelInstance<B>,
321    variant: ModelVariant,
322) -> Result<(), CheckpointError> {
323    match (model, variant.uses_ddl()) {
324        (ModelInstance::Baseline(_), false) | (ModelInstance::Ddl(_), true) => Ok(()),
325        (ModelInstance::Baseline(_), true) => Err(CheckpointError::ModelVariantMismatch {
326            variant,
327            actual_model: "baseline",
328        }),
329        (ModelInstance::Ddl(_), false) => Err(CheckpointError::ModelVariantMismatch {
330            variant,
331            actual_model: "ddl",
332        }),
333    }
334}
335
336fn manifest_path(directory: &Path) -> PathBuf {
337    directory.join(MANIFEST_FILE)
338}
339
340fn weights_path(directory: &Path, weights_stem: &str) -> PathBuf {
341    directory.join(weights_stem)
342}
343
344fn training_report_path(directory: &Path) -> PathBuf {
345    directory.join(TRAINING_REPORT_FILE)
346}
347
348fn training_comparison_report_path(directory: &Path) -> PathBuf {
349    directory.join(TRAINING_COMPARISON_REPORT_FILE)
350}
351
352fn optimizer_state_path(directory: &Path) -> PathBuf {
353    directory.join(OPTIMIZER_STATE_FILE)
354}
355
356fn best_validation_path(directory: &Path) -> PathBuf {
357    directory.join(BEST_VALIDATION_DIR)
358}
359
360fn write_json<T: Serialize>(path: PathBuf, value: &T) -> Result<(), CheckpointError> {
361    std::fs::write(path, serde_json::to_string_pretty(value)?)?;
362    Ok(())
363}
364
365fn read_json<T: DeserializeOwned>(path: PathBuf) -> Result<T, CheckpointError> {
366    let contents = std::fs::read_to_string(path)?;
367    Ok(serde_json::from_str(&contents)?)
368}