1use crate::layers::{LayerError, LayerPlan, LayerSpec};
2
3#[derive(Clone, Copy, Debug, PartialEq, Eq)]
4pub enum ForwardError {
5 InvalidPlan,
6 ShapeMismatch,
7 ScratchTooSmall,
8}
9
10pub fn forward_dense_plan(plan: &LayerPlan<'_>, input: &[f32], output: &mut [f32], scratch: &mut [f32]) -> Result<(), ForwardError> {
11 plan.validate().map_err(map_layer_error)?;
12
13 let in_size = plan.input_size().ok_or(ForwardError::InvalidPlan)?;
14 let out_size = plan.output_size().ok_or(ForwardError::InvalidPlan)?;
15 if input.len() != in_size || output.len() != out_size {
16 return Err(ForwardError::ShapeMismatch);
17 }
18
19 let needed = required_batch_scratch_len(plan, 1).ok_or(ForwardError::ScratchTooSmall)?;
20 if scratch.len() < needed {
21 return Err(ForwardError::ScratchTooSmall);
22 }
23
24 forward_dense_plan_big_kernel(plan, input, output, 1, scratch)
25}
26
27pub fn required_batch_scratch_len(plan: &LayerPlan<'_>, batch_size: usize) -> Option<usize> {
28 let max_width = plan.max_width()?;
29 batch_size.checked_mul(max_width)?.checked_mul(2)
30}
31
32pub fn forward_dense_plan_big_kernel(
33 plan: &LayerPlan<'_>,
34 input_batch: &[f32],
35 output_batch: &mut [f32],
36 batch_size: usize,
37 scratch: &mut [f32],
38) -> Result<(), ForwardError> {
39 plan.validate().map_err(map_layer_error)?;
40 if batch_size == 0 {
41 return Err(ForwardError::ShapeMismatch);
42 }
43
44 let in_size = plan.input_size().ok_or(ForwardError::InvalidPlan)?;
45 let out_size = plan.output_size().ok_or(ForwardError::InvalidPlan)?;
46
47 let expected_in = batch_size.checked_mul(in_size).ok_or(ForwardError::ShapeMismatch)?;
48 let expected_out = batch_size.checked_mul(out_size).ok_or(ForwardError::ShapeMismatch)?;
49 if input_batch.len() != expected_in || output_batch.len() != expected_out {
50 return Err(ForwardError::ShapeMismatch);
51 }
52
53 let max_width = plan.max_width().ok_or(ForwardError::InvalidPlan)?;
54 let lane = batch_size.checked_mul(max_width).ok_or(ForwardError::ScratchTooSmall)?;
55 let needed = lane.checked_mul(2).ok_or(ForwardError::ScratchTooSmall)?;
56 if scratch.len() < needed {
57 return Err(ForwardError::ScratchTooSmall);
58 }
59
60 let (buf_a, buf_b) = scratch.split_at_mut(lane);
61
62 for b in 0..batch_size {
63 let src_off = b * in_size;
64 let dst_off = b * max_width;
65 buf_a[dst_off..dst_off + in_size].copy_from_slice(&input_batch[src_off..src_off + in_size]);
66 }
67
68 let mut cur_len = in_size;
69 let mut use_a_as_src = true;
70
71 for layer in plan.layers {
72 match layer {
73 LayerSpec::Dense(d) => {
74 if cur_len != d.input_size {
75 return Err(ForwardError::ShapeMismatch);
76 }
77
78 let w_len = d.weight_len().ok_or(ForwardError::InvalidPlan)?;
79 let w = &plan.weights[d.weight_offset..d.weight_offset + w_len];
80 let b = &plan.biases[d.bias_offset..d.bias_offset + d.output_size];
81
82 if use_a_as_src {
83 let src = &buf_a[..lane];
84 let dst = &mut buf_b[..lane];
85 dense_forward_batch_kernel(
86 src,
87 dst,
88 batch_size,
89 max_width,
90 cur_len,
91 d.output_size,
92 w,
93 b,
94 d.activation,
95 );
96 } else {
97 let src = &buf_b[..lane];
98 let dst = &mut buf_a[..lane];
99 dense_forward_batch_kernel(
100 src,
101 dst,
102 batch_size,
103 max_width,
104 cur_len,
105 d.output_size,
106 w,
107 b,
108 d.activation,
109 );
110 }
111
112 cur_len = d.output_size;
113 use_a_as_src = !use_a_as_src;
114 }
115 }
116 }
117
118 let final_src = if use_a_as_src { &buf_a[..lane] } else { &buf_b[..lane] };
119 for b in 0..batch_size {
120 let src_off = b * max_width;
121 let dst_off = b * out_size;
122 output_batch[dst_off..dst_off + out_size].copy_from_slice(&final_src[src_off..src_off + out_size]);
123 }
124 Ok(())
125}
126
127fn dense_forward_batch_kernel(
128 src: &[f32],
129 dst: &mut [f32],
130 batch_size: usize,
131 stride: usize,
132 in_size: usize,
133 out_size: usize,
134 weights: &[f32],
135 biases: &[f32],
136 activation: crate::activations::ActivationKind,
137) {
138 let mut o = 0usize;
139 while o < out_size {
140 let row_off = o * in_size;
141 let mut b = 0usize;
142 while b < batch_size {
143 let base = b * stride;
144 let mut acc = biases[o];
145 let mut i = 0usize;
146 while i < in_size {
147 acc += weights[row_off + i] * src[base + i];
148 i += 1;
149 }
150 dst[base + o] = activation.apply(acc);
151 b += 1;
152 }
153 o += 1;
154 }
155}
156
157fn map_layer_error(_e: LayerError) -> ForwardError {
158 ForwardError::InvalidPlan
159}