#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use std::sync::{Arc, Mutex};
use candle_core::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{LayerNorm, Linear};
use mistralrs_quant::{MatMul, ShardedVarBuilder};
use crate::{
layers::{self, layer_norm, GetFloatInfo},
layers_masker::masked_fill,
utils::unvarbuilder::UnVarBuilder,
};
const DEFAULT_MAX_SIZE: (usize, usize) = (70, 70);
fn get_2d_sincos_pos_embed(
embed_dim: usize,
image_size: (usize, usize),
device: &Device,
dtype: DType,
) -> Result<Tensor> {
let (grid_h_size, grid_w_size) = image_size;
let grid_h = Tensor::arange(0f32, grid_h_size as f32, device)?;
let grid_w = Tensor::arange(0f32, grid_w_size as f32, device)?;
let grid = Tensor::meshgrid(&[grid_w, grid_h], true)?;
let grid = Tensor::stack(&grid, 0)?;
get_2d_sincos_pos_embed_from_grid(embed_dim, &grid)?.to_dtype(dtype)
}
fn get_2d_sincos_pos_embed_from_grid(embed_dim: usize, grid: &Tensor) -> Result<Tensor> {
assert_eq!(embed_dim % 2, 0);
let emb_h = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, &grid.i(0)?)?;
let emb_w = get_1d_sincos_pos_embed_from_grid_new(embed_dim / 2, &grid.i(1)?)?;
Tensor::cat(&[emb_h, emb_w], D::Minus1)
}
fn get_1d_sincos_pos_embed_from_grid_new(embed_dim: usize, pos: &Tensor) -> Result<Tensor> {
let inv_freq: Vec<_> = (0..embed_dim)
.step_by(2)
.map(|i| 1f32 / 10_000f32.powf(i as f32 / embed_dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let omega = Tensor::from_vec(inv_freq, (1, inv_freq_len), pos.device())?;
let (h, w) = pos.dims2()?;
let mut out = pos
.reshape(((), 1))?
.matmul(&omega.reshape((1, ()))?)
.unwrap();
out = out.reshape((h, w, ()))?;
let emb_sin = out.sin()?;
let emb_cos = out.cos()?;
Tensor::cat(&[emb_sin, emb_cos], D::Minus1)
}
struct SinCos2dPosEmbed {
pos_embed: Tensor,
max_size: (usize, usize),
}
pub struct Resampler {
query: Tensor,
kv_proj: Option<Linear>,
proj: Tensor,
ln_q: LayerNorm,
ln_kv: LayerNorm,
ln_post: LayerNorm,
attn: MultiheadAttention,
sincos_pos_embed: Arc<Mutex<SinCos2dPosEmbed>>,
embed_dim: usize,
}
impl Resampler {
pub fn new(
num_queries: usize,
embed_dim: usize,
num_heads: usize,
kv_dim: usize,
_adaptive: bool,
max_size: Option<(usize, usize)>,
vb: ShardedVarBuilder,
) -> Result<Self> {
let max_size = max_size.unwrap_or(DEFAULT_MAX_SIZE);
let query = vb.get((num_queries, embed_dim), "query")?;
let kv_proj = if kv_dim != embed_dim {
Some(layers::linear_no_bias(kv_dim, embed_dim, vb.pp("kv_proj"))?)
} else {
None
};
let ln_q = layer_norm(embed_dim, 1e-6, vb.pp("ln_q"))?;
let ln_kv = layer_norm(embed_dim, 1e-6, vb.pp("ln_kv"))?;
let ln_post = layer_norm(embed_dim, 1e-6, vb.pp("ln_post"))?;
let proj = vb.get((embed_dim, embed_dim), "proj")?;
let attn = MultiheadAttention::new(embed_dim, num_heads, vb.pp("attn"))?;
let pos_embed = Arc::new(Mutex::new(SinCos2dPosEmbed {
pos_embed: get_2d_sincos_pos_embed(embed_dim, max_size, vb.device(), vb.dtype())?,
max_size,
}));
Ok(Self {
query,
kv_proj,
proj,
ln_q,
ln_kv,
ln_post,
attn,
sincos_pos_embed: pos_embed,
embed_dim,
})
}
pub fn forward(&self, x: &Tensor, tgt_sizes_vec: &[Vec<u32>]) -> Result<Tensor> {
let mut pos_embed_cache = self.sincos_pos_embed.lock().unwrap();
let bs = x.dim(0)?;
let device = x.device();
assert_eq!(bs, tgt_sizes_vec.len());
let tgt_sizes_vec_0 = tgt_sizes_vec.iter().map(|x| x[0]).collect::<Vec<_>>();
let tgt_sizes_vec_1 = tgt_sizes_vec.iter().map(|x| x[1]).collect::<Vec<_>>();
let patch_len = tgt_sizes_vec_0
.iter()
.zip(&tgt_sizes_vec_1)
.map(|(x, y)| x * y)
.collect::<Vec<_>>();
{
let max_h = *tgt_sizes_vec_0.iter().max().unwrap() as usize;
let max_w = *tgt_sizes_vec_1.iter().max().unwrap() as usize;
if max_h > pos_embed_cache.max_size.0 || max_w > pos_embed_cache.max_size.1 {
pos_embed_cache.max_size = (
max_h.max(pos_embed_cache.max_size.0),
max_w.max(pos_embed_cache.max_size.1),
);
pos_embed_cache.pos_embed = get_2d_sincos_pos_embed(
self.embed_dim,
pos_embed_cache.max_size,
device,
x.dtype(),
)?;
}
}
let max_patch_len = *patch_len.iter().max().unwrap() as usize;
let mut key_padding_mask = Tensor::zeros((bs, max_patch_len), DType::U8, device)?;
let mut pos_embed = Vec::new();
for (i, tgt_sizes_vec_i) in tgt_sizes_vec.iter().enumerate().take(bs) {
let (tgt_h, tgt_w) = (tgt_sizes_vec_i[0] as usize, tgt_sizes_vec_i[1] as usize);
pos_embed.push(
pos_embed_cache
.pos_embed
.i((..tgt_h, ..tgt_w, ..))?
.reshape((tgt_h * tgt_w, ()))?,
);
let n = patch_len[i] as usize;
if n != max_patch_len {
key_padding_mask = key_padding_mask.slice_assign(
&[i..i + 1, n..max_patch_len],
&Tensor::ones((1, max_patch_len - n), DType::U8, device)?,
)?;
}
}
let lens = pos_embed
.iter()
.map(|emb| emb.dim(0))
.collect::<Result<Vec<_>>>()?;
let max_len = lens.into_iter().max().expect("No pixe values somehow?");
pos_embed = pos_embed
.into_iter()
.map(|emb| emb.pad_with_zeros(0, 0, max_len - emb.dim(0)?))
.collect::<Result<Vec<_>>>()?;
let pos_embed = Tensor::stack(&pos_embed, 0)?;
let mut x = if let Some(kv_proj) = &self.kv_proj {
x.apply(kv_proj)?
} else {
x.clone()
};
x = x.apply(&self.ln_kv)?;
let q = self.query.apply(&self.ln_q)?;
let mut out = self.attn.forward(
&self.repeat_q_bs(&q, bs)?,
&(&x + &pos_embed)?,
&x,
Some(key_padding_mask),
None,
)?;
out = out.apply(&self.ln_post)?;
out.broadcast_matmul(&self.proj)
}
fn repeat_q_bs(&self, q: &Tensor, n: usize) -> Result<Tensor> {
q.unsqueeze(0)?.repeat((n, 1, 1))
}
pub fn residual_tensors(&self) -> Vec<(String, Tensor)> {
let uvb = UnVarBuilder::new();
let uvb_attn = uvb.pp("attn");
uvb_attn.pp("out_proj").add(&self.attn.out_proj);
uvb_attn.add_tensor("in_proj_weight", self.attn.in_proj_weight.clone());
uvb_attn.add_tensor("in_proj_bias", self.attn.in_proj_bias.clone());
uvb.pp("ln_kv").add(&self.ln_kv);
uvb.pp("ln_post").add(&self.ln_post);
uvb.pp("ln_q").add(&self.ln_q);
uvb.add_tensor("proj", self.proj.clone());
uvb.add_tensor("query", self.query.clone());
uvb.to_safetensors()
}
}
struct MultiheadAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
num_heads: usize,
head_dim: usize,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
}
impl MultiheadAttention {
fn new(embed_dim: usize, num_heads: usize, vb: ShardedVarBuilder) -> Result<Self> {
let in_proj_bias = vb.get(embed_dim * 3, "in_proj_bias")?;
let in_proj_weight = vb.get((embed_dim * 3, embed_dim), "in_proj_weight")?;
let q_proj = Linear::new(
in_proj_weight.i(0..embed_dim)?,
Some(in_proj_bias.i(0..embed_dim)?),
);
let k_proj = Linear::new(
in_proj_weight.i(embed_dim..embed_dim * 2)?,
Some(in_proj_bias.i(embed_dim..embed_dim * 2)?),
);
let v_proj = Linear::new(
in_proj_weight.i(embed_dim * 2..embed_dim * 3)?,
Some(in_proj_bias.i(embed_dim * 2..embed_dim * 3)?),
);
let out_proj = layers::linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
num_heads,
head_dim: embed_dim / num_heads,
in_proj_weight,
in_proj_bias,
})
}
fn forward(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
key_padding_mask: Option<Tensor>,
mut attn_mask: Option<Tensor>,
) -> Result<Tensor> {
let (bs, q_seq, _) = q.dims3()?;
let (_, kv_seq, _) = k.dims3()?;
let mut q = q.apply(&self.q_proj)?;
let mut k = k.apply(&self.k_proj)?;
let mut v = v.apply(&self.v_proj)?;
if let Some(mut key_padding_mask) = key_padding_mask {
key_padding_mask = key_padding_mask
.reshape((bs, 1, 1, kv_seq))?
.repeat((1, self.num_heads, 1, 1))?
.reshape((bs * self.num_heads, 1, kv_seq))?;
if let Some(attn_mask) = attn_mask.as_mut() {
*attn_mask = attn_mask.broadcast_add(&key_padding_mask)?;
} else {
attn_mask = Some(key_padding_mask);
}
}
q = q
.reshape((bs, q_seq, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
k = k
.reshape((bs, kv_seq, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
v = v
.reshape((bs, kv_seq, self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()?;
let mut y = {
let mut att =
MatMul.matmul_affine_mul(&q, &k.t()?, (1. / self.head_dim as f64).sqrt())?;
att = match attn_mask {
Some(mask) => {
let mask = mask.reshape((bs, self.num_heads, (), kv_seq))?;
masked_fill(&att, &mask, att.dtype().finfo()?.min as f32)?
}
None => att,
};
att = candle_nn::ops::softmax_last_dim(&att)?;
MatMul.matmul(&att, &v)?
};
y = y.transpose(1, 2)?.reshape((bs, q_seq, ()))?;
y.apply(&self.out_proj)
}
}