#[derive(Debug)]
pub enum SamplingStdError {
Empty,
ShapeMismatch,
InvalidParameter,
}
impl From<native_neural_network::sampling::SamplingError> for SamplingStdError {
fn from(e: native_neural_network::sampling::SamplingError) -> Self {
match e {
native_neural_network::sampling::SamplingError::Empty => SamplingStdError::Empty,
native_neural_network::sampling::SamplingError::ShapeMismatch => {
SamplingStdError::ShapeMismatch
}
native_neural_network::sampling::SamplingError::InvalidParameter => {
SamplingStdError::InvalidParameter
}
}
}
}
pub fn softmax_temperature(
logits: &[f32],
temperature: f32,
out: &mut [f32],
) -> Result<(), SamplingStdError> {
native_neural_network::sampling::softmax_temperature_f32(logits, temperature, out)
.map_err(|e| e.into())
}
pub fn argmax_sample(probabilities: &[f32]) -> Result<usize, SamplingStdError> {
native_neural_network::sampling::argmax_sample_f32(probabilities).map_err(|e| e.into())
}
pub fn sample_from_cumulative(
probabilities: &[f32],
threshold: f32,
) -> Result<usize, SamplingStdError> {
native_neural_network::sampling::sample_from_cumulative_f32(probabilities, threshold)
.map_err(|e| e.into())
}
pub fn top_k_mask(logits: &mut [f32], k: usize, mask_value: f32) -> Result<(), SamplingStdError> {
native_neural_network::sampling::top_k_mask_f32(logits, k, mask_value).map_err(|e| e.into())
}
pub fn top_p_cutoff(probabilities: &[f32], p: f32) -> Result<usize, SamplingStdError> {
native_neural_network::sampling::top_p_cutoff_f32(probabilities, p).map_err(|e| e.into())
}
impl core::fmt::Display for SamplingStdError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "SamplingStdError::{:?}", self)
}
}
impl std::error::Error for SamplingStdError {}