use crate::engine::{forward_dense_plan_big_kernel, required_batch_scratch_len, ForwardError};
use crate::layers::LayerPlan;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum InferenceError {
InvalidPlan,
ShapeMismatch,
BatchMismatch,
ScratchTooSmall,
}
pub fn softmax_stable(logits: &[f32], out: &mut [f32]) -> Result<(), InferenceError> {
if logits.is_empty() || out.len() != logits.len() {
return Err(InferenceError::ShapeMismatch);
}
let mut max_v = logits[0];
for value in logits.iter().skip(1) {
if *value > max_v {
max_v = *value;
}
}
let mut sum = 0.0f32;
for i in 0..logits.len() {
let e = crate::math::expf(logits[i] - max_v);
out[i] = e;
sum += e;
}
if !sum.is_finite() || sum <= 0.0 {
return Err(InferenceError::InvalidPlan);
}
let inv_sum = 1.0 / sum;
for value in out {
*value *= inv_sum;
}
Ok(())
}
pub fn forward_dense_batch(
plan: &LayerPlan<'_>,
input_batch: &[f32],
output_batch: &mut [f32],
batch_size: usize,
scratch_batch: &mut [f32],
) -> Result<(), InferenceError> {
let input_size = plan.input_size().ok_or(InferenceError::InvalidPlan)?;
let output_size = plan.output_size().ok_or(InferenceError::InvalidPlan)?;
let expected_in = batch_size.checked_mul(input_size).ok_or(InferenceError::BatchMismatch)?;
let expected_out = batch_size.checked_mul(output_size).ok_or(InferenceError::BatchMismatch)?;
if input_batch.len() != expected_in || output_batch.len() != expected_out {
return Err(InferenceError::BatchMismatch);
}
let needed = required_batch_scratch_len(plan, batch_size).ok_or(InferenceError::ScratchTooSmall)?;
if scratch_batch.len() < needed {
return Err(InferenceError::ScratchTooSmall);
}
forward_dense_plan_big_kernel(plan, input_batch, output_batch, batch_size, scratch_batch)
.map_err(map_forward_error)?;
Ok(())
}
fn map_forward_error(err: ForwardError) -> InferenceError {
match err {
ForwardError::InvalidPlan => InferenceError::InvalidPlan,
ForwardError::ShapeMismatch => InferenceError::ShapeMismatch,
ForwardError::ScratchTooSmall => InferenceError::ScratchTooSmall,
}
}