Skip to main content

entrenar/train/transformer_trainer/
wgpu_checkpoint.rs

1//! WGPU LoRA checkpoint save/load (7-module LoRA)
2//!
3//! # Contract: C-WGPU-CKPT-001, C-WGPU-CKPT-002
4
5/// All 7 LoRA adapters for one transformer layer (Q/K/V/O/gate/up/down)
6#[cfg(feature = "gpu")]
7#[derive(Clone, serde::Serialize, serde::Deserialize)]
8pub struct LoraLayerSet {
9    pub q: super::wgpu_nf4::LoraAdapter,
10    pub k: super::wgpu_nf4::LoraAdapter,
11    pub v: super::wgpu_nf4::LoraAdapter,
12    pub o: super::wgpu_nf4::LoraAdapter,
13    pub gate: super::wgpu_nf4::LoraAdapter,
14    pub up: super::wgpu_nf4::LoraAdapter,
15    pub down: super::wgpu_nf4::LoraAdapter,
16}
17
18#[cfg(feature = "gpu")]
19impl LoraLayerSet {
20    pub fn new(rank: u32, h: u32, q_dim: u32, kv_dim: u32, i_size: u32) -> Self {
21        use super::wgpu_nf4::LoraAdapter;
22        Self {
23            q: LoraAdapter::new(rank, h, q_dim),
24            k: LoraAdapter::new(rank, h, kv_dim),
25            v: LoraAdapter::new(rank, h, kv_dim),
26            o: LoraAdapter::new(rank, q_dim, h),
27            gate: LoraAdapter::new(rank, h, i_size),
28            up: LoraAdapter::new(rank, h, i_size),
29            down: LoraAdapter::new(rank, i_size, h),
30        }
31    }
32    pub fn num_params(&self) -> usize {
33        self.q.num_params()
34            + self.k.num_params()
35            + self.v.num_params()
36            + self.o.num_params()
37            + self.gate.num_params()
38            + self.up.num_params()
39            + self.down.num_params()
40    }
41}
42
43/// Checkpoint format: all 7 LoRA adapters per layer + metadata
44#[cfg(feature = "gpu")]
45#[derive(serde::Serialize, serde::Deserialize)]
46pub struct LoraCheckpointV2 {
47    pub step: u32,
48    pub rank: u32,
49    pub alpha: f32,
50    pub loss: f32,
51    pub hidden_size: u32,
52    pub num_layers: u32,
53    pub layers: Vec<LoraLayerSet>,
54}
55
56#[cfg(feature = "gpu")]
57pub fn save_lora_checkpoint(
58    lora: &[LoraLayerSet],
59    hidden_size: usize,
60    output_dir: &std::path::Path,
61    step: u32,
62    loss: f32,
63    rank: u32,
64    alpha: f32,
65) -> Result<std::path::PathBuf, String> {
66    std::fs::create_dir_all(output_dir).map_err(|e| format!("Cannot create output dir: {e}"))?;
67    let ckpt = LoraCheckpointV2 {
68        step,
69        rank,
70        alpha,
71        loss,
72        hidden_size: hidden_size as u32,
73        num_layers: lora.len() as u32,
74        layers: lora.to_vec(),
75    };
76    let filename = format!("lora-checkpoint-step{step}.json");
77    let path = output_dir.join(&filename);
78    let json = serde_json::to_string(&ckpt).map_err(|e| format!("Serialize: {e}"))?;
79    std::fs::write(&path, &json).map_err(|e| format!("Write: {e}"))?;
80    let mb = json.len() as f64 / 1024.0 / 1024.0;
81    eprintln!("  Saved checkpoint: {} ({mb:.1} MB)", path.display());
82    Ok(path)
83}
84
85#[cfg(feature = "gpu")]
86pub fn load_lora_checkpoint(
87    lora: &mut [LoraLayerSet],
88    num_layers: usize,
89    hidden_size: usize,
90    checkpoint_path: &std::path::Path,
91) -> Result<(u32, f32), String> {
92    let json = std::fs::read_to_string(checkpoint_path).map_err(|e| format!("Read: {e}"))?;
93    let ckpt: LoraCheckpointV2 = serde_json::from_str(&json).map_err(|e| format!("Parse: {e}"))?;
94    if ckpt.layers.len() != num_layers {
95        return Err(format!("Checkpoint {} layers, model {}", ckpt.layers.len(), num_layers));
96    }
97    if ckpt.hidden_size != hidden_size as u32 {
98        return Err(format!("Checkpoint h={}, model h={hidden_size}", ckpt.hidden_size));
99    }
100    for (i, layer) in ckpt.layers.into_iter().enumerate() {
101        lora[i] = layer;
102    }
103    eprintln!(
104        "  Loaded checkpoint: step={}, loss={:.3}, {} layers",
105        ckpt.step, ckpt.loss, num_layers
106    );
107    Ok((ckpt.step, ckpt.loss))
108}
109
110#[cfg(all(test, feature = "gpu"))]
111mod tests {
112    use super::*;
113    use LoraLayerSet;
114
115    #[test]
116    fn test_checkpoint_round_trip() {
117        let lora: Vec<LoraLayerSet> = (0..2).map(|_| LoraLayerSet::new(4, 8, 8, 8, 16)).collect();
118        let tmpdir = std::env::temp_dir().join("entrenar-ckpt-v2");
119        let path = save_lora_checkpoint(&lora, 8, &tmpdir, 42, 3.14, 4, 8.0).expect("save");
120        assert!(path.exists());
121        let mut lora2: Vec<LoraLayerSet> =
122            (0..2).map(|_| LoraLayerSet::new(4, 8, 8, 8, 16)).collect();
123        let (step, loss) = load_lora_checkpoint(&mut lora2, 2, 8, &path).expect("load");
124        assert_eq!(step, 42);
125        assert!((loss - 3.14).abs() < 1e-5);
126        assert_eq!(lora[0].q.a, lora2[0].q.a);
127        assert_eq!(lora[0].gate.b, lora2[0].gate.b);
128        let _ = std::fs::remove_dir_all(&tmpdir);
129    }
130
131    #[test]
132    fn test_checkpoint_dimension_mismatch() {
133        let lora = vec![LoraLayerSet::new(4, 8, 8, 8, 16)];
134        let tmpdir = std::env::temp_dir().join("entrenar-ckpt-v2-mm");
135        let path = save_lora_checkpoint(&lora, 8, &tmpdir, 1, 5.0, 4, 8.0).expect("save");
136        let mut lora2 =
137            vec![LoraLayerSet::new(4, 16, 16, 16, 32), LoraLayerSet::new(4, 16, 16, 16, 32)];
138        let result = load_lora_checkpoint(&mut lora2, 2, 16, &path);
139        assert!(result.is_err());
140        let _ = std::fs::remove_dir_all(&tmpdir);
141    }
142}