use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct Projector {
pub layers: Vec<ProjectorLayer>,
}
#[derive(Debug, Clone)]
pub struct ProjectorLayer {
pub weight: Tensor,
pub bias: Tensor,
pub ln_weight: Option<Tensor>,
pub ln_bias: Option<Tensor>,
pub has_activation: bool,
}
impl Projector {
pub fn new_linear(in_dim: usize, out_dim: usize) -> Self {
Self {
layers: vec![ProjectorLayer {
weight: Tensor::zeros(&[in_dim, out_dim]),
bias: Tensor::zeros(&[out_dim]),
ln_weight: None,
ln_bias: None,
has_activation: false,
}],
}
}
pub fn forward(&self, x: &Tensor) -> Tensor {
let nd = x.ndim();
let d = *x.shape.last().unwrap();
let batch: usize = x.shape[..nd - 1].iter().product();
let batch_shape = x.shape[..nd - 1].to_vec();
let mut current = x.reshape(&[batch, d]);
for layer in &self.layers {
let out_dim = layer.bias.shape[0];
current = current.matmul(&layer.weight).add_bias(&layer.bias);
if let (Some(w), Some(b)) = (&layer.ln_weight, &layer.ln_bias) {
current = current.layer_norm(w, b, 1e-5);
}
if layer.has_activation {
current = current.gelu();
}
let _ = out_dim;
}
let out_dim = *current.shape.last().unwrap();
let mut out_shape = batch_shape;
out_shape.push(out_dim);
current.reshape(&out_shape)
}
}