use burn::prelude::*;
pub fn causal_mask<B: Backend>(
batch_size: usize,
seq_len: usize,
device: &B::Device,
) -> Tensor<B, 3> {
let mask = Tensor::<B, 2>::ones([seq_len, seq_len], device)
.triu(1)
.mul_scalar(-1e9);
mask.unsqueeze_dim::<3>(0).repeat_dim(0, batch_size)
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type TestBackend = NdArray;
#[test]
fn test_causal_mask_shape() {
let device = Default::default();
let mask = causal_mask::<TestBackend>(2, 8, &device);
assert_eq!(mask.dims(), [2, 8, 8]);
}
#[test]
fn test_causal_mask_values() {
let device = Default::default();
let mask = causal_mask::<TestBackend>(1, 4, &device);
let data: Vec<f32> = mask.reshape([16]).into_data().to_vec().unwrap();
assert_eq!(data[0], 0.0);
assert!(data[1] < -1e8);
assert_eq!(data[4], 0.0);
assert_eq!(data[5], 0.0);
assert!(data[6] < -1e8);
assert_eq!(data[12], 0.0);
assert_eq!(data[13], 0.0);
assert_eq!(data[14], 0.0);
assert_eq!(data[15], 0.0);
}
}