use crate::blip2::config::{Blip2Config, Blip2QFormerConfig, Blip2TextConfig, Blip2VisionConfig};
use trustformers_core::{
device::Device,
kernels::fused_ops::ActivationType,
layers::{
attention::{AttentionConfig, MultiHeadAttention},
embedding::Embedding,
layernorm::LayerNorm,
linear::Linear,
},
tensor::{DType, Tensor},
traits::Layer,
};
#[derive(Debug, Clone)]
pub struct Blip2Model {
pub config: Blip2Config,
pub vision_model: Blip2VisionModel,
pub qformer_model: Blip2QFormerModel,
pub vision_projection: Linear,
pub text_projection: Linear,
pub query_tokens: Tensor,
pub query_layer_norm: LayerNorm,
device: Device,
}
impl Blip2Model {
pub fn new_with_device(
config: Blip2Config,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let vision_model = Blip2VisionModel::new_with_device(config.vision_config.clone(), device)?;
let qformer_model =
Blip2QFormerModel::new_with_device(config.qformer_config.clone(), device)?;
let vision_projection = Linear::new(
config.qformer_config.hidden_size,
config.vision_config.hidden_size,
false,
);
let text_projection = Linear::new(
config.qformer_config.hidden_size,
config.text_config.hidden_size,
false,
);
let query_tokens =
Tensor::randn(&[config.num_query_tokens, config.qformer_config.hidden_size])?;
let query_layer_norm = LayerNorm::new(
vec![config.qformer_config.hidden_size],
config.qformer_config.layer_norm_eps as f32,
)?;
Ok(Self {
config,
vision_model,
qformer_model,
vision_projection,
text_projection,
query_tokens,
query_layer_norm,
device,
})
}
pub fn new(config: Blip2Config) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
input_ids: &Tensor,
pixel_values: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Blip2Output, Box<dyn std::error::Error>> {
let batch_size = input_ids.shape()[0];
let image_embeds = if let Some(pixel_values) = pixel_values {
Some(self.vision_model.forward(pixel_values)?)
} else {
None
};
let query_embeds = self.get_query_embeddings(batch_size)?;
let qformer_outputs = self.qformer_model.forward(
input_ids,
image_embeds.as_ref(),
Some(&query_embeds),
attention_mask,
)?;
let image_features = if image_embeds.is_some() {
Some(self.vision_projection.forward(qformer_outputs.pooler_output.clone())?)
} else {
None
};
let text_features = self.text_projection.forward(qformer_outputs.pooler_output.clone())?;
Ok(Blip2Output {
last_hidden_state: qformer_outputs.last_hidden_state,
pooler_output: qformer_outputs.pooler_output,
image_features,
text_features,
logits: qformer_outputs.logits,
})
}
fn get_query_embeddings(
&self,
batch_size: usize,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let query_embeds = self.query_layer_norm.forward(self.query_tokens.clone())?;
let expanded_shape = vec![
batch_size,
self.config.num_query_tokens,
self.config.qformer_config.hidden_size,
];
let expanded_embeds = query_embeds.unsqueeze(0)?.broadcast_to(&expanded_shape)?;
Ok(expanded_embeds)
}
pub fn get_image_features(
&self,
pixel_values: &Tensor,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let image_embeds = self.vision_model.forward(pixel_values)?;
let batch_size = pixel_values.shape()[0];
let query_embeds = self.get_query_embeddings(batch_size)?;
let input_ids = Tensor::zeros(&[batch_size, 1])?;
let qformer_outputs = self.qformer_model.forward(
&input_ids,
Some(&image_embeds),
Some(&query_embeds),
None,
)?;
self.vision_projection
.forward(qformer_outputs.pooler_output.clone())
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
}
pub fn get_text_features(
&self,
input_ids: &Tensor,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let batch_size = input_ids.shape()[0];
let query_embeds = self.get_query_embeddings(batch_size)?;
let qformer_outputs =
self.qformer_model.forward(input_ids, None, Some(&query_embeds), None)?;
self.text_projection
.forward(qformer_outputs.pooler_output.clone())
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
}
}
pub struct Blip2ForConditionalGeneration {
pub blip2_model: Blip2Model,
pub language_model: Box<dyn LanguageModel>,
pub language_projection: Linear,
device: Device,
}
impl Blip2ForConditionalGeneration {
pub fn new_with_device(
config: Blip2Config,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let blip2_model = Blip2Model::new_with_device(config.clone(), device)?;
let language_model: Box<dyn LanguageModel> = if config.use_decoder_only_language_model {
Box::new(Blip2OptLanguageModel::new_with_device(
config.text_config.clone(),
device,
)?)
} else {
Box::new(Blip2T5LanguageModel::new_with_device(
config.text_config.clone(),
device,
)?)
};
let language_projection = Linear::new(
config.qformer_config.hidden_size,
config.text_config.hidden_size,
false,
);
Ok(Self {
blip2_model,
language_model,
language_projection,
device,
})
}
pub fn new(config: Blip2Config) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
input_ids: &Tensor,
pixel_values: Option<&Tensor>,
attention_mask: Option<&Tensor>,
labels: Option<&Tensor>,
) -> Result<Blip2ConditionalGenerationOutput, Box<dyn std::error::Error>> {
let batch_size = input_ids.shape()[0];
let image_features = if let Some(pixel_values) = pixel_values {
let image_embeds = self.blip2_model.vision_model.forward(pixel_values)?;
let query_embeds = self.blip2_model.get_query_embeddings(batch_size)?;
let dummy_input_ids = Tensor::zeros(&[batch_size, 1])?;
let qformer_outputs = self.blip2_model.qformer_model.forward(
&dummy_input_ids,
Some(&image_embeds),
Some(&query_embeds),
None,
)?;
let projected =
self.language_projection.forward(qformer_outputs.last_hidden_state.clone())?;
Some(projected)
} else {
None
};
let language_outputs = self.language_model.forward(
input_ids,
image_features.as_ref(),
attention_mask,
labels,
)?;
Ok(Blip2ConditionalGenerationOutput {
loss: language_outputs.loss,
logits: language_outputs.logits,
hidden_states: language_outputs.hidden_states,
image_features,
})
}
pub fn generate(
&self,
pixel_values: &Tensor,
input_ids: Option<&Tensor>,
max_length: usize,
temperature: f32,
top_p: f32,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let batch_size = pixel_values.shape()[0];
let image_embeds = self.blip2_model.vision_model.forward(pixel_values)?;
let query_embeds = self.blip2_model.get_query_embeddings(batch_size)?;
let dummy_input_ids = Tensor::zeros(&[batch_size, 1])?;
let qformer_outputs = self.blip2_model.qformer_model.forward(
&dummy_input_ids,
Some(&image_embeds),
Some(&query_embeds),
None,
)?;
let image_features =
self.language_projection.forward(qformer_outputs.last_hidden_state.clone())?;
let mut generated_ids = if let Some(input_ids) = input_ids {
input_ids.clone()
} else {
Tensor::full(
self.blip2_model.config.text_config.bos_token_id as f32,
vec![batch_size, 1],
)?
};
for _ in 0..max_length {
let outputs =
self.language_model.forward(&generated_ids, Some(&image_features), None, None)?;
let next_token_logits = outputs.logits.select(1, -1)?;
let next_token = self.sample_token(&next_token_logits, temperature, top_p)?;
generated_ids = Tensor::concat(&[generated_ids, next_token.clone()], 1)?;
if self.check_eos_token(&next_token)? {
break;
}
}
Ok(generated_ids)
}
fn sample_token(
&self,
logits: &Tensor,
temperature: f32,
_top_p: f32,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let scaled_logits = logits.div_scalar(temperature)?;
let probabilities = scaled_logits.softmax(-1)?;
let token_id = probabilities.argmax(-1)?;
token_id
.unsqueeze_i64(-1)
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
}
fn check_eos_token(&self, token: &Tensor) -> Result<bool, Box<dyn std::error::Error>> {
let token_id = token.item::<i32>()?;
Ok(token_id == self.blip2_model.config.text_config.eos_token_id)
}
}
#[derive(Debug, Clone)]
pub struct Blip2VisionModel {
pub config: Blip2VisionConfig,
pub patch_embedding: Blip2PatchEmbedding,
pub class_embedding: Tensor,
pub position_embedding: Tensor,
pub layers: Vec<Blip2VisionTransformerLayer>,
pub layer_norm: LayerNorm,
pub pooler: Option<Linear>,
device: Device,
}
impl Blip2VisionModel {
pub fn new_with_device(
config: Blip2VisionConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let patch_embedding = Blip2PatchEmbedding::new_with_device(&config, device)?;
let class_embedding = Tensor::randn(&[config.hidden_size])?;
let position_embedding = Tensor::randn(&[config.seq_len(), config.hidden_size])?;
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(Blip2VisionTransformerLayer::new_with_device(
&config, device,
)?);
}
let layer_norm = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let pooler = Some(Linear::new(config.hidden_size, config.hidden_size, true));
Ok(Self {
config,
patch_embedding,
class_embedding,
position_embedding,
layers,
layer_norm,
pooler,
device,
})
}
pub fn new(config: Blip2VisionConfig) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(&self, pixel_values: &Tensor) -> Result<Tensor, Box<dyn std::error::Error>> {
let batch_size = pixel_values.shape()[0];
let patch_embeds = self.patch_embedding.forward(pixel_values)?;
let class_token = self.class_embedding.unsqueeze(0)?.unsqueeze(0)?;
let class_token = class_token
.broadcast_to(&[batch_size, 1, self.config.hidden_size])?
.contiguous()?;
let embeddings = Tensor::concat(&[class_token, patch_embeds], 1)?;
let position_embeds = self.position_embedding.unsqueeze(0)?;
let position_embeds = position_embeds
.broadcast_to(&[batch_size, self.config.seq_len(), self.config.hidden_size])?
.contiguous()?;
let embeddings = embeddings.add(&position_embeds)?;
let mut hidden_states = embeddings;
for layer in &self.layers {
hidden_states = layer.forward(&hidden_states)?;
}
let hidden_states = self.layer_norm.forward(hidden_states)?;
if let Some(pooler) = &self.pooler {
let cls_token = hidden_states.select(1, 0)?;
let pooled = pooler.forward(cls_token)?;
let pooled = pooled.tanh()?;
let broadcasted = pooled.unsqueeze(1)?.broadcast_to(&[
batch_size,
hidden_states.shape()[1],
self.config.hidden_size,
])?;
Ok(broadcasted.contiguous()?)
} else {
Ok(hidden_states)
}
}
}
#[derive(Debug, Clone)]
pub struct Blip2PatchEmbedding {
pub projection: Linear,
pub patch_size: usize,
pub hidden_size: usize,
device: Device,
}
impl Blip2PatchEmbedding {
pub fn new_with_device(
config: &Blip2VisionConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let projection = Linear::new(
config.patch_size * config.patch_size * config.num_channels,
config.hidden_size,
true,
);
Ok(Self {
projection,
patch_size: config.patch_size,
hidden_size: config.hidden_size,
device,
})
}
pub fn new(config: &Blip2VisionConfig) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(&self, pixel_values: &Tensor) -> Result<Tensor, Box<dyn std::error::Error>> {
let batch_size = pixel_values.shape()[0];
let channels = pixel_values.shape()[1];
let height = pixel_values.shape()[2];
let width = pixel_values.shape()[3];
let num_patches_h = height / self.patch_size;
let num_patches_w = width / self.patch_size;
let num_patches = num_patches_h * num_patches_w;
let patches = pixel_values.reshape(&[
batch_size,
channels,
num_patches_h,
self.patch_size,
num_patches_w,
self.patch_size,
])?;
let patches = patches.permute(&[0, 2, 4, 1, 3, 5])?;
let patches_vec = patches.to_vec_f32()?;
let target_shape = vec![
batch_size,
num_patches,
channels * self.patch_size * self.patch_size,
];
let patches = Tensor::from_vec(patches_vec, &target_shape)?;
let result = self
.projection
.forward(patches)
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
Ok(result)
}
}
#[derive(Debug, Clone)]
pub struct Blip2VisionTransformerLayer {
pub self_attention: MultiHeadAttention,
pub layer_norm1: LayerNorm,
pub mlp: Blip2MLP,
pub layer_norm2: LayerNorm,
device: Device,
}
impl Blip2VisionTransformerLayer {
pub fn new_with_device(
config: &Blip2VisionConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let attention_config = AttentionConfig {
hidden_size: config.hidden_size,
num_heads: config.num_attention_heads,
head_dim: config.hidden_size / config.num_attention_heads,
dropout_prob: config.attention_dropout as f32,
bias: true,
max_seq_len: None,
};
let self_attention = MultiHeadAttention::new(
attention_config.hidden_size,
attention_config.num_heads,
attention_config.dropout_prob,
attention_config.bias,
)?;
let layer_norm1 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let mlp = Blip2MLP::new_with_device(
config.hidden_size,
config.intermediate_size,
&config.hidden_act,
device,
)?;
let layer_norm2 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
Ok(Self {
self_attention,
layer_norm1,
mlp,
layer_norm2,
device,
})
}
pub fn new(config: &Blip2VisionConfig) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor, Box<dyn std::error::Error>> {
let residual = hidden_states.clone();
let hidden_states = self.layer_norm1.forward(hidden_states.clone())?;
let attention_output = self.self_attention.forward_attention(
&hidden_states,
&hidden_states,
&hidden_states,
None,
false, )?;
let hidden_states = residual.add(&attention_output)?;
let residual = hidden_states.clone();
let hidden_states = self.layer_norm2.forward(hidden_states)?;
let mlp_output = self.mlp.forward(&hidden_states)?;
let hidden_states = residual.add(&mlp_output)?;
Ok(hidden_states)
}
}
#[derive(Debug, Clone)]
pub struct Blip2QFormerModel {
pub config: Blip2QFormerConfig,
pub embeddings: Blip2QFormerEmbeddings,
pub encoder_layers: Vec<Blip2QFormerLayer>,
pub pooler: Linear,
device: Device,
}
impl Blip2QFormerModel {
pub fn new_with_device(
config: Blip2QFormerConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let embeddings = Blip2QFormerEmbeddings::new_with_device(&config, device)?;
let mut encoder_layers = Vec::new();
for layer_idx in 0..config.num_hidden_layers {
let has_cross_attention = layer_idx % config.cross_attention_frequency == 0;
encoder_layers.push(Blip2QFormerLayer::new_with_device(
&config,
has_cross_attention,
device,
)?);
}
let pooler = Linear::new(config.hidden_size, config.hidden_size, true);
Ok(Self {
config,
embeddings,
encoder_layers,
pooler,
device,
})
}
pub fn new(config: Blip2QFormerConfig) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
input_ids: &Tensor,
encoder_hidden_states: Option<&Tensor>,
query_embeds: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Blip2QFormerOutput, Box<dyn std::error::Error>> {
let mut hidden_states = self.embeddings.forward(input_ids)?;
if let Some(query_embeds) = query_embeds {
hidden_states = Tensor::concat(&[query_embeds.clone(), hidden_states], 1)?;
}
if let Some(_enc_hidden) = encoder_hidden_states {}
for layer in self.encoder_layers.iter() {
hidden_states = layer.forward(&hidden_states, encoder_hidden_states, attention_mask)?;
}
let pooler_output = self.pooler.forward(hidden_states.select(1, 0)?)?;
let pooler_output = pooler_output.tanh()?;
let logits = Linear::new(self.config.hidden_size, self.config.vocab_size, false)
.forward(hidden_states.clone())?;
Ok(Blip2QFormerOutput {
last_hidden_state: hidden_states,
pooler_output,
logits,
})
}
}
#[derive(Debug, Clone)]
pub struct Blip2QFormerEmbeddings {
pub word_embeddings: Embedding,
pub position_embeddings: Embedding,
pub token_type_embeddings: Embedding,
pub layer_norm: LayerNorm,
pub dropout: f64,
device: Device,
}
impl Blip2QFormerEmbeddings {
pub fn new_with_device(
config: &Blip2QFormerConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let word_embeddings = Embedding::new(config.vocab_size, config.hidden_size, None)?;
let position_embeddings =
Embedding::new(config.max_position_embeddings, config.hidden_size, None)?;
let token_type_embeddings =
Embedding::new(config.type_vocab_size, config.hidden_size, None)?;
let layer_norm = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
Ok(Self {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout: config.hidden_dropout_prob,
device,
})
}
pub fn new(config: &Blip2QFormerConfig) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, Box<dyn std::error::Error>> {
let seq_length = input_ids.shape()[1];
let batch_size = input_ids.shape()[0];
let input_ids_vec: Vec<u32> =
input_ids.to_vec_f32()?.into_iter().map(|x| x as u32).collect();
let word_embeds = self.word_embeddings.forward(input_ids_vec)?;
let word_embeds = if word_embeds.shape().len() == 2
&& word_embeds.shape()[0] == batch_size * seq_length
{
word_embeds.reshape(&[batch_size, seq_length, word_embeds.shape()[1]])?
} else {
word_embeds
};
let position_ids = Tensor::range(0, seq_length as i64, DType::I64)?
.unsqueeze(0)?
.broadcast_to(&[batch_size, seq_length])?;
let position_ids_vec: Vec<u32> =
position_ids.to_vec_f32()?.into_iter().map(|x| x as u32).collect();
let position_embeds = self.position_embeddings.forward(position_ids_vec)?;
let position_embeds = if position_embeds.shape().len() == 2
&& position_embeds.shape()[0] == batch_size * seq_length
{
position_embeds.reshape(&[batch_size, seq_length, position_embeds.shape()[1]])?
} else {
position_embeds
};
let token_type_ids = Tensor::zeros(&[batch_size, seq_length])?;
let token_type_ids_vec: Vec<u32> =
token_type_ids.to_vec_f32()?.into_iter().map(|x| x as u32).collect();
let token_type_embeds = self.token_type_embeddings.forward(token_type_ids_vec)?;
let token_type_embeds = if token_type_embeds.shape().len() == 2
&& token_type_embeds.shape()[0] == batch_size * seq_length
{
token_type_embeds.reshape(&[batch_size, seq_length, token_type_embeds.shape()[1]])?
} else {
token_type_embeds
};
let embeddings = word_embeds.add(&position_embeds)?.add(&token_type_embeds)?;
let embeddings = self.layer_norm.forward(embeddings)?;
Ok(embeddings)
}
}
#[derive(Debug, Clone)]
pub struct Blip2QFormerLayer {
pub self_attention: MultiHeadAttention,
pub cross_attention: Option<MultiHeadAttention>,
pub encoder_projection: Option<Linear>,
pub layer_norm1: LayerNorm,
pub layer_norm2: LayerNorm,
pub layer_norm3: Option<LayerNorm>,
pub mlp: Blip2MLP,
device: Device,
}
impl Blip2QFormerLayer {
pub fn new_with_device(
config: &Blip2QFormerConfig,
has_cross_attention: bool,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let attention_config = AttentionConfig {
hidden_size: config.hidden_size,
num_heads: config.num_attention_heads,
head_dim: config.hidden_size / config.num_attention_heads,
dropout_prob: config.attention_probs_dropout_prob as f32,
bias: true,
max_seq_len: None,
};
let self_attention = MultiHeadAttention::new(
attention_config.hidden_size,
attention_config.num_heads,
attention_config.dropout_prob,
attention_config.bias,
)?;
let (cross_attention, encoder_projection) = if has_cross_attention {
(
Some(MultiHeadAttention::new(
attention_config.hidden_size,
attention_config.num_heads,
attention_config.dropout_prob,
attention_config.bias,
)?),
Some(Linear::new(1408, config.hidden_size, false)),
)
} else {
(None, None)
};
let layer_norm1 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let layer_norm2 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let layer_norm3 = if has_cross_attention {
Some(LayerNorm::new(
vec![config.hidden_size],
config.layer_norm_eps as f32,
)?)
} else {
None
};
let mlp = Blip2MLP::new_with_device(
config.hidden_size,
config.intermediate_size,
&config.hidden_act,
device,
)?;
Ok(Self {
self_attention,
cross_attention,
encoder_projection,
layer_norm1,
layer_norm2,
layer_norm3,
mlp,
device,
})
}
pub fn new(
config: &Blip2QFormerConfig,
has_cross_attention: bool,
) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, has_cross_attention, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
hidden_states: &Tensor,
encoder_hidden_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let residual = hidden_states.clone();
let hidden_states = self.layer_norm1.forward(hidden_states.clone())?;
let attention_output = self.self_attention.forward_attention(
&hidden_states,
&hidden_states,
&hidden_states,
attention_mask,
false, )?;
let hidden_states = residual.add(&attention_output)?;
let hidden_states = if let (
Some(cross_attention),
Some(encoder_hidden_states),
Some(layer_norm3),
Some(encoder_proj),
) = (
&self.cross_attention,
encoder_hidden_states,
&self.layer_norm3,
&self.encoder_projection,
) {
let residual = hidden_states.clone();
let hidden_states_norm = layer_norm3.forward(hidden_states.clone())?;
let projected_encoder = encoder_proj.forward(encoder_hidden_states.clone())?;
let cross_attention_output = cross_attention.forward_attention(
&hidden_states_norm,
&projected_encoder,
&projected_encoder,
None,
false, )?;
residual.add(&cross_attention_output)?
} else {
hidden_states
};
let residual = hidden_states.clone();
let hidden_states = self.layer_norm2.forward(hidden_states)?;
let mlp_output = self.mlp.forward(&hidden_states)?;
let hidden_states = residual.add(&mlp_output)?;
Ok(hidden_states)
}
}
#[derive(Debug, Clone)]
pub struct Blip2MLP {
pub linear1: Linear,
pub linear2: Linear,
pub activation: ActivationType,
device: Device,
}
impl Blip2MLP {
pub fn new_with_device(
hidden_size: usize,
intermediate_size: usize,
activation: &str,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let linear1 = Linear::new(hidden_size, intermediate_size, true);
let linear2 = Linear::new(intermediate_size, hidden_size, true);
let activation = match activation {
"gelu" => ActivationType::GELU,
"relu" => ActivationType::ReLU,
"silu" | "swish" => ActivationType::SiLU,
"tanh" => ActivationType::Tanh,
"sigmoid" => ActivationType::Sigmoid,
_ => ActivationType::GELU, };
Ok(Self {
linear1,
linear2,
activation,
device,
})
}
pub fn new(
hidden_size: usize,
intermediate_size: usize,
activation: &str,
) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(hidden_size, intermediate_size, activation, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor, Box<dyn std::error::Error>> {
use trustformers_core::ops::activations::*;
let hidden_states = self.linear1.forward(hidden_states.clone())?;
let hidden_states = match self.activation {
ActivationType::GELU => gelu(&hidden_states)?,
ActivationType::ReLU => relu(&hidden_states)?,
ActivationType::SiLU => silu(&hidden_states)?,
ActivationType::Tanh => tanh(&hidden_states)?,
ActivationType::Sigmoid => sigmoid(&hidden_states)?,
};
let hidden_states = self.linear2.forward(hidden_states)?;
Ok(hidden_states)
}
}
pub trait LanguageModel: Send + Sync {
fn forward(
&self,
input_ids: &Tensor,
encoder_hidden_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
labels: Option<&Tensor>,
) -> Result<LanguageModelOutput, Box<dyn std::error::Error>>;
}
#[derive(Debug, Clone)]
pub struct Blip2OptLanguageModel {
pub config: Blip2TextConfig,
pub embeddings: Embedding,
pub layers: Vec<Blip2OptLayer>,
pub layer_norm: LayerNorm,
pub lm_head: Linear,
device: Device,
}
impl Blip2OptLanguageModel {
pub fn new_with_device(
config: Blip2TextConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let embeddings = Embedding::new(config.vocab_size, config.hidden_size, None)?;
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(Blip2OptLayer::new_with_device(&config, device)?);
}
let layer_norm = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let lm_head = Linear::new(config.hidden_size, config.vocab_size, false);
Ok(Self {
config,
embeddings,
layers,
layer_norm,
lm_head,
device,
})
}
pub fn new(config: Blip2TextConfig) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
}
impl LanguageModel for Blip2OptLanguageModel {
fn forward(
&self,
input_ids: &Tensor,
encoder_hidden_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
labels: Option<&Tensor>,
) -> Result<LanguageModelOutput, Box<dyn std::error::Error>> {
let batch_size = input_ids.shape()[0];
let seq_length = input_ids.shape()[1];
let input_ids_vec: Vec<u32> =
input_ids.to_vec_f32()?.into_iter().map(|x| x as u32).collect();
let hidden_states = self.embeddings.forward(input_ids_vec)?;
let mut hidden_states = if hidden_states.shape().len() == 2
&& hidden_states.shape()[0] == batch_size * seq_length
{
hidden_states.reshape(&[batch_size, seq_length, hidden_states.shape()[1]])?
} else {
hidden_states
};
if let Some(encoder_hidden_states) = encoder_hidden_states {
hidden_states = Tensor::concat(&[encoder_hidden_states.clone(), hidden_states], 1)?;
}
for layer in &self.layers {
hidden_states = layer.forward(&hidden_states, attention_mask)?;
}
let hidden_states = self.layer_norm.forward(hidden_states)?;
let logits = self.lm_head.forward(hidden_states.clone())?;
let loss = if let Some(_labels) = labels {
Some(Tensor::scalar(1.0)?)
} else {
None
};
Ok(LanguageModelOutput {
loss,
logits,
hidden_states,
})
}
}
#[derive(Debug, Clone)]
pub struct Blip2OptLayer {
pub self_attention: MultiHeadAttention,
pub layer_norm1: LayerNorm,
pub mlp: Blip2MLP,
pub layer_norm2: LayerNorm,
device: Device,
}
impl Blip2OptLayer {
pub fn new_with_device(
config: &Blip2TextConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let attention_config = AttentionConfig {
hidden_size: config.hidden_size,
num_heads: config.num_attention_heads,
head_dim: config.hidden_size / config.num_attention_heads,
dropout_prob: config.attention_dropout as f32,
bias: true,
max_seq_len: None,
};
let self_attention = MultiHeadAttention::new(
attention_config.hidden_size,
attention_config.num_heads,
attention_config.dropout_prob,
attention_config.bias,
)?;
let layer_norm1 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let mlp = Blip2MLP::new_with_device(
config.hidden_size,
config.intermediate_size,
&config.hidden_act,
device,
)?;
let layer_norm2 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
Ok(Self {
self_attention,
layer_norm1,
mlp,
layer_norm2,
device,
})
}
pub fn new(config: &Blip2TextConfig) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
hidden_states: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let residual = hidden_states.clone();
let hidden_states = self.layer_norm1.forward(hidden_states.clone())?;
let attention_output = self.self_attention.forward_attention(
&hidden_states,
&hidden_states,
&hidden_states,
attention_mask,
false, )?;
let hidden_states = residual.add(&attention_output)?;
let residual = hidden_states.clone();
let hidden_states = self.layer_norm2.forward(hidden_states)?;
let mlp_output = self.mlp.forward(&hidden_states)?;
let hidden_states = residual.add(&mlp_output)?;
Ok(hidden_states)
}
}
#[derive(Debug, Clone)]
pub struct Blip2T5LanguageModel {
pub config: Blip2TextConfig,
pub embeddings: Embedding,
pub encoder_layers: Vec<Blip2OptLayer>,
pub decoder_layers: Vec<Blip2OptLayer>,
pub layer_norm: LayerNorm,
pub lm_head: Linear,
device: Device,
}
impl Blip2T5LanguageModel {
pub fn new_with_device(
config: Blip2TextConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let embeddings = Embedding::new(config.vocab_size, config.hidden_size, None)?;
let mut encoder_layers = Vec::new();
let mut decoder_layers = Vec::new();
for _ in 0..config.num_hidden_layers {
encoder_layers.push(Blip2OptLayer::new_with_device(&config, device)?);
decoder_layers.push(Blip2OptLayer::new_with_device(&config, device)?);
}
let layer_norm = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let lm_head = Linear::new(config.hidden_size, config.vocab_size, false);
Ok(Self {
config,
embeddings,
encoder_layers,
decoder_layers,
layer_norm,
lm_head,
device,
})
}
pub fn new(config: Blip2TextConfig) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
}
impl LanguageModel for Blip2T5LanguageModel {
fn forward(
&self,
input_ids: &Tensor,
encoder_hidden_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
labels: Option<&Tensor>,
) -> Result<LanguageModelOutput, Box<dyn std::error::Error>> {
let input_ids_vec: Vec<u32> =
input_ids.to_vec_f32()?.into_iter().map(|x| x as u32).collect();
let mut hidden_states = self.embeddings.forward(input_ids_vec)?;
for layer in &self.encoder_layers {
hidden_states = layer.forward(&hidden_states, attention_mask)?;
}
if let Some(encoder_hidden_states) = encoder_hidden_states {
hidden_states = Tensor::concat(&[encoder_hidden_states.clone(), hidden_states], 1)?;
}
for layer in &self.decoder_layers {
hidden_states = layer.forward(&hidden_states, attention_mask)?;
}
let hidden_states = self.layer_norm.forward(hidden_states)?;
let logits = self.lm_head.forward(hidden_states.clone())?;
let loss = if let Some(_labels) = labels {
Some(Tensor::scalar(1.0)?)
} else {
None
};
Ok(LanguageModelOutput {
loss,
logits,
hidden_states,
})
}
}
#[derive(Debug, Clone)]
pub struct Blip2Output {
pub last_hidden_state: Tensor,
pub pooler_output: Tensor,
pub image_features: Option<Tensor>,
pub text_features: Tensor,
pub logits: Tensor,
}
#[derive(Debug, Clone)]
pub struct Blip2ConditionalGenerationOutput {
pub loss: Option<Tensor>,
pub logits: Tensor,
pub hidden_states: Tensor,
pub image_features: Option<Tensor>,
}
#[derive(Debug, Clone)]
pub struct Blip2QFormerOutput {
pub last_hidden_state: Tensor,
pub pooler_output: Tensor,
pub logits: Tensor,
}
#[derive(Debug, Clone)]
pub struct LanguageModelOutput {
pub loss: Option<Tensor>,
pub logits: Tensor,
pub hidden_states: Tensor,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore] fn test_blip2_model_creation() {
let config = Blip2Config::default();
let model = Blip2Model::new(config);
assert!(model.is_ok());
}
#[test]
#[ignore] fn test_blip2_vision_model() {
let config = Blip2VisionConfig::default();
let model = Blip2VisionModel::new(config);
assert!(model.is_ok());
}
#[test]
#[ignore] fn test_blip2_qformer_model() {
let config = Blip2QFormerConfig::default();
let model = Blip2QFormerModel::new(config);
assert!(model.is_ok());
}
#[test]
fn test_blip2_patch_embedding() {
let config = Blip2VisionConfig::default();
let embedding = Blip2PatchEmbedding::new(&config);
assert!(embedding.is_ok());
}
#[test]
fn test_blip2_mlp() {
let mlp = Blip2MLP::new(768, 3072, "gelu");
assert!(mlp.is_ok());
}
#[test]
#[ignore] fn test_blip2_opt_language_model() {
let config = Blip2TextConfig::opt_2_7b();
let model = Blip2OptLanguageModel::new(config);
assert!(model.is_ok());
}
#[test]
#[ignore] fn test_blip2_t5_language_model() {
let config = Blip2TextConfig::flan_t5_xl();
let model = Blip2T5LanguageModel::new(config);
assert!(model.is_ok());
}
}