use crate::deit::config::DeiTConfig;
use crate::deit::model::DeiTModel;
use scirs2_core::ndarray::{Array2, Array4, Ix2};
use std::collections::HashMap;
use trustformers_core::device::Device;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::layers::linear::Linear;
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Layer;
#[derive(Debug, Clone)]
pub struct DeiTForImageClassification {
pub deit: DeiTModel,
pub cls_head: Linear,
pub distill_head: Option<Linear>,
device: Device,
}
impl DeiTForImageClassification {
pub fn new(config: DeiTConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: DeiTConfig, device: Device) -> Result<Self> {
let distill_head = if config.use_distillation_token {
Some(Linear::new_with_device(
config.hidden_size,
config.num_labels,
true,
device,
))
} else {
None
};
Ok(Self {
cls_head: Linear::new_with_device(
config.hidden_size,
config.num_labels,
true,
device,
),
distill_head,
deit: DeiTModel::new_with_device(config, device)?,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(&self, images: &Array4<f32>) -> Result<Array2<f32>> {
let hidden = self.deit.forward(images)?;
let cls_repr = hidden.slice(scirs2_core::ndarray::s![.., 0, ..]).to_owned();
let cls_tensor = Tensor::F32(cls_repr.into_dyn());
let cls_logits = match self.cls_head.forward(cls_tensor)? {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix2>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::invalid_input_simple(
"Expected F32 from cls_head".to_string(),
))
},
};
if let Some(ref dist_head) = self.distill_head {
let dist_repr = hidden.slice(scirs2_core::ndarray::s![.., 1, ..]).to_owned();
let dist_tensor = Tensor::F32(dist_repr.into_dyn());
let dist_logits = match dist_head.forward(dist_tensor)? {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix2>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::invalid_input_simple(
"Expected F32 from distill_head".to_string(),
))
},
};
Ok((cls_logits + dist_logits) * 0.5)
} else {
Ok(cls_logits)
}
}
pub fn forward_cls_only(&self, images: &Array4<f32>) -> Result<Array2<f32>> {
let cls_repr = self.deit.get_cls_output(images)?;
let cls_tensor = Tensor::F32(cls_repr.into_dyn());
match self.cls_head.forward(cls_tensor)? {
Tensor::F32(arr) => Ok(arr
.into_dimensionality::<Ix2>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?),
_ => Err(TrustformersError::invalid_input_simple(
"Expected F32 from cls_head".to_string(),
)),
}
}
pub fn weight_map(&self) -> HashMap<String, String> {
let mut map = HashMap::new();
map.insert(
"deit.embeddings.patch_embeddings.projection.weight".to_string(),
"deit.embeddings.patch_embeddings.projection.weight".to_string(),
);
map.insert(
"deit.embeddings.patch_embeddings.projection.bias".to_string(),
"deit.embeddings.patch_embeddings.projection.bias".to_string(),
);
map.insert(
"deit.embeddings.cls_token".to_string(),
"deit.embeddings.cls_token".to_string(),
);
if self.deit.config.use_distillation_token {
map.insert(
"deit.embeddings.distillation_token".to_string(),
"deit.embeddings.distillation_token".to_string(),
);
}
map.insert(
"deit.embeddings.position_embeddings".to_string(),
"deit.embeddings.position_embeddings".to_string(),
);
for i in 0..self.deit.config.num_hidden_layers {
let prefix = format!("deit.encoder.layer.{}", i);
map.insert(
format!("{}.attention.attention.query.weight", prefix),
format!("encoder.layers.{}.attention.attention.query.weight", i),
);
map.insert(
format!("{}.attention.attention.key.weight", prefix),
format!("encoder.layers.{}.attention.attention.key.weight", i),
);
map.insert(
format!("{}.attention.attention.value.weight", prefix),
format!("encoder.layers.{}.attention.attention.value.weight", i),
);
map.insert(
format!("{}.layernorm_before.weight", prefix),
format!("encoder.layers.{}.attention.layer_norm.weight", i),
);
map.insert(
format!("{}.layernorm_after.weight", prefix),
format!("encoder.layers.{}.mlp.layer_norm.weight", i),
);
map.insert(
format!("{}.intermediate.dense.weight", prefix),
format!("encoder.layers.{}.mlp.feed_forward.fc1.weight", i),
);
map.insert(
format!("{}.output.dense.weight", prefix),
format!("encoder.layers.{}.mlp.feed_forward.fc2.weight", i),
);
}
map.insert(
"deit.layernorm.weight".to_string(),
"deit.layer_norm.weight".to_string(),
);
map.insert(
"deit.layernorm.bias".to_string(),
"deit.layer_norm.bias".to_string(),
);
map.insert(
"classifier.weight".to_string(),
"cls_head.weight".to_string(),
);
map.insert("classifier.bias".to_string(), "cls_head.bias".to_string());
if self.deit.config.use_distillation_token {
map.insert(
"dist_head.weight".to_string(),
"distill_head.weight".to_string(),
);
map.insert(
"dist_head.bias".to_string(),
"distill_head.bias".to_string(),
);
}
map
}
}