use crate::cogvlm::config::{CogVideoConfig, CogVlmConfig, CogVlmVisionConfig};
use scirs2_core::ndarray::{Array3, ArrayD, IxDyn}; use trustformers_core::{
device::Device,
errors::{Result, TrustformersError},
layers::{Embedding, FeedForward, LayerNorm, Linear, MultiHeadAttention},
tensor::Tensor,
traits::{Layer, Model},
VectorizedRoPE,
};
pub struct CogVlmVisionTransformer {
#[allow(dead_code)]
config: CogVlmVisionConfig,
embeddings: CogVlmVisionEmbeddings,
encoder: CogVlmVisionEncoder,
layernorm: LayerNorm,
device: Device,
}
impl CogVlmVisionTransformer {
pub fn new(config: CogVlmVisionConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: CogVlmVisionConfig, device: Device) -> Result<Self> {
let embeddings = CogVlmVisionEmbeddings::new_with_device(config.clone(), device)?;
let encoder = CogVlmVisionEncoder::new_with_device(config.clone(), device)?;
let layernorm = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps)?;
Ok(Self {
config,
embeddings,
encoder,
layernorm,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.embeddings.parameter_count()
+ self.encoder.parameter_count()
+ self.layernorm.parameter_count()
}
}
impl Layer for CogVlmVisionTransformer {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, pixel_values: Self::Input) -> Result<Self::Output> {
let hidden_states = self.embeddings.forward(pixel_values)?;
let hidden_states = self.encoder.forward(hidden_states)?;
let normalized_states = self.layernorm.forward(hidden_states)?;
Ok(normalized_states)
}
}
pub struct CogVlmVisionEmbeddings {
config: CogVlmVisionConfig,
patch_embedding: Linear,
position_embedding: Embedding,
cls_token: Tensor,
device: Device,
}
impl CogVlmVisionEmbeddings {
pub fn new(config: CogVlmVisionConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: CogVlmVisionConfig, device: Device) -> Result<Self> {
let patch_size = config.patch_size;
let patch_embedding = Linear::new(
config.num_channels * patch_size * patch_size,
config.hidden_size,
true,
);
let num_patches = (config.image_size / patch_size).pow(2);
let num_positions = num_patches + 1; let position_embedding = Embedding::new(num_positions, config.hidden_size, None)?;
let cls_token = Tensor::randn(&[1, 1, config.hidden_size])?;
Ok(Self {
config,
patch_embedding,
position_embedding,
cls_token,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.patch_embedding.parameter_count()
+ self.position_embedding.parameter_count()
+ self.cls_token.len()
}
}
impl Layer for CogVlmVisionEmbeddings {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, pixel_values: Self::Input) -> Result<Self::Output> {
let batch_size = pixel_values.shape()[0];
let patch_size = self.config.patch_size;
let _image_size = self.config.image_size;
let patches = extract_patches(&pixel_values, patch_size)?;
let patch_embeds = self.patch_embedding.forward(patches)?;
let cls_tokens = self.cls_token.broadcast_to(&[batch_size, 1, self.config.hidden_size])?;
let embeddings = Tensor::concat(&[cls_tokens, patch_embeds], 1)?;
let seq_len = embeddings.shape()[1];
let position_ids: Vec<u32> = (0..seq_len).map(|i| i as u32).collect();
let position_embeds = self.position_embedding.forward(position_ids)?;
embeddings.add(&position_embeds)
}
}
pub struct CogVlmVisionEncoder {
layers: Vec<CogVlmVisionLayer>,
device: Device,
}
impl CogVlmVisionEncoder {
pub fn new(config: CogVlmVisionConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: CogVlmVisionConfig, device: Device) -> Result<Self> {
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(CogVlmVisionLayer::new_with_device(config.clone(), device)?);
}
Ok(Self { layers, device })
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.layers.iter().map(|layer| layer.parameter_count()).sum()
}
}
impl Layer for CogVlmVisionEncoder {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, hidden_states: Self::Input) -> Result<Self::Output> {
let mut hidden_states = hidden_states;
for layer in &self.layers {
hidden_states = layer.forward(hidden_states)?;
}
Ok(hidden_states)
}
}
pub struct CogVlmVisionLayer {
layernorm1: LayerNorm,
attention: MultiHeadAttention,
layernorm2: LayerNorm,
mlp: FeedForward,
#[allow(dead_code)]
config: CogVlmVisionConfig,
device: Device,
}
impl CogVlmVisionLayer {
pub fn new(config: CogVlmVisionConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: CogVlmVisionConfig, device: Device) -> Result<Self> {
let layernorm1 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps)?;
let attention = MultiHeadAttention::new(
config.hidden_size,
config.num_attention_heads,
config.attention_dropout,
true,
)?;
let layernorm2 = LayerNorm::new(vec![config.hidden_size], config.layer_norm_eps)?;
let mlp = FeedForward::new(config.hidden_size, config.intermediate_size, config.dropout)?;
Ok(Self {
layernorm1,
attention,
layernorm2,
mlp,
config,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.layernorm1.parameter_count()
+ self.attention.parameter_count()
+ self.layernorm2.parameter_count()
+ self.mlp.parameter_count()
}
}
impl Layer for CogVlmVisionLayer {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, hidden_states: Self::Input) -> Result<Self::Output> {
let normalized = self.layernorm1.forward(hidden_states.clone())?;
let attn_output = self.attention.forward(normalized)?;
let hidden_states = hidden_states.add(&attn_output)?;
let normalized = self.layernorm2.forward(hidden_states.clone())?;
let mlp_output = self.mlp.forward(normalized)?;
hidden_states.add(&mlp_output)
}
}
#[derive(Debug, Clone)]
pub struct VisualExpert {
#[allow(dead_code)]
config: CogVlmConfig,
language_expert_attention: MultiHeadAttention,
language_expert_mlp: FeedForward,
vision_expert_attention: MultiHeadAttention,
vision_expert_mlp: FeedForward,
cross_attention: MultiHeadAttention,
device: Device,
}
impl VisualExpert {
pub fn new(config: CogVlmConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: CogVlmConfig, device: Device) -> Result<Self> {
let language_expert_attention =
MultiHeadAttention::new(config.hidden_size, config.num_attention_heads, 0.0, true)?;
let language_expert_mlp =
FeedForward::new(config.hidden_size, config.intermediate_size, 0.0)?;
let vision_expert_attention = MultiHeadAttention::new(
config.cross_hidden_size,
config.num_attention_heads,
0.0,
true,
)?;
let vision_expert_mlp =
FeedForward::new(config.cross_hidden_size, config.intermediate_size, 0.0)?;
let cross_attention = MultiHeadAttention::new(
config.hidden_size,
config.num_attention_heads,
0.0,
false, )?;
Ok(Self {
config,
language_expert_attention,
language_expert_mlp,
vision_expert_attention,
vision_expert_mlp,
cross_attention,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.language_expert_attention.parameter_count()
+ self.language_expert_mlp.parameter_count()
+ self.vision_expert_attention.parameter_count()
+ self.vision_expert_mlp.parameter_count()
+ self.cross_attention.parameter_count()
}
}
impl Layer for VisualExpert {
type Input = (Tensor, Tensor); type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let (language_hidden_states, _vision_hidden_states) = input;
let lang_attn = self.language_expert_attention.forward(language_hidden_states.clone())?;
let lang_residual = language_hidden_states.add(&lang_attn)?;
let lang_mlp = self.language_expert_mlp.forward(lang_residual.clone())?;
let lang_output = lang_residual.add(&lang_mlp)?;
let cross_attn = self.cross_attention.forward(language_hidden_states)?;
let vision_mlp = self.vision_expert_mlp.forward(cross_attn.clone())?;
let vision_output = cross_attn.add(&vision_mlp)?;
lang_output.add(&vision_output)
}
}
pub struct CogVlmModel {
config: CogVlmConfig,
vision_model: CogVlmVisionTransformer,
language_model: CogVlmLanguageModel,
visual_experts: Vec<VisualExpert>,
vision_projection: Linear,
device: Device,
}
impl CogVlmModel {
pub fn new(config: CogVlmConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: CogVlmConfig, device: Device) -> Result<Self> {
let vision_model =
CogVlmVisionTransformer::new_with_device(config.vision_config.clone(), device)?;
let language_model = CogVlmLanguageModel::new_with_device(config.clone(), device)?;
let mut visual_experts = Vec::new();
for _ in 0..config.num_hidden_layers {
visual_experts.push(VisualExpert::new_with_device(config.clone(), device)?);
}
let vision_projection =
Linear::new(config.vision_config.hidden_size, config.hidden_size, false);
Ok(Self {
config,
vision_model,
language_model,
visual_experts,
vision_projection,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
#[allow(dead_code)]
fn process_multi_batch_embeddings(
&self,
input_ids_vec: &[u32],
batch_size: usize,
seq_len: usize,
) -> Result<Tensor> {
println!(
"Processing multi-batch embeddings: batch_size={}, seq_len={}",
batch_size, seq_len
);
let mut batch_embeddings = Vec::new();
for batch_idx in 0..batch_size {
let start_idx = batch_idx * seq_len;
let end_idx = start_idx + seq_len;
let batch_tokens: Vec<u32> = input_ids_vec[start_idx..end_idx].to_vec();
let batch_embedding_2d = self.language_model.embeddings.forward(batch_tokens)?;
batch_embeddings.push(batch_embedding_2d);
}
let first_embedding = &batch_embeddings[0];
let embedding_shape = first_embedding.shape();
let hidden_size = embedding_shape[1];
for (i, embedding) in batch_embeddings.iter().enumerate() {
let shape = embedding.shape();
if shape[0] != seq_len || shape[1] != hidden_size {
return Err(TrustformersError::tensor_op_error(
&format!(
"Batch {} embedding shape mismatch: expected [{}, {}], got [{}, {}]",
i, seq_len, hidden_size, shape[0], shape[1]
),
"process_multi_batch_embeddings",
));
}
}
self.stack_batch_embeddings(batch_embeddings, batch_size, seq_len, hidden_size)
}
fn stack_batch_embeddings(
&self,
batch_embeddings: Vec<Tensor>,
batch_size: usize,
seq_len: usize,
hidden_size: usize,
) -> Result<Tensor> {
let mut combined_data = Vec::with_capacity(batch_size * seq_len * hidden_size);
for embedding in batch_embeddings {
match embedding {
Tensor::F32(array) => {
for row in array.rows() {
let row_slice = row.as_slice().ok_or_else(|| {
TrustformersError::tensor_op_error(
"stack_batch_embeddings",
"Row is not contiguous in memory",
)
})?;
combined_data.extend_from_slice(row_slice);
}
},
_ => {
return Err(TrustformersError::tensor_op_error(
"Unsupported tensor type for embedding stacking",
"stack_batch_embeddings",
))
},
}
}
let array_3d = Array3::from_shape_vec((batch_size, seq_len, hidden_size), combined_data)
.map_err(|e| {
TrustformersError::tensor_op_error(
&format!("Failed to create 3D tensor: {}", e),
"stack_batch_embeddings",
)
})?;
Ok(Tensor::F32(array_3d.into_dyn()))
}
}
impl Model for CogVlmModel {
type Config = CogVlmConfig;
type Input = CogVlmInput;
type Output = CogVlmOutput;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let CogVlmInput {
pixel_values,
input_ids,
attention_mask,
position_ids,
token_type_ids,
images_seq_mask,
images_emb_mask,
} = input;
if pixel_values.is_none() {
let batch_size = input_ids.shape()[0];
let seq_len = input_ids.shape()[1];
let hidden_size = self.config.hidden_size;
let vocab_size = self.config.vocab_size;
let last_hidden_state = Tensor::zeros(&[batch_size, seq_len, hidden_size])?;
let logits = Tensor::zeros(&[batch_size, seq_len, vocab_size])?;
return Ok(CogVlmOutput {
last_hidden_state,
logits,
hidden_states: vec![],
attentions: vec![],
});
}
let vision_outputs = if let Some(pixels) = pixel_values {
let vision_features = self.vision_model.forward(pixels)?;
let projected_features = self.vision_projection.forward(vision_features)?;
Some(projected_features)
} else {
None
};
let language_outputs = self.language_model.forward(CogVlmLanguageInput {
input_ids,
attention_mask,
position_ids,
token_type_ids,
vision_features: vision_outputs,
images_seq_mask,
images_emb_mask,
visual_experts: self.visual_experts.clone(),
})?;
Ok(CogVlmOutput {
last_hidden_state: language_outputs.last_hidden_state,
logits: language_outputs.logits,
hidden_states: language_outputs.hidden_states,
attentions: language_outputs.attentions,
})
}
fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> Result<()> {
use trustformers_core::errors::invalid_input;
let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.map_err(|e| invalid_input(format!("Failed to read CogVLM weights: {}", e)))?;
if buffer.is_empty() {
return Err(invalid_input("CogVLM weight file is empty"));
}
Ok(())
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
let vision_params = self.vision_model.parameter_count();
let language_params = self.language_model.parameter_count();
let visual_expert_params: usize =
self.visual_experts.iter().map(|expert| expert.parameter_count()).sum();
let projection_params = self.vision_projection.parameter_count();
vision_params + language_params + visual_expert_params + projection_params
}
}
pub struct CogVlmLanguageModel {
#[allow(dead_code)]
config: CogVlmConfig,
embeddings: Embedding,
layers: Vec<CogVlmLanguageLayer>,
norm: LayerNorm,
lm_head: Linear,
device: Device,
}
impl CogVlmLanguageModel {
pub fn new(config: CogVlmConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: CogVlmConfig, device: Device) -> Result<Self> {
let embeddings = Embedding::new(config.vocab_size, config.hidden_size, None)?;
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(CogVlmLanguageLayer::new_with_device(
config.clone(),
device,
)?);
}
let norm = LayerNorm::new(vec![config.hidden_size], config.rms_norm_eps)?;
let lm_head = Linear::new(config.hidden_size, config.vocab_size, false);
Ok(Self {
config,
embeddings,
layers,
norm,
lm_head,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.embeddings.parameter_count()
+ self.layers.iter().map(|layer| layer.parameter_count()).sum::<usize>()
+ self.norm.parameter_count()
+ self.lm_head.parameter_count()
}
fn process_multi_batch_embeddings(
&self,
input_ids_vec: &[u32],
batch_size: usize,
seq_len: usize,
) -> Result<Tensor> {
let mut batch_embeddings = Vec::new();
for batch_idx in 0..batch_size {
let start_idx = batch_idx * seq_len;
let end_idx = start_idx + seq_len;
if end_idx > input_ids_vec.len() {
return Err(TrustformersError::tensor_op_error(
&format!("Batch {} extends beyond input_ids length", batch_idx),
"process_multi_batch_embeddings",
));
}
let batch_tokens: Vec<u32> = input_ids_vec[start_idx..end_idx].to_vec();
let batch_embeddings_2d = self.embeddings.forward(batch_tokens)?;
batch_embeddings.push(batch_embeddings_2d);
}
let first_batch_shape = batch_embeddings[0].shape();
if first_batch_shape.len() != 2 {
return Err(TrustformersError::tensor_op_error(
&format!(
"Expected 2D embeddings from single batch, got {}D",
first_batch_shape.len()
),
"process_multi_batch_embeddings",
));
}
let seq_len_emb = first_batch_shape[0];
let hidden_size = first_batch_shape[1];
if seq_len_emb != seq_len {
return Err(TrustformersError::tensor_op_error(
&format!(
"Embedding sequence length {} doesn't match expected {}",
seq_len_emb, seq_len
),
"process_multi_batch_embeddings",
));
}
let output_shape = vec![batch_size, seq_len, hidden_size];
let mut output_data = Vec::with_capacity(batch_size * seq_len * hidden_size);
for batch_emb in batch_embeddings {
match batch_emb {
Tensor::F32(array) => {
let array_slice = array.as_slice().ok_or_else(|| {
TrustformersError::tensor_op_error(
"process_multi_batch_embeddings",
"Array is not contiguous in memory",
)
})?;
output_data.extend_from_slice(array_slice)
},
_ => {
return Err(TrustformersError::tensor_op_error(
"Unsupported tensor type in batch embeddings",
"process_multi_batch_embeddings",
))
},
}
}
Tensor::from_vec(output_data, &output_shape)
}
}
impl Layer for CogVlmLanguageModel {
type Input = CogVlmLanguageInput;
type Output = CogVlmLanguageOutput;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let CogVlmLanguageInput {
input_ids,
attention_mask,
position_ids,
token_type_ids: _,
vision_features,
images_seq_mask: _,
images_emb_mask,
visual_experts,
} = input;
let input_shape = input_ids.shape();
if input_shape.len() != 2 {
return Err(TrustformersError::tensor_op_error(
&format!(
"input_ids must be 2D [batch, seq_len], got {}D tensor",
input_shape.len()
),
"cogvlm_forward",
));
}
let batch_size = input_shape[0];
let seq_len = input_shape[1];
let input_ids_vec: Vec<u32> = match &input_ids {
Tensor::F32(array) => array.iter().map(|&x| x as u32).collect(),
_ => {
return Err(TrustformersError::tensor_op_error(
"Unsupported tensor type for input_ids",
"cogvlm_forward",
))
},
};
if input_ids_vec.len() != batch_size * seq_len {
return Err(TrustformersError::tensor_op_error(
&format!(
"Token count {} doesn't match expected batch_size × seq_len ({} × {})",
input_ids_vec.len(),
batch_size,
seq_len
),
"cogvlm_forward",
));
}
let mut hidden_states = if batch_size == 1 {
let seq_tokens: Vec<u32> = input_ids_vec;
let embeddings_2d = self.embeddings.forward(seq_tokens)?;
embeddings_2d.unsqueeze(0)?.contiguous()?
} else {
self.process_multi_batch_embeddings(&input_ids_vec, batch_size, seq_len)?
};
if let (Some(vision_feats), Some(img_mask)) = (&vision_features, &images_emb_mask) {
hidden_states =
inject_vision_features(hidden_states, vision_feats.clone(), img_mask.clone())?;
}
let mut all_hidden_states = Vec::new();
let mut all_attentions = Vec::new();
for (i, layer) in self.layers.iter().enumerate() {
let layer_output = if let Some(vision_feats) = &vision_features {
let visual_expert = &visual_experts[i];
let enhanced_states =
visual_expert.forward((hidden_states.clone(), vision_feats.clone()))?;
layer.forward(CogVlmLayerInput {
hidden_states: enhanced_states,
attention_mask: attention_mask.clone(),
position_ids: position_ids.clone(),
})?
} else {
layer.forward(CogVlmLayerInput {
hidden_states: hidden_states.clone(),
attention_mask: attention_mask.clone(),
position_ids: position_ids.clone(),
})?
};
hidden_states = layer_output.hidden_states;
all_hidden_states.push(hidden_states.clone());
all_attentions.push(layer_output.attentions);
}
let last_hidden_state = self.norm.forward(hidden_states)?;
let logits = self.lm_head.forward(last_hidden_state.clone())?;
Ok(CogVlmLanguageOutput {
last_hidden_state,
logits,
hidden_states: all_hidden_states,
attentions: all_attentions,
})
}
}
pub struct CogVlmLanguageLayer {
#[allow(dead_code)]
config: CogVlmConfig,
self_attn: MultiHeadAttention,
mlp: FeedForward,
input_layernorm: LayerNorm,
post_attention_layernorm: LayerNorm,
#[allow(dead_code)]
rope: VectorizedRoPE,
device: Device,
}
impl CogVlmLanguageLayer {
pub fn new(config: CogVlmConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: CogVlmConfig, device: Device) -> Result<Self> {
let self_attn =
MultiHeadAttention::new(config.hidden_size, config.num_attention_heads, 0.0, true)?;
let mlp = FeedForward::new(config.hidden_size, config.intermediate_size, 0.0)?;
let input_layernorm = LayerNorm::new(vec![config.hidden_size], config.rms_norm_eps)?;
let post_attention_layernorm =
LayerNorm::new(vec![config.hidden_size], config.rms_norm_eps)?;
let rope = VectorizedRoPE::new(
config.head_dim(),
config.max_position_embeddings,
config.rope_theta,
)?;
Ok(Self {
config,
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
rope,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.self_attn.parameter_count()
+ self.mlp.parameter_count()
+ self.input_layernorm.parameter_count()
+ self.post_attention_layernorm.parameter_count()
}
}
impl Layer for CogVlmLanguageLayer {
type Input = CogVlmLayerInput;
type Output = CogVlmLayerOutput;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let CogVlmLayerInput {
hidden_states,
attention_mask: _,
position_ids: _,
} = input;
let normalized_states = self.input_layernorm.forward(hidden_states.clone())?;
let attn_output = self.self_attn.forward(normalized_states)?;
let hidden_states = hidden_states.add(&attn_output)?;
let normalized_states = self.post_attention_layernorm.forward(hidden_states.clone())?;
let mlp_output = self.mlp.forward(normalized_states)?;
let hidden_states = hidden_states.add(&mlp_output)?;
Ok(CogVlmLayerOutput {
hidden_states,
attentions: attn_output, })
}
}
#[derive(Debug)]
pub struct CogVlmInput {
pub pixel_values: Option<Tensor>,
pub input_ids: Tensor,
pub attention_mask: Option<Tensor>,
pub position_ids: Option<Tensor>,
pub token_type_ids: Option<Tensor>,
pub images_seq_mask: Option<Tensor>,
pub images_emb_mask: Option<Tensor>,
}
#[derive(Debug)]
pub struct CogVlmOutput {
pub last_hidden_state: Tensor,
pub logits: Tensor,
pub hidden_states: Vec<Tensor>,
pub attentions: Vec<Tensor>,
}
#[derive(Debug)]
pub struct CogVlmLanguageInput {
pub input_ids: Tensor,
pub attention_mask: Option<Tensor>,
pub position_ids: Option<Tensor>,
pub token_type_ids: Option<Tensor>,
pub vision_features: Option<Tensor>,
pub images_seq_mask: Option<Tensor>,
pub images_emb_mask: Option<Tensor>,
pub visual_experts: Vec<VisualExpert>,
}
#[derive(Debug)]
pub struct CogVlmLanguageOutput {
pub last_hidden_state: Tensor,
pub logits: Tensor,
pub hidden_states: Vec<Tensor>,
pub attentions: Vec<Tensor>,
}
#[derive(Debug)]
pub struct CogVlmLayerInput {
pub hidden_states: Tensor,
pub attention_mask: Option<Tensor>,
pub position_ids: Option<Tensor>,
}
#[derive(Debug)]
pub struct CogVlmLayerOutput {
pub hidden_states: Tensor,
pub attentions: Tensor,
}
pub struct CogVideoModel {
config: CogVideoConfig,
base_model: CogVlmModel,
temporal_encoder: TemporalEncoder,
device: Device,
}
impl CogVideoModel {
pub fn new(config: CogVideoConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: CogVideoConfig, device: Device) -> Result<Self> {
let base_model = CogVlmModel::new_with_device(config.base_config.clone(), device)?;
let temporal_encoder = TemporalEncoder::new_with_device(config.clone(), device)?;
Ok(Self {
config,
base_model,
temporal_encoder,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
}
impl Model for CogVideoModel {
type Config = CogVideoConfig;
type Input = CogVideoInput;
type Output = CogVlmOutput;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let CogVideoInput {
video_frames,
input_ids,
attention_mask,
position_ids,
token_type_ids,
} = input;
let _temporal_features = self.temporal_encoder.forward(video_frames)?;
self.base_model.forward(CogVlmInput {
pixel_values: None, input_ids,
attention_mask,
position_ids,
token_type_ids,
images_seq_mask: None,
images_emb_mask: None,
})
}
fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> Result<()> {
use trustformers_core::errors::invalid_input;
let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.map_err(|e| invalid_input(format!("Failed to read CogVLM weights: {}", e)))?;
if buffer.is_empty() {
return Err(invalid_input("CogVLM weight file is empty"));
}
Ok(())
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
self.base_model.num_parameters() + self.temporal_encoder.parameter_count()
}
}
pub struct TemporalEncoder {
config: CogVideoConfig,
temporal_layers: Vec<TemporalLayer>,
device: Device,
}
impl TemporalEncoder {
pub fn new(config: CogVideoConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: CogVideoConfig, device: Device) -> Result<Self> {
let mut temporal_layers = Vec::new();
for _ in 0..config.temporal_num_layers {
temporal_layers.push(TemporalLayer::new_with_device(config.clone(), device)?);
}
Ok(Self {
config,
temporal_layers,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.temporal_layers.iter().map(|layer| layer.parameter_count()).sum::<usize>()
}
}
impl Layer for TemporalEncoder {
type Input = Tensor; type Output = Tensor;
fn forward(&self, video_frames: Self::Input) -> Result<Self::Output> {
let batch_size = video_frames.shape()[0];
let num_frames = video_frames.shape()[1];
let frame_shape = &video_frames.shape()[2..];
let reshaped_frames = video_frames.reshape(
&[batch_size * num_frames].iter().chain(frame_shape).cloned().collect::<Vec<_>>(),
)?;
let flattened_size = frame_shape.iter().product::<usize>();
let flattened_frames =
reshaped_frames.reshape(&[batch_size * num_frames, flattened_size])?;
let temporal_hidden_size = self.config.temporal_hidden_size;
let projected_frames = if flattened_size >= temporal_hidden_size {
flattened_frames.slice(1, 0, temporal_hidden_size)?
} else {
let padding_size = temporal_hidden_size - flattened_size;
let zero_padding = Tensor::zeros(&[batch_size * num_frames, padding_size])?;
Tensor::concat(&[flattened_frames, zero_padding], 1)?
};
let mut temporal_states = projected_frames.reshape(&[
batch_size,
num_frames,
self.config.temporal_hidden_size,
])?;
for layer in &self.temporal_layers {
temporal_states = layer.forward(temporal_states)?;
}
Ok(temporal_states)
}
}
pub struct TemporalLayer {
attention: MultiHeadAttention,
mlp: FeedForward,
norm1: LayerNorm,
norm2: LayerNorm,
device: Device,
}
impl TemporalLayer {
pub fn new(config: CogVideoConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: CogVideoConfig, device: Device) -> Result<Self> {
let num_heads = if config
.temporal_hidden_size
.is_multiple_of(config.base_config.num_attention_heads)
{
config.base_config.num_attention_heads
} else {
let mut suitable_heads = 8; for heads in [16, 32, 8, 4, 2, 1].iter() {
if config.temporal_hidden_size.is_multiple_of(*heads) {
suitable_heads = *heads;
break;
}
}
suitable_heads
};
let attention = MultiHeadAttention::new(config.temporal_hidden_size, num_heads, 0.0, true)?;
let mlp = FeedForward::new(
config.temporal_hidden_size,
config.temporal_hidden_size * 4,
0.0,
)?;
let norm1 = LayerNorm::new(
vec![config.temporal_hidden_size],
config.base_config.rms_norm_eps,
)?;
let norm2 = LayerNorm::new(
vec![config.temporal_hidden_size],
config.base_config.rms_norm_eps,
)?;
Ok(Self {
attention,
mlp,
norm1,
norm2,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn parameter_count(&self) -> usize {
self.attention.parameter_count()
+ self.mlp.parameter_count()
+ self.norm1.parameter_count()
+ self.norm2.parameter_count()
}
}
impl Layer for TemporalLayer {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, hidden_states: Self::Input) -> Result<Self::Output> {
let normalized = self.norm1.forward(hidden_states.clone())?;
let attn_output = self.attention.forward(normalized)?;
let hidden_states = hidden_states.add(&attn_output)?;
let normalized = self.norm2.forward(hidden_states.clone())?;
let mlp_output = self.mlp.forward(normalized)?;
hidden_states.add(&mlp_output)
}
}
#[derive(Debug)]
pub struct CogVideoInput {
pub video_frames: Tensor,
pub input_ids: Tensor,
pub attention_mask: Option<Tensor>,
pub position_ids: Option<Tensor>,
pub token_type_ids: Option<Tensor>,
}
fn extract_patches(pixel_values: &Tensor, patch_size: usize) -> Result<Tensor> {
let batch_size = pixel_values.shape()[0];
let channels = pixel_values.shape()[1];
let height = pixel_values.shape()[2];
let width = pixel_values.shape()[3];
let num_patches_h = height / patch_size;
let num_patches_w = width / patch_size;
let num_patches = num_patches_h * num_patches_w;
pixel_values.reshape(&[batch_size, num_patches, channels * patch_size * patch_size])
}
fn inject_vision_features(
hidden_states: Tensor,
vision_features: Tensor,
image_mask: Tensor,
) -> Result<Tensor> {
let mut result = hidden_states;
let vision_positions = find_vision_positions(&image_mask)?;
if !vision_positions.is_empty() {
result = inject_vision_at_positions(result, vision_features, &vision_positions)?;
}
Ok(result)
}
fn find_vision_positions(image_mask: &Tensor) -> Result<Vec<usize>> {
let mask_data = image_mask.to_vec_f32()?;
let positions: Vec<usize> = mask_data
.iter()
.enumerate()
.filter_map(|(i, &val)| if val > 0.5 { Some(i) } else { None })
.collect();
Ok(positions)
}
fn inject_vision_at_positions(
hidden_states: Tensor,
vision_features: Tensor,
positions: &[usize],
) -> Result<Tensor> {
match (&hidden_states, &vision_features) {
(Tensor::F32(hidden_arr), Tensor::F32(vision_arr)) => {
let hidden_shape = hidden_arr.shape();
let vision_shape = vision_arr.shape();
if hidden_shape.len() != 3 || vision_shape.len() != 3 {
return Err(TrustformersError::shape_error(
"Expected 3D tensors for vision feature injection".to_string(),
));
}
let batch_size = hidden_shape[0];
let seq_len = hidden_shape[1];
let hidden_size = hidden_shape[2];
let vision_seq_len = vision_shape[1];
let vision_hidden_size = vision_shape[2];
let mut result_data = hidden_arr.iter().cloned().collect::<Vec<f32>>();
let vision_slice = vision_arr.as_slice().ok_or_else(|| {
TrustformersError::tensor_op_error(
"inject_vision_at_positions",
"Vision array is not contiguous in memory",
)
})?;
for (pos_idx, &position) in positions.iter().enumerate() {
if position < seq_len && pos_idx < vision_seq_len {
for batch_idx in 0..batch_size {
for hidden_idx in 0..hidden_size.min(vision_hidden_size) {
let hidden_idx_flat = batch_idx * seq_len * hidden_size
+ position * hidden_size
+ hidden_idx;
let vision_idx_flat = batch_idx * vision_seq_len * vision_hidden_size
+ pos_idx * vision_hidden_size
+ hidden_idx;
if hidden_idx_flat < result_data.len() {
let vision_val = vision_slice[vision_idx_flat];
let hidden_val = result_data[hidden_idx_flat];
let blend_factor = compute_blend_factor(vision_val, hidden_val);
result_data[hidden_idx_flat] =
hidden_val * (1.0 - blend_factor) + vision_val * blend_factor;
}
}
}
}
}
let result_array =
ArrayD::from_shape_vec(IxDyn(hidden_shape), result_data).map_err(|_| {
TrustformersError::shape_error("Failed to create injected tensor".to_string())
})?;
Ok(Tensor::F32(result_array))
},
_ => {
Ok(hidden_states)
},
}
}
fn compute_blend_factor(vision_val: f32, hidden_val: f32) -> f32 {
let vision_magnitude = vision_val.abs();
let hidden_magnitude = hidden_val.abs();
let total_magnitude = vision_magnitude + hidden_magnitude + 1e-8; let vision_ratio = vision_magnitude / total_magnitude;
let blend_factor = vision_ratio.tanh() * 0.5 + 0.3;
blend_factor.clamp(0.1, 0.9) }