use crate::layers::{LayerError, LayerPlan, LayerPlanF64, LayerSpec};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ForwardError {
InvalidPlan,
ShapeMismatch,
ScratchTooSmall,
}
struct KernelParamsF64<'a> {
weights: &'a [f64],
biases: &'a [f64],
activation: crate::activations::ActivationKind,
}
pub fn forward_plan_f32(
plan: &LayerPlan<'_>,
input: &[f32],
output: &mut [f32],
scratch: &mut [f32],
) -> Result<(), ForwardError> {
forward_plan_f32_impl(plan, input, output, scratch)
}
pub fn forward_plan(
plan: &LayerPlan<'_>,
input: &[f32],
output: &mut [f32],
scratch: &mut [f32],
) -> Result<(), ForwardError> {
forward_plan_f32(plan, input, output, scratch)
}
fn forward_plan_f32_impl(
plan: &LayerPlan<'_>,
input: &[f32],
output: &mut [f32],
scratch: &mut [f32],
) -> Result<(), ForwardError> {
plan.validate().map_err(map_layer_error)?;
let in_size = plan.input_size().ok_or(ForwardError::InvalidPlan)?;
let out_size = plan.output_size().ok_or(ForwardError::InvalidPlan)?;
if input.len() != in_size || output.len() != out_size {
return Err(ForwardError::ShapeMismatch);
}
let needed = required_batch_scratch_len_f32(plan, 1).ok_or(ForwardError::ScratchTooSmall)?;
if scratch.len() < needed {
return Err(ForwardError::ScratchTooSmall);
}
forward_plan_kernel_f32(plan, input, output, 1, scratch)
}
pub fn forward_plan_f64(
plan: &LayerPlanF64<'_>,
input: &[f64],
output: &mut [f64],
scratch: &mut [f64],
) -> Result<(), ForwardError> {
plan.validate().map_err(map_layer_error)?;
let in_size = plan.input_size().ok_or(ForwardError::InvalidPlan)?;
let out_size = plan.output_size().ok_or(ForwardError::InvalidPlan)?;
if input.len() != in_size || output.len() != out_size {
return Err(ForwardError::ShapeMismatch);
}
let needed = required_batch_scratch_len_f64(plan, 1).ok_or(ForwardError::ScratchTooSmall)?;
if scratch.len() < needed {
return Err(ForwardError::ScratchTooSmall);
}
forward_plan_kernel_f64(plan, input, output, 1, scratch)
}
pub fn required_batch_scratch_len_f32(plan: &LayerPlan<'_>, batch_size: usize) -> Option<usize> {
let max_width = plan.max_width()?;
batch_size.checked_mul(max_width)?.checked_mul(2)
}
pub fn required_batch_scratch_len(plan: &LayerPlan<'_>, batch_size: usize) -> Option<usize> {
required_batch_scratch_len_f32(plan, batch_size)
}
pub fn required_batch_scratch_len_f64(plan: &LayerPlanF64<'_>, batch_size: usize) -> Option<usize> {
let max_width = plan.max_width()?;
batch_size.checked_mul(max_width)?.checked_mul(2)
}
pub fn forward_plan_kernel_f32(
plan: &LayerPlan<'_>,
input_batch: &[f32],
output_batch: &mut [f32],
batch_size: usize,
scratch: &mut [f32],
) -> Result<(), ForwardError> {
plan.validate().map_err(map_layer_error)?;
if batch_size == 0 {
return Err(ForwardError::ShapeMismatch);
}
let in_size = plan.input_size().ok_or(ForwardError::InvalidPlan)?;
let out_size = plan.output_size().ok_or(ForwardError::InvalidPlan)?;
let expected_in = batch_size
.checked_mul(in_size)
.ok_or(ForwardError::ShapeMismatch)?;
let expected_out = batch_size
.checked_mul(out_size)
.ok_or(ForwardError::ShapeMismatch)?;
if input_batch.len() != expected_in || output_batch.len() != expected_out {
return Err(ForwardError::ShapeMismatch);
}
let max_width = plan.max_width().ok_or(ForwardError::InvalidPlan)?;
let lane = batch_size
.checked_mul(max_width)
.ok_or(ForwardError::ScratchTooSmall)?;
let needed = lane.checked_mul(2).ok_or(ForwardError::ScratchTooSmall)?;
if scratch.len() < needed {
return Err(ForwardError::ScratchTooSmall);
}
let (buf_a, buf_b) = scratch.split_at_mut(lane);
for b in 0..batch_size {
let src_off = b * in_size;
let dst_off = b * max_width;
buf_a[dst_off..dst_off + in_size].copy_from_slice(&input_batch[src_off..src_off + in_size]);
}
let mut cur_len = in_size;
let mut use_a_as_src = true;
for layer in plan.layers {
match layer {
LayerSpec::Dense(d) => {
if cur_len != d.input_size {
return Err(ForwardError::ShapeMismatch);
}
let w_len = d.weight_len().ok_or(ForwardError::InvalidPlan)?;
let w = &plan.weights[d.weight_offset..d.weight_offset + w_len];
let b = &plan.biases[d.bias_offset..d.bias_offset + d.output_size];
let params = KernelParams {
weights: w,
biases: b,
activation: d.activation,
};
if use_a_as_src {
let src = &buf_a[..lane];
let dst = &mut buf_b[..lane];
forward_batch_kernel_dispatch(
src,
dst,
batch_size,
max_width,
cur_len,
d.output_size,
params,
);
} else {
let src = &buf_b[..lane];
let dst = &mut buf_a[..lane];
forward_batch_kernel_dispatch(
src,
dst,
batch_size,
max_width,
cur_len,
d.output_size,
params,
);
}
cur_len = d.output_size;
use_a_as_src = !use_a_as_src;
}
}
}
let final_src = if use_a_as_src {
&buf_a[..lane]
} else {
&buf_b[..lane]
};
for b in 0..batch_size {
let src_off = b * max_width;
let dst_off = b * out_size;
output_batch[dst_off..dst_off + out_size]
.copy_from_slice(&final_src[src_off..src_off + out_size]);
}
Ok(())
}
pub fn forward_plan_kernel(
plan: &LayerPlan<'_>,
input_batch: &[f32],
output_batch: &mut [f32],
batch_size: usize,
scratch: &mut [f32],
) -> Result<(), ForwardError> {
forward_plan_kernel_f32(plan, input_batch, output_batch, batch_size, scratch)
}
pub fn forward_plan_kernel_f64(
plan: &LayerPlanF64<'_>,
input_batch: &[f64],
output_batch: &mut [f64],
batch_size: usize,
scratch: &mut [f64],
) -> Result<(), ForwardError> {
plan.validate().map_err(map_layer_error)?;
if batch_size == 0 {
return Err(ForwardError::ShapeMismatch);
}
let in_size = plan.input_size().ok_or(ForwardError::InvalidPlan)?;
let out_size = plan.output_size().ok_or(ForwardError::InvalidPlan)?;
let expected_in = batch_size
.checked_mul(in_size)
.ok_or(ForwardError::ShapeMismatch)?;
let expected_out = batch_size
.checked_mul(out_size)
.ok_or(ForwardError::ShapeMismatch)?;
if input_batch.len() != expected_in || output_batch.len() != expected_out {
return Err(ForwardError::ShapeMismatch);
}
let max_width = plan.max_width().ok_or(ForwardError::InvalidPlan)?;
let lane = batch_size
.checked_mul(max_width)
.ok_or(ForwardError::ScratchTooSmall)?;
let needed = lane.checked_mul(2).ok_or(ForwardError::ScratchTooSmall)?;
if scratch.len() < needed {
return Err(ForwardError::ScratchTooSmall);
}
let (buf_a, buf_b) = scratch.split_at_mut(lane);
for b in 0..batch_size {
let src_off = b * in_size;
let dst_off = b * max_width;
buf_a[dst_off..dst_off + in_size].copy_from_slice(&input_batch[src_off..src_off + in_size]);
}
let mut cur_len = in_size;
let mut use_a_as_src = true;
for layer in plan.layers {
match layer {
LayerSpec::Dense(d) => {
if cur_len != d.input_size {
return Err(ForwardError::ShapeMismatch);
}
let w_len = d.weight_len().ok_or(ForwardError::InvalidPlan)?;
let w = &plan.weights[d.weight_offset..d.weight_offset + w_len];
let b = &plan.biases[d.bias_offset..d.bias_offset + d.output_size];
let params_f64 = KernelParamsF64 {
weights: w,
biases: b,
activation: d.activation,
};
if use_a_as_src {
let src = &buf_a[..lane];
let dst = &mut buf_b[..lane];
forward_batch_kernel_f64_dispatch(
src,
dst,
batch_size,
max_width,
cur_len,
d.output_size,
params_f64,
);
} else {
let src = &buf_b[..lane];
let dst = &mut buf_a[..lane];
forward_batch_kernel_f64_dispatch(
src,
dst,
batch_size,
max_width,
cur_len,
d.output_size,
params_f64,
);
}
cur_len = d.output_size;
use_a_as_src = !use_a_as_src;
}
}
}
let final_src = if use_a_as_src {
&buf_a[..lane]
} else {
&buf_b[..lane]
};
for b in 0..batch_size {
let src_off = b * max_width;
let dst_off = b * out_size;
output_batch[dst_off..dst_off + out_size]
.copy_from_slice(&final_src[src_off..src_off + out_size]);
}
Ok(())
}
struct KernelParams<'a> {
weights: &'a [f32],
biases: &'a [f32],
activation: crate::activations::ActivationKind,
}
fn dot_unrolled_f32(weights: &[f32], input: &[f32]) -> f32 {
let len = weights.len();
let mut i = 0usize;
let mut acc0 = 0.0f32;
let mut acc1 = 0.0f32;
let mut acc2 = 0.0f32;
let mut acc3 = 0.0f32;
while i + 4 <= len {
acc0 += weights[i] * input[i];
acc1 += weights[i + 1] * input[i + 1];
acc2 += weights[i + 2] * input[i + 2];
acc3 += weights[i + 3] * input[i + 3];
i += 4;
}
let mut acc = (acc0 + acc1) + (acc2 + acc3);
while i < len {
acc += weights[i] * input[i];
i += 1;
}
acc
}
fn dot_unrolled_f64(weights: &[f64], input: &[f64]) -> f64 {
let len = weights.len();
let mut i = 0usize;
let mut acc0 = 0.0f64;
let mut acc1 = 0.0f64;
let mut acc2 = 0.0f64;
let mut acc3 = 0.0f64;
while i + 4 <= len {
acc0 += weights[i] * input[i];
acc1 += weights[i + 1] * input[i + 1];
acc2 += weights[i + 2] * input[i + 2];
acc3 += weights[i + 3] * input[i + 3];
i += 4;
}
let mut acc = (acc0 + acc1) + (acc2 + acc3);
while i < len {
acc += weights[i] * input[i];
i += 1;
}
acc
}
fn forward_batch_kernel_cpu(
src: &[f32],
dst: &mut [f32],
batch_size: usize,
stride: usize,
in_size: usize,
out_size: usize,
params: KernelParams,
) {
let KernelParams {
weights,
biases,
activation,
} = params;
let mut o = 0usize;
while o < out_size {
let row_off = o * in_size;
let weight_row = &weights[row_off..row_off + in_size];
let mut b = 0usize;
while b < batch_size {
let base = b * stride;
let input_row = &src[base..base + in_size];
let acc = biases[o] + dot_unrolled_f32(weight_row, input_row);
dst[base + o] = activation.apply(acc);
b += 1;
}
o += 1;
}
}
fn forward_batch_kernel_dispatch(
src: &[f32],
dst: &mut [f32],
batch_size: usize,
stride: usize,
in_size: usize,
out_size: usize,
params: KernelParams,
) {
let KernelParams {
weights,
biases,
activation,
} = params;
if super::try_invoke_gpu_kernel_f32(super::KernelInvokeF32 {
src,
dst,
batch_size,
stride,
in_size,
out_size,
weights,
biases,
activation,
}) {
return;
}
forward_batch_kernel_cpu(
src,
dst,
batch_size,
stride,
in_size,
out_size,
KernelParams {
weights,
biases,
activation,
},
);
}
fn forward_batch_kernel_f64_cpu(
src: &[f64],
dst: &mut [f64],
batch_size: usize,
stride: usize,
in_size: usize,
out_size: usize,
params: KernelParamsF64,
) {
let KernelParamsF64 {
weights,
biases,
activation,
} = params;
let mut o = 0usize;
while o < out_size {
let row_off = o * in_size;
let weight_row = &weights[row_off..row_off + in_size];
let mut b = 0usize;
while b < batch_size {
let base = b * stride;
let input_row = &src[base..base + in_size];
let acc = biases[o] + dot_unrolled_f64(weight_row, input_row);
dst[base + o] = activation.apply_f64(acc);
b += 1;
}
o += 1;
}
}
fn forward_batch_kernel_f64_dispatch(
src: &[f64],
dst: &mut [f64],
batch_size: usize,
stride: usize,
in_size: usize,
out_size: usize,
params: KernelParamsF64,
) {
let KernelParamsF64 {
weights,
biases,
activation,
} = params;
if super::try_invoke_gpu_kernel_f64(super::KernelInvokeF64 {
src,
dst,
batch_size,
stride,
in_size,
out_size,
weights,
biases,
activation,
}) {
return;
}
forward_batch_kernel_f64_cpu(
src,
dst,
batch_size,
stride,
in_size,
out_size,
KernelParamsF64 {
weights,
biases,
activation,
},
);
}
fn map_layer_error(e: LayerError) -> ForwardError {
match e {
LayerError::EmptyPlan | LayerError::InvalidRange | LayerError::CountMismatch => {
ForwardError::InvalidPlan
}
LayerError::InvalidShape | LayerError::IncompatibleChain => ForwardError::ShapeMismatch,
LayerError::BufferTooSmall => ForwardError::ScratchTooSmall,
}
}