use candle::{DType, Device, Result, Tensor};
use super::transformer::SEQ_MULTI_OF;
#[derive(Debug, Clone)]
pub struct PreparedInputs {
pub latents: Tensor,
pub cap_feats: Tensor,
pub cap_mask: Tensor,
pub text_lengths: Vec<usize>,
}
#[inline]
pub fn compute_padding_len(ori_len: usize) -> usize {
(SEQ_MULTI_OF - (ori_len % SEQ_MULTI_OF)) % SEQ_MULTI_OF
}
pub fn pad_text_embeddings(
text_embeddings: &[Tensor],
pad_value: f32,
device: &Device,
) -> Result<(Tensor, Tensor, Vec<usize>)> {
if text_embeddings.is_empty() {
candle::bail!("text_embeddings cannot be empty");
}
let batch_size = text_embeddings.len();
let dim = text_embeddings[0].dim(1)?;
let dtype = text_embeddings[0].dtype();
let lengths: Vec<usize> = text_embeddings
.iter()
.map(|t| t.dim(0))
.collect::<Result<Vec<_>>>()?;
let max_len = *lengths.iter().max().unwrap();
let padded_len = max_len + compute_padding_len(max_len);
let mut padded_list = Vec::with_capacity(batch_size);
let mut mask_list = Vec::with_capacity(batch_size);
for (i, emb) in text_embeddings.iter().enumerate() {
let seq_len = lengths[i];
let pad_len = padded_len - seq_len;
let padded = if pad_len > 0 {
let padding = Tensor::full(pad_value, (pad_len, dim), device)?.to_dtype(dtype)?;
Tensor::cat(&[emb, &padding], 0)?
} else {
emb.clone()
};
padded_list.push(padded);
let valid = Tensor::ones((seq_len,), DType::U8, device)?;
let mask = if pad_len > 0 {
let invalid = Tensor::zeros((pad_len,), DType::U8, device)?;
Tensor::cat(&[&valid, &invalid], 0)?
} else {
valid
};
mask_list.push(mask);
}
let cap_feats = Tensor::stack(&padded_list, 0)?;
let cap_mask = Tensor::stack(&mask_list, 0)?;
Ok((cap_feats, cap_mask, lengths))
}
pub fn prepare_inputs(
latents: &Tensor,
text_embeddings: &[Tensor],
device: &Device,
) -> Result<PreparedInputs> {
let latents = latents.unsqueeze(2)?;
let (cap_feats, cap_mask, text_lengths) = pad_text_embeddings(text_embeddings, 0.0, device)?;
Ok(PreparedInputs {
latents,
cap_feats,
cap_mask,
text_lengths,
})
}
pub fn create_attention_mask(
valid_len: usize,
total_len: usize,
device: &Device,
) -> Result<Tensor> {
let valid = Tensor::ones((valid_len,), DType::U8, device)?;
if valid_len < total_len {
let invalid = Tensor::zeros((total_len - valid_len,), DType::U8, device)?;
Tensor::cat(&[&valid, &invalid], 0)
} else {
Ok(valid)
}
}
pub fn batch_text_embedding(text_embedding: &Tensor, batch_size: usize) -> Result<Tensor> {
let (seq_len, dim) = text_embedding.dims2()?;
text_embedding
.unsqueeze(0)?
.broadcast_as((batch_size, seq_len, dim))?
.contiguous()
}
pub fn batch_mask(mask: &Tensor, batch_size: usize) -> Result<Tensor> {
let seq_len = mask.dim(0)?;
mask.unsqueeze(0)?
.broadcast_as((batch_size, seq_len))?
.contiguous()
}