rnn/inference/
inference.rs1use crate::engine::{forward_dense_plan_big_kernel, required_batch_scratch_len, ForwardError};
2use crate::layers::LayerPlan;
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq)]
5pub enum InferenceError {
6 InvalidPlan,
7 ShapeMismatch,
8 BatchMismatch,
9 ScratchTooSmall,
10}
11
12pub fn softmax_stable(logits: &[f32], out: &mut [f32]) -> Result<(), InferenceError> {
13 if logits.is_empty() || out.len() != logits.len() {
14 return Err(InferenceError::ShapeMismatch);
15 }
16
17 let mut max_v = logits[0];
18 for value in logits.iter().skip(1) {
19 if *value > max_v {
20 max_v = *value;
21 }
22 }
23
24 let mut sum = 0.0f32;
25 for i in 0..logits.len() {
26 let e = crate::math::expf(logits[i] - max_v);
27 out[i] = e;
28 sum += e;
29 }
30
31 if !sum.is_finite() || sum <= 0.0 {
32 return Err(InferenceError::InvalidPlan);
33 }
34
35 let inv_sum = 1.0 / sum;
36 for value in out {
37 *value *= inv_sum;
38 }
39
40 Ok(())
41}
42
43pub fn forward_dense_batch(
44 plan: &LayerPlan<'_>,
45 input_batch: &[f32],
46 output_batch: &mut [f32],
47 batch_size: usize,
48 scratch_batch: &mut [f32],
49) -> Result<(), InferenceError> {
50 let input_size = plan.input_size().ok_or(InferenceError::InvalidPlan)?;
51 let output_size = plan.output_size().ok_or(InferenceError::InvalidPlan)?;
52
53 let expected_in = batch_size.checked_mul(input_size).ok_or(InferenceError::BatchMismatch)?;
54 let expected_out = batch_size.checked_mul(output_size).ok_or(InferenceError::BatchMismatch)?;
55
56 if input_batch.len() != expected_in || output_batch.len() != expected_out {
57 return Err(InferenceError::BatchMismatch);
58 }
59
60 let needed = required_batch_scratch_len(plan, batch_size).ok_or(InferenceError::ScratchTooSmall)?;
61 if scratch_batch.len() < needed {
62 return Err(InferenceError::ScratchTooSmall);
63 }
64
65 forward_dense_plan_big_kernel(plan, input_batch, output_batch, batch_size, scratch_batch)
66 .map_err(map_forward_error)?;
67
68 Ok(())
69}
70
71fn map_forward_error(err: ForwardError) -> InferenceError {
72 match err {
73 ForwardError::InvalidPlan => InferenceError::InvalidPlan,
74 ForwardError::ShapeMismatch => InferenceError::ShapeMismatch,
75 ForwardError::ScratchTooSmall => InferenceError::ScratchTooSmall,
76 }
77}