Skip to main content

rnn/engine/
engine.rs

1use crate::layers::{LayerError, LayerPlan, LayerSpec};
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
4pub enum ForwardError {
5    InvalidPlan,
6    ShapeMismatch,
7    ScratchTooSmall,
8}
9
10pub fn forward_dense_plan(plan: &LayerPlan<'_>, input: &[f32], output: &mut [f32], scratch: &mut [f32]) -> Result<(), ForwardError> {
11    plan.validate().map_err(map_layer_error)?;
12
13    let in_size = plan.input_size().ok_or(ForwardError::InvalidPlan)?;
14    let out_size = plan.output_size().ok_or(ForwardError::InvalidPlan)?;
15    if input.len() != in_size || output.len() != out_size {
16        return Err(ForwardError::ShapeMismatch);
17    }
18
19    let needed = required_batch_scratch_len(plan, 1).ok_or(ForwardError::ScratchTooSmall)?;
20    if scratch.len() < needed {
21        return Err(ForwardError::ScratchTooSmall);
22    }
23
24    forward_dense_plan_big_kernel(plan, input, output, 1, scratch)
25}
26
27pub fn required_batch_scratch_len(plan: &LayerPlan<'_>, batch_size: usize) -> Option<usize> {
28    let max_width = plan.max_width()?;
29    batch_size.checked_mul(max_width)?.checked_mul(2)
30}
31
32pub fn forward_dense_plan_big_kernel(
33    plan: &LayerPlan<'_>,
34    input_batch: &[f32],
35    output_batch: &mut [f32],
36    batch_size: usize,
37    scratch: &mut [f32],
38) -> Result<(), ForwardError> {
39    plan.validate().map_err(map_layer_error)?;
40    if batch_size == 0 {
41        return Err(ForwardError::ShapeMismatch);
42    }
43
44    let in_size = plan.input_size().ok_or(ForwardError::InvalidPlan)?;
45    let out_size = plan.output_size().ok_or(ForwardError::InvalidPlan)?;
46
47    let expected_in = batch_size.checked_mul(in_size).ok_or(ForwardError::ShapeMismatch)?;
48    let expected_out = batch_size.checked_mul(out_size).ok_or(ForwardError::ShapeMismatch)?;
49    if input_batch.len() != expected_in || output_batch.len() != expected_out {
50        return Err(ForwardError::ShapeMismatch);
51    }
52
53    let max_width = plan.max_width().ok_or(ForwardError::InvalidPlan)?;
54    let lane = batch_size.checked_mul(max_width).ok_or(ForwardError::ScratchTooSmall)?;
55    let needed = lane.checked_mul(2).ok_or(ForwardError::ScratchTooSmall)?;
56    if scratch.len() < needed {
57        return Err(ForwardError::ScratchTooSmall);
58    }
59
60    let (buf_a, buf_b) = scratch.split_at_mut(lane);
61
62    for b in 0..batch_size {
63        let src_off = b * in_size;
64        let dst_off = b * max_width;
65        buf_a[dst_off..dst_off + in_size].copy_from_slice(&input_batch[src_off..src_off + in_size]);
66    }
67
68    let mut cur_len = in_size;
69    let mut use_a_as_src = true;
70
71    for layer in plan.layers {
72        match layer {
73            LayerSpec::Dense(d) => {
74                if cur_len != d.input_size {
75                    return Err(ForwardError::ShapeMismatch);
76                }
77
78                let w_len = d.weight_len().ok_or(ForwardError::InvalidPlan)?;
79                let w = &plan.weights[d.weight_offset..d.weight_offset + w_len];
80                let b = &plan.biases[d.bias_offset..d.bias_offset + d.output_size];
81
82                if use_a_as_src {
83                    let src = &buf_a[..lane];
84                    let dst = &mut buf_b[..lane];
85                    dense_forward_batch_kernel(
86                        src,
87                        dst,
88                        batch_size,
89                        max_width,
90                        cur_len,
91                        d.output_size,
92                        w,
93                        b,
94                        d.activation,
95                    );
96                } else {
97                    let src = &buf_b[..lane];
98                    let dst = &mut buf_a[..lane];
99                    dense_forward_batch_kernel(
100                        src,
101                        dst,
102                        batch_size,
103                        max_width,
104                        cur_len,
105                        d.output_size,
106                        w,
107                        b,
108                        d.activation,
109                    );
110                }
111
112                cur_len = d.output_size;
113                use_a_as_src = !use_a_as_src;
114            }
115        }
116    }
117
118    let final_src = if use_a_as_src { &buf_a[..lane] } else { &buf_b[..lane] };
119    for b in 0..batch_size {
120        let src_off = b * max_width;
121        let dst_off = b * out_size;
122        output_batch[dst_off..dst_off + out_size].copy_from_slice(&final_src[src_off..src_off + out_size]);
123    }
124    Ok(())
125}
126
127fn dense_forward_batch_kernel(
128    src: &[f32],
129    dst: &mut [f32],
130    batch_size: usize,
131    stride: usize,
132    in_size: usize,
133    out_size: usize,
134    weights: &[f32],
135    biases: &[f32],
136    activation: crate::activations::ActivationKind,
137) {
138    let mut o = 0usize;
139    while o < out_size {
140        let row_off = o * in_size;
141        let mut b = 0usize;
142        while b < batch_size {
143            let base = b * stride;
144            let mut acc = biases[o];
145            let mut i = 0usize;
146            while i < in_size {
147                acc += weights[row_off + i] * src[base + i];
148                i += 1;
149            }
150            dst[base + o] = activation.apply(acc);
151            b += 1;
152        }
153        o += 1;
154    }
155}
156
157fn map_layer_error(_e: LayerError) -> ForwardError {
158    ForwardError::InvalidPlan
159}