transfer_learning/
main.rs1use flodl::*;
10
11fn main() -> Result<()> {
12 let opts = TensorOptions::default();
13
14 println!("=== Phase 1: Train encoder ===");
16
17 let encoder = FlowBuilder::from(Linear::new(4, 16)?)
18 .through(GELU)
19 .through(LayerNorm::new(16)?)
20 .tag("encoded")
21 .through(Linear::new(16, 4)?)
22 .build()?;
23
24 let params = encoder.parameters();
25 let mut opt = Adam::new(¶ms, 0.01);
26 encoder.train();
27
28 for epoch in 0..50 {
29 let x = Tensor::randn(&[32, 4], opts)?;
30 let y = Tensor::randn(&[32, 4], opts)?;
31 let input = Variable::new(x, true);
32 let target = Variable::new(y, false);
33
34 opt.zero_grad();
35 let pred = encoder.forward(&input)?;
36 let loss = mse_loss(&pred, &target)?;
37 loss.backward()?;
38 clip_grad_norm(¶ms, 1.0)?;
39 opt.step()?;
40
41 if epoch % 10 == 0 {
42 println!(" epoch {:>3}: loss={:.4}", epoch, loss.item()?);
43 }
44 }
45
46 let ckpt_path = "pretrained_encoder.fdl";
48 let named = encoder.named_parameters();
49 let named_bufs = encoder.named_buffers();
50 save_checkpoint_file(ckpt_path, &named, &named_bufs, None)?;
51 println!("Encoder saved to {}", ckpt_path);
52
53 println!("\n=== Phase 2: Transfer to new architecture ===");
55
56 let model = FlowBuilder::from(Linear::new(4, 16)?)
58 .through(GELU)
59 .through(LayerNorm::new(16)?)
60 .tag("encoded")
61 .also(Linear::new(16, 16)?) .through(Linear::new(16, 2)?) .build()?;
64
65 let named2 = model.named_parameters();
67 let named_bufs2 = model.named_buffers();
68 let report = load_checkpoint_file(ckpt_path, &named2, &named_bufs2, None)?;
69
70 println!("Loaded {} parameters, skipped {}, missing {}",
71 report.loaded.len(), report.skipped.len(), report.missing.len());
72
73 let all_params = model.parameters();
75 for (i, p) in all_params.iter().enumerate() {
76 if i < 3 {
77 p.freeze()?;
78 }
79 }
80 println!("Encoder layers frozen");
81
82 let trainable: Vec<Parameter> = all_params
86 .iter()
87 .filter(|p| !p.is_frozen())
88 .cloned()
89 .collect();
90
91 let mut opt2 = Adam::new(&trainable, 0.005);
92 model.train();
93
94 println!("\n=== Phase 3: Fine-tune new head ===");
96
97 for epoch in 0..50 {
98 let x = Tensor::randn(&[32, 4], opts)?;
99 let y = Tensor::randn(&[32, 2], opts)?;
100 let input = Variable::new(x, true);
101 let target = Variable::new(y, false);
102
103 opt2.zero_grad();
104 let pred = model.forward(&input)?;
105 let loss = mse_loss(&pred, &target)?;
106 loss.backward()?;
107 clip_grad_norm(&trainable, 1.0)?;
108 opt2.step()?;
109
110 if epoch % 10 == 0 {
111 println!(" epoch {:>3}: loss={:.4}", epoch, loss.item()?);
112 }
113 }
114
115 println!("\n=== Phase 4: Full fine-tune (unfrozen) ===");
117
118 for p in &all_params {
119 p.unfreeze()?;
120 }
121
122 let mut opt3 = Adam::new(&all_params, 0.001); for epoch in 0..30 {
124 let x = Tensor::randn(&[32, 4], opts)?;
125 let y = Tensor::randn(&[32, 2], opts)?;
126 let input = Variable::new(x, true);
127 let target = Variable::new(y, false);
128
129 opt3.zero_grad();
130 let pred = model.forward(&input)?;
131 let loss = mse_loss(&pred, &target)?;
132 loss.backward()?;
133 clip_grad_norm(&all_params, 1.0)?;
134 opt3.step()?;
135
136 if epoch % 10 == 0 {
137 println!(" epoch {:>3}: loss={:.4}", epoch, loss.item()?);
138 }
139 }
140
141 println!("\nTransfer learning complete.");
142
143 std::fs::remove_file(ckpt_path).ok();
145 Ok(())
146}