use candle_core::{Device, IndexOp, Tensor};
use snafu::{ResultExt, Snafu};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct AttentionLinearBiasesConfig {
n_attention_heads: usize,
is_causal: bool,
is_inverted: bool,
}
impl AttentionLinearBiasesConfig {
pub fn build(&self) -> Result<AttentionLinearBiases, AttentionLinearBiasesError> {
let slopes = AttentionLinearBiases::calculate_slopes(self.n_attention_heads)?;
Ok(AttentionLinearBiases {
slopes,
is_causal: self.is_causal,
is_inverted: self.is_inverted,
})
}
pub fn n_attention_heads(mut self, n_attention_heads: usize) -> Self {
self.n_attention_heads = n_attention_heads;
self
}
pub fn is_causal(mut self, is_causal: bool) -> Self {
self.is_causal = is_causal;
self
}
pub fn is_inverted(mut self, is_inverted: bool) -> Self {
self.is_inverted = is_inverted;
self
}
}
impl Default for AttentionLinearBiasesConfig {
fn default() -> Self {
Self {
n_attention_heads: 12,
is_causal: false,
is_inverted: false,
}
}
}
#[derive(Debug, Snafu)]
pub enum AttentionLinearBiasesError {
#[snafu(display("Cannot apply attention linear biases"))]
ApplyBiases { source: candle_core::Error },
#[snafu(display("Cannot calculate biases"))]
CalculateBiases { source: candle_core::Error },
#[snafu(display("Cannot calculate slopes"))]
CalculateSlopes { source: candle_core::Error },
}
#[derive(Clone, Debug)]
pub struct AttentionLinearBiases {
slopes: Tensor,
is_causal: bool,
is_inverted: bool,
}
impl AttentionLinearBiases {
#[allow(clippy::unnecessary_cast)]
fn calculate_slopes(n_attention_heads: usize) -> Result<Tensor, AttentionLinearBiasesError> {
fn slopes_with_step(
n_attention_heads: usize,
step: usize,
) -> Result<Tensor, AttentionLinearBiasesError> {
let ratio = 2.0f32.powf(-8.0f32 / n_attention_heads as f32);
let slope = (1..n_attention_heads + 1)
.step_by(step)
.map(|x| ratio.powi(x as i32))
.collect::<Vec<_>>();
Tensor::new(slope, &Device::Cpu).context(CalculateSlopesSnafu)
}
let k = 1 << ((usize::BITS - (n_attention_heads as usize).leading_zeros()) - 1);
let mut slopes = slopes_with_step(k, 1)?;
if n_attention_heads != k {
let remaining = n_attention_heads - k;
let slopes_rest = slopes_with_step(2 * k, 2)?
.i(..remaining)
.context(CalculateSlopesSnafu)?;
slopes = Tensor::cat(&[slopes, slopes_rest], 0).context(CalculateSlopesSnafu)?;
}
slopes.reshape((1, (), 1, 1)).context(CalculateSlopesSnafu)
}
fn calculate_biases(&self, seq_len: usize) -> Result<Tensor, AttentionLinearBiasesError> {
let mut distances = if self.is_causal {
Tensor::arange(1 - seq_len as i64, 1, &Device::Cpu).context(CalculateBiasesSnafu)?
} else {
Tensor::arange(0, seq_len as i64, &Device::Cpu)
.and_then(|xs| xs.broadcast_sub(&xs.reshape(((), 1))?))
.and_then(|xs| xs.abs())
.and_then(|xs| xs.broadcast_mul(&Tensor::new(-1i64, &Device::Cpu)?))
.and_then(|xs| xs.reshape((1, 1, seq_len, seq_len)))
.context(CalculateBiasesSnafu)?
};
if self.is_inverted {
distances = Tensor::new(seq_len as i64 - 1i64, &Device::Cpu)
.and_then(|xs| distances.broadcast_add(&xs))
.context(CalculateBiasesSnafu)?;
}
distances
.to_dtype(self.slopes.dtype())
.and_then(|xs| xs.broadcast_mul(&self.slopes))
.context(CalculateBiasesSnafu)
}
pub fn forward(&self, attention_scores: &Tensor) -> Result<Tensor, AttentionLinearBiasesError> {
let (_, _, _, key_len) = attention_scores.shape().dims4().context(ApplyBiasesSnafu)?;
let biases = self
.calculate_biases(key_len)?
.to_dtype(attention_scores.dtype())
.and_then(|xs| xs.to_device(attention_scores.device()))
.context(ApplyBiasesSnafu)?;
attention_scores.add(&biases).context(ApplyBiasesSnafu)
}
}
#[cfg(test)]
mod tests {
use candle_core::{DType, Device, Tensor};
use ndarray::array;
use super::AttentionLinearBiasesConfig;
use crate::util::tests::assert_tensor_eq;
#[test]
fn test_attention_linear_biases_slopes() {
let pow2_biases = AttentionLinearBiasesConfig::default()
.n_attention_heads(8)
.build()
.unwrap();
assert_tensor_eq!(
&pow2_biases.slopes,
array![0.5f32, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625,]
.into_shape((1, 8, 1, 1))
.unwrap(),
epsilon = 1e-4,
);
let non_pow2_biases = AttentionLinearBiasesConfig::default()
.n_attention_heads(12)
.build()
.unwrap();
assert_tensor_eq!(
&non_pow2_biases.slopes,
array![
0.5f32,
0.25,
0.125,
0.0625,
0.03125,
0.015625,
0.0078125,
0.00390625,
0.7071067811865476,
0.35355339059327384,
0.17677669529663692,
0.08838834764831849,
]
.into_shape((1, 12, 1, 1))
.unwrap(),
epsilon = 1e-4,
);
}
#[test]
fn test_attention_linear_biases_causal() {
let device = Device::Cpu;
let causal = AttentionLinearBiasesConfig::default()
.n_attention_heads(4)
.is_causal(true)
.build()
.unwrap();
assert_tensor_eq!(
&causal
.forward(&Tensor::zeros((1, 4, 1, 3), DType::F32, &device).unwrap())
.unwrap(),
array![
-0.5000f32,
-0.2500,
0.0000,
-0.1250,
-0.0625,
0.0000,
-0.03125,
-0.015625,
0.0000,
-0.0078125,
-0.00390625,
0.0000,
]
.into_shape((1, 4, 1, 3))
.unwrap(),
epsilon = 1e-4,
);
let inverted = AttentionLinearBiasesConfig::default()
.n_attention_heads(4)
.is_causal(true)
.is_inverted(true)
.build()
.unwrap();
assert_tensor_eq!(
&inverted
.forward(&Tensor::zeros((1, 4, 1, 3), DType::F32, &device).unwrap())
.unwrap(),
array![
0.0000f32, 0.2500, 0.5000, 0.0000, 0.0625, 0.1250, 0.0000, 0.015625, 0.03125,
0.0000, 0.00390625, 0.0078125,
]
.into_shape((1, 4, 1, 3))
.unwrap(),
epsilon = 1e-4,
);
}
#[test]
fn test_attention_linear_biases_non_causal() {
let device = Device::Cpu;
let non_causal = AttentionLinearBiasesConfig::default()
.n_attention_heads(4)
.build()
.unwrap();
assert_tensor_eq!(
&non_causal
.forward(&Tensor::zeros((1, 4, 3, 3), DType::F32, &device).unwrap())
.unwrap(),
array![
0.0000f32,
-0.2500,
-0.5000,
-0.2500,
0.0000,
-0.2500,
-0.5000,
-0.2500,
0.0000,
0.0000,
-0.0625,
-0.1250,
-0.0625,
0.0000,
-0.0625,
-0.1250,
-0.0625,
0.0000,
0.0000,
-0.015625,
-0.03125,
-0.015625,
0.0000,
-0.015625,
-0.03125,
-0.015625,
0.0000,
0.0000,
-0.00390625,
-0.0078125,
-0.00390625,
0.0000,
-0.00390625,
-0.0078125,
-0.00390625,
0.0000,
]
.into_shape((1, 4, 3, 3))
.unwrap(),
epsilon = 1e-4,
);
let inverted = AttentionLinearBiasesConfig::default()
.n_attention_heads(4)
.is_inverted(true)
.build()
.unwrap();
assert_tensor_eq!(
&inverted
.forward(&Tensor::zeros((1, 4, 3, 3), DType::F32, &device).unwrap())
.unwrap(),
array![
0.5000f32, 0.2500, 0.0000, 0.2500, 0.5000, 0.2500, 0.0000, 0.2500, 0.5000, 0.1250,
0.0625, 0.0000, 0.0625, 0.1250, 0.0625, 0.0000, 0.0625, 0.1250, 0.03125, 0.015625,
0.0000, 0.015625, 0.03125, 0.015625, 0.0000, 0.015625, 0.03125, 0.0078125,
0.00390625, 0.0000, 0.00390625, 0.0078125, 0.00390625, 0.0000, 0.00390625,
0.0078125,
]
.into_shape((1, 4, 3, 3))
.unwrap(),
epsilon = 1e-4,
);
}
}