use crate::tensor::Tensor;
use crate::config::SubjectLayersConfig;
use super::subject_layers::SubjectLayers;
#[derive(Debug, Clone)]
pub enum Projector {
Mlp(MlpProjector),
SubjectLayers(SubjectLayers),
}
#[derive(Debug, Clone)]
pub struct MlpProjector {
pub layers: Vec<ProjectorLayer>,
}
#[derive(Debug, Clone)]
pub struct ProjectorLayer {
pub weight: Tensor,
pub bias: Tensor,
pub ln_weight: Option<Tensor>,
pub ln_bias: Option<Tensor>,
pub has_activation: bool,
}
impl Projector {
pub fn new_linear(in_dim: usize, out_dim: usize) -> Self {
Self::Mlp(MlpProjector::new_linear(in_dim, out_dim))
}
pub fn new_mlp(
in_dim: usize,
out_dim: usize,
hidden_sizes: &[usize],
has_norm: bool,
) -> Self {
Self::Mlp(MlpProjector::new_mlp(in_dim, out_dim, hidden_sizes, has_norm))
}
pub fn new_subject_layers(
in_channels: usize,
out_channels: usize,
config: &SubjectLayersConfig,
) -> Self {
Self::SubjectLayers(SubjectLayers::new(in_channels, out_channels, config))
}
pub fn forward(&self, x: &Tensor) -> Tensor {
match self {
Self::Mlp(mlp) => mlp.forward(x),
Self::SubjectLayers(sl) => {
let nd = x.ndim();
assert!(nd >= 2);
if nd == 3 {
let perm = x.permute(&[0, 2, 1]); let out = sl.forward(&perm, None); out.permute(&[0, 2, 1]) } else {
sl.forward(x, None)
}
}
}
}
pub fn forward_with_subjects(&self, x: &Tensor, subject_ids: Option<&[usize]>) -> Tensor {
match self {
Self::Mlp(mlp) => mlp.forward(x),
Self::SubjectLayers(sl) => {
let nd = x.ndim();
if nd == 3 {
let perm = x.permute(&[0, 2, 1]);
let out = sl.forward(&perm, subject_ids);
out.permute(&[0, 2, 1])
} else {
sl.forward(x, subject_ids)
}
}
}
}
pub fn as_mlp_mut(&mut self) -> Option<&mut MlpProjector> {
match self {
Self::Mlp(mlp) => Some(mlp),
_ => None,
}
}
pub fn as_subject_layers_mut(&mut self) -> Option<&mut SubjectLayers> {
match self {
Self::SubjectLayers(sl) => Some(sl),
_ => None,
}
}
}
impl MlpProjector {
pub fn new_linear(in_dim: usize, out_dim: usize) -> Self {
Self {
layers: vec![ProjectorLayer {
weight: Tensor::zeros(&[in_dim, out_dim]),
bias: Tensor::zeros(&[out_dim]),
ln_weight: None,
ln_bias: None,
has_activation: false,
}],
}
}
pub fn new_mlp(
in_dim: usize,
out_dim: usize,
hidden_sizes: &[usize],
has_norm: bool,
) -> Self {
let mut layers = Vec::new();
let mut prev_dim = in_dim;
for &h in hidden_sizes {
layers.push(ProjectorLayer {
weight: Tensor::zeros(&[prev_dim, h]),
bias: Tensor::zeros(&[h]),
ln_weight: if has_norm { Some(Tensor::ones(&[h])) } else { None },
ln_bias: if has_norm { Some(Tensor::zeros(&[h])) } else { None },
has_activation: true,
});
prev_dim = h;
}
layers.push(ProjectorLayer {
weight: Tensor::zeros(&[prev_dim, out_dim]),
bias: Tensor::zeros(&[out_dim]),
ln_weight: None,
ln_bias: None,
has_activation: false,
});
Self { layers }
}
pub fn forward(&self, x: &Tensor) -> Tensor {
let nd = x.ndim();
let d = *x.shape.last().unwrap();
let batch: usize = x.shape[..nd - 1].iter().product();
let batch_shape = x.shape[..nd - 1].to_vec();
let mut current = x.reshape(&[batch, d]);
for layer in &self.layers {
current = current.matmul(&layer.weight).add_bias(&layer.bias);
if let (Some(w), Some(b)) = (&layer.ln_weight, &layer.ln_bias) {
current = current.layer_norm(w, b, 1e-5);
}
if layer.has_activation {
current = current.gelu();
}
}
let out_dim = *current.shape.last().unwrap();
let mut out_shape = batch_shape;
out_shape.push(out_dim);
current.reshape(&out_shape)
}
}