entrenar/train/transformer_trainer/
wgpu_checkpoint.rs1#[cfg(feature = "gpu")]
7#[derive(Clone, serde::Serialize, serde::Deserialize)]
8pub struct LoraLayerSet {
9 pub q: super::wgpu_nf4::LoraAdapter,
10 pub k: super::wgpu_nf4::LoraAdapter,
11 pub v: super::wgpu_nf4::LoraAdapter,
12 pub o: super::wgpu_nf4::LoraAdapter,
13 pub gate: super::wgpu_nf4::LoraAdapter,
14 pub up: super::wgpu_nf4::LoraAdapter,
15 pub down: super::wgpu_nf4::LoraAdapter,
16}
17
18#[cfg(feature = "gpu")]
19impl LoraLayerSet {
20 pub fn new(rank: u32, h: u32, q_dim: u32, kv_dim: u32, i_size: u32) -> Self {
21 use super::wgpu_nf4::LoraAdapter;
22 Self {
23 q: LoraAdapter::new(rank, h, q_dim),
24 k: LoraAdapter::new(rank, h, kv_dim),
25 v: LoraAdapter::new(rank, h, kv_dim),
26 o: LoraAdapter::new(rank, q_dim, h),
27 gate: LoraAdapter::new(rank, h, i_size),
28 up: LoraAdapter::new(rank, h, i_size),
29 down: LoraAdapter::new(rank, i_size, h),
30 }
31 }
32 pub fn num_params(&self) -> usize {
33 self.q.num_params()
34 + self.k.num_params()
35 + self.v.num_params()
36 + self.o.num_params()
37 + self.gate.num_params()
38 + self.up.num_params()
39 + self.down.num_params()
40 }
41}
42
43#[cfg(feature = "gpu")]
45#[derive(serde::Serialize, serde::Deserialize)]
46pub struct LoraCheckpointV2 {
47 pub step: u32,
48 pub rank: u32,
49 pub alpha: f32,
50 pub loss: f32,
51 pub hidden_size: u32,
52 pub num_layers: u32,
53 pub layers: Vec<LoraLayerSet>,
54}
55
56#[cfg(feature = "gpu")]
57pub fn save_lora_checkpoint(
58 lora: &[LoraLayerSet],
59 hidden_size: usize,
60 output_dir: &std::path::Path,
61 step: u32,
62 loss: f32,
63 rank: u32,
64 alpha: f32,
65) -> Result<std::path::PathBuf, String> {
66 std::fs::create_dir_all(output_dir).map_err(|e| format!("Cannot create output dir: {e}"))?;
67 let ckpt = LoraCheckpointV2 {
68 step,
69 rank,
70 alpha,
71 loss,
72 hidden_size: hidden_size as u32,
73 num_layers: lora.len() as u32,
74 layers: lora.to_vec(),
75 };
76 let filename = format!("lora-checkpoint-step{step}.json");
77 let path = output_dir.join(&filename);
78 let json = serde_json::to_string(&ckpt).map_err(|e| format!("Serialize: {e}"))?;
79 std::fs::write(&path, &json).map_err(|e| format!("Write: {e}"))?;
80 let mb = json.len() as f64 / 1024.0 / 1024.0;
81 eprintln!(" Saved checkpoint: {} ({mb:.1} MB)", path.display());
82 Ok(path)
83}
84
85#[cfg(feature = "gpu")]
86pub fn load_lora_checkpoint(
87 lora: &mut [LoraLayerSet],
88 num_layers: usize,
89 hidden_size: usize,
90 checkpoint_path: &std::path::Path,
91) -> Result<(u32, f32), String> {
92 let json = std::fs::read_to_string(checkpoint_path).map_err(|e| format!("Read: {e}"))?;
93 let ckpt: LoraCheckpointV2 = serde_json::from_str(&json).map_err(|e| format!("Parse: {e}"))?;
94 if ckpt.layers.len() != num_layers {
95 return Err(format!("Checkpoint {} layers, model {}", ckpt.layers.len(), num_layers));
96 }
97 if ckpt.hidden_size != hidden_size as u32 {
98 return Err(format!("Checkpoint h={}, model h={hidden_size}", ckpt.hidden_size));
99 }
100 for (i, layer) in ckpt.layers.into_iter().enumerate() {
101 lora[i] = layer;
102 }
103 eprintln!(
104 " Loaded checkpoint: step={}, loss={:.3}, {} layers",
105 ckpt.step, ckpt.loss, num_layers
106 );
107 Ok((ckpt.step, ckpt.loss))
108}
109
110#[cfg(all(test, feature = "gpu"))]
111mod tests {
112 use super::*;
113 use LoraLayerSet;
114
115 #[test]
116 fn test_checkpoint_round_trip() {
117 let lora: Vec<LoraLayerSet> = (0..2).map(|_| LoraLayerSet::new(4, 8, 8, 8, 16)).collect();
118 let tmpdir = std::env::temp_dir().join("entrenar-ckpt-v2");
119 let path = save_lora_checkpoint(&lora, 8, &tmpdir, 42, 3.14, 4, 8.0).expect("save");
120 assert!(path.exists());
121 let mut lora2: Vec<LoraLayerSet> =
122 (0..2).map(|_| LoraLayerSet::new(4, 8, 8, 8, 16)).collect();
123 let (step, loss) = load_lora_checkpoint(&mut lora2, 2, 8, &path).expect("load");
124 assert_eq!(step, 42);
125 assert!((loss - 3.14).abs() < 1e-5);
126 assert_eq!(lora[0].q.a, lora2[0].q.a);
127 assert_eq!(lora[0].gate.b, lora2[0].gate.b);
128 let _ = std::fs::remove_dir_all(&tmpdir);
129 }
130
131 #[test]
132 fn test_checkpoint_dimension_mismatch() {
133 let lora = vec![LoraLayerSet::new(4, 8, 8, 8, 16)];
134 let tmpdir = std::env::temp_dir().join("entrenar-ckpt-v2-mm");
135 let path = save_lora_checkpoint(&lora, 8, &tmpdir, 1, 5.0, 4, 8.0).expect("save");
136 let mut lora2 =
137 vec![LoraLayerSet::new(4, 16, 16, 16, 32), LoraLayerSet::new(4, 16, 16, 16, 32)];
138 let result = load_lora_checkpoint(&mut lora2, 2, 16, &path);
139 assert!(result.is_err());
140 let _ = std::fs::remove_dir_all(&tmpdir);
141 }
142}