use crate::layers::{LayerError, LayerPlan, LayerSpec};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ForwardError {
InvalidPlan,
ShapeMismatch,
ScratchTooSmall,
}
pub fn forward_dense_plan(plan: &LayerPlan<'_>, input: &[f32], output: &mut [f32], scratch: &mut [f32]) -> Result<(), ForwardError> {
plan.validate().map_err(map_layer_error)?;
let in_size = plan.input_size().ok_or(ForwardError::InvalidPlan)?;
let out_size = plan.output_size().ok_or(ForwardError::InvalidPlan)?;
if input.len() != in_size || output.len() != out_size {
return Err(ForwardError::ShapeMismatch);
}
let needed = required_batch_scratch_len(plan, 1).ok_or(ForwardError::ScratchTooSmall)?;
if scratch.len() < needed {
return Err(ForwardError::ScratchTooSmall);
}
forward_dense_plan_big_kernel(plan, input, output, 1, scratch)
}
pub fn required_batch_scratch_len(plan: &LayerPlan<'_>, batch_size: usize) -> Option<usize> {
let max_width = plan.max_width()?;
batch_size.checked_mul(max_width)?.checked_mul(2)
}
pub fn forward_dense_plan_big_kernel(
plan: &LayerPlan<'_>,
input_batch: &[f32],
output_batch: &mut [f32],
batch_size: usize,
scratch: &mut [f32],
) -> Result<(), ForwardError> {
plan.validate().map_err(map_layer_error)?;
if batch_size == 0 {
return Err(ForwardError::ShapeMismatch);
}
let in_size = plan.input_size().ok_or(ForwardError::InvalidPlan)?;
let out_size = plan.output_size().ok_or(ForwardError::InvalidPlan)?;
let expected_in = batch_size.checked_mul(in_size).ok_or(ForwardError::ShapeMismatch)?;
let expected_out = batch_size.checked_mul(out_size).ok_or(ForwardError::ShapeMismatch)?;
if input_batch.len() != expected_in || output_batch.len() != expected_out {
return Err(ForwardError::ShapeMismatch);
}
let max_width = plan.max_width().ok_or(ForwardError::InvalidPlan)?;
let lane = batch_size.checked_mul(max_width).ok_or(ForwardError::ScratchTooSmall)?;
let needed = lane.checked_mul(2).ok_or(ForwardError::ScratchTooSmall)?;
if scratch.len() < needed {
return Err(ForwardError::ScratchTooSmall);
}
let (buf_a, buf_b) = scratch.split_at_mut(lane);
for b in 0..batch_size {
let src_off = b * in_size;
let dst_off = b * max_width;
buf_a[dst_off..dst_off + in_size].copy_from_slice(&input_batch[src_off..src_off + in_size]);
}
let mut cur_len = in_size;
let mut use_a_as_src = true;
for layer in plan.layers {
match layer {
LayerSpec::Dense(d) => {
if cur_len != d.input_size {
return Err(ForwardError::ShapeMismatch);
}
let w_len = d.weight_len().ok_or(ForwardError::InvalidPlan)?;
let w = &plan.weights[d.weight_offset..d.weight_offset + w_len];
let b = &plan.biases[d.bias_offset..d.bias_offset + d.output_size];
let params = DenseKernelParams { weights: w, biases: b, activation: d.activation };
if use_a_as_src {
let src = &buf_a[..lane];
let dst = &mut buf_b[..lane];
dense_forward_batch_kernel(src, dst, batch_size, max_width, cur_len, d.output_size, params);
} else {
let src = &buf_b[..lane];
let dst = &mut buf_a[..lane];
dense_forward_batch_kernel(src, dst, batch_size, max_width, cur_len, d.output_size, params);
}
cur_len = d.output_size;
use_a_as_src = !use_a_as_src;
}
}
}
let final_src = if use_a_as_src { &buf_a[..lane] } else { &buf_b[..lane] };
for b in 0..batch_size {
let src_off = b * max_width;
let dst_off = b * out_size;
output_batch[dst_off..dst_off + out_size].copy_from_slice(&final_src[src_off..src_off + out_size]);
}
Ok(())
}
struct DenseKernelParams<'a> {
weights: &'a [f32],
biases: &'a [f32],
activation: crate::activations::ActivationKind,
}
fn dense_forward_batch_kernel(
src: &[f32],
dst: &mut [f32],
batch_size: usize,
stride: usize,
in_size: usize,
out_size: usize,
params: DenseKernelParams,
) {
let DenseKernelParams { weights, biases, activation } = params;
let mut o = 0usize;
while o < out_size {
let row_off = o * in_size;
let mut b = 0usize;
while b < batch_size {
let base = b * stride;
let mut acc = biases[o];
let mut i = 0usize;
while i < in_size {
acc += weights[row_off + i] * src[base + i];
i += 1;
}
dst[base + o] = activation.apply(acc);
b += 1;
}
o += 1;
}
}
fn map_layer_error(e: LayerError) -> ForwardError {
match e {
LayerError::EmptyPlan | LayerError::InvalidRange | LayerError::CountMismatch => ForwardError::InvalidPlan,
LayerError::InvalidShape | LayerError::IncompatibleChain => ForwardError::ShapeMismatch,
LayerError::BufferTooSmall => ForwardError::ScratchTooSmall,
}
}