use candle_core::Tensor;
use snafu::{ensure, ResultExt, Snafu};
#[derive(Debug, Snafu)]
pub enum AttentionMaskError {
#[snafu(display("Cannot concatenate masks"))]
ConcatMasks { source: candle_core::Error },
#[snafu(display("Attention mask must be 2D, was {}D", n_dims))]
InvalidDims { n_dims: usize },
}
#[derive(Clone, Debug)]
pub struct AttentionMask {
pub(crate) bool_mask: Tensor,
}
impl AttentionMask {
pub fn new(bool_mask: Tensor) -> Result<Self, AttentionMaskError> {
let n_dims = bool_mask.dims().len();
ensure!(n_dims == 2, InvalidDimsSnafu { n_dims });
Ok(AttentionMask { bool_mask })
}
pub fn bool_mask(&self) -> &Tensor {
&self.bool_mask
}
pub fn extend(&self, other: &Self) -> Result<Self, AttentionMaskError> {
Ok(AttentionMask {
bool_mask: Tensor::cat(&[&self.bool_mask, &other.bool_mask], 1)
.context(ConcatMasksSnafu)?,
})
}
}