pub use native_neural_network::attention::AttentionMask;
pub use native_neural_network::attention::AttentionShape;
#[derive(Debug)]
pub enum AttentionStdError {
ShapeMismatch,
BufferTooSmall,
InvalidDim,
}
impl From<native_neural_network::attention::AttentionError> for AttentionStdError {
fn from(e: native_neural_network::attention::AttentionError) -> Self {
match e {
native_neural_network::attention::AttentionError::ShapeMismatch => {
AttentionStdError::ShapeMismatch
}
native_neural_network::attention::AttentionError::BufferTooSmall => {
AttentionStdError::BufferTooSmall
}
native_neural_network::attention::AttentionError::InvalidDim => {
AttentionStdError::InvalidDim
}
}
}
}
pub fn scaled_dot_product_attention(
q: &[f32],
k: &[f32],
v: &[f32],
shape: AttentionShape,
out: &mut [f32],
scratch_scores: &mut [f32],
mask: AttentionMask,
) -> Result<(), AttentionStdError> {
native_neural_network::attention::scaled_dot_product_attention_f32(
q,
k,
v,
shape,
out,
scratch_scores,
mask,
)
.map_err(|e| e.into())
}
impl core::fmt::Display for AttentionStdError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "AttentionStdError::{:?}", self)
}
}
impl std::error::Error for AttentionStdError {}