#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BeamStdError {
ShapeMismatch,
Empty,
}
impl From<native_neural_network::beam_search::BeamError> for BeamStdError {
fn from(e: native_neural_network::beam_search::BeamError) -> Self {
match e {
native_neural_network::beam_search::BeamError::ShapeMismatch => {
BeamStdError::ShapeMismatch
}
native_neural_network::beam_search::BeamError::Empty => BeamStdError::Empty,
}
}
}
pub fn select_top_beams(
scores: &[f32],
beam_width: usize,
out_indices: &mut [usize],
) -> Result<usize, BeamStdError> {
native_neural_network::beam_search::select_top_beams(scores, beam_width, out_indices)
.map_err(|e| e.into())
}
pub fn prune_by_threshold(scores: &[f32], threshold: f32, out_indices: &mut [usize]) -> usize {
native_neural_network::beam_search::prune_by_threshold(scores, threshold, out_indices)
}
pub fn log_softmax_in_place(scores: &mut [f32]) -> bool {
native_neural_network::beam_search::log_softmax_in_place_f32(scores)
}