use flodl::*;
fn main() -> Result<()> {
manual_seed(42);
let opts = TensorOptions::default();
println!("=== Phase 1: Train encoder ===");
let encoder = FlowBuilder::from(Linear::new(4, 16)?)
.through(GELU)
.through(LayerNorm::new(16)?)
.tag("encoded")
.through(Linear::new(16, 4)?)
.build()?;
let params = encoder.parameters();
let mut opt = Adam::new(¶ms, 0.01);
encoder.train();
for epoch in 0..50 {
let x = Tensor::randn(&[32, 4], opts)?;
let y = Tensor::randn(&[32, 4], opts)?;
let input = Variable::new(x, true);
let target = Variable::new(y, false);
opt.zero_grad();
let pred = encoder.forward(&input)?;
let loss = mse_loss(&pred, &target)?;
loss.backward()?;
clip_grad_norm(¶ms, 1.0)?;
opt.step()?;
if epoch % 10 == 0 {
println!(" epoch {:>3}: loss={:.4}", epoch, loss.item()?);
}
}
let ckpt_path = "pretrained_encoder.fdl";
let named = encoder.named_parameters();
let named_bufs = encoder.named_buffers();
save_checkpoint_file(ckpt_path, &named, &named_bufs, None)?;
println!("Encoder saved to {}", ckpt_path);
println!("\n=== Phase 2: Transfer to new architecture ===");
let model = FlowBuilder::from(Linear::new(4, 16)?)
.through(GELU)
.through(LayerNorm::new(16)?)
.tag("encoded")
.also(Linear::new(16, 16)?) .through(Linear::new(16, 2)?) .build()?;
let named2 = model.named_parameters();
let named_bufs2 = model.named_buffers();
let report = load_checkpoint_file(ckpt_path, &named2, &named_bufs2, None)?;
println!("Loaded {} parameters, skipped {}, missing {}",
report.loaded.len(), report.skipped.len(), report.missing.len());
let all_params = model.parameters();
for (i, p) in all_params.iter().enumerate() {
if i < 3 {
p.freeze()?;
}
}
println!("Encoder layers frozen");
let trainable: Vec<Parameter> = all_params
.iter()
.filter(|p| !p.is_frozen())
.cloned()
.collect();
let mut opt2 = Adam::new(&trainable, 0.005);
model.train();
println!("\n=== Phase 3: Fine-tune new head ===");
for epoch in 0..50 {
let x = Tensor::randn(&[32, 4], opts)?;
let y = Tensor::randn(&[32, 2], opts)?;
let input = Variable::new(x, true);
let target = Variable::new(y, false);
opt2.zero_grad();
let pred = model.forward(&input)?;
let loss = mse_loss(&pred, &target)?;
loss.backward()?;
clip_grad_norm(&trainable, 1.0)?;
opt2.step()?;
if epoch % 10 == 0 {
println!(" epoch {:>3}: loss={:.4}", epoch, loss.item()?);
}
}
println!("\n=== Phase 4: Full fine-tune (unfrozen) ===");
for p in &all_params {
p.unfreeze()?;
}
let mut opt3 = Adam::new(&all_params, 0.001); for epoch in 0..30 {
let x = Tensor::randn(&[32, 4], opts)?;
let y = Tensor::randn(&[32, 2], opts)?;
let input = Variable::new(x, true);
let target = Variable::new(y, false);
opt3.zero_grad();
let pred = model.forward(&input)?;
let loss = mse_loss(&pred, &target)?;
loss.backward()?;
clip_grad_norm(&all_params, 1.0)?;
opt3.step()?;
if epoch % 10 == 0 {
println!(" epoch {:>3}: loss={:.4}", epoch, loss.item()?);
}
}
println!("\nTransfer learning complete.");
std::fs::remove_file(ckpt_path).ok();
Ok(())
}