#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BatchStdError {
ShapeMismatch,
Empty,
}
impl From<native_neural_network::batching::BatchError> for BatchStdError {
fn from(e: native_neural_network::batching::BatchError) -> Self {
match e {
native_neural_network::batching::BatchError::ShapeMismatch => {
BatchStdError::ShapeMismatch
}
native_neural_network::batching::BatchError::Empty => BatchStdError::Empty,
}
}
}
pub fn pad_sequences_u32(
sequences: &[&[u32]],
pad_id: u32,
out: &mut [u32],
max_len: usize,
) -> Result<(), BatchStdError> {
native_neural_network::batching::pad_sequences_u32(sequences, pad_id, out, max_len)
.map_err(|e| e.into())
}
pub fn make_padding_mask(
tokens: &[u32],
pad_id: u32,
out_mask: &mut [u8],
) -> Result<(), BatchStdError> {
native_neural_network::batching::make_padding_mask(tokens, pad_id, out_mask)
.map_err(|e| e.into())
}
pub fn for_each_token_row(tokens: &[u32], row_width: usize, f: impl FnMut(usize, &[u32])) -> bool {
native_neural_network::batching::for_each_token_row(tokens, row_width, f)
}
pub fn count_non_pad(tokens: &[u32], pad_id: u32) -> usize {
native_neural_network::batching::count_non_pad(tokens, pad_id)
}
pub fn max_sequence_len(sequences: &[&[u32]]) -> Option<usize> {
native_neural_network::batching::max_sequence_len(sequences)
}
pub fn sequence_lengths(sequences: &[&[u32]], out: &mut [usize]) -> bool {
native_neural_network::batching::sequence_lengths(sequences, out)
}