pub mod config;
pub mod model;
pub use config::{
FlamingoConfig, FlamingoLanguageConfig, FlamingoPerceiverConfig, FlamingoVisionConfig,
FlamingoXAttentionConfig,
};
pub use model::{
FlamingoLanguageLayer, FlamingoLanguageLayerOutput, FlamingoLanguageModel,
FlamingoLanguageOutput, FlamingoMLP, FlamingoModel, FlamingoOutput, FlamingoVisionEncoder,
FlamingoVisionLayer, GatedCrossAttention, GatedCrossAttentionOutput, PerceiverLayer,
PerceiverResampler,
};
#[cfg(test)]
mod tests {
use super::*;
use trustformers_core::{Tensor, TensorType};
#[test]
#[ignore] fn test_flamingo_module_imports() {
let _config = FlamingoConfig::default();
let _vision_config = FlamingoVisionConfig::default();
let _language_config = FlamingoLanguageConfig::default();
let _perceiver_config = FlamingoPerceiverConfig::default();
let _x_attention_config = FlamingoXAttentionConfig::default();
let config = FlamingoConfig::flamingo_3b();
let model = FlamingoModel::new(config);
assert!(model.is_ok());
}
#[test]
#[ignore] fn test_flamingo_3b_end_to_end() {
let config = FlamingoConfig::flamingo_3b();
let model = FlamingoModel::new(config.clone()).expect("operation failed");
let batch_size = 1;
let seq_len = 50; let input_ids = Tensor::randint(
0,
config.language_config.vocab_size as i64,
&[batch_size, seq_len],
TensorType::I64,
)
.expect("operation failed");
let attention_mask = Tensor::ones(&[batch_size, seq_len]).expect("operation failed");
let pixel_values = Tensor::randn(&[batch_size, 3, 224, 224]).expect("operation failed");
let mut media_locations = Tensor::zeros(&[batch_size, seq_len]).expect("operation failed");
for &start in &[5, 20, 35] {
for i in 0..5 {
if start + i < seq_len {
media_locations =
media_locations.set_scalar(&[0, start + i], 1.0).expect("operation failed");
}
}
}
let train_output = model.forward_train(
&input_ids,
&attention_mask,
Some(&pixel_values),
Some(&media_locations),
None,
);
assert!(train_output.is_ok());
let train_output = train_output.expect("operation failed");
assert_eq!(
train_output.logits.shape(),
&[batch_size, seq_len, config.language_config.vocab_size]
);
assert!(train_output.vision_features.is_some());
assert!(!train_output.cross_attention_weights.is_empty());
let vision_features = train_output.vision_features.expect("operation failed");
assert_eq!(vision_features.shape()[0], batch_size);
assert_eq!(vision_features.shape()[1], config.media_token_length);
assert_eq!(vision_features.shape()[2], config.vision_language_dim);
let generated = model.generate_with_shots(
&input_ids,
&attention_mask,
Some(&pixel_values),
Some(&media_locations),
10, 1.0, false, );
assert!(generated.is_ok());
let generated = generated.expect("operation failed");
assert_eq!(generated.shape()[0], batch_size);
assert!(generated.shape()[1] > seq_len); assert!(generated.shape()[1] <= seq_len + 10); }
#[test]
#[ignore] fn test_flamingo_9b_configuration() {
let config = FlamingoConfig::flamingo_9b();
let model = FlamingoModel::new(config.clone()).expect("operation failed");
assert_eq!(config.language_config.vocab_size, 32000);
assert_eq!(config.language_config.hidden_size, 4096);
assert_eq!(config.language_config.num_hidden_layers, 32);
assert_eq!(config.vision_language_dim, 4096);
assert_eq!(config.perceiver_config.latent_dim, 4096);
assert_eq!(config.cross_attention_config.cross_attention_dim, 4096);
assert_eq!(config.num_shots, 8);
assert_eq!(config.max_seq_length, 4096);
let batch_size = 1;
let seq_len = 32;
let _input_ids = Tensor::randint(
0,
config.language_config.vocab_size as i64,
&[batch_size, seq_len],
TensorType::I64,
)
.expect("operation failed");
let _attention_mask = Tensor::ones(&[batch_size, seq_len]).expect("operation failed");
let pixel_values = Tensor::randn(&[batch_size, 3, 224, 224]).expect("operation failed");
let vision_features = model.encode_vision(&pixel_values);
assert!(vision_features.is_ok());
let vision_features = vision_features.expect("operation failed");
assert_eq!(vision_features.shape()[0], batch_size);
assert_eq!(vision_features.shape()[1], config.media_token_length);
assert_eq!(vision_features.shape()[2], config.vision_language_dim);
}
#[test]
#[ignore] fn test_flamingo_open_source_variant() {
let config = FlamingoConfig::open_flamingo();
assert_eq!(config.language_config.vocab_size, 50432); assert_eq!(config.language_config.hidden_size, 4096);
assert_eq!(config.language_config.hidden_act, "gelu");
assert_eq!(config.vision_language_dim, 4096);
assert_eq!(config.num_shots, 8);
let model = FlamingoModel::new(config);
assert!(model.is_ok());
}
#[test]
#[ignore] fn test_flamingo_components_separately() {
let config = FlamingoConfig::flamingo_3b();
let vision_encoder = FlamingoVisionEncoder::new(config.vision_config.clone());
assert!(vision_encoder.is_ok());
let language_model = FlamingoLanguageModel::new(
config.language_config.clone(),
config.cross_attention_config.clone(),
config.cross_attention_layers.clone(),
);
assert!(language_model.is_ok());
let perceiver = PerceiverResampler::new(
config.perceiver_config.clone(),
config.vision_config.hidden_size,
);
assert!(perceiver.is_ok());
let cross_attn = GatedCrossAttention::new(
config.language_config.hidden_size,
config.cross_attention_config.clone(),
);
assert!(cross_attn.is_ok());
}
#[test]
#[ignore] fn test_flamingo_cross_attention_layers() {
let config = FlamingoConfig::flamingo_9b();
let model = FlamingoModel::new(config.clone()).expect("operation failed");
assert_eq!(
model.language_model.cross_attention_layers,
config.cross_attention_layers
);
for &layer_idx in &config.cross_attention_layers {
assert!(layer_idx < config.language_config.num_hidden_layers);
}
for (i, layer) in model.language_model.layers.iter().enumerate() {
if config.cross_attention_layers.contains(&i) {
assert!(layer.cross_attention.is_some());
assert!(layer.layer_norm3.is_some());
} else {
assert!(layer.cross_attention.is_none());
assert!(layer.layer_norm3.is_none());
}
}
}
#[test]
fn test_flamingo_perceiver_functionality() {
let config = FlamingoPerceiverConfig {
num_latents: 8, latent_dim: 128, num_layers: 1, num_heads: 2, mlp_hidden_size: 256, ..FlamingoPerceiverConfig::default()
};
let input_dim = 128; let perceiver =
PerceiverResampler::new(config.clone(), input_dim).expect("operation failed");
let batch_size = 1;
let input_seq_len = 16; let vision_features =
Tensor::randn(&[batch_size, input_seq_len, input_dim]).expect("operation failed");
let output = perceiver.forward(&vision_features).expect("operation failed");
assert_eq!(
output.shape(),
&[batch_size, config.num_latents, config.latent_dim]
);
assert!(config.num_latents < input_seq_len);
drop(output);
drop(vision_features);
drop(perceiver);
std::hint::black_box(());
}
#[test]
fn test_flamingo_config_serialization() {
let configs = vec![
FlamingoConfig::flamingo_3b(),
FlamingoConfig::flamingo_9b(),
FlamingoConfig::open_flamingo(),
];
for config in configs {
let json = serde_json::to_string(&config);
assert!(json.is_ok());
let json_str = json.expect("operation failed");
let deserialized: Result<FlamingoConfig, _> = serde_json::from_str(&json_str);
assert!(deserialized.is_ok());
let deserialized = deserialized.expect("operation failed");
assert_eq!(config.media_token_length, deserialized.media_token_length);
assert_eq!(config.vision_language_dim, deserialized.vision_language_dim);
assert_eq!(
config.use_gated_cross_attention,
deserialized.use_gated_cross_attention
);
assert_eq!(
config.language_config.vocab_size,
deserialized.language_config.vocab_size
);
assert_eq!(
config.vision_config.hidden_size,
deserialized.vision_config.hidden_size
);
assert_eq!(
config.perceiver_config.num_latents,
deserialized.perceiver_config.num_latents
);
assert_eq!(
config.cross_attention_config.cross_attention_dim,
deserialized.cross_attention_config.cross_attention_dim
);
assert_eq!(
config.cross_attention_layers,
deserialized.cross_attention_layers
);
}
}
#[test]
#[ignore] fn test_flamingo_few_shot_simulation() {
let config = FlamingoConfig::flamingo_3b();
let model = FlamingoModel::new(config.clone()).expect("operation failed");
let batch_size = 1;
let num_examples = 3;
let text_per_example = 15;
let media_tokens_per_example = 5;
let seq_len = num_examples * (text_per_example + media_tokens_per_example) + 10;
let input_ids = Tensor::randint(
0,
config.language_config.vocab_size as i64,
&[batch_size, seq_len],
TensorType::I64,
)
.expect("operation failed");
let attention_mask = Tensor::ones(&[batch_size, seq_len]).expect("operation failed");
let pixel_values = Tensor::randn(&[batch_size, 3, 224, 224]).expect("operation failed");
let mut media_locations = Tensor::zeros(&[batch_size, seq_len]).expect("operation failed");
for example in 0..num_examples {
let start_pos = example * (text_per_example + media_tokens_per_example);
for i in 0..media_tokens_per_example {
if start_pos + i < seq_len {
media_locations = media_locations
.set_scalar(&[0, start_pos + i], 1.0)
.expect("operation failed");
}
}
}
let output = model.forward_train(
&input_ids,
&attention_mask,
Some(&pixel_values),
Some(&media_locations),
None,
);
assert!(output.is_ok());
let output = output.expect("operation failed");
assert_eq!(
output.logits.shape(),
&[batch_size, seq_len, config.language_config.vocab_size]
);
assert!(output.vision_features.is_some());
assert!(!output.cross_attention_weights.is_empty());
let generated = model.generate_with_shots(
&input_ids,
&attention_mask,
Some(&pixel_values),
Some(&media_locations),
15, 0.8, true, );
assert!(generated.is_ok());
let generated = generated.expect("operation failed");
assert!(generated.shape()[1] > seq_len);
}
#[test]
fn test_flamingo_gating_mechanisms() {
let mut config = FlamingoXAttentionConfig::default();
let hidden_size = 2048;
let gating_types = vec!["tanh", "sigmoid", "relu"];
for gating_type in gating_types {
config.gating_type = gating_type.to_string();
let cross_attn =
GatedCrossAttention::new(hidden_size, config.clone()).expect("operation failed");
let batch_size = 1;
let seq_len = 10;
let vision_seq_len = 64;
let hidden_states =
Tensor::randn(&[batch_size, seq_len, hidden_size]).expect("operation failed");
let vision_features =
Tensor::randn(&[batch_size, vision_seq_len, config.cross_attention_dim])
.expect("operation failed");
let output = cross_attn.forward(&hidden_states, &vision_features, None);
assert!(output.is_ok(), "Gating type {} failed", gating_type);
let output = output.expect("operation failed");
assert_eq!(
output.hidden_states.shape(),
&[batch_size, seq_len, config.cross_attention_dim]
);
}
}
}