use std::cell::Cell;
use flodl::nn::{Dropout, GeluApprox, LayerNorm, Linear, Module, Parameter, GELU};
use flodl::{DType, Device, Result, Tensor, TensorOptions, Variable};
use crate::path::prefix_params;
#[derive(Debug, Clone, Copy)]
pub struct DebertaV2LayerConfig {
pub hidden_size: i64,
pub num_attention_heads: i64,
pub intermediate_size: i64,
pub hidden_dropout_prob: f64,
pub attention_probs_dropout_prob: f64,
pub layer_norm_eps: f64,
pub position_buckets: i64,
pub max_relative_positions: i64,
pub hidden_act: GeluApprox,
}
fn make_log_bucket_position(
rel_pos: &Tensor,
bucket_size: i64,
max_position: i64,
) -> Result<Tensor> {
let mid = bucket_size / 2;
let device = rel_pos.device();
let f32_opts = TensorOptions { dtype: DType::Float32, device };
let rp_f = rel_pos.to_dtype(DType::Float32)?;
let sign = rp_f.sign()?;
let abs_rp = rp_f.abs()?;
let near_mask = rp_f
.lt_scalar(mid as f64)?
.logical_and(&rp_f.gt_scalar(-(mid as f64))?)?;
let mid_minus_one_tensor = Tensor::full(&rel_pos.shape(), (mid - 1) as f64, f32_opts)?;
let abs_pos = Tensor::where_cond(&near_mask, &mid_minus_one_tensor, &abs_rp)?;
let log_denom = ((max_position as f64 - 1.0) / mid as f64).ln();
let log_pos = abs_pos
.div_scalar(mid as f64)?
.log()?
.div_scalar(log_denom)?
.clamp(0.0, 1.0)?
.mul_scalar((mid - 1) as f64)?
.ceil()?
.add_scalar(mid as f64)?;
let in_range = abs_rp.le_scalar(mid as f64)?;
let log_signed = log_pos.mul(&sign)?;
let bucket = Tensor::where_cond(&in_range, &rp_f, &log_signed)?;
bucket.to_dtype(DType::Int64)
}
pub fn build_relative_position(
seq_len: i64,
position_buckets: i64,
max_relative_positions: i64,
device: Device,
) -> Result<Tensor> {
let i64_opts = TensorOptions { dtype: DType::Int64, device };
let ids = Tensor::arange(0.0, seq_len as f64, 1.0, i64_opts)?;
let q = ids.unsqueeze(-1)?.expand(&[seq_len, seq_len])?.contiguous()?;
let k = ids.unsqueeze(0)?.expand(&[seq_len, seq_len])?.contiguous()?;
let rel = q.sub(&k)?;
let bucketed = make_log_bucket_position(&rel, position_buckets, max_relative_positions)?;
bucketed.unsqueeze(0)
}
pub struct DisentangledSelfAttention {
query_proj: Linear,
key_proj: Linear,
value_proj: Linear,
num_heads: i64,
head_dim: i64,
attn_dropout: Dropout,
pos_dropout: Dropout,
position_buckets: i64,
#[allow(dead_code)]
max_relative_positions: i64,
training: Cell<bool>,
}
impl DisentangledSelfAttention {
pub fn on_device(config: &DebertaV2LayerConfig, device: Device) -> Result<Self> {
assert!(
config.hidden_size % config.num_attention_heads == 0,
"hidden_size ({}) must be divisible by num_attention_heads ({})",
config.hidden_size, config.num_attention_heads,
);
let head_dim = config.hidden_size / config.num_attention_heads;
Ok(DisentangledSelfAttention {
query_proj: Linear::on_device(config.hidden_size, config.hidden_size, device)?,
key_proj: Linear::on_device(config.hidden_size, config.hidden_size, device)?,
value_proj: Linear::on_device(config.hidden_size, config.hidden_size, device)?,
num_heads: config.num_attention_heads,
head_dim,
attn_dropout: Dropout::new(config.attention_probs_dropout_prob),
pos_dropout: Dropout::new(config.hidden_dropout_prob),
position_buckets: config.position_buckets,
max_relative_positions: config.max_relative_positions,
training: Cell::new(true),
})
}
fn split_heads(&self, x: &Variable) -> Result<Variable> {
let shape = x.shape();
let batch = shape[0];
let seq = shape[1];
x.reshape(&[batch, seq, self.num_heads, self.head_dim])?
.transpose(1, 2)?
.reshape(&[batch * self.num_heads, seq, self.head_dim])
}
fn merge_heads(&self, x: &Variable, batch: i64) -> Result<Variable> {
let shape = x.shape();
let seq = shape[1];
x.reshape(&[batch, self.num_heads, seq, self.head_dim])?
.transpose(1, 2)?
.reshape(&[batch, seq, self.num_heads * self.head_dim])
}
fn disentangled_bias(
&self,
query_layer: &Variable, key_layer: &Variable, relative_pos: &Tensor, rel_embeddings: &Variable, scale: f64,
) -> Result<Variable> {
let att_span = self.position_buckets;
let two_span = att_span * 2;
let rel = rel_embeddings
.narrow(0, 0, two_span)?
.unsqueeze(0)?;
let rel = self.pos_dropout.forward(&rel)?;
let bh = query_layer.shape()[0]; let batch = bh / self.num_heads;
let pos_key = self.split_heads(&self.key_proj.forward(&rel)?)? .repeat(&[batch, 1, 1])?; let pos_query = self.split_heads(&self.query_proj.forward(&rel)?)?
.repeat(&[batch, 1, 1])?;
let c2p_scores = query_layer.matmul(&pos_key.transpose(-1, -2)?)?; let c2p_pos = relative_pos
.add_scalar(att_span as f64)?
.clamp(0.0, (two_span - 1) as f64)?
.to_dtype(DType::Int64)?;
let s = c2p_scores.shape()[1]; let c2p_idx = c2p_pos
.squeeze(0)?
.expand(&[bh, s, s])?
.contiguous()?;
let c2p_att = c2p_scores.gather(-1, &c2p_idx)?;
let p2c_scores = key_layer.matmul(&pos_query.transpose(-1, -2)?)?; let p2c_pos = relative_pos
.mul_scalar(-1.0)?
.add_scalar(att_span as f64)?
.clamp(0.0, (two_span - 1) as f64)?
.to_dtype(DType::Int64)?;
let p2c_idx = p2c_pos
.squeeze(0)?
.expand(&[bh, s, s])?
.contiguous()?;
let p2c_att = p2c_scores.gather(-1, &p2c_idx)?.transpose(-1, -2)?;
let scaled_c2p = c2p_att.div_scalar(scale)?;
let scaled_p2c = p2c_att.div_scalar(scale)?;
scaled_c2p.add(&scaled_p2c)
}
pub fn forward(
&self,
hidden_states: &Variable,
attention_mask: &Variable,
relative_pos: &Tensor,
rel_embeddings: &Variable,
) -> Result<Variable> {
let batch = hidden_states.shape()[0];
let seq = hidden_states.shape()[1];
let q = self.split_heads(&self.query_proj.forward(hidden_states)?)?;
let k = self.split_heads(&self.key_proj.forward(hidden_states)?)?;
let v = self.split_heads(&self.value_proj.forward(hidden_states)?)?;
let scale = ((self.head_dim as f64) * 3.0).sqrt();
let kt = k.transpose(-1, -2)?;
let c2c = q.matmul(&kt)?.div_scalar(scale)?;
let bias = self.disentangled_bias(&q, &k, relative_pos, rel_embeddings, scale)?;
let scores = c2c.add(&bias)?;
let scores = scores.reshape(&[batch, self.num_heads, seq, seq])?;
let scores = scores.add(attention_mask)?;
let probs = scores.softmax(-1)?;
let probs = self.attn_dropout.forward(&probs)?;
let probs = probs.reshape(&[batch * self.num_heads, seq, seq])?;
let context = probs.matmul(&v)?; self.merge_heads(&context, batch)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut out = Vec::new();
out.extend(prefix_params("query_proj", self.query_proj.parameters()));
out.extend(prefix_params("key_proj", self.key_proj.parameters()));
out.extend(prefix_params("value_proj", self.value_proj.parameters()));
out
}
pub fn set_training(&self, training: bool) {
self.training.set(training);
self.attn_dropout.set_training(training);
self.pos_dropout.set_training(training);
}
}
pub struct DebertaV2SelfOutput {
dense: Linear,
layer_norm: LayerNorm,
dropout: Dropout,
}
impl DebertaV2SelfOutput {
pub fn on_device(config: &DebertaV2LayerConfig, device: Device) -> Result<Self> {
Ok(DebertaV2SelfOutput {
dense: Linear::on_device(config.hidden_size, config.hidden_size, device)?,
layer_norm: LayerNorm::on_device_with_eps(config.hidden_size, config.layer_norm_eps, device)?,
dropout: Dropout::new(config.hidden_dropout_prob),
})
}
fn forward(&self, hidden: &Variable, residual: &Variable) -> Result<Variable> {
let x = self.dense.forward(hidden)?;
let x = self.dropout.forward(&x)?;
self.layer_norm.forward(&x.add(residual)?)
}
fn parameters(&self) -> Vec<Parameter> {
let mut out = prefix_params("dense", self.dense.parameters());
out.extend(prefix_params("LayerNorm", self.layer_norm.parameters()));
out
}
fn set_training(&self, training: bool) {
self.dropout.set_training(training);
}
}
pub struct DebertaV2Intermediate {
dense: Linear,
activation: GELU,
}
impl DebertaV2Intermediate {
pub fn on_device(config: &DebertaV2LayerConfig, device: Device) -> Result<Self> {
Ok(DebertaV2Intermediate {
dense: Linear::on_device(config.hidden_size, config.intermediate_size, device)?,
activation: GELU::with_approximate(config.hidden_act),
})
}
fn forward(&self, input: &Variable) -> Result<Variable> {
let x = self.dense.forward(input)?;
self.activation.forward(&x)
}
fn parameters(&self) -> Vec<Parameter> {
prefix_params("dense", self.dense.parameters())
}
}
pub struct DebertaV2Output {
dense: Linear,
layer_norm: LayerNorm,
dropout: Dropout,
}
impl DebertaV2Output {
pub fn on_device(config: &DebertaV2LayerConfig, device: Device) -> Result<Self> {
Ok(DebertaV2Output {
dense: Linear::on_device(config.intermediate_size, config.hidden_size, device)?,
layer_norm: LayerNorm::on_device_with_eps(config.hidden_size, config.layer_norm_eps, device)?,
dropout: Dropout::new(config.hidden_dropout_prob),
})
}
fn forward(&self, input: &Variable, residual: &Variable) -> Result<Variable> {
let x = self.dense.forward(input)?;
let x = self.dropout.forward(&x)?;
self.layer_norm.forward(&x.add(residual)?)
}
fn parameters(&self) -> Vec<Parameter> {
let mut out = prefix_params("dense", self.dense.parameters());
out.extend(prefix_params("LayerNorm", self.layer_norm.parameters()));
out
}
fn set_training(&self, training: bool) {
self.dropout.set_training(training);
}
}
pub struct DebertaV2TransformerLayer {
attention_self: DisentangledSelfAttention,
attention_output: DebertaV2SelfOutput,
intermediate: DebertaV2Intermediate,
output: DebertaV2Output,
}
impl DebertaV2TransformerLayer {
pub fn on_device(config: &DebertaV2LayerConfig, device: Device) -> Result<Self> {
Ok(DebertaV2TransformerLayer {
attention_self: DisentangledSelfAttention::on_device(config, device)?,
attention_output: DebertaV2SelfOutput::on_device(config, device)?,
intermediate: DebertaV2Intermediate::on_device(config, device)?,
output: DebertaV2Output::on_device(config, device)?,
})
}
pub fn forward(
&self,
hidden_states: &Variable,
attention_mask: &Variable,
relative_pos: &Tensor,
rel_embeddings: &Variable,
) -> Result<Variable> {
let attn = self.attention_self.forward(
hidden_states, attention_mask, relative_pos, rel_embeddings,
)?;
let attn_out = self.attention_output.forward(&attn, hidden_states)?;
let ffn_mid = self.intermediate.forward(&attn_out)?;
self.output.forward(&ffn_mid, &attn_out)
}
pub fn parameters(&self) -> Vec<Parameter> {
let mut out = Vec::new();
out.extend(prefix_params("attention.self", self.attention_self.parameters()));
out.extend(prefix_params("attention.output", self.attention_output.parameters()));
out.extend(prefix_params("intermediate", self.intermediate.parameters()));
out.extend(prefix_params("output", self.output.parameters()));
out
}
pub fn set_training(&self, training: bool) {
self.attention_self.set_training(training);
self.attention_output.set_training(training);
self.output.set_training(training);
}
}
#[cfg(test)]
mod tests {
use super::*;
fn mini_config() -> DebertaV2LayerConfig {
DebertaV2LayerConfig {
hidden_size: 16,
num_attention_heads: 4,
intermediate_size: 32,
hidden_dropout_prob: 0.0,
attention_probs_dropout_prob: 0.0,
layer_norm_eps: 1e-7,
position_buckets: 4,
max_relative_positions: 8,
hidden_act: GeluApprox::Exact,
}
}
#[test]
fn log_bucket_near_range_passthrough() {
let dev = Device::CPU;
let raw = Tensor::from_i64(&[-1, 0, 1], &[3], dev).unwrap();
let bucketed = make_log_bucket_position(&raw, 4, 8).unwrap();
let out = bucketed.to_i64_vec().unwrap();
assert_eq!(out, vec![-1, 0, 1], "near-range values must pass through");
}
#[test]
fn build_relative_position_antisymmetric() {
let dev = Device::CPU;
let rel = build_relative_position(3, 4, 8, dev).unwrap();
assert_eq!(rel.shape(), vec![1, 3, 3]);
let data = rel.to_i64_vec().unwrap();
let at = |q: usize, k: usize| data[q * 3 + k];
for q in 0..3 {
for k in 0..3 {
assert_eq!(at(q, k), q as i64 - k as i64, "rel[{q}, {k}]");
}
}
}
#[test]
fn log_bucket_far_range_compressed() {
let dev = Device::CPU;
let raw = Tensor::from_i64(&[2, 3, 4, 5, 6, 7], &[6], dev).unwrap();
let bucketed = make_log_bucket_position(&raw, 4, 8).unwrap();
let out = bucketed.to_i64_vec().unwrap();
assert_eq!(out[0], 2);
for (i, &v) in out[1..].iter().enumerate() {
assert!(
(2..=3).contains(&v),
"out[{}] = {} not in [2, 3] for bucket_size=4",
i + 1, v,
);
}
}
#[test]
fn transformer_layer_param_keys() {
let layer = DebertaV2TransformerLayer::on_device(&mini_config(), Device::CPU).unwrap();
let names: Vec<String> = layer.parameters().into_iter().map(|p| p.name).collect();
let expected = [
"attention.self.query_proj.weight",
"attention.self.query_proj.bias",
"attention.self.key_proj.weight",
"attention.self.key_proj.bias",
"attention.self.value_proj.weight",
"attention.self.value_proj.bias",
"attention.output.dense.weight",
"attention.output.dense.bias",
"attention.output.LayerNorm.weight",
"attention.output.LayerNorm.bias",
"intermediate.dense.weight",
"intermediate.dense.bias",
"output.dense.weight",
"output.dense.bias",
"output.LayerNorm.weight",
"output.LayerNorm.bias",
];
assert_eq!(names.len(), expected.len(), "got {names:?}");
for key in expected {
assert!(names.iter().any(|n| n == key), "missing {key} in {names:?}");
}
}
#[test]
fn transformer_layer_forward_shape() {
let cfg = mini_config();
let dev = Device::CPU;
let layer = DebertaV2TransformerLayer::on_device(&cfg, dev).unwrap();
layer.set_training(false);
let batch = 1;
let seq = 3;
let hidden = cfg.hidden_size;
let hidden_data: Vec<f32> = (0..(batch * seq * hidden) as usize)
.map(|i| (i as f32) * 0.01)
.collect();
let x = Variable::new(
Tensor::from_f32(&hidden_data, &[batch, seq, hidden], dev).unwrap(),
false,
);
let mask = Variable::new(
Tensor::zeros(
&[batch, 1, seq, seq],
TensorOptions { dtype: DType::Float32, device: dev },
).unwrap(),
false,
);
let rel_pos = build_relative_position(
seq, cfg.position_buckets, cfg.max_relative_positions, dev,
).unwrap();
let rel_emb_shape = [cfg.position_buckets * 2, hidden];
let rel_emb_data: Vec<f32> = (0..(rel_emb_shape[0] * rel_emb_shape[1]) as usize)
.map(|i| ((i as f32) * 0.003).sin())
.collect();
let rel_emb = Variable::new(
Tensor::from_f32(&rel_emb_data, &rel_emb_shape, dev).unwrap(),
false,
);
let out = layer.forward(&x, &mask, &rel_pos, &rel_emb).unwrap();
assert_eq!(out.shape(), vec![batch, seq, hidden]);
}
#[test]
#[should_panic(expected = "must be divisible")]
fn hidden_size_must_divide_num_heads() {
let mut cfg = mini_config();
cfg.num_attention_heads = 3; let _ = DisentangledSelfAttention::on_device(&cfg, Device::CPU);
}
#[test]
fn attention_mask_is_applied() {
let cfg = mini_config();
let dev = Device::CPU;
let layer = DebertaV2TransformerLayer::on_device(&cfg, dev).unwrap();
layer.set_training(false);
let batch = 1;
let seq = 4;
let hidden = cfg.hidden_size;
let hidden_data: Vec<f32> = (0..(batch * seq * hidden) as usize)
.map(|i| ((i as f32) * 0.017).sin())
.collect();
let x = Variable::new(
Tensor::from_f32(&hidden_data, &[batch, seq, hidden], dev).unwrap(),
false,
);
let rel_pos = build_relative_position(
seq, cfg.position_buckets, cfg.max_relative_positions, dev,
).unwrap();
let rel_emb_data: Vec<f32> = (0..((cfg.position_buckets * 2) * hidden) as usize)
.map(|i| ((i as f32) * 0.003).sin())
.collect();
let rel_emb = Variable::new(
Tensor::from_f32(&rel_emb_data, &[cfg.position_buckets * 2, hidden], dev).unwrap(),
false,
);
let all_attend = Variable::new(
Tensor::zeros(
&[batch, 1, seq, seq],
TensorOptions { dtype: DType::Float32, device: dev },
).unwrap(),
false,
);
let mut mask_data = vec![0.0_f32; (batch * seq * seq) as usize];
mask_data[3] = -1e4; let partial = Variable::new(
Tensor::from_f32(&mask_data, &[batch, 1, seq, seq], dev).unwrap(),
false,
);
let out_all = layer.forward(&x, &all_attend, &rel_pos, &rel_emb).unwrap();
let out_mask = layer.forward(&x, &partial, &rel_pos, &rel_emb).unwrap();
let a: Vec<f32> = out_all.data().to_f32_vec().unwrap();
let b: Vec<f32> = out_mask.data().to_f32_vec().unwrap();
let max_diff = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).fold(0.0_f32, f32::max);
assert!(
max_diff > 1e-5,
"masking one key position must change the output; got max_diff = {max_diff}",
);
}
}