use burn::module::{Param, ParamId};
use burn::nn::Linear;
use burn::prelude::*;
use crate::model::linear_zeros;
#[derive(Module, Debug)]
pub struct MLPHead<B: Backend> {
pub lin1: Linear<B>,
pub lin2: Linear<B>,
pub lin3: Linear<B>,
}
impl<B: Backend> MLPHead<B> {
pub fn new(
in_features: usize,
hidden_dim: usize,
out_features: usize,
device: &B::Device,
) -> Self {
Self {
lin1: linear_zeros(in_features, hidden_dim, true, device),
lin2: linear_zeros(hidden_dim, hidden_dim, true, device),
lin3: linear_zeros(hidden_dim, out_features, true, device),
}
}
pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.lin1.forward(x);
let x = burn::tensor::activation::relu(x);
let x = self.lin2.forward(x);
let x = burn::tensor::activation::relu(x);
self.lin3.forward(x)
}
}
#[derive(Module, Debug)]
pub struct ClassificationHead<B: Backend> {
pub fc_norm: burn::nn::LayerNorm<B>,
pub head: Linear<B>,
pub num_classes: usize,
}
impl<B: Backend> ClassificationHead<B> {
pub fn new(embed_dim: usize, num_classes: usize, device: &B::Device) -> Self {
Self {
fc_norm: burn::nn::LayerNormConfig::new(embed_dim)
.with_epsilon(1e-6)
.init(device),
head: linear_zeros(embed_dim, num_classes, true, device),
num_classes,
}
}
pub fn forward(&self, encoder_output: Tensor<B, 3>) -> Tensor<B, 2> {
let [b, _n, d] = encoder_output.dims();
let pooled = encoder_output.mean_dim(1).reshape([b, d]);
let normed = self.fc_norm.forward(pooled);
self.head.forward(normed)
}
pub fn load_weights(
&mut self,
wm: &mut crate::weights::WeightMap,
prefix: &str,
device: &<B as Backend>::Device,
) -> anyhow::Result<()> {
if wm.has(&format!("{prefix}.head.weight")) {
let w: Tensor<B, 2> = wm.take(&format!("{prefix}.head.weight"), device)?;
self.head.weight = Param::initialized(ParamId::new(), w.transpose());
}
if wm.has(&format!("{prefix}.head.bias")) {
let b: Tensor<B, 1> = wm.take(&format!("{prefix}.head.bias"), device)?;
self.head.bias = Some(Param::initialized(ParamId::new(), b));
}
if wm.has(&format!("{prefix}.fc_norm.weight")) {
let w: Tensor<B, 1> = wm.take(&format!("{prefix}.fc_norm.weight"), device)?;
self.fc_norm.gamma = Param::initialized(ParamId::new(), w);
}
if wm.has(&format!("{prefix}.fc_norm.bias")) {
let b: Tensor<B, 1> = wm.take(&format!("{prefix}.fc_norm.bias"), device)?;
self.fc_norm.beta = Some(Param::initialized(ParamId::new(), b));
}
Ok(())
}
}
pub fn predict_classes<B: Backend>(logits: Tensor<B, 2>) -> Tensor<B, 1, Int> {
let [b, _c] = logits.dims();
logits.argmax(1).reshape([b])
}