Skip to main content

rnn/inference/
inference.rs

1use crate::engine::{forward_dense_plan_big_kernel, required_batch_scratch_len, ForwardError};
2use crate::layers::LayerPlan;
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq)]
5pub enum InferenceError {
6    InvalidPlan,
7    ShapeMismatch,
8    BatchMismatch,
9    ScratchTooSmall,
10}
11
12pub fn softmax_stable(logits: &[f32], out: &mut [f32]) -> Result<(), InferenceError> {
13    if logits.is_empty() || out.len() != logits.len() {
14        return Err(InferenceError::ShapeMismatch);
15    }
16
17    let mut max_v = logits[0];
18    for value in logits.iter().skip(1) {
19        if *value > max_v {
20            max_v = *value;
21        }
22    }
23
24    let mut sum = 0.0f32;
25    for i in 0..logits.len() {
26        let e = crate::math::expf(logits[i] - max_v);
27        out[i] = e;
28        sum += e;
29    }
30
31    if !sum.is_finite() || sum <= 0.0 {
32        return Err(InferenceError::InvalidPlan);
33    }
34
35    let inv_sum = 1.0 / sum;
36    for value in out {
37        *value *= inv_sum;
38    }
39
40    Ok(())
41}
42
43pub fn forward_dense_batch(
44    plan: &LayerPlan<'_>,
45    input_batch: &[f32],
46    output_batch: &mut [f32],
47    batch_size: usize,
48    scratch_batch: &mut [f32],
49) -> Result<(), InferenceError> {
50    let input_size = plan.input_size().ok_or(InferenceError::InvalidPlan)?;
51    let output_size = plan.output_size().ok_or(InferenceError::InvalidPlan)?;
52
53    let expected_in = batch_size.checked_mul(input_size).ok_or(InferenceError::BatchMismatch)?;
54    let expected_out = batch_size.checked_mul(output_size).ok_or(InferenceError::BatchMismatch)?;
55
56    if input_batch.len() != expected_in || output_batch.len() != expected_out {
57        return Err(InferenceError::BatchMismatch);
58    }
59
60    let needed = required_batch_scratch_len(plan, batch_size).ok_or(InferenceError::ScratchTooSmall)?;
61    if scratch_batch.len() < needed {
62        return Err(InferenceError::ScratchTooSmall);
63    }
64
65    forward_dense_plan_big_kernel(plan, input_batch, output_batch, batch_size, scratch_batch)
66        .map_err(map_forward_error)?;
67
68    Ok(())
69}
70
71fn map_forward_error(err: ForwardError) -> InferenceError {
72    match err {
73        ForwardError::InvalidPlan => InferenceError::InvalidPlan,
74        ForwardError::ShapeMismatch => InferenceError::ShapeMismatch,
75        ForwardError::ScratchTooSmall => InferenceError::ScratchTooSmall,
76    }
77}