use crate::error::Result;
use crate::tensor::Tensor;
pub fn causal_mask(seq_len: usize) -> Tensor {
let mut data = vec![0.0; seq_len * seq_len];
for i in 0..seq_len {
for j in (i + 1)..seq_len {
data[i * seq_len + j] = -1e9;
}
}
Tensor::new(data, vec![seq_len, seq_len]).unwrap_or_else(|_| Tensor::zeros(&[seq_len, seq_len]))
}
pub fn apply_mask(scores: &Tensor, mask: &Tensor) -> Result<Tensor> {
scores.add(mask)
}
pub fn padding_mask(lengths: &[usize], max_len: usize) -> Tensor {
let batch_size = lengths.len();
let mut data = vec![0.0; batch_size * max_len];
for (b, &len) in lengths.iter().enumerate() {
for j in len..max_len {
data[b * max_len + j] = -1e9;
}
}
Tensor::new(data, vec![batch_size, max_len])
.unwrap_or_else(|_| Tensor::zeros(&[batch_size, max_len]))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_causal_mask_shape() {
let mask = causal_mask(4);
assert_eq!(mask.shape(), &[4, 4]);
}
#[test]
fn test_causal_mask_structure() {
let mask = causal_mask(3);
assert!((mask.get2d(0, 0).unwrap()).abs() < 1e-10);
assert!(mask.get2d(0, 1).unwrap() < -1e8);
assert!((mask.get2d(2, 0).unwrap()).abs() < 1e-10);
assert!((mask.get2d(2, 2).unwrap()).abs() < 1e-10);
}
#[test]
fn test_padding_mask() {
let mask = padding_mask(&[2, 3], 4);
assert_eq!(mask.shape(), &[2, 4]);
assert!((mask.get2d(0, 0).unwrap()).abs() < 1e-10);
assert!((mask.get2d(0, 1).unwrap()).abs() < 1e-10);
assert!(mask.get2d(0, 2).unwrap() < -1e8);
}
}