use crate::deberta::config::DebertaConfig;
use scirs2_core::ndarray::{s, Array1, Array2, Array3, Array4, Axis, Ix2, Ix3}; use trustformers_core::device::Device;
use trustformers_core::errors::{Result, TrustformersError};
use trustformers_core::layers::{
embedding::Embedding, feedforward::FeedForward, layernorm::LayerNorm, linear::Linear,
};
use trustformers_core::ops::activations::gelu;
use trustformers_core::tensor::Tensor;
use trustformers_core::traits::Layer;
#[derive(Debug, Clone)]
pub struct DebertaEmbeddings {
pub word_embeddings: Embedding,
pub layer_norm: LayerNorm,
pub dropout: f32,
device: Device,
}
impl DebertaEmbeddings {
pub fn new(config: &DebertaConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &DebertaConfig, device: Device) -> Result<Self> {
Ok(Self {
word_embeddings: Embedding::new_with_device(
config.vocab_size,
config.hidden_size,
Some(config.pad_token_id as usize),
device,
)?,
layer_norm: LayerNorm::new_with_device(
vec![config.hidden_size],
config.layer_norm_eps,
device,
)?,
dropout: config.hidden_dropout_prob,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(&self, input_ids: &Array1<u32>) -> Result<Array2<f32>> {
let input_ids_slice = input_ids.as_slice().ok_or_else(|| {
TrustformersError::tensor_op_error("forward", "input_ids is not contiguous in memory")
})?;
let embeddings = self.word_embeddings.forward_ids(input_ids_slice)?;
let embeddings_2d = match embeddings {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix2>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor for word embeddings",
"embeddings",
))
},
};
let norm_input = Tensor::F32(embeddings_2d.clone().into_dyn());
let embeddings = self.layer_norm.forward(norm_input)?;
let embeddings_2d = match embeddings {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix2>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor after layer norm",
"embeddings",
))
},
};
Ok(embeddings_2d * (1.0 - self.dropout))
}
}
#[derive(Debug, Clone)]
pub struct DebertaDisentangledSelfAttention {
pub query_proj: Linear,
pub key_proj: Linear,
pub value_proj: Linear,
pub pos_query_proj: Option<Linear>, pub pos_key_proj: Option<Linear>, pub pos_proj: Option<Linear>, pub dropout: f32,
pub num_attention_heads: usize,
pub attention_head_size: usize,
pub all_head_size: usize,
pub max_relative_positions: i32,
pub pos_att_type: Vec<String>,
pub share_att_key: bool,
device: Device,
}
impl DebertaDisentangledSelfAttention {
pub fn new(config: &DebertaConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &DebertaConfig, device: Device) -> Result<Self> {
let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size;
let pos_query_proj = if config.pos_att_type.contains(&"c2p".to_string()) {
Some(Linear::new_with_device(
config.hidden_size,
all_head_size,
true,
device,
))
} else {
None
};
let pos_key_proj =
if config.pos_att_type.contains(&"p2c".to_string()) && !config.share_att_key {
Some(Linear::new_with_device(
config.hidden_size,
all_head_size,
true,
device,
))
} else {
None
};
let pos_proj = if config.max_relative_positions > 0 {
Some(Linear::new_with_device(
config.max_relative_positions as usize * 2,
all_head_size,
false,
device,
))
} else {
None
};
Ok(Self {
query_proj: Linear::new_with_device(config.hidden_size, all_head_size, true, device),
key_proj: Linear::new_with_device(config.hidden_size, all_head_size, true, device),
value_proj: Linear::new_with_device(config.hidden_size, all_head_size, true, device),
pos_query_proj,
pos_key_proj,
pos_proj,
dropout: config.attention_probs_dropout_prob,
num_attention_heads: config.num_attention_heads,
attention_head_size,
all_head_size,
max_relative_positions: config.max_relative_positions,
pos_att_type: config.pos_att_type.clone(),
share_att_key: config.share_att_key,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
fn transpose_for_scores(&self, x: &Array3<f32>) -> Result<Array4<f32>> {
let (batch_size, seq_len, _) = x.dim();
let reshaped = x
.to_shape((
batch_size,
seq_len,
self.num_attention_heads,
self.attention_head_size,
))
.map_err(|e| {
TrustformersError::shape_error(format!("Failed to reshape tensor: {}", e))
})?
.to_owned();
Ok(reshaped.permuted_axes([0, 2, 1, 3]))
}
fn build_relative_position(&self, query_size: usize, key_size: usize) -> Array2<i32> {
let mut relative_positions = Array2::zeros((query_size, key_size));
for i in 0..query_size {
for j in 0..key_size {
let relative_pos = i as i32 - j as i32;
let clamped_pos = if self.max_relative_positions > 0 {
relative_pos.clamp(-self.max_relative_positions, self.max_relative_positions)
} else {
relative_pos
};
relative_positions[[i, j]] = clamped_pos;
}
}
relative_positions
}
pub fn forward(
&self,
hidden_states: &Array3<f32>,
attention_mask: Option<&Array3<f32>>,
) -> Result<Array3<f32>> {
let (batch_size, seq_len, _hidden_size) = hidden_states.dim();
let query_input = Tensor::F32(hidden_states.clone().into_dyn());
let key_input = Tensor::F32(hidden_states.clone().into_dyn());
let value_input = Tensor::F32(hidden_states.clone().into_dyn());
let query_layer = self.query_proj.forward(query_input)?;
let key_layer = self.key_proj.forward(key_input)?;
let value_layer = self.value_proj.forward(value_input)?;
let query_layer = match query_layer {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix3>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor from query projection",
"attention",
))
},
};
let key_layer = match key_layer {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix3>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor from key projection",
"attention",
))
},
};
let value_layer = match value_layer {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix3>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor from value projection",
"attention",
))
},
};
let query_layer = self.transpose_for_scores(&query_layer)?;
let key_layer = self.transpose_for_scores(&key_layer)?;
let value_layer = self.transpose_for_scores(&value_layer)?;
let mut attention_scores =
Array4::zeros((batch_size, self.num_attention_heads, seq_len, seq_len));
for b in 0..batch_size {
for h in 0..self.num_attention_heads {
let q = query_layer.slice(s![b, h, .., ..]);
let k = key_layer.slice(s![b, h, .., ..]);
for i in 0..seq_len {
for j in 0..seq_len {
let score: f32 = q
.slice(s![i, ..])
.iter()
.zip(k.slice(s![j, ..]).iter())
.map(|(a, b)| a * b)
.sum();
attention_scores[[b, h, i, j]] =
score / (self.attention_head_size as f32).sqrt();
}
}
}
}
if self.pos_att_type.contains(&"c2p".to_string()) {
if let Some(pos_query_proj) = &self.pos_query_proj {
let pos_query_input = Tensor::F32(hidden_states.clone().into_dyn());
let pos_query_result = pos_query_proj.forward(pos_query_input)?;
let pos_query_layer = match pos_query_result {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix3>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor from pos query projection",
"attention",
))
},
};
let _pos_query_layer = self.transpose_for_scores(&pos_query_layer)?;
let relative_pos = self.build_relative_position(seq_len, seq_len);
for b in 0..batch_size {
for h in 0..self.num_attention_heads {
for i in 0..seq_len {
for j in 0..seq_len {
let pos_bias = relative_pos[[i, j]] as f32 * 0.01; attention_scores[[b, h, i, j]] += pos_bias;
}
}
}
}
}
}
if self.pos_att_type.contains(&"p2c".to_string()) {
let relative_pos = self.build_relative_position(seq_len, seq_len);
for b in 0..batch_size {
for h in 0..self.num_attention_heads {
for i in 0..seq_len {
for j in 0..seq_len {
let pos_bias = relative_pos[[i, j]] as f32 * 0.01; attention_scores[[b, h, i, j]] += pos_bias;
}
}
}
}
}
if let Some(mask) = attention_mask {
for b in 0..batch_size {
for h in 0..self.num_attention_heads {
for i in 0..seq_len {
for j in 0..seq_len {
if mask[[b, i, j]] == 0.0 {
attention_scores[[b, h, i, j]] = -10000.0; }
}
}
}
}
}
let mut attention_probs =
Array4::zeros((batch_size, self.num_attention_heads, seq_len, seq_len));
for b in 0..batch_size {
for h in 0..self.num_attention_heads {
for i in 0..seq_len {
let mut max_val = f32::NEG_INFINITY;
for j in 0..seq_len {
max_val = max_val.max(attention_scores[[b, h, i, j]]);
}
let mut sum_exp = 0.0;
for j in 0..seq_len {
let exp_val = (attention_scores[[b, h, i, j]] - max_val).exp();
attention_probs[[b, h, i, j]] = exp_val;
sum_exp += exp_val;
}
for j in 0..seq_len {
attention_probs[[b, h, i, j]] /= sum_exp;
}
}
}
}
attention_probs *= 1.0 - self.dropout;
let mut context_layer = Array4::zeros((
batch_size,
self.num_attention_heads,
seq_len,
self.attention_head_size,
));
for b in 0..batch_size {
for h in 0..self.num_attention_heads {
for i in 0..seq_len {
for d in 0..self.attention_head_size {
let mut sum = 0.0;
for j in 0..seq_len {
sum += attention_probs[[b, h, i, j]] * value_layer[[b, h, j, d]];
}
context_layer[[b, h, i, d]] = sum;
}
}
}
}
let context_layer = context_layer.permuted_axes([0, 2, 1, 3]);
let context_layer = context_layer
.to_shape((batch_size, seq_len, self.all_head_size))
.expect("operation failed")
.to_owned();
Ok(context_layer)
}
}
#[derive(Debug, Clone)]
pub struct DebertaSelfOutput {
pub dense: Linear,
pub layer_norm: LayerNorm,
pub dropout: f32,
device: Device,
}
impl DebertaSelfOutput {
pub fn new(config: &DebertaConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &DebertaConfig, device: Device) -> Result<Self> {
Ok(Self {
dense: Linear::new_with_device(config.hidden_size, config.hidden_size, true, device),
layer_norm: LayerNorm::new_with_device(
vec![config.hidden_size],
config.layer_norm_eps,
device,
)?,
dropout: config.hidden_dropout_prob,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
hidden_states: &Array3<f32>,
input_tensor: &Array3<f32>,
) -> Result<Array3<f32>> {
let dense_input = Tensor::F32(hidden_states.clone().into_dyn());
let dense_output = self.dense.forward(dense_input)?;
let hidden_states = match dense_output {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix3>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor from dense layer",
"dense_layer",
))
},
};
let hidden_states = hidden_states * (1.0 - self.dropout);
let residual = hidden_states + input_tensor;
let norm_input = Tensor::F32(residual.into_dyn());
let output = self.layer_norm.forward(norm_input)?;
let output = match output {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix3>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor from layer norm",
"layer_norm",
))
},
};
Ok(output)
}
}
#[derive(Debug, Clone)]
pub struct DebertaAttention {
pub self_attention: DebertaDisentangledSelfAttention,
pub output: DebertaSelfOutput,
device: Device,
}
impl DebertaAttention {
pub fn new(config: &DebertaConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &DebertaConfig, device: Device) -> Result<Self> {
Ok(Self {
self_attention: DebertaDisentangledSelfAttention::new_with_device(config, device)?,
output: DebertaSelfOutput::new_with_device(config, device)?,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
hidden_states: &Array3<f32>,
attention_mask: Option<&Array3<f32>>,
) -> Result<Array3<f32>> {
let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?;
let attention_output = self.output.forward(&self_outputs, hidden_states)?;
Ok(attention_output)
}
}
#[derive(Debug, Clone)]
pub struct DebertaLayer {
pub attention: DebertaAttention,
pub feed_forward: FeedForward,
pub output_layer_norm: LayerNorm,
pub dropout: f32,
device: Device,
}
impl DebertaLayer {
pub fn new(config: &DebertaConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &DebertaConfig, device: Device) -> Result<Self> {
Ok(Self {
attention: DebertaAttention::new_with_device(config, device)?,
feed_forward: FeedForward::new_with_device(
config.hidden_size,
config.intermediate_size,
config.hidden_dropout_prob,
device,
),
output_layer_norm: LayerNorm::new_with_device(
vec![config.hidden_size],
config.layer_norm_eps,
device,
)?,
dropout: config.hidden_dropout_prob,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
hidden_states: &Array3<f32>,
attention_mask: Option<&Array3<f32>>,
) -> Result<Array3<f32>> {
let attention_output = self.attention.forward(hidden_states, attention_mask)?;
let ff_input = Tensor::F32(attention_output.clone().into_dyn());
let ff_output = self.feed_forward.forward(ff_input)?;
let ff_output = match ff_output {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix3>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor from feed forward",
"feed_forward",
))
},
};
let ff_output = ff_output * (1.0 - self.dropout);
let residual = &attention_output + &ff_output;
let norm_input = Tensor::F32(residual.into_dyn());
let output = self.output_layer_norm.forward(norm_input)?;
let output = match output {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix3>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor from layer norm",
"layer_norm",
))
},
};
Ok(output)
}
}
#[derive(Debug, Clone)]
pub struct DebertaEncoder {
pub layers: Vec<DebertaLayer>,
device: Device,
}
impl DebertaEncoder {
pub fn new(config: &DebertaConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: &DebertaConfig, device: Device) -> Result<Self> {
let mut layers = Vec::new();
for _ in 0..config.num_hidden_layers {
layers.push(DebertaLayer::new_with_device(config, device)?);
}
Ok(Self { layers, device })
}
pub fn device(&self) -> Device {
self.device
}
pub fn forward(
&self,
mut hidden_states: Array3<f32>,
attention_mask: Option<&Array3<f32>>,
) -> Result<Array3<f32>> {
for layer in &self.layers {
hidden_states = layer.forward(&hidden_states, attention_mask)?;
}
Ok(hidden_states)
}
}
#[derive(Debug, Clone)]
pub struct DebertaModel {
pub embeddings: DebertaEmbeddings,
pub encoder: DebertaEncoder,
pub config: DebertaConfig,
device: Device,
}
impl DebertaModel {
pub fn new(config: DebertaConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: DebertaConfig, device: Device) -> Result<Self> {
Ok(Self {
embeddings: DebertaEmbeddings::new_with_device(&config, device)?,
encoder: DebertaEncoder::new_with_device(&config, device)?,
config,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn from_pretrained(model_name: &str) -> Result<Self> {
let config = DebertaConfig::from_pretrained_name(model_name);
Self::new(config)
}
pub fn forward(
&self,
input_ids: &Array1<u32>,
attention_mask: Option<&Array3<f32>>,
) -> Result<Array3<f32>> {
let embeddings = self.embeddings.forward(input_ids)?;
let hidden_states = embeddings.insert_axis(Axis(0));
let encoder_output = self.encoder.forward(hidden_states, attention_mask)?;
Ok(encoder_output)
}
}
#[derive(Debug, Clone)]
pub struct DebertaForSequenceClassification {
pub deberta: DebertaModel,
pub pooler: Linear,
pub classifier: Linear,
pub dropout: f32,
pub num_labels: usize,
device: Device,
}
impl DebertaForSequenceClassification {
pub fn new(config: DebertaConfig, num_labels: usize) -> Result<Self> {
Self::new_with_device(config, num_labels, Device::CPU)
}
pub fn new_with_device(
config: DebertaConfig,
num_labels: usize,
device: Device,
) -> Result<Self> {
let dropout = config.classifier_dropout.unwrap_or(config.hidden_dropout_prob);
Ok(Self {
deberta: DebertaModel::new_with_device(config.clone(), device)?,
pooler: Linear::new_with_device(config.hidden_size, config.hidden_size, true, device),
classifier: Linear::new_with_device(config.hidden_size, num_labels, true, device),
dropout,
num_labels,
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn from_pretrained(model_name: &str, num_labels: usize) -> Result<Self> {
let config = DebertaConfig::from_pretrained_name(model_name);
Self::new(config, num_labels)
}
pub fn forward(
&self,
input_ids: &Array1<u32>,
attention_mask: Option<&Array3<f32>>,
) -> Result<Array2<f32>> {
let hidden_states = self.deberta.forward(input_ids, attention_mask)?;
let cls_hidden = hidden_states.slice(s![0, 0, ..]).to_owned();
let pooler_input = Tensor::F32(cls_hidden.insert_axis(Axis(0)).into_dyn());
let pooled_output = self.pooler.forward(pooler_input)?;
let pooled_output = match pooled_output {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix2>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor from pooler",
"pooler",
))
},
};
let pooled_tensor = Tensor::F32(pooled_output.into_dyn());
let pooled_output = gelu(&pooled_tensor)?;
let pooled_output = match pooled_output {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix2>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor from gelu",
"gelu",
))
},
};
let pooled_output = pooled_output * (1.0 - self.dropout);
let classifier_input = Tensor::F32(pooled_output.into_dyn());
let logits = self.classifier.forward(classifier_input)?;
let logits = match logits {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix2>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor from classifier",
"classifier",
))
},
};
Ok(logits)
}
}
#[derive(Debug, Clone)]
pub struct DebertaForMaskedLM {
pub deberta: DebertaModel,
pub cls: Linear,
device: Device,
}
impl DebertaForMaskedLM {
pub fn new(config: DebertaConfig) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn new_with_device(config: DebertaConfig, device: Device) -> Result<Self> {
Ok(Self {
deberta: DebertaModel::new_with_device(config.clone(), device)?,
cls: Linear::new_with_device(config.hidden_size, config.vocab_size, true, device),
device,
})
}
pub fn device(&self) -> Device {
self.device
}
pub fn from_pretrained(model_name: &str) -> Result<Self> {
let config = DebertaConfig::from_pretrained_name(model_name);
Self::new(config)
}
pub fn forward(
&self,
input_ids: &Array1<u32>,
attention_mask: Option<&Array3<f32>>,
) -> Result<Array3<f32>> {
let hidden_states = self.deberta.forward(input_ids, attention_mask)?;
let cls_input = Tensor::F32(hidden_states.clone().into_dyn());
let prediction_scores = self.cls.forward(cls_input)?;
let prediction_scores = match prediction_scores {
Tensor::F32(arr) => arr
.into_dimensionality::<Ix3>()
.map_err(|e| TrustformersError::shape_error(e.to_string()))?,
_ => {
return Err(TrustformersError::tensor_op_error(
"Expected F32 tensor from cls layer",
"cls_layer",
))
},
};
Ok(prediction_scores)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::deberta::config::DebertaConfig;
use scirs2_core::ndarray::Array1;
use trustformers_core::traits::Config;
fn mini_config() -> DebertaConfig {
DebertaConfig {
vocab_size: 100,
hidden_size: 64,
num_hidden_layers: 1,
num_attention_heads: 4,
intermediate_size: 256,
hidden_act: "gelu".to_string(),
hidden_dropout_prob: 0.0,
attention_probs_dropout_prob: 0.0,
max_position_embeddings: 32,
type_vocab_size: 0,
initializer_range: 0.02,
layer_norm_eps: 1e-7,
pad_token_id: 0,
position_embedding_type: "relative_key_query".to_string(),
use_cache: true,
classifier_dropout: None,
relative_attention: true,
max_relative_positions: -1,
pos_att_type: vec!["p2c".to_string(), "c2p".to_string()],
norm_rel_ebd: "layer_norm".to_string(),
share_att_key: true,
model_type: "deberta".to_string(),
}
}
fn sample_ids(len: usize) -> Array1<u32> {
(0..len as u32).collect()
}
#[test]
fn test_deberta_embeddings_new_succeeds() {
let cfg = mini_config();
DebertaEmbeddings::new(&cfg).expect("DebertaEmbeddings::new should succeed");
}
#[test]
fn test_deberta_embeddings_forward_shape() {
let cfg = mini_config();
let emb = DebertaEmbeddings::new(&cfg).expect("DebertaEmbeddings::new failed");
let ids: Array1<u32> = sample_ids(6);
let out = emb.forward(&ids).expect("DebertaEmbeddings::forward failed");
assert_eq!(out.shape(), &[6, cfg.hidden_size]);
}
#[test]
fn test_deberta_disentangled_attention_new_with_relative() {
let cfg = mini_config();
DebertaDisentangledSelfAttention::new(&cfg)
.expect("DebertaDisentangledSelfAttention::new should succeed");
}
#[test]
fn test_deberta_disentangled_attention_new_without_relative() {
let mut cfg = mini_config();
cfg.relative_attention = false;
DebertaDisentangledSelfAttention::new(&cfg)
.expect("DebertaDisentangledSelfAttention without relative should succeed");
}
#[test]
fn test_deberta_disentangled_attention_pos_att_type_p2c_and_c2p() {
let cfg = mini_config();
let attn = DebertaDisentangledSelfAttention::new(&cfg).expect("attention creation failed");
assert!(
cfg.pos_att_type.contains(&"p2c".to_string()),
"pos_att_type should contain 'p2c'"
);
assert!(
cfg.pos_att_type.contains(&"c2p".to_string()),
"pos_att_type should contain 'c2p'"
);
assert!(attn.pos_query_proj.is_some() || attn.pos_key_proj.is_some());
}
#[test]
fn test_deberta_self_output_new_succeeds() {
let cfg = mini_config();
DebertaSelfOutput::new(&cfg).expect("DebertaSelfOutput::new should succeed");
}
#[test]
fn test_deberta_attention_new_succeeds() {
let cfg = mini_config();
DebertaAttention::new(&cfg).expect("DebertaAttention::new should succeed");
}
#[test]
fn test_deberta_layer_new_succeeds() {
let cfg = mini_config();
DebertaLayer::new(&cfg).expect("DebertaLayer::new should succeed");
}
#[test]
fn test_deberta_encoder_new_single_layer() {
let cfg = mini_config();
DebertaEncoder::new(&cfg).expect("DebertaEncoder::new should succeed");
}
#[test]
fn test_deberta_encoder_new_multi_layer() {
let mut cfg = mini_config();
cfg.num_hidden_layers = 2;
DebertaEncoder::new(&cfg).expect("DebertaEncoder with 2 layers should succeed");
}
#[test]
fn test_deberta_model_new_with_base_config() {
let cfg = mini_config();
DebertaModel::new(cfg).expect("DebertaModel::new should succeed");
}
#[test]
fn test_deberta_model_forward_output_shape() {
let cfg = mini_config();
let model = DebertaModel::new(cfg.clone()).expect("DebertaModel::new failed");
let ids: Array1<u32> = sample_ids(5);
let out = model.forward(&ids, None).expect("DebertaModel::forward failed");
assert_eq!(out.shape(), &[1, 5, cfg.hidden_size]);
}
#[test]
fn test_deberta_model_from_pretrained_deberta_base() {
let _model = DebertaModel::from_pretrained("deberta-base")
.expect("from_pretrained deberta-base should succeed");
}
#[test]
fn test_deberta_model_from_pretrained_deberta_large() {
let _model = DebertaModel::from_pretrained("deberta-large")
.expect("from_pretrained deberta-large should succeed");
}
#[test]
fn test_deberta_v2_xlarge_vocab_size() {
let cfg = DebertaConfig::xlarge();
assert_eq!(
cfg.vocab_size, 128100,
"DeBERTa-v2 xlarge should have vocab_size=128100"
);
}
#[test]
fn test_deberta_v3_large_vocab_size() {
let cfg = DebertaConfig::v3_large();
assert_eq!(
cfg.vocab_size, 128100,
"DeBERTa-v3 large should have vocab_size=128100"
);
}
#[test]
fn test_deberta_default_share_att_key_true() {
let cfg = DebertaConfig::default();
assert!(cfg.share_att_key, "share_att_key should default to true");
}
#[test]
fn test_deberta_seq_class_new_two_labels() {
let cfg = mini_config();
DebertaForSequenceClassification::new(cfg, 2)
.expect("DebertaForSequenceClassification with 2 labels failed");
}
#[test]
fn test_deberta_seq_class_forward_output_shape() {
let cfg = mini_config();
let model = DebertaForSequenceClassification::new(cfg, 2).expect("model creation failed");
let ids: Array1<u32> = sample_ids(4);
let out = model.forward(&ids, None).expect("forward should succeed");
assert_eq!(out.shape(), &[1, 2]);
}
#[test]
fn test_deberta_seq_class_three_labels_output_shape() {
let cfg = mini_config();
let model =
DebertaForSequenceClassification::new(cfg, 3).expect("model with 3 labels failed");
let ids: Array1<u32> = sample_ids(4);
let out = model.forward(&ids, None).expect("forward should succeed");
assert_eq!(out.shape(), &[1, 3]);
}
#[test]
fn test_deberta_masked_lm_new_succeeds() {
let cfg = mini_config();
DebertaForMaskedLM::new(cfg).expect("DebertaForMaskedLM::new should succeed");
}
#[test]
fn test_deberta_masked_lm_forward_output_shape() {
let cfg = mini_config();
let model = DebertaForMaskedLM::new(cfg.clone()).expect("model creation failed");
let ids: Array1<u32> = sample_ids(4);
let out = model.forward(&ids, None).expect("forward should succeed");
assert_eq!(out.shape(), &[1, 4, cfg.vocab_size]);
}
#[test]
fn test_deberta_mini_config_validates() {
let cfg = mini_config();
cfg.validate().expect("mini_config should be valid");
}
#[test]
fn test_deberta_base_config_validates() {
let cfg = DebertaConfig::base();
cfg.validate().expect("base config should be valid");
}
}