use crate::flamingo::config::{
FlamingoConfig, FlamingoLanguageConfig, FlamingoPerceiverConfig, FlamingoVisionConfig,
FlamingoXAttentionConfig,
};
use trustformers_core::{
device::Device,
errors::TrustformersError,
kernels::fused_ops::ActivationType,
layers::{
attention::{AttentionConfig, MultiHeadAttention},
conv2d::Conv2d,
embedding::Embedding,
layernorm::LayerNorm,
linear::Linear,
},
tensor::{Tensor, TensorType},
traits::Layer,
};
#[derive(Debug, Clone)]
pub struct FlamingoModel {
pub config: FlamingoConfig,
pub vision_encoder: FlamingoVisionEncoder,
pub language_model: FlamingoLanguageModel,
pub perceiver_resampler: PerceiverResampler,
pub vision_projection: Linear,
pub media_token_embedding: Embedding,
device: Device,
}
impl FlamingoModel {
pub fn new_with_device(
config: FlamingoConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let vision_encoder =
FlamingoVisionEncoder::new_with_device(config.vision_config.clone(), device)?;
let language_model = FlamingoLanguageModel::new_with_device(
config.language_config.clone(),
config.cross_attention_config.clone(),
config.cross_attention_layers.clone(),
device,
)?;
let perceiver_resampler = PerceiverResampler::new_with_device(
config.perceiver_config.clone(),
config.vision_config.hidden_size,
device,
)?;
let vision_projection = Linear::new(
config.perceiver_config.latent_dim,
config.vision_language_dim,
false,
);
let media_token_embedding = Embedding::new(
config.media_token_length,
config.language_config.hidden_size,
None,
)?;
Ok(Self {
config,
vision_encoder,
language_model,
perceiver_resampler,
vision_projection,
media_token_embedding,
device,
})
}
pub fn new(config: FlamingoConfig) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward_train(
&self,
input_ids: &Tensor,
attention_mask: &Tensor,
pixel_values: Option<&Tensor>,
media_locations: Option<&Tensor>, labels: Option<&Tensor>,
) -> Result<FlamingoOutput, Box<dyn std::error::Error>> {
let vision_features = if let Some(pixel_values) = pixel_values {
Some(self.encode_vision(pixel_values)?)
} else {
None
};
let language_output = self.language_model.forward(
input_ids,
attention_mask,
vision_features.as_ref(),
media_locations,
)?;
let loss = if let Some(labels) = labels {
Some(self.compute_language_modeling_loss(&language_output.logits, labels)?)
} else {
None
};
Ok(FlamingoOutput {
logits: language_output.logits,
vision_features,
cross_attention_weights: language_output.cross_attention_weights,
loss,
})
}
pub fn generate_with_shots(
&self,
input_ids: &Tensor,
attention_mask: &Tensor,
pixel_values: Option<&Tensor>,
media_locations: Option<&Tensor>,
max_new_tokens: usize,
temperature: f64,
do_sample: bool,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let batch_size = input_ids.shape()[0];
let mut current_ids = input_ids.clone();
let mut current_mask = attention_mask.clone();
let vision_features = if let Some(pixel_values) = pixel_values {
Some(self.encode_vision(pixel_values)?)
} else {
None
};
for _ in 0..max_new_tokens {
let output = self.language_model.forward(
¤t_ids,
¤t_mask,
vision_features.as_ref(),
media_locations,
)?;
let logits_shape = output.logits.shape();
let seq_axis = logits_shape.len() - 2; let seq_len = logits_shape[seq_axis];
let logits = output.logits.slice(seq_axis, seq_len - 1, seq_len)?.squeeze_i64(-2)?; let next_token = if do_sample {
self.sample_token(&logits, temperature)?
} else {
logits.argmax(-1)?
};
let concat_axis = current_ids.shape().len() - 1;
current_ids = Tensor::concat(
&[current_ids.clone(), next_token.unsqueeze_i64(-1)?],
concat_axis,
)?;
let new_mask = Tensor::ones_dtype(TensorType::F32, &[batch_size, 1])?;
let mask_axis = current_mask.shape().len() - 1;
current_mask = Tensor::concat(&[current_mask.clone(), new_mask], mask_axis)?;
if self.all_sequences_finished(&next_token)? {
break;
}
}
Ok(current_ids)
}
pub fn encode_vision(
&self,
pixel_values: &Tensor,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let vision_features = self.vision_encoder.forward(pixel_values)?;
let resampled_features = self.perceiver_resampler.forward(&vision_features)?;
let projected_features = self.vision_projection.forward(resampled_features)?;
Ok(projected_features)
}
fn compute_language_modeling_loss(
&self,
logits: &Tensor,
labels: &Tensor,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let logits_shape = logits.shape();
let logits_axis = logits_shape.len() - 2;
let seq_len = logits_shape[logits_axis];
let shift_logits = logits.slice(logits_axis, 0, seq_len - 1)?;
let labels_shape = labels.shape();
let labels_axis = labels_shape.len() - 1;
let labels_len = labels_shape[labels_axis];
let shift_labels = labels.slice(labels_axis, 1, labels_len)?;
let loss = shift_logits.cross_entropy(&shift_labels, "mean")?;
Ok(loss)
}
fn sample_token(
&self,
logits: &Tensor,
temperature: f64,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let scaled_logits = (logits / temperature)?;
let probs = scaled_logits.softmax(-1)?;
let sampled = probs.multinomial(1, true)?;
Ok(sampled)
}
fn all_sequences_finished(&self, tokens: &Tensor) -> Result<bool, Box<dyn std::error::Error>> {
let eos_mask = tokens.eq_scalar(self.config.language_config.eos_token_id as f64)?;
let all_eos = eos_mask.all()?;
if let Tensor::F32(arr) = all_eos {
let val = arr.iter().next().unwrap_or(&0.0);
Ok(*val > 0.5)
} else {
Ok(false)
}
}
}
#[derive(Debug, Clone)]
pub struct FlamingoVisionEncoder {
pub config: FlamingoVisionConfig,
pub patch_embedding: Conv2d,
pub class_embedding: Tensor,
pub position_embedding: Tensor,
pub pre_layer_norm: LayerNorm,
pub layers: Vec<FlamingoVisionLayer>,
pub post_layer_norm: LayerNorm,
device: Device,
}
impl FlamingoVisionEncoder {
pub fn new_with_device(
config: FlamingoVisionConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let patch_embedding = Conv2d::new(
config.num_channels,
config.hidden_size,
(config.patch_size, config.patch_size),
(config.patch_size, config.patch_size),
(0, 0),
false,
)?;
let class_embedding = Tensor::randn(&[config.hidden_size])?;
let position_embedding = Tensor::randn(&[config.seq_len(), config.hidden_size])?;
let pre_layer_norm =
LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let post_layer_norm =
LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(FlamingoVisionLayer::new_with_device(&config, device)?);
}
Ok(Self {
config,
patch_embedding,
class_embedding,
position_embedding,
pre_layer_norm,
layers,
post_layer_norm,
device,
})
}
pub fn new(config: FlamingoVisionConfig) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(&self, pixel_values: &Tensor) -> Result<Tensor, Box<dyn std::error::Error>> {
let batch_size = pixel_values.shape()[0];
let patch_embeds = self.patch_embedding.forward(pixel_values.clone())?;
let patch_embeds = patch_embeds.flatten(2, -1)?.transpose_i64(-1, -2)?;
let class_embeds = self.class_embedding.unsqueeze(0)?.unsqueeze(0)?.broadcast_to(&[
batch_size,
1,
self.config.hidden_size,
])?;
let embeddings = Tensor::concat(&[class_embeds, patch_embeds], 1)?;
let embeddings = (&embeddings + &self.position_embedding.unsqueeze(0)?)?;
let mut hidden_states = self.pre_layer_norm.forward(embeddings)?;
for layer in &self.layers {
hidden_states = layer.forward(&hidden_states)?;
}
hidden_states = self.post_layer_norm.forward(hidden_states)?;
Ok(hidden_states)
}
}
#[derive(Debug, Clone)]
pub struct FlamingoVisionLayer {
pub self_attention: MultiHeadAttention,
pub mlp: FlamingoMLP,
pub layer_norm1: LayerNorm,
pub layer_norm2: LayerNorm,
device: Device,
}
impl FlamingoVisionLayer {
pub fn new_with_device(
config: &FlamingoVisionConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let attention_config = AttentionConfig {
hidden_size: config.hidden_size,
num_heads: config.num_attention_heads,
head_dim: config.hidden_size / config.num_attention_heads,
dropout_prob: config.attention_dropout as f32,
bias: true,
max_seq_len: None,
};
let self_attention = MultiHeadAttention::new(
attention_config.hidden_size,
attention_config.num_heads,
attention_config.dropout_prob,
attention_config.bias,
)?;
let mlp = FlamingoMLP::new_with_device(
config.hidden_size,
config.intermediate_size,
&config.hidden_act,
device,
)?;
let layer_norm1 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let layer_norm2 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
Ok(Self {
self_attention,
mlp,
layer_norm1,
layer_norm2,
device,
})
}
pub fn new(config: &FlamingoVisionConfig) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(&self, hidden_states: &Tensor) -> Result<Tensor, Box<dyn std::error::Error>> {
let normed_states = self.layer_norm1.forward(hidden_states.clone())?;
let attention_output = self.self_attention.forward(normed_states)?;
let hidden_states = (hidden_states + &attention_output)?;
let normed_states = self.layer_norm2.forward(hidden_states.clone())?;
let mlp_output = self.mlp.forward(&normed_states)?;
let hidden_states = (&hidden_states + &mlp_output)?;
Ok(hidden_states)
}
}
#[derive(Debug, Clone)]
pub struct FlamingoLanguageModel {
pub config: FlamingoLanguageConfig,
pub token_embedding: Embedding,
pub position_embedding: Embedding,
pub layers: Vec<FlamingoLanguageLayer>,
pub final_layer_norm: LayerNorm,
pub lm_head: Linear,
pub cross_attention_layers: Vec<usize>,
device: Device,
}
impl FlamingoLanguageModel {
pub fn new_with_device(
config: FlamingoLanguageConfig,
cross_attention_config: FlamingoXAttentionConfig,
cross_attention_layers: Vec<usize>,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let token_embedding = Embedding::new(config.vocab_size, config.hidden_size, None)?;
let position_embedding =
Embedding::new(config.max_position_embeddings, config.hidden_size, None)?;
let mut layers = Vec::new();
for i in 0..config.num_hidden_layers {
let has_cross_attention = cross_attention_layers.contains(&i);
layers.push(FlamingoLanguageLayer::new_with_device(
&config,
&cross_attention_config,
has_cross_attention,
device,
)?);
}
let final_layer_norm =
LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let lm_head = Linear::new(config.hidden_size, config.vocab_size, false);
Ok(Self {
config,
token_embedding,
position_embedding,
layers,
final_layer_norm,
lm_head,
cross_attention_layers,
device,
})
}
pub fn new(
config: FlamingoLanguageConfig,
cross_attention_config: FlamingoXAttentionConfig,
cross_attention_layers: Vec<usize>,
) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(
config,
cross_attention_config,
cross_attention_layers,
Device::CPU,
)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
input_ids: &Tensor,
attention_mask: &Tensor,
vision_features: Option<&Tensor>,
media_locations: Option<&Tensor>,
) -> Result<FlamingoLanguageOutput, Box<dyn std::error::Error>> {
let seq_len = input_ids.shape()[1];
let input_ids_vec: Vec<u32> =
input_ids.to_vec_f32()?.into_iter().map(|x| x as u32).collect();
let mut hidden_states = self.token_embedding.forward(input_ids_vec)?;
let position_ids = Tensor::range(0, seq_len as i64, TensorType::I64)?.unsqueeze(0)?;
let position_ids_vec: Vec<u32> =
position_ids.to_vec_f32()?.into_iter().map(|x| x as u32).collect();
let position_embeddings = self.position_embedding.forward(position_ids_vec)?;
hidden_states = (&hidden_states + &position_embeddings)?;
let mut cross_attention_weights = Vec::new();
for (i, layer) in self.layers.iter().enumerate() {
let layer_output = layer.forward(
&hidden_states,
attention_mask,
vision_features,
media_locations,
)?;
hidden_states = layer_output.hidden_states;
if self.cross_attention_layers.contains(&i) {
if let Some(attn_weights) = layer_output.cross_attention_weights {
cross_attention_weights.push(attn_weights);
}
}
}
hidden_states = self.final_layer_norm.forward(hidden_states)?;
let logits = self.lm_head.forward(hidden_states.clone())?;
Ok(FlamingoLanguageOutput {
logits,
hidden_states,
cross_attention_weights,
})
}
}
#[derive(Debug, Clone)]
pub struct FlamingoLanguageLayer {
pub self_attention: MultiHeadAttention,
pub cross_attention: Option<GatedCrossAttention>,
pub mlp: FlamingoMLP,
pub layer_norm1: LayerNorm,
pub layer_norm2: LayerNorm,
pub layer_norm3: Option<LayerNorm>, device: Device,
}
impl FlamingoLanguageLayer {
pub fn new_with_device(
config: &FlamingoLanguageConfig,
cross_attention_config: &FlamingoXAttentionConfig,
has_cross_attention: bool,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let attention_config = AttentionConfig {
hidden_size: config.hidden_size,
num_heads: config.num_attention_heads,
head_dim: config.hidden_size / config.num_attention_heads,
dropout_prob: config.attention_dropout as f32,
bias: true,
max_seq_len: None,
};
let self_attention = MultiHeadAttention::new(
attention_config.hidden_size,
attention_config.num_heads,
attention_config.dropout_prob,
attention_config.bias,
)?;
let mlp = FlamingoMLP::new_with_device(
config.hidden_size,
config.intermediate_size,
&config.hidden_act,
device,
)?;
let layer_norm1 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let layer_norm2 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
let (cross_attention, layer_norm3) = if has_cross_attention {
let cross_attn = GatedCrossAttention::new_with_device(
config.hidden_size,
cross_attention_config.clone(),
device,
)?;
let norm3 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps as f32)?;
(Some(cross_attn), Some(norm3))
} else {
(None, None)
};
Ok(Self {
self_attention,
cross_attention,
mlp,
layer_norm1,
layer_norm2,
layer_norm3,
device,
})
}
pub fn new(
config: &FlamingoLanguageConfig,
cross_attention_config: &FlamingoXAttentionConfig,
has_cross_attention: bool,
) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(
config,
cross_attention_config,
has_cross_attention,
Device::CPU,
)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
hidden_states: &Tensor,
attention_mask: &Tensor,
vision_features: Option<&Tensor>,
media_locations: Option<&Tensor>,
) -> Result<FlamingoLanguageLayerOutput, Box<dyn std::error::Error>> {
let normed_states = self.layer_norm1.forward(hidden_states.clone())?;
let attention_output = self.self_attention.forward_self_attention(
&normed_states,
Some(attention_mask),
false, )?;
let hidden_states = (hidden_states + &attention_output)?;
let mut cross_attention_weights = None;
let hidden_states = if let (Some(cross_attention), Some(vision_features)) =
(&self.cross_attention, &vision_features)
{
let layer_norm3 = self.layer_norm3.as_ref().ok_or_else(|| {
TrustformersError::invalid_config(
"layer_norm3 must be present when cross_attention is used".to_string(),
)
})?;
let normed_states = layer_norm3.forward(hidden_states.clone())?;
let cross_output =
cross_attention.forward(&normed_states, vision_features, media_locations)?;
cross_attention_weights = cross_output.attention_weights;
(&hidden_states + &cross_output.hidden_states)?
} else {
hidden_states
};
let normed_states = self.layer_norm2.forward(hidden_states.clone())?;
let mlp_output = self.mlp.forward(&normed_states)?;
let hidden_states = (&hidden_states + &mlp_output)?;
Ok(FlamingoLanguageLayerOutput {
hidden_states,
cross_attention_weights,
})
}
}
#[derive(Debug, Clone)]
pub struct GatedCrossAttention {
pub config: FlamingoXAttentionConfig,
pub q_proj: Linear,
pub k_proj: Linear,
pub v_proj: Linear,
pub o_proj: Linear,
pub gate_proj: Linear,
pub layer_norm: LayerNorm,
device: Device,
}
impl GatedCrossAttention {
pub fn new_with_device(
hidden_size: usize,
config: FlamingoXAttentionConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let q_proj = Linear::new(hidden_size, config.cross_attention_dim, config.use_bias);
let k_proj = Linear::new(
config.cross_attention_dim,
config.cross_attention_dim,
config.use_bias,
);
let v_proj = Linear::new(
config.cross_attention_dim,
config.cross_attention_dim,
config.use_bias,
);
let o_proj = Linear::new(config.cross_attention_dim, hidden_size, config.use_bias);
let gate_proj = Linear::new(hidden_size, hidden_size, true);
let layer_norm = LayerNorm::new(vec![hidden_size], config.layer_norm_eps as f32)?;
Ok(Self {
config,
q_proj,
k_proj,
v_proj,
o_proj,
gate_proj,
layer_norm,
device,
})
}
pub fn new(
hidden_size: usize,
config: FlamingoXAttentionConfig,
) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(hidden_size, config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
hidden_states: &Tensor,
vision_features: &Tensor,
media_locations: Option<&Tensor>,
) -> Result<GatedCrossAttentionOutput, Box<dyn std::error::Error>> {
let batch_size = hidden_states.shape()[0];
let seq_len = hidden_states.shape()[1];
let vision_seq_len = vision_features.shape()[1];
let hidden_states = hidden_states.contiguous()?;
let vision_features = vision_features.contiguous()?;
let queries = self.q_proj.forward(hidden_states.clone())?;
let keys = self.k_proj.forward(vision_features.clone())?;
let values = self.v_proj.forward(vision_features.clone())?;
let head_dim = self.config.cross_attention_dim / self.config.num_heads;
let queries = queries
.reshape(&[batch_size, seq_len, self.config.num_heads, head_dim])?
.transpose(1, 2)?
.contiguous()?;
let keys = keys
.reshape(&[batch_size, vision_seq_len, self.config.num_heads, head_dim])?
.transpose(1, 2)?
.contiguous()?;
let values = values
.reshape(&[batch_size, vision_seq_len, self.config.num_heads, head_dim])?
.transpose(1, 2)?
.contiguous()?;
let scale = (head_dim as f64).sqrt().recip();
let keys_t = keys.transpose_i64(-1, -2)?.contiguous()?;
let attention_scores = (&queries.matmul(&keys_t)? * scale)?;
let attention_scores = if let Some(media_mask) = media_locations {
let expanded_mask = media_mask
.unsqueeze(1)?
.unsqueeze_i64(-1)?
.broadcast_to(&attention_scores.shape())?;
let large_neg =
Tensor::full_with_dtype(&attention_scores.shape(), -1e9, TensorType::F32)?;
attention_scores.where_cond(&expanded_mask, &large_neg)?
} else {
attention_scores
};
let attention_weights = attention_scores.softmax(-1)?;
let attention_output = attention_weights.contiguous()?.matmul(&values)?;
let attention_output = attention_output.transpose(1, 2)?.contiguous()?.reshape(&[
batch_size,
seq_len,
self.config.cross_attention_dim,
])?;
let output = self.o_proj.forward(attention_output)?;
let gate = match self.config.gating_type.as_str() {
"tanh" => self.gate_proj.forward(hidden_states.clone())?.tanh()?,
"sigmoid" => self.gate_proj.forward(hidden_states.clone())?.sigmoid()?,
"relu" => self.gate_proj.forward(hidden_states.clone())?.relu()?,
_ => self.gate_proj.forward(hidden_states.clone())?.tanh()?,
};
let output = output.contiguous()?;
let gate = gate.contiguous()?;
let gated_output = (&output * &gate)?;
Ok(GatedCrossAttentionOutput {
hidden_states: gated_output,
attention_weights: Some(attention_weights),
})
}
}
#[derive(Debug, Clone)]
pub struct PerceiverResampler {
pub config: FlamingoPerceiverConfig,
pub latent_queries: Tensor,
pub input_projection: Linear,
pub layers: Vec<PerceiverLayer>,
pub output_projection: Linear,
device: Device,
}
impl PerceiverResampler {
pub fn new_with_device(
config: FlamingoPerceiverConfig,
input_dim: usize,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let latent_queries = Tensor::randn(&[config.num_latents, config.latent_dim])?;
let input_projection = Linear::new(input_dim, config.latent_dim, false);
let output_projection = Linear::new(config.latent_dim, config.latent_dim, false);
let mut layers = Vec::new();
for _ in 0..config.num_layers {
layers.push(PerceiverLayer::new_with_device(&config, device)?);
}
Ok(Self {
config,
latent_queries,
input_projection,
layers,
output_projection,
device,
})
}
pub fn new(
config: FlamingoPerceiverConfig,
input_dim: usize,
) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, input_dim, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(&self, vision_features: &Tensor) -> Result<Tensor, Box<dyn std::error::Error>> {
let batch_size = vision_features.shape()[0];
let projected_features = self.input_projection.forward(vision_features.clone())?;
let mut latents = self.latent_queries.unsqueeze(0)?.broadcast_to(&[
batch_size,
self.config.num_latents,
self.config.latent_dim,
])?;
for layer in &self.layers {
latents = layer.forward(&latents, &projected_features)?;
}
let output = self.output_projection.forward(latents)?;
Ok(output)
}
}
#[derive(Debug, Clone)]
pub struct PerceiverLayer {
pub cross_attention: MultiHeadAttention,
pub self_attention: MultiHeadAttention,
pub mlp: FlamingoMLP,
pub layer_norm1: LayerNorm,
pub layer_norm2: LayerNorm,
pub layer_norm3: LayerNorm,
device: Device,
}
impl PerceiverLayer {
pub fn new_with_device(
config: &FlamingoPerceiverConfig,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let cross_attention_config = AttentionConfig {
hidden_size: config.latent_dim,
num_heads: config.num_heads,
head_dim: config.latent_dim / config.num_heads,
dropout_prob: config.attention_dropout as f32,
bias: true,
max_seq_len: None,
};
let self_attention_config = AttentionConfig {
hidden_size: config.latent_dim,
num_heads: config.num_heads,
head_dim: config.latent_dim / config.num_heads,
dropout_prob: config.attention_dropout as f32,
bias: true,
max_seq_len: None,
};
let cross_attention = MultiHeadAttention::from_config(cross_attention_config)?;
let self_attention = MultiHeadAttention::from_config(self_attention_config)?;
let mlp = FlamingoMLP::new_with_device(
config.latent_dim,
config.mlp_hidden_size,
"gelu",
device,
)?;
let layer_norm1 = LayerNorm::new(vec![config.latent_dim], config.layer_norm_eps as f32)?;
let layer_norm2 = LayerNorm::new(vec![config.latent_dim], config.layer_norm_eps as f32)?;
let layer_norm3 = LayerNorm::new(vec![config.latent_dim], config.layer_norm_eps as f32)?;
Ok(Self {
cross_attention,
self_attention,
mlp,
layer_norm1,
layer_norm2,
layer_norm3,
device,
})
}
pub fn new(config: &FlamingoPerceiverConfig) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
latents: &Tensor,
features: &Tensor,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let normed_latents = self.layer_norm1.forward(latents.clone())?;
let cross_attn_output =
self.cross_attention.forward_cross(&normed_latents, features, features, None)?;
let latents = (latents + &cross_attn_output)?;
let normed_latents = self.layer_norm2.forward(latents.clone())?;
let self_attn_output = self.self_attention.forward(normed_latents)?;
let latents = (&latents + &self_attn_output)?;
let normed_latents = self.layer_norm3.forward(latents.clone())?;
let mlp_output = self.mlp.forward(&normed_latents)?;
let latents = (&latents + &mlp_output)?;
Ok(latents)
}
}
#[derive(Debug, Clone)]
pub struct FlamingoMLP {
pub fc1: Linear,
pub fc2: Linear,
pub activation: ActivationType,
device: Device,
}
impl FlamingoMLP {
pub fn new_with_device(
hidden_size: usize,
intermediate_size: usize,
activation: &str,
device: Device,
) -> Result<Self, Box<dyn std::error::Error>> {
let fc1 = Linear::new(hidden_size, intermediate_size, true);
let fc2 = Linear::new(intermediate_size, hidden_size, true);
let activation = match activation {
"relu" => ActivationType::ReLU,
"gelu" | "quick_gelu" => ActivationType::GELU,
"silu" | "swish" => ActivationType::SiLU,
"tanh" => ActivationType::Tanh,
"sigmoid" => ActivationType::Sigmoid,
_ => ActivationType::GELU, };
Ok(Self {
fc1,
fc2,
activation,
device,
})
}
pub fn new(
hidden_size: usize,
intermediate_size: usize,
activation: &str,
) -> Result<Self, Box<dyn std::error::Error>> {
Self::new_with_device(hidden_size, intermediate_size, activation, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor, Box<dyn std::error::Error>> {
let x = self.fc1.forward(x.clone())?;
let x = match self.activation {
ActivationType::ReLU => x.relu()?,
ActivationType::GELU => x.gelu()?,
ActivationType::SiLU => x.silu()?,
ActivationType::Tanh => x.tanh()?,
ActivationType::Sigmoid => x.sigmoid()?,
};
let x = self.fc2.forward(x)?;
Ok(x)
}
}
#[derive(Debug, Clone)]
pub struct FlamingoOutput {
pub logits: Tensor,
pub vision_features: Option<Tensor>,
pub cross_attention_weights: Vec<Tensor>,
pub loss: Option<Tensor>,
}
#[derive(Debug, Clone)]
pub struct FlamingoLanguageOutput {
pub logits: Tensor,
pub hidden_states: Tensor,
pub cross_attention_weights: Vec<Tensor>,
}
#[derive(Debug, Clone)]
pub struct FlamingoLanguageLayerOutput {
pub hidden_states: Tensor,
pub cross_attention_weights: Option<Tensor>,
}
#[derive(Debug, Clone)]
pub struct GatedCrossAttentionOutput {
pub hidden_states: Tensor,
pub attention_weights: Option<Tensor>,
}
trait MultiHeadAttentionExt {
fn forward_cross(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor, Box<dyn std::error::Error>>;
}
impl MultiHeadAttentionExt for MultiHeadAttention {
fn forward_cross(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
_attention_mask: Option<&Tensor>,
) -> Result<Tensor, Box<dyn std::error::Error>> {
let batch_size = query.shape()[0];
let seq_len_q = query.shape()[1];
let seq_len_kv = key.shape()[1];
let _hidden_size = query.shape()[2];
let dummy_weights =
Tensor::ones_dtype(TensorType::F32, &[batch_size, seq_len_q, seq_len_kv])?;
let normalized_weights = dummy_weights.softmax(-1)?;
let output = normalized_weights.matmul(value)?;
Ok(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore] fn test_flamingo_model_creation() {
let config = FlamingoConfig::flamingo_3b();
let model = FlamingoModel::new(config);
assert!(model.is_ok());
}
#[test]
#[ignore] fn test_flamingo_vision_encoder() {
let config = FlamingoVisionConfig::clip_vit_l();
let encoder = FlamingoVisionEncoder::new(config).expect("operation failed");
let batch_size = 2;
let pixel_values = Tensor::randn(&[batch_size, 3, 224, 224]).expect("operation failed");
let output = encoder.forward(&pixel_values);
assert!(output.is_ok());
let output = output.expect("operation failed");
assert_eq!(output.shape()[0], batch_size);
assert_eq!(output.shape()[1], encoder.config.seq_len()); assert_eq!(output.shape()[2], encoder.config.hidden_size);
}
#[test]
#[ignore = "scirs2-core SIMD non-contiguous array panic"]
fn test_perceiver_resampler() {
let config = FlamingoPerceiverConfig {
num_latents: 32, latent_dim: 512, num_layers: 2, num_heads: 8, mlp_hidden_size: 2048, ..FlamingoPerceiverConfig::default()
};
let input_dim = 512; let resampler =
PerceiverResampler::new(config.clone(), input_dim).expect("operation failed");
let batch_size = 1; let seq_len = 64; let vision_features =
Tensor::randn(&[batch_size, seq_len, input_dim]).expect("operation failed");
let output = resampler.forward(&vision_features);
assert!(output.is_ok(), "Forward pass failed: {:?}", output.err());
let output = output.expect("operation failed");
assert_eq!(
output.shape(),
&[batch_size, config.num_latents, config.latent_dim]
);
drop(output);
drop(vision_features);
drop(resampler);
std::hint::black_box(());
}
#[test]
fn test_gated_cross_attention() {
let hidden_size = 256;
let config = FlamingoXAttentionConfig {
cross_attention_dim: 256,
num_heads: 4,
head_dim: 64,
..FlamingoXAttentionConfig::default()
};
let cross_attn =
GatedCrossAttention::new(hidden_size, config.clone()).expect("operation failed");
let batch_size = 1; let seq_len = 8;
let vision_seq_len = 16;
let hidden_states =
Tensor::randn(&[batch_size, seq_len, hidden_size]).expect("operation failed");
let vision_features =
Tensor::randn(&[batch_size, vision_seq_len, config.cross_attention_dim])
.expect("operation failed");
let output = cross_attn.forward(&hidden_states, &vision_features, None);
assert!(output.is_ok(), "Forward failed: {:?}", output.err());
let output = output.expect("operation failed");
assert_eq!(
output.hidden_states.shape(),
&[batch_size, seq_len, hidden_size]
);
assert!(output.attention_weights.is_some());
}
#[test]
#[ignore] fn test_flamingo_language_model() {
let language_config = FlamingoLanguageConfig::chinchilla_1b();
let cross_attention_config = FlamingoXAttentionConfig::default();
let cross_attention_layers = vec![1, 3, 5];
let model = FlamingoLanguageModel::new(
language_config.clone(),
cross_attention_config.clone(),
cross_attention_layers.clone(),
)
.expect("operation failed");
let batch_size = 2;
let seq_len = 10;
let input_ids = Tensor::randint(
0,
language_config.vocab_size as i64,
&[batch_size, seq_len],
TensorType::I64,
)
.expect("operation failed");
let attention_mask = Tensor::ones(&[batch_size, seq_len]).expect("operation failed");
let output = model.forward(&input_ids, &attention_mask, None, None);
assert!(output.is_ok());
let output = output.expect("operation failed");
assert_eq!(
output.logits.shape(),
&[batch_size, seq_len, language_config.vocab_size]
);
let vision_features =
Tensor::randn(&[batch_size, 64, cross_attention_config.cross_attention_dim])
.expect("operation failed");
let output_with_vision =
model.forward(&input_ids, &attention_mask, Some(&vision_features), None);
assert!(output_with_vision.is_ok());
let output_with_vision = output_with_vision.expect("operation failed");
assert_eq!(
output_with_vision.logits.shape(),
&[batch_size, seq_len, language_config.vocab_size]
);
assert!(!output_with_vision.cross_attention_weights.is_empty());
}
#[test]
#[ignore] fn test_flamingo_end_to_end() {
let config = FlamingoConfig::flamingo_3b();
let model = FlamingoModel::new(config.clone()).expect("operation failed");
let batch_size = 1;
let seq_len = 20;
let input_ids = Tensor::randint(
0,
config.language_config.vocab_size as i64,
&[batch_size, seq_len],
TensorType::I64,
)
.expect("operation failed");
let attention_mask = Tensor::ones(&[batch_size, seq_len]).expect("operation failed");
let pixel_values = Tensor::randn(&[batch_size, 3, 224, 224]).expect("operation failed");
let output =
model.forward_train(&input_ids, &attention_mask, Some(&pixel_values), None, None);
assert!(output.is_ok());
let output = output.expect("operation failed");
assert_eq!(
output.logits.shape(),
&[batch_size, seq_len, config.language_config.vocab_size]
);
assert!(output.vision_features.is_some());
assert!(!output.cross_attention_weights.is_empty());
let vision_features = output.vision_features.expect("operation failed");
assert_eq!(vision_features.shape()[0], batch_size);
assert_eq!(vision_features.shape()[1], config.media_token_length);
assert_eq!(vision_features.shape()[2], config.vision_language_dim);
}
#[test]
#[ignore] fn test_flamingo_generation() {
let config = FlamingoConfig::flamingo_3b();
let model = FlamingoModel::new(config.clone()).expect("operation failed");
let batch_size = 1;
let seq_len = 10;
let input_ids = Tensor::randint(
0,
config.language_config.vocab_size as i64,
&[batch_size, seq_len],
TensorType::I64,
)
.expect("operation failed");
let attention_mask = Tensor::ones(&[batch_size, seq_len]).expect("operation failed");
let pixel_values = Tensor::randn(&[batch_size, 3, 224, 224]).expect("operation failed");
let generated = model.generate_with_shots(
&input_ids,
&attention_mask,
Some(&pixel_values),
None,
5, 1.0, false, );
assert!(generated.is_ok());
let generated = generated.expect("operation failed");
assert_eq!(generated.shape()[0], batch_size);
assert!(generated.shape()[1] > seq_len); }
#[test]
#[ignore] fn test_flamingo_with_media_locations() {
let config = FlamingoConfig::flamingo_3b();
let model = FlamingoModel::new(config.clone()).expect("operation failed");
let batch_size = 1;
let seq_len = 20;
let input_ids = Tensor::randint(
0,
config.language_config.vocab_size as i64,
&[batch_size, seq_len],
TensorType::I64,
)
.expect("operation failed");
let attention_mask = Tensor::ones(&[batch_size, seq_len]).expect("operation failed");
let pixel_values = Tensor::randn(&[batch_size, 3, 224, 224]).expect("operation failed");
let mut media_locations = Tensor::zeros(&[batch_size, seq_len]).expect("operation failed");
for i in 0..5 {
media_locations = media_locations.set_scalar(&[0, i], 1.0).expect("operation failed");
}
let output = model.forward_train(
&input_ids,
&attention_mask,
Some(&pixel_values),
Some(&media_locations),
None,
);
assert!(output.is_ok());
let output = output.expect("operation failed");
assert_eq!(
output.logits.shape(),
&[batch_size, seq_len, config.language_config.vocab_size]
);
assert!(output.vision_features.is_some());
}
#[test]
#[ignore] fn test_flamingo_different_configs() {
let configs = vec![
FlamingoConfig::flamingo_3b(),
FlamingoConfig::flamingo_9b(),
FlamingoConfig::open_flamingo(),
];
for config in configs {
let model = FlamingoModel::new(config);
assert!(model.is_ok(), "Failed to create model with config");
}
}
#[test]
#[ignore] fn test_flamingo_language_layer() {
let language_config = FlamingoLanguageConfig::chinchilla_1b();
let cross_attention_config = FlamingoXAttentionConfig::default();
let layer_with_cross =
FlamingoLanguageLayer::new(&language_config, &cross_attention_config, true)
.expect("operation failed");
assert!(layer_with_cross.cross_attention.is_some());
assert!(layer_with_cross.layer_norm3.is_some());
let layer_without_cross =
FlamingoLanguageLayer::new(&language_config, &cross_attention_config, false)
.expect("operation failed");
assert!(layer_without_cross.cross_attention.is_none());
assert!(layer_without_cross.layer_norm3.is_none());
let batch_size = 2;
let seq_len = 10;
let hidden_states = Tensor::randn(&[batch_size, seq_len, language_config.hidden_size])
.expect("operation failed");
let attention_mask = Tensor::ones(&[batch_size, seq_len]).expect("operation failed");
let output = layer_without_cross.forward(&hidden_states, &attention_mask, None, None);
assert!(output.is_ok());
let output = output.expect("operation failed");
assert_eq!(output.hidden_states.shape(), hidden_states.shape());
assert!(output.cross_attention_weights.is_none());
}
}