use crate::error::{Result, TensorError};
pub fn fused_linear_relu(
input: &[f32],
input_shape: &[usize],
weight: &[f32],
weight_shape: &[usize],
bias: Option<&[f32]>,
) -> Result<(Vec<f32>, Vec<usize>)> {
if input_shape.len() < 1 {
return Err(TensorError::InvalidShape {
operation: "fused_linear_relu".to_string(),
reason: "input_shape must have at least 1 dimension".to_string(),
shape: Some(input_shape.to_vec()),
context: None,
});
}
if weight_shape.len() != 2 {
return Err(TensorError::InvalidShape {
operation: "fused_linear_relu".to_string(),
reason: format!("weight_shape must be 2D [out, in], got {weight_shape:?}"),
shape: Some(weight_shape.to_vec()),
context: None,
});
}
let in_features = *input_shape.last().expect("checked above");
let out_features = weight_shape[0];
let weight_in = weight_shape[1];
if in_features != weight_in {
return Err(TensorError::ShapeMismatch {
operation: "fused_linear_relu".to_string(),
expected: format!("in_features={in_features}"),
got: format!("weight in_features={weight_in}"),
context: None,
});
}
if let Some(b) = bias {
if b.len() != out_features {
return Err(TensorError::ShapeMismatch {
operation: "fused_linear_relu".to_string(),
expected: format!("bias length={out_features}"),
got: format!("bias length={}", b.len()),
context: None,
});
}
}
let input_expected: usize = input_shape.iter().product();
let weight_expected: usize = weight_shape.iter().product();
if input.len() != input_expected {
return Err(TensorError::InvalidArgument {
operation: "fused_linear_relu".to_string(),
reason: format!(
"input length {} != product of input_shape {:?} ({})",
input.len(),
input_shape,
input_expected
),
context: None,
});
}
if weight.len() != weight_expected {
return Err(TensorError::InvalidArgument {
operation: "fused_linear_relu".to_string(),
reason: format!(
"weight length {} != product of weight_shape {:?} ({})",
weight.len(),
weight_shape,
weight_expected
),
context: None,
});
}
let batch: usize = input_shape[..input_shape.len() - 1].iter().product::<usize>().max(1);
let effective_batch = if input_shape.len() == 1 { 1 } else { batch };
let out_len = effective_batch * out_features;
let mut output = vec![0.0f32; out_len];
for n in 0..effective_batch {
let input_offset = n * in_features;
let output_offset = n * out_features;
for o in 0..out_features {
let weight_offset = o * in_features;
let mut sum = 0.0f32;
for f in 0..in_features {
sum += input[input_offset + f] * weight[weight_offset + f];
}
if let Some(b) = bias {
sum += b[o];
}
output[output_offset + o] = sum.max(0.0);
}
}
let mut out_shape = if input_shape.len() > 1 {
input_shape[..input_shape.len() - 1].to_vec()
} else {
Vec::new()
};
out_shape.push(out_features);
Ok((output, out_shape))
}
pub fn fused_layer_norm_linear(
input: &[f32],
shape: &[usize],
weight: &[f32],
weight_shape: &[usize],
norm_weight: &[f32],
norm_bias: &[f32],
eps: f32,
) -> Result<(Vec<f32>, Vec<usize>)> {
if shape.len() < 1 {
return Err(TensorError::InvalidShape {
operation: "fused_layer_norm_linear".to_string(),
reason: "shape must have at least 1 dimension".to_string(),
shape: Some(shape.to_vec()),
context: None,
});
}
if weight_shape.len() != 2 {
return Err(TensorError::InvalidShape {
operation: "fused_layer_norm_linear".to_string(),
reason: format!("weight_shape must be 2D [out, in], got {weight_shape:?}"),
shape: Some(weight_shape.to_vec()),
context: None,
});
}
let in_features = *shape.last().expect("checked above");
let out_features = weight_shape[0];
let weight_in = weight_shape[1];
if in_features != weight_in {
return Err(TensorError::ShapeMismatch {
operation: "fused_layer_norm_linear".to_string(),
expected: format!("in_features={in_features}"),
got: format!("weight in_features={weight_in}"),
context: None,
});
}
if norm_weight.len() != in_features {
return Err(TensorError::ShapeMismatch {
operation: "fused_layer_norm_linear".to_string(),
expected: format!("norm_weight length={in_features}"),
got: format!("{}", norm_weight.len()),
context: None,
});
}
if norm_bias.len() != in_features {
return Err(TensorError::ShapeMismatch {
operation: "fused_layer_norm_linear".to_string(),
expected: format!("norm_bias length={in_features}"),
got: format!("{}", norm_bias.len()),
context: None,
});
}
let input_expected: usize = shape.iter().product();
let weight_expected: usize = weight_shape.iter().product();
if input.len() != input_expected {
return Err(TensorError::InvalidArgument {
operation: "fused_layer_norm_linear".to_string(),
reason: format!(
"input length {} != product of shape {:?} ({})",
input.len(),
shape,
input_expected
),
context: None,
});
}
if weight.len() != weight_expected {
return Err(TensorError::InvalidArgument {
operation: "fused_layer_norm_linear".to_string(),
reason: format!(
"weight length {} != product of weight_shape {:?} ({})",
weight.len(),
weight_shape,
weight_expected
),
context: None,
});
}
let effective_batch = if shape.len() == 1 {
1
} else {
shape[..shape.len() - 1].iter().product::<usize>().max(1)
};
let out_len = effective_batch * out_features;
let mut output = vec![0.0f32; out_len];
let mut normed_row = vec![0.0f32; in_features];
for n in 0..effective_batch {
let input_offset = n * in_features;
let output_offset = n * out_features;
let row = &input[input_offset..input_offset + in_features];
let mean = row.iter().sum::<f32>() / (in_features as f32);
let var = row
.iter()
.map(|&x| {
let diff = x - mean;
diff * diff
})
.sum::<f32>()
/ (in_features as f32);
let inv_std = 1.0 / (var + eps).sqrt();
for (f, (&x, nw)) in row.iter().zip(norm_weight.iter()).enumerate() {
normed_row[f] = nw * (x - mean) * inv_std + norm_bias[f];
}
for o in 0..out_features {
let weight_offset = o * in_features;
let mut sum = 0.0f32;
for f in 0..in_features {
sum += normed_row[f] * weight[weight_offset + f];
}
output[output_offset + o] = sum;
}
}
let mut out_shape = if shape.len() > 1 {
shape[..shape.len() - 1].to_vec()
} else {
Vec::new()
};
out_shape.push(out_features);
Ok((output, out_shape))
}
pub fn layer_norm(
input: &[f32],
shape: &[usize],
weight: &[f32],
bias: &[f32],
eps: f32,
) -> Result<Vec<f32>> {
if shape.is_empty() {
return Err(TensorError::InvalidShape {
operation: "layer_norm".to_string(),
reason: "shape must have at least 1 dimension".to_string(),
shape: Some(shape.to_vec()),
context: None,
});
}
let in_features = *shape.last().expect("checked above");
if weight.len() != in_features {
return Err(TensorError::ShapeMismatch {
operation: "layer_norm".to_string(),
expected: format!("weight length={in_features}"),
got: format!("{}", weight.len()),
context: None,
});
}
if bias.len() != in_features {
return Err(TensorError::ShapeMismatch {
operation: "layer_norm".to_string(),
expected: format!("bias length={in_features}"),
got: format!("{}", bias.len()),
context: None,
});
}
let total: usize = shape.iter().product();
if input.len() != total {
return Err(TensorError::InvalidArgument {
operation: "layer_norm".to_string(),
reason: format!(
"input length {} != product of shape {:?} ({})",
input.len(),
shape,
total
),
context: None,
});
}
let batch = total / in_features;
let mut output = vec![0.0f32; total];
for n in 0..batch {
let offset = n * in_features;
let row = &input[offset..offset + in_features];
let mean = row.iter().sum::<f32>() / (in_features as f32);
let var = row
.iter()
.map(|&x| {
let diff = x - mean;
diff * diff
})
.sum::<f32>()
/ (in_features as f32);
let inv_std = 1.0 / (var + eps).sqrt();
for f in 0..in_features {
output[offset + f] = weight[f] * (row[f] - mean) * inv_std + bias[f];
}
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_norm_mean_zero_std_one() {
let input = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![2, 3]; let weight = vec![1.0f32; 3];
let bias = vec![0.0f32; 3];
let output = layer_norm(&input, &shape, &weight, &bias, 1e-5)
.expect("layer_norm should succeed");
assert_eq!(output.len(), input.len());
for n in 0..2 {
let row = &output[n * 3..(n + 1) * 3];
let mean: f32 = row.iter().sum::<f32>() / 3.0;
let var: f32 = row.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / 3.0;
assert!(
mean.abs() < 1e-5,
"sample {n} mean should be ~0, got {mean}"
);
assert!(
(var.sqrt() - 1.0).abs() < 1e-4,
"sample {n} std should be ~1, got {}",
var.sqrt()
);
}
}
#[test]
fn test_layer_norm_affine_transform() {
let input = vec![10.0f32, 20.0, 30.0];
let shape = vec![3];
let weight = vec![2.0f32; 3];
let bias_vec = vec![1.0f32; 3];
let output = layer_norm(&input, &shape, &weight, &bias_vec, 1e-5)
.expect("layer_norm should succeed");
let mean: f32 = output.iter().sum::<f32>() / 3.0;
assert!((mean - 1.0).abs() < 1e-4, "affine output mean should be beta=1, got {mean}");
}
#[test]
fn test_layer_norm_constant_input() {
let input = vec![5.0f32; 4];
let shape = vec![4];
let weight = vec![1.0f32; 4];
let bias_vec = vec![3.0f32; 4];
let output = layer_norm(&input, &shape, &weight, &bias_vec, 1e-5)
.expect("layer_norm should succeed on zero-variance input");
for (i, &o) in output.iter().enumerate() {
assert!(
o.is_finite(),
"output[{i}] must be finite for zero-variance input, got {o}"
);
}
}
#[test]
fn test_fused_linear_relu_known_weights() {
let input = vec![1.0f32, 2.0];
let input_shape = vec![1, 2];
let weight = vec![1.0f32, 0.0, 0.0, 1.0, -1.0, -1.0];
let weight_shape = vec![3, 2];
let bias = vec![0.0f32, 0.0, 10.0];
let (out, out_shape) =
fused_linear_relu(&input, &input_shape, &weight, &weight_shape, Some(&bias))
.expect("fused_linear_relu should succeed");
assert_eq!(out_shape, vec![1, 3]);
assert_eq!(out.len(), 3);
assert!((out[0] - 1.0).abs() < 1e-6, "out[0] expected 1.0, got {}", out[0]);
assert!((out[1] - 2.0).abs() < 1e-6, "out[1] expected 2.0, got {}", out[1]);
assert!((out[2] - 7.0).abs() < 1e-6, "out[2] expected 7.0, got {}", out[2]);
}
#[test]
fn test_fused_linear_relu_negative_clipped() {
let input = vec![-1.0f32, -2.0];
let input_shape = vec![1, 2];
let weight = vec![1.0f32, 1.0, 1.0, 1.0]; let weight_shape = vec![2, 2];
let (out, _) =
fused_linear_relu(&input, &input_shape, &weight, &weight_shape, None)
.expect("fused_linear_relu should succeed");
for (i, &x) in out.iter().enumerate() {
assert_eq!(x, 0.0, "ReLU of negative input must be 0, got out[{i}]={x}");
}
}
#[test]
fn test_fused_linear_relu_no_bias() {
let input = vec![3.0f32];
let input_shape = vec![1, 1];
let weight = vec![2.0f32];
let weight_shape = vec![1, 1];
let (out, out_shape) =
fused_linear_relu(&input, &input_shape, &weight, &weight_shape, None)
.expect("fused_linear_relu should succeed");
assert_eq!(out_shape, vec![1, 1]);
assert!((out[0] - 6.0).abs() < 1e-6, "expected 6.0, got {}", out[0]);
}
#[test]
fn test_fused_linear_relu_shape_error() {
let input = vec![1.0f32, 2.0];
let input_shape = vec![1, 2];
let weight = vec![1.0f32, 0.0, 0.0]; let weight_shape = vec![1, 3];
let result = fused_linear_relu(&input, &input_shape, &weight, &weight_shape, None);
assert!(result.is_err(), "mismatched weight dimensions must return Err");
}
#[test]
fn test_fused_layer_norm_linear_output_shape() {
let batch = 4;
let in_feat = 8;
let out_feat = 3;
let input: Vec<f32> = (0..batch * in_feat).map(|i| i as f32).collect();
let shape = vec![batch, in_feat];
let weight: Vec<f32> = vec![0.1f32; out_feat * in_feat];
let weight_shape = vec![out_feat, in_feat];
let norm_weight = vec![1.0f32; in_feat];
let norm_bias = vec![0.0f32; in_feat];
let (out, out_shape) = fused_layer_norm_linear(
&input,
&shape,
&weight,
&weight_shape,
&norm_weight,
&norm_bias,
1e-5,
)
.expect("fused_layer_norm_linear should succeed");
assert_eq!(out_shape, vec![batch, out_feat]);
assert_eq!(out.len(), batch * out_feat);
}
#[test]
fn test_fused_layer_norm_linear_no_nan() {
let batch = 2;
let in_feat = 4;
let out_feat = 4;
let input: Vec<f32> = (0..batch * in_feat).map(|i| (i as f32 - 3.5) * 2.0).collect();
let shape = vec![batch, in_feat];
let weight: Vec<f32> = vec![0.5f32; out_feat * in_feat];
let weight_shape = vec![out_feat, in_feat];
let norm_weight = vec![1.0f32; in_feat];
let norm_bias = vec![0.0f32; in_feat];
let (out, _) = fused_layer_norm_linear(
&input,
&shape,
&weight,
&weight_shape,
&norm_weight,
&norm_bias,
1e-5,
)
.expect("fused_layer_norm_linear should succeed");
for (i, &x) in out.iter().enumerate() {
assert!(x.is_finite(), "output[{i}] must be finite, got {x}");
}
}
}