Skip to main content

transfer_learning/
main.rs

1//! Transfer learning — freeze a pretrained encoder and train a new head.
2//!
3//! Demonstrates checkpoint save/load, partial loading into a different
4//! architecture, parameter freezing, and optimizer parameter groups with
5//! per-group learning rates.
6//!
7//! Run: `cargo run --example transfer_learning`
8
9use flodl::*;
10
11fn main() -> Result<()> {
12    let opts = TensorOptions::default();
13
14    // --- Phase 1: Train an encoder on a proxy task ---
15    println!("=== Phase 1: Train encoder ===");
16
17    let encoder = FlowBuilder::from(Linear::new(4, 16)?)
18        .through(GELU)
19        .through(LayerNorm::new(16)?)
20        .tag("encoded")
21        .through(Linear::new(16, 4)?)
22        .build()?;
23
24    let params = encoder.parameters();
25    let mut opt = Adam::new(&params, 0.01);
26    encoder.train();
27
28    for epoch in 0..50 {
29        let x = Tensor::randn(&[32, 4], opts)?;
30        let y = Tensor::randn(&[32, 4], opts)?;
31        let input = Variable::new(x, true);
32        let target = Variable::new(y, false);
33
34        opt.zero_grad();
35        let pred = encoder.forward(&input)?;
36        let loss = mse_loss(&pred, &target)?;
37        loss.backward()?;
38        clip_grad_norm(&params, 1.0)?;
39        opt.step()?;
40
41        if epoch % 10 == 0 {
42            println!("  epoch {:>3}: loss={:.4}", epoch, loss.item()?);
43        }
44    }
45
46    // Save the pretrained encoder.
47    let ckpt_path = "pretrained_encoder.fdl";
48    let named = encoder.named_parameters();
49    let named_bufs = encoder.named_buffers();
50    save_checkpoint_file(ckpt_path, &named, &named_bufs, None)?;
51    println!("Encoder saved to {}", ckpt_path);
52
53    // --- Phase 2: Build a new model and load encoder weights ---
54    println!("\n=== Phase 2: Transfer to new architecture ===");
55
56    // New architecture reuses the encoder layers but adds a different head.
57    let model = FlowBuilder::from(Linear::new(4, 16)?)
58        .through(GELU)
59        .through(LayerNorm::new(16)?)
60        .tag("encoded")
61        .also(Linear::new(16, 16)?)      // new: residual connection
62        .through(Linear::new(16, 2)?)     // new: different output dim
63        .build()?;
64
65    // Partial load: matching names get loaded, new layers keep random init.
66    let named2 = model.named_parameters();
67    let named_bufs2 = model.named_buffers();
68    let report = load_checkpoint_file(ckpt_path, &named2, &named_bufs2, None)?;
69
70    println!("Loaded {} parameters, skipped {}, missing {}",
71        report.loaded.len(), report.skipped.len(), report.missing.len());
72
73    // Freeze the encoder layers (first 3 modules: Linear, GELU, LayerNorm).
74    let all_params = model.parameters();
75    for (i, p) in all_params.iter().enumerate() {
76        if i < 3 {
77            p.freeze()?;
78        }
79    }
80    println!("Encoder layers frozen");
81
82    // Set up optimizer with parameter groups:
83    // - Frozen params are excluded automatically (zero grad).
84    // - New head gets higher LR.
85    let trainable: Vec<Parameter> = all_params
86        .iter()
87        .filter(|p| !p.is_frozen())
88        .cloned()
89        .collect();
90
91    let mut opt2 = Adam::new(&trainable, 0.005);
92    model.train();
93
94    // --- Phase 3: Fine-tune the new head ---
95    println!("\n=== Phase 3: Fine-tune new head ===");
96
97    for epoch in 0..50 {
98        let x = Tensor::randn(&[32, 4], opts)?;
99        let y = Tensor::randn(&[32, 2], opts)?;
100        let input = Variable::new(x, true);
101        let target = Variable::new(y, false);
102
103        opt2.zero_grad();
104        let pred = model.forward(&input)?;
105        let loss = mse_loss(&pred, &target)?;
106        loss.backward()?;
107        clip_grad_norm(&trainable, 1.0)?;
108        opt2.step()?;
109
110        if epoch % 10 == 0 {
111            println!("  epoch {:>3}: loss={:.4}", epoch, loss.item()?);
112        }
113    }
114
115    // --- Phase 4: Unfreeze and fine-tune everything ---
116    println!("\n=== Phase 4: Full fine-tune (unfrozen) ===");
117
118    for p in &all_params {
119        p.unfreeze()?;
120    }
121
122    let mut opt3 = Adam::new(&all_params, 0.001); // lower LR for full model
123    for epoch in 0..30 {
124        let x = Tensor::randn(&[32, 4], opts)?;
125        let y = Tensor::randn(&[32, 2], opts)?;
126        let input = Variable::new(x, true);
127        let target = Variable::new(y, false);
128
129        opt3.zero_grad();
130        let pred = model.forward(&input)?;
131        let loss = mse_loss(&pred, &target)?;
132        loss.backward()?;
133        clip_grad_norm(&all_params, 1.0)?;
134        opt3.step()?;
135
136        if epoch % 10 == 0 {
137            println!("  epoch {:>3}: loss={:.4}", epoch, loss.item()?);
138        }
139    }
140
141    println!("\nTransfer learning complete.");
142
143    // Clean up.
144    std::fs::remove_file(ckpt_path).ok();
145    Ok(())
146}