use crate::{
error::{RealizarError, Result},
generate::{sample_token, GenerationConfig},
tensor::Tensor,
};
pub fn softmax(input: &Tensor<f32>) -> Result<Tensor<f32>> {
let data = input.data();
let shape = input.shape();
if data.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Cannot apply softmax to empty tensor".to_string(),
});
}
if shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Cannot apply softmax to tensor with empty shape".to_string(),
});
}
let last_dim = shape[shape.len() - 1];
let num_groups = data.len() / last_dim;
let mut output = Vec::with_capacity(data.len());
for group_idx in 0..num_groups {
let start = group_idx * last_dim;
let end = start + last_dim;
let group = &data[start..end];
let max_val = group.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_vals: Vec<f32> = group.iter().map(|&x| (x - max_val).exp()).collect();
let sum_exp: f32 = exp_vals.iter().sum();
for &exp_val in &exp_vals {
output.push(exp_val / sum_exp);
}
}
Tensor::from_vec(shape.to_vec(), output)
}
pub fn gelu(input: &Tensor<f32>) -> Result<Tensor<f32>> {
let data = input.data();
if data.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Cannot apply GELU to empty tensor".to_string(),
});
}
let output: Vec<f32> = data
.iter()
.map(|&x| {
let sqrt_2_over_pi = 0.797_884_6; let c = 0.044_715;
let inner = sqrt_2_over_pi * (x + c * x * x * x);
0.5 * x * (1.0 + inner.tanh())
})
.collect();
Tensor::from_vec(input.shape().to_vec(), output)
}
#[derive(Debug, Clone)]
pub struct LayerNorm {
normalized_shape: usize,
eps: f32,
weight: Vec<f32>,
bias: Vec<f32>,
}
impl LayerNorm {
pub fn new(normalized_shape: usize, eps: f32) -> Result<Self> {
if normalized_shape == 0 {
return Err(RealizarError::InvalidShape {
reason: "normalized_shape must be > 0".to_string(),
});
}
let weight = vec![1.0; normalized_shape];
let bias = vec![0.0; normalized_shape];
Ok(Self {
normalized_shape,
eps,
weight,
bias,
})
}
pub fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Input tensor cannot be empty".to_string(),
});
}
let last_dim = shape[shape.len() - 1];
if last_dim != self.normalized_shape {
return Err(RealizarError::InvalidShape {
reason: format!(
"Last dimension {} doesn't match normalized_shape {}",
last_dim, self.normalized_shape
),
});
}
let data = input.data();
let total_size = data.len();
let num_groups = total_size / self.normalized_shape;
let mut output = Vec::with_capacity(total_size);
for group_idx in 0..num_groups {
let start = group_idx * self.normalized_shape;
let end = start + self.normalized_shape;
let group = &data[start..end];
#[allow(clippy::cast_precision_loss)]
let mean: f32 = group.iter().sum::<f32>() / self.normalized_shape as f32;
#[allow(clippy::cast_precision_loss)]
let variance: f32 = group
.iter()
.map(|&x| {
let diff = x - mean;
diff * diff
})
.sum::<f32>()
/ self.normalized_shape as f32;
for (i, &x) in group.iter().enumerate() {
let normalized = (x - mean) / (variance + self.eps).sqrt();
let transformed = normalized * self.weight[i] + self.bias[i];
output.push(transformed);
}
}
debug_assert!(
output.iter().all(|&x| x.is_finite()),
"LayerNorm produced NaN or Inf values - check input distribution"
);
Tensor::from_vec(shape.to_vec(), output)
}
#[must_use]
pub fn normalized_shape(&self) -> usize {
self.normalized_shape
}
#[must_use]
pub fn eps(&self) -> f32 {
self.eps
}
}
#[derive(Debug, Clone)]
pub struct Linear {
in_features: usize,
out_features: usize,
weight: Vec<f32>,
bias: Vec<f32>,
}
impl Linear {
pub fn new(in_features: usize, out_features: usize) -> Result<Self> {
if in_features == 0 || out_features == 0 {
return Err(RealizarError::InvalidShape {
reason: "in_features and out_features must be > 0".to_string(),
});
}
let weight = vec![0.0; in_features * out_features];
let bias = vec![0.0; out_features];
Ok(Self {
in_features,
out_features,
weight,
bias,
})
}
pub fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Input tensor cannot be empty".to_string(),
});
}
let last_dim = shape[shape.len() - 1];
if last_dim != self.in_features {
return Err(RealizarError::InvalidShape {
reason: format!(
"Last dimension {} doesn't match in_features {}",
last_dim, self.in_features
),
});
}
let data = input.data();
let total_size = data.len();
let num_rows = total_size / self.in_features;
let mut output = Vec::with_capacity(num_rows * self.out_features);
for row_idx in 0..num_rows {
let input_start = row_idx * self.in_features;
let input_row = &data[input_start..input_start + self.in_features];
for j in 0..self.out_features {
let mut sum = self.bias[j];
for (i, &input_val) in input_row.iter().enumerate() {
sum += input_val * self.weight[i * self.out_features + j];
}
output.push(sum);
}
}
let mut output_shape = shape[..shape.len() - 1].to_vec();
output_shape.push(self.out_features);
debug_assert!(
output.iter().all(|&x| x.is_finite()),
"Linear layer produced NaN or Inf values - check for exploding gradients/activations"
);
Tensor::from_vec(output_shape, output)
}
#[must_use]
pub fn in_features(&self) -> usize {
self.in_features
}
#[must_use]
pub fn out_features(&self) -> usize {
self.out_features
}
#[must_use]
pub fn weight_mut(&mut self) -> &mut [f32] {
&mut self.weight
}
#[must_use]
pub fn bias_mut(&mut self) -> &mut [f32] {
&mut self.bias
}
}
#[derive(Debug, Clone)]
pub struct QuantizedLinear {
in_features: usize,
out_features: usize,
weight_bytes: Vec<u8>,
bias: Vec<f32>,
bytes_per_row: usize,
}
impl QuantizedLinear {
pub fn new(
in_features: usize,
out_features: usize,
weight_bytes: Vec<u8>,
bias: Vec<f32>,
) -> Result<Self> {
const SUPER_BLOCK_VALUES: usize = 256;
const SUPER_BLOCK_BYTES: usize = 144;
if in_features == 0 || out_features == 0 {
return Err(RealizarError::InvalidShape {
reason: "in_features and out_features must be > 0".to_string(),
});
}
if bias.len() != out_features {
return Err(RealizarError::InvalidShape {
reason: format!(
"Bias length {} doesn't match out_features {}",
bias.len(),
out_features
),
});
}
let super_blocks_per_row = in_features.div_ceil(SUPER_BLOCK_VALUES);
let bytes_per_row = super_blocks_per_row * SUPER_BLOCK_BYTES;
let expected_bytes = out_features * bytes_per_row;
if weight_bytes.len() != expected_bytes {
return Err(RealizarError::InvalidShape {
reason: format!(
"Weight bytes {} doesn't match expected {} ({}x{})",
weight_bytes.len(),
expected_bytes,
out_features,
bytes_per_row
),
});
}
Ok(Self {
in_features,
out_features,
weight_bytes,
bias,
bytes_per_row,
})
}
pub fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
use crate::quantize::fused_q4k_dot_simd;
let shape = input.shape();
if shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Input tensor cannot be empty".to_string(),
});
}
let last_dim = shape[shape.len() - 1];
if last_dim != self.in_features {
return Err(RealizarError::InvalidShape {
reason: format!(
"Last dimension {} doesn't match in_features {}",
last_dim, self.in_features
),
});
}
let data = input.data();
let total_size = data.len();
let num_rows = total_size / self.in_features;
let mut output = Vec::with_capacity(num_rows * self.out_features);
for row_idx in 0..num_rows {
let input_start = row_idx * self.in_features;
let input_row = &data[input_start..input_start + self.in_features];
for j in 0..self.out_features {
let weight_start = j * self.bytes_per_row;
let weight_row =
&self.weight_bytes[weight_start..weight_start + self.bytes_per_row];
let dot = fused_q4k_dot_simd(weight_row, input_row)?;
output.push(dot + self.bias[j]);
}
}
let mut output_shape = shape[..shape.len() - 1].to_vec();
output_shape.push(self.out_features);
if output_shape.is_empty() {
output_shape.push(self.out_features);
}
Tensor::from_vec(output_shape, output)
}
#[must_use]
pub fn in_features(&self) -> usize {
self.in_features
}
#[must_use]
pub fn out_features(&self) -> usize {
self.out_features
}
#[must_use]
pub fn weight_bytes(&self) -> &[u8] {
&self.weight_bytes
}
#[must_use]
pub fn bias(&self) -> &[f32] {
&self.bias
}
#[must_use]
pub fn memory_bytes(&self) -> usize {
self.weight_bytes.len() + self.bias.len() * std::mem::size_of::<f32>()
}
}
#[derive(Debug, Clone)]
pub struct FusedLayerNormLinear {
feature_dim: usize,
out_features: usize,
eps: f32,
norm_weight: Vec<f32>,
norm_bias: Vec<f32>,
linear_weight: Vec<f32>,
linear_bias: Vec<f32>,
}
impl FusedLayerNormLinear {
pub fn new(feature_dim: usize, out_features: usize, eps: f32) -> Result<Self> {
if feature_dim == 0 || out_features == 0 {
return Err(RealizarError::InvalidShape {
reason: "feature_dim and out_features must be > 0".to_string(),
});
}
Ok(Self {
feature_dim,
out_features,
eps,
norm_weight: vec![1.0; feature_dim],
norm_bias: vec![0.0; feature_dim],
linear_weight: vec![0.0; feature_dim * out_features],
linear_bias: vec![0.0; out_features],
})
}
pub fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Input tensor cannot be empty".to_string(),
});
}
let last_dim = shape[shape.len() - 1];
if last_dim != self.feature_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Last dimension {} doesn't match feature_dim {}",
last_dim, self.feature_dim
),
});
}
let data = input.data();
let num_rows = data.len() / self.feature_dim;
let mut output = Vec::with_capacity(num_rows * self.out_features);
for row_idx in 0..num_rows {
let row_start = row_idx * self.feature_dim;
let row = &data[row_start..row_start + self.feature_dim];
#[allow(clippy::cast_precision_loss)]
let mean: f32 = row.iter().sum::<f32>() / self.feature_dim as f32;
#[allow(clippy::cast_precision_loss)]
let variance: f32 = row
.iter()
.map(|&x| {
let diff = x - mean;
diff * diff
})
.sum::<f32>()
/ self.feature_dim as f32;
let inv_std = 1.0 / (variance + self.eps).sqrt();
for j in 0..self.out_features {
let mut sum = self.linear_bias[j];
for (i, &x) in row.iter().enumerate() {
let normalized = (x - mean) * inv_std;
let transformed = normalized * self.norm_weight[i] + self.norm_bias[i];
sum += transformed * self.linear_weight[i * self.out_features + j];
}
output.push(sum);
}
}
let mut output_shape = shape[..shape.len() - 1].to_vec();
output_shape.push(self.out_features);
Tensor::from_vec(output_shape, output)
}
pub fn forward_parallel(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
use rayon::prelude::*;
let shape = input.shape();
if shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Input tensor cannot be empty".to_string(),
});
}
let last_dim = shape[shape.len() - 1];
if last_dim != self.feature_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Last dimension {} doesn't match feature_dim {}",
last_dim, self.feature_dim
),
});
}
let data = input.data();
let num_rows = data.len() / self.feature_dim;
let output: Vec<f32> = (0..num_rows)
.into_par_iter()
.flat_map(|row_idx| {
let row_start = row_idx * self.feature_dim;
let row = &data[row_start..row_start + self.feature_dim];
#[allow(clippy::cast_precision_loss)]
let mean: f32 = row.iter().sum::<f32>() / self.feature_dim as f32;
#[allow(clippy::cast_precision_loss)]
let variance: f32 = row
.iter()
.map(|&x| {
let diff = x - mean;
diff * diff
})
.sum::<f32>()
/ self.feature_dim as f32;
let inv_std = 1.0 / (variance + self.eps).sqrt();
(0..self.out_features)
.map(|j| {
let mut sum = self.linear_bias[j];
for (i, &x) in row.iter().enumerate() {
let normalized = (x - mean) * inv_std;
let transformed = normalized * self.norm_weight[i] + self.norm_bias[i];
sum += transformed * self.linear_weight[i * self.out_features + j];
}
sum
})
.collect::<Vec<f32>>()
})
.collect();
let mut output_shape = shape[..shape.len() - 1].to_vec();
output_shape.push(self.out_features);
Tensor::from_vec(output_shape, output)
}
#[must_use]
pub fn feature_dim(&self) -> usize {
self.feature_dim
}
#[must_use]
pub fn out_features(&self) -> usize {
self.out_features
}
#[must_use]
pub fn norm_weight_mut(&mut self) -> &mut [f32] {
&mut self.norm_weight
}
#[must_use]
pub fn norm_bias_mut(&mut self) -> &mut [f32] {
&mut self.norm_bias
}
#[must_use]
pub fn linear_weight_mut(&mut self) -> &mut [f32] {
&mut self.linear_weight
}
#[must_use]
pub fn linear_bias_mut(&mut self) -> &mut [f32] {
&mut self.linear_bias
}
}
#[derive(Debug, Clone)]
pub struct FeedForward {
fc1: Linear,
fc2: Linear,
hidden_dim: usize,
intermediate_dim: usize,
}
impl FeedForward {
pub fn new(hidden_dim: usize, intermediate_dim: usize) -> Result<Self> {
let fc1 = Linear::new(hidden_dim, intermediate_dim)?;
let fc2 = Linear::new(intermediate_dim, hidden_dim)?;
Ok(Self {
fc1,
fc2,
hidden_dim,
intermediate_dim,
})
}
pub fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let hidden = self.fc1.forward(input)?;
let activated = gelu(&hidden)?;
self.fc2.forward(&activated)
}
#[must_use]
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
#[must_use]
pub fn intermediate_dim(&self) -> usize {
self.intermediate_dim
}
#[must_use]
pub fn fc1_mut(&mut self) -> &mut Linear {
&mut self.fc1
}
#[must_use]
pub fn fc2_mut(&mut self) -> &mut Linear {
&mut self.fc2
}
}
#[derive(Debug, Clone)]
pub struct Attention {
head_dim: usize,
scale: f32,
}
impl Attention {
pub fn new(head_dim: usize) -> Result<Self> {
if head_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "head_dim must be > 0".to_string(),
});
}
#[allow(clippy::cast_precision_loss)]
let scale = 1.0 / (head_dim as f32).sqrt();
Ok(Self { head_dim, scale })
}
pub fn forward(
&self,
query: &Tensor<f32>,
key: &Tensor<f32>,
value: &Tensor<f32>,
) -> Result<Tensor<f32>> {
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.is_empty() || k_shape.is_empty() || v_shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Query, key, value tensors must have at least 1 dimension".to_string(),
});
}
let q_last = q_shape[q_shape.len() - 1];
let k_last = k_shape[k_shape.len() - 1];
let v_last = v_shape[v_shape.len() - 1];
if q_last != self.head_dim || k_last != self.head_dim || v_last != self.head_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Expected head_dim={}, got Q={}, K={}, V={}",
self.head_dim, q_last, k_last, v_last
),
});
}
let q_seq_len = if q_shape.len() > 1 { q_shape[0] } else { 1 };
let k_seq_len = if k_shape.len() > 1 { k_shape[0] } else { 1 };
let v_seq_len = if v_shape.len() > 1 { v_shape[0] } else { 1 };
if k_seq_len != v_seq_len {
return Err(RealizarError::InvalidShape {
reason: format!("Key seq_len {k_seq_len} != Value seq_len {v_seq_len}"),
});
}
let q_data = query.data();
let k_data = key.data();
let v_data = value.data();
let mut scores = Vec::with_capacity(q_seq_len * k_seq_len);
for i in 0..q_seq_len {
for j in 0..k_seq_len {
let mut dot = 0.0;
for k in 0..self.head_dim {
dot += q_data[i * self.head_dim + k] * k_data[j * self.head_dim + k];
}
scores.push(dot * self.scale);
}
}
let scores_tensor = Tensor::from_vec(vec![q_seq_len, k_seq_len], scores)?;
let attn_weights = softmax(&scores_tensor)?;
let attn_data = attn_weights.data();
let mut output = Vec::with_capacity(q_seq_len * self.head_dim);
for i in 0..q_seq_len {
for k in 0..self.head_dim {
let mut sum = 0.0;
for j in 0..k_seq_len {
sum += attn_data[i * k_seq_len + j] * v_data[j * self.head_dim + k];
}
output.push(sum);
}
}
debug_assert!(
output.iter().all(|&x| x.is_finite()),
"Attention layer produced NaN or Inf values - check input scaling"
);
Tensor::from_vec(vec![q_seq_len, self.head_dim], output)
}
#[must_use]
pub fn head_dim(&self) -> usize {
self.head_dim
}
#[must_use]
pub fn scale(&self) -> f32 {
self.scale
}
pub fn flash_forward(
&self,
query: &Tensor<f32>,
key: &Tensor<f32>,
value: &Tensor<f32>,
block_size: usize,
) -> Result<Tensor<f32>> {
if block_size == 0 {
return Err(RealizarError::InvalidShape {
reason: "block_size must be > 0".to_string(),
});
}
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.is_empty() || k_shape.is_empty() || v_shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Query, key, value tensors must have at least 1 dimension".to_string(),
});
}
let q_last = q_shape[q_shape.len() - 1];
let k_last = k_shape[k_shape.len() - 1];
let v_last = v_shape[v_shape.len() - 1];
if q_last != self.head_dim || k_last != self.head_dim || v_last != self.head_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Expected head_dim={}, got Q={}, K={}, V={}",
self.head_dim, q_last, k_last, v_last
),
});
}
let q_seq_len = if q_shape.len() > 1 { q_shape[0] } else { 1 };
let k_seq_len = if k_shape.len() > 1 { k_shape[0] } else { 1 };
let v_seq_len = if v_shape.len() > 1 { v_shape[0] } else { 1 };
if k_seq_len != v_seq_len {
return Err(RealizarError::InvalidShape {
reason: format!("Key seq_len {k_seq_len} != Value seq_len {v_seq_len}"),
});
}
let q_data = query.data();
let k_data = key.data();
let v_data = value.data();
let mut output = vec![0.0; q_seq_len * self.head_dim];
let mut row_max = vec![f32::NEG_INFINITY; q_seq_len]; let mut row_sum = vec![0.0; q_seq_len];
let num_kv_blocks = k_seq_len.div_ceil(block_size);
for kv_block_idx in 0..num_kv_blocks {
let kv_start = kv_block_idx * block_size;
let kv_end = (kv_start + block_size).min(k_seq_len);
let kv_block_len = kv_end - kv_start;
let num_q_blocks = q_seq_len.div_ceil(block_size);
for q_block_idx in 0..num_q_blocks {
let q_start = q_block_idx * block_size;
let q_end = (q_start + block_size).min(q_seq_len);
let mut scores = vec![0.0; (q_end - q_start) * kv_block_len];
for (i, q_idx) in (q_start..q_end).enumerate() {
for (j, kv_idx) in (kv_start..kv_end).enumerate() {
let mut dot = 0.0;
for k in 0..self.head_dim {
dot += q_data[q_idx * self.head_dim + k]
* k_data[kv_idx * self.head_dim + k];
}
scores[i * kv_block_len + j] = dot * self.scale;
}
}
for (i, q_idx) in (q_start..q_end).enumerate() {
let block_max = (0..kv_block_len)
.map(|j| scores[i * kv_block_len + j])
.fold(f32::NEG_INFINITY, f32::max);
let old_max = row_max[q_idx];
let new_max = old_max.max(block_max);
row_max[q_idx] = new_max;
let mut block_sum = 0.0;
for j in 0..kv_block_len {
let exp_val = (scores[i * kv_block_len + j] - new_max).exp();
scores[i * kv_block_len + j] = exp_val;
block_sum += exp_val;
}
let scale_factor = (old_max - new_max).exp();
for k in 0..self.head_dim {
output[q_idx * self.head_dim + k] *= scale_factor;
}
row_sum[q_idx] = row_sum[q_idx] * scale_factor + block_sum;
}
for (i, q_idx) in (q_start..q_end).enumerate() {
for k in 0..self.head_dim {
let mut weighted_sum = 0.0;
for (j, kv_idx) in (kv_start..kv_end).enumerate() {
weighted_sum +=
scores[i * kv_block_len + j] * v_data[kv_idx * self.head_dim + k];
}
output[q_idx * self.head_dim + k] += weighted_sum;
}
}
}
}
for i in 0..q_seq_len {
for k in 0..self.head_dim {
output[i * self.head_dim + k] /= row_sum[i];
}
}
Tensor::from_vec(vec![q_seq_len, self.head_dim], output)
}
#[allow(clippy::similar_names)]
pub fn flash_forward_v2(
&self,
query: &Tensor<f32>,
key: &Tensor<f32>,
value: &Tensor<f32>,
block_size: usize,
) -> Result<Tensor<f32>> {
if block_size == 0 {
return Err(RealizarError::InvalidShape {
reason: "block_size must be > 0".to_string(),
});
}
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.is_empty() || k_shape.is_empty() || v_shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Query, key, value tensors must have at least 1 dimension".to_string(),
});
}
let q_last = q_shape[q_shape.len() - 1];
let k_last = k_shape[k_shape.len() - 1];
let v_last = v_shape[v_shape.len() - 1];
if q_last != self.head_dim || k_last != self.head_dim || v_last != self.head_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Expected head_dim={}, got Q={}, K={}, V={}",
self.head_dim, q_last, k_last, v_last
),
});
}
let q_seq_len = if q_shape.len() > 1 { q_shape[0] } else { 1 };
let k_seq_len = if k_shape.len() > 1 { k_shape[0] } else { 1 };
let v_seq_len = if v_shape.len() > 1 { v_shape[0] } else { 1 };
if k_seq_len != v_seq_len {
return Err(RealizarError::InvalidShape {
reason: format!("Key seq_len {k_seq_len} != Value seq_len {v_seq_len}"),
});
}
let q_data = query.data();
let k_data = key.data();
let v_data = value.data();
let head_dim = self.head_dim;
let scale = self.scale;
let mut output = vec![0.0; q_seq_len * head_dim];
let mut row_max = vec![f32::NEG_INFINITY; q_seq_len];
let mut row_sum = vec![0.0; q_seq_len];
let num_kv_blocks = k_seq_len.div_ceil(block_size);
for kv_block_idx in 0..num_kv_blocks {
let kv_start = kv_block_idx * block_size;
let kv_end = (kv_start + block_size).min(k_seq_len);
let kv_block_len = kv_end - kv_start;
for q_idx in 0..q_seq_len {
let mut scores = Vec::with_capacity(kv_block_len);
for kv_idx in kv_start..kv_end {
let dot = Self::simd_dot_product(
&q_data[q_idx * head_dim..(q_idx + 1) * head_dim],
&k_data[kv_idx * head_dim..(kv_idx + 1) * head_dim],
);
scores.push(dot * scale);
}
let block_max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let old_max = row_max[q_idx];
let new_max = old_max.max(block_max);
row_max[q_idx] = new_max;
let mut block_sum = 0.0;
for score in &mut scores {
let exp_val = (*score - new_max).exp();
*score = exp_val;
block_sum += exp_val;
}
let scale_factor = (old_max - new_max).exp();
for k in 0..head_dim {
output[q_idx * head_dim + k] *= scale_factor;
}
row_sum[q_idx] = row_sum[q_idx] * scale_factor + block_sum;
for (j, kv_idx) in (kv_start..kv_end).enumerate() {
let weight = scores[j];
for k in 0..head_dim {
output[q_idx * head_dim + k] += weight * v_data[kv_idx * head_dim + k];
}
}
}
}
for i in 0..q_seq_len {
let inv_sum = 1.0 / row_sum[i];
for k in 0..head_dim {
output[i * head_dim + k] *= inv_sum;
}
}
Tensor::from_vec(vec![q_seq_len, self.head_dim], output)
}
#[inline]
fn simd_dot_product(a: &[f32], b: &[f32]) -> f32 {
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
{
Self::simd_dot_avx2(a, b)
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
{
Self::scalar_dot_product(a, b)
}
}
#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
#[inline]
#[allow(clippy::wildcard_imports)]
fn simd_dot_avx2(a: &[f32], b: &[f32]) -> f32 {
use std::arch::x86_64::*;
let len = a.len().min(b.len());
let chunks = len / 8;
let remainder = len % 8;
let simd_sum = unsafe {
let mut acc = _mm256_setzero_ps();
for i in 0..chunks {
let a_vec = _mm256_loadu_ps(a.as_ptr().add(i * 8));
let b_vec = _mm256_loadu_ps(b.as_ptr().add(i * 8));
acc = _mm256_fmadd_ps(a_vec, b_vec, acc);
}
let hi = _mm256_extractf128_ps(acc, 1);
let lo = _mm256_castps256_ps128(acc);
let sum128 = _mm_add_ps(lo, hi);
let hi64 = _mm_movehl_ps(sum128, sum128);
let sum64 = _mm_add_ps(sum128, hi64);
let hi32 = _mm_shuffle_ps(sum64, sum64, 0x55);
let sum32 = _mm_add_ss(sum64, hi32);
_mm_cvtss_f32(sum32)
};
let remainder_sum: f32 = (0..remainder)
.map(|i| a[chunks * 8 + i] * b[chunks * 8 + i])
.sum();
simd_sum + remainder_sum
}
#[cfg(not(all(target_arch = "x86_64", target_feature = "avx2")))]
#[inline]
fn scalar_dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[allow(clippy::similar_names)]
pub fn flash_forward_parallel(
&self,
query: &Tensor<f32>,
key: &Tensor<f32>,
value: &Tensor<f32>,
block_size: usize,
) -> Result<Tensor<f32>> {
use rayon::prelude::*;
if block_size == 0 {
return Err(RealizarError::InvalidShape {
reason: "block_size must be > 0".to_string(),
});
}
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.is_empty() || k_shape.is_empty() || v_shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Query, key, value tensors must have at least 1 dimension".to_string(),
});
}
let q_last = q_shape[q_shape.len() - 1];
let k_last = k_shape[k_shape.len() - 1];
let v_last = v_shape[v_shape.len() - 1];
if q_last != self.head_dim || k_last != self.head_dim || v_last != self.head_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Expected head_dim={}, got Q={}, K={}, V={}",
self.head_dim, q_last, k_last, v_last
),
});
}
let q_seq_len = if q_shape.len() > 1 { q_shape[0] } else { 1 };
let k_seq_len = if k_shape.len() > 1 { k_shape[0] } else { 1 };
let v_seq_len = if v_shape.len() > 1 { v_shape[0] } else { 1 };
if k_seq_len != v_seq_len {
return Err(RealizarError::InvalidShape {
reason: format!("Key seq_len {k_seq_len} != Value seq_len {v_seq_len}"),
});
}
let q_data = query.data();
let k_data = key.data();
let v_data = value.data();
let head_dim = self.head_dim;
let scale = self.scale;
let output: Vec<f32> = (0..q_seq_len)
.into_par_iter()
.flat_map(|q_idx| {
let mut row_output = vec![0.0; head_dim];
let mut row_max = f32::NEG_INFINITY;
let mut row_sum = 0.0;
let num_kv_blocks = k_seq_len.div_ceil(block_size);
for kv_block_idx in 0..num_kv_blocks {
let kv_start = kv_block_idx * block_size;
let kv_end = (kv_start + block_size).min(k_seq_len);
let mut scores: Vec<f32> = (kv_start..kv_end)
.map(|kv_idx| {
let dot = Self::simd_dot_product(
&q_data[q_idx * head_dim..(q_idx + 1) * head_dim],
&k_data[kv_idx * head_dim..(kv_idx + 1) * head_dim],
);
dot * scale
})
.collect();
let block_max = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let old_max = row_max;
let new_max = old_max.max(block_max);
row_max = new_max;
let mut block_sum = 0.0;
for score in &mut scores {
let exp_val = (*score - new_max).exp();
*score = exp_val;
block_sum += exp_val;
}
let scale_factor = (old_max - new_max).exp();
for out_val in &mut row_output {
*out_val *= scale_factor;
}
row_sum = row_sum * scale_factor + block_sum;
for (j, kv_idx) in (kv_start..kv_end).enumerate() {
let weight = scores[j];
for k in 0..head_dim {
row_output[k] += weight * v_data[kv_idx * head_dim + k];
}
}
}
let inv_sum = 1.0 / row_sum;
for out_val in &mut row_output {
*out_val *= inv_sum;
}
row_output
})
.collect();
Tensor::from_vec(vec![q_seq_len, self.head_dim], output)
}
}
#[derive(Debug, Clone)]
pub struct SlidingWindowAttention {
head_dim: usize,
scale: f32,
window_size: usize,
}
impl SlidingWindowAttention {
pub fn new(head_dim: usize, window_size: usize) -> Result<Self> {
if head_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "head_dim must be > 0".to_string(),
});
}
if window_size == 0 {
return Err(RealizarError::InvalidShape {
reason: "window_size must be > 0".to_string(),
});
}
#[allow(clippy::cast_precision_loss)]
let scale = 1.0 / (head_dim as f32).sqrt();
Ok(Self {
head_dim,
scale,
window_size,
})
}
pub fn forward(
&self,
query: &Tensor<f32>,
key: &Tensor<f32>,
value: &Tensor<f32>,
) -> Result<Tensor<f32>> {
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.is_empty() || k_shape.is_empty() || v_shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Query, key, value tensors must have at least 1 dimension".to_string(),
});
}
let q_last = q_shape[q_shape.len() - 1];
let k_last = k_shape[k_shape.len() - 1];
let v_last = v_shape[v_shape.len() - 1];
if q_last != self.head_dim || k_last != self.head_dim || v_last != self.head_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Expected head_dim={}, got Q={}, K={}, V={}",
self.head_dim, q_last, k_last, v_last
),
});
}
let q_seq_len = if q_shape.len() > 1 { q_shape[0] } else { 1 };
let k_seq_len = if k_shape.len() > 1 { k_shape[0] } else { 1 };
let v_seq_len = if v_shape.len() > 1 { v_shape[0] } else { 1 };
if k_seq_len != v_seq_len {
return Err(RealizarError::InvalidShape {
reason: format!("Key seq_len {k_seq_len} != Value seq_len {v_seq_len}"),
});
}
let q_data = query.data();
let k_data = key.data();
let v_data = value.data();
let mut output = Vec::with_capacity(q_seq_len * self.head_dim);
for i in 0..q_seq_len {
let window_end = (i + 1).min(k_seq_len);
let window_start = window_end.saturating_sub(self.window_size);
let window_len = window_end - window_start;
if window_len == 0 {
output.extend(std::iter::repeat(0.0).take(self.head_dim));
continue;
}
let mut scores = Vec::with_capacity(window_len);
for j in window_start..window_end {
let mut dot = 0.0;
for k in 0..self.head_dim {
dot += q_data[i * self.head_dim + k] * k_data[j * self.head_dim + k];
}
scores.push(dot * self.scale);
}
let max_score = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut exp_sum = 0.0;
for score in &mut scores {
let exp_val = (*score - max_score).exp();
*score = exp_val;
exp_sum += exp_val;
}
let inv_sum = 1.0 / exp_sum;
for score in &mut scores {
*score *= inv_sum;
}
for k in 0..self.head_dim {
let mut sum = 0.0;
for (idx, j) in (window_start..window_end).enumerate() {
sum += scores[idx] * v_data[j * self.head_dim + k];
}
output.push(sum);
}
}
Tensor::from_vec(vec![q_seq_len, self.head_dim], output)
}
pub fn forward_with_mask(
&self,
query: &Tensor<f32>,
key: &Tensor<f32>,
value: &Tensor<f32>,
causal: bool,
) -> Result<Tensor<f32>> {
if causal {
return self.forward(query, key, value);
}
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.is_empty() || k_shape.is_empty() || v_shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Query, key, value tensors must have at least 1 dimension".to_string(),
});
}
let q_last = q_shape[q_shape.len() - 1];
let k_last = k_shape[k_shape.len() - 1];
let v_last = v_shape[v_shape.len() - 1];
if q_last != self.head_dim || k_last != self.head_dim || v_last != self.head_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Expected head_dim={}, got Q={}, K={}, V={}",
self.head_dim, q_last, k_last, v_last
),
});
}
let q_seq_len = if q_shape.len() > 1 { q_shape[0] } else { 1 };
let k_seq_len = if k_shape.len() > 1 { k_shape[0] } else { 1 };
let v_seq_len = if v_shape.len() > 1 { v_shape[0] } else { 1 };
if k_seq_len != v_seq_len {
return Err(RealizarError::InvalidShape {
reason: format!("Key seq_len {k_seq_len} != Value seq_len {v_seq_len}"),
});
}
let q_data = query.data();
let k_data = key.data();
let v_data = value.data();
let mut output = Vec::with_capacity(q_seq_len * self.head_dim);
let half_window = self.window_size / 2;
for i in 0..q_seq_len {
let window_start = i.saturating_sub(half_window);
let window_end = (i + half_window + 1).min(k_seq_len);
let window_len = window_end - window_start;
if window_len == 0 {
output.extend(std::iter::repeat(0.0).take(self.head_dim));
continue;
}
let mut scores = Vec::with_capacity(window_len);
for j in window_start..window_end {
let mut dot = 0.0;
for k in 0..self.head_dim {
dot += q_data[i * self.head_dim + k] * k_data[j * self.head_dim + k];
}
scores.push(dot * self.scale);
}
let max_score = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut exp_sum = 0.0;
for score in &mut scores {
let exp_val = (*score - max_score).exp();
*score = exp_val;
exp_sum += exp_val;
}
let inv_sum = 1.0 / exp_sum;
for score in &mut scores {
*score *= inv_sum;
}
for k in 0..self.head_dim {
let mut sum = 0.0;
for (idx, j) in (window_start..window_end).enumerate() {
sum += scores[idx] * v_data[j * self.head_dim + k];
}
output.push(sum);
}
}
Tensor::from_vec(vec![q_seq_len, self.head_dim], output)
}
#[must_use]
pub fn head_dim(&self) -> usize {
self.head_dim
}
#[must_use]
pub fn scale(&self) -> f32 {
self.scale
}
#[must_use]
pub fn window_size(&self) -> usize {
self.window_size
}
#[must_use]
pub fn effective_context(&self, position: usize, seq_len: usize) -> usize {
let window_end = (position + 1).min(seq_len);
let window_start = window_end.saturating_sub(self.window_size);
window_end - window_start
}
#[must_use]
pub fn memory_ratio(&self, seq_len: usize) -> f32 {
if seq_len == 0 {
return 1.0;
}
#[allow(clippy::cast_precision_loss)]
{
(self.window_size.min(seq_len) as f32) / (seq_len as f32)
}
}
}
#[derive(Debug, Clone)]
pub struct FusedQKVAttention {
head_dim: usize,
hidden_dim: usize,
num_heads: usize,
scale: f32,
w_q: Vec<f32>,
w_k: Vec<f32>,
w_v: Vec<f32>,
w_o: Vec<f32>,
}
impl FusedQKVAttention {
pub fn new(head_dim: usize, hidden_dim: usize) -> Result<Self> {
if head_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "head_dim must be > 0".to_string(),
});
}
if hidden_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "hidden_dim must be > 0".to_string(),
});
}
if hidden_dim % head_dim != 0 {
return Err(RealizarError::InvalidShape {
reason: format!(
"hidden_dim ({}) must be divisible by head_dim ({})",
hidden_dim, head_dim
),
});
}
let num_heads = hidden_dim / head_dim;
let scale = 1.0 / (head_dim as f32).sqrt();
let proj_size = hidden_dim * hidden_dim;
let init_weight = |size: usize| -> Vec<f32> {
(0..size).map(|i| (i as f32 * 0.001).sin() * 0.02).collect()
};
Ok(Self {
head_dim,
hidden_dim,
num_heads,
scale,
w_q: init_weight(proj_size),
w_k: init_weight(proj_size),
w_v: init_weight(proj_size),
w_o: init_weight(proj_size),
})
}
pub fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.len() < 2 {
return Err(RealizarError::InvalidShape {
reason: "Input must have at least 2 dimensions [seq_len, hidden_dim]".to_string(),
});
}
let seq_len = shape[0];
let input_dim = shape[shape.len() - 1];
if input_dim != self.hidden_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Input hidden_dim ({}) doesn't match layer hidden_dim ({})",
input_dim, self.hidden_dim
),
});
}
let data = input.data();
let mut q = vec![0.0f32; seq_len * self.hidden_dim];
let mut k = vec![0.0f32; seq_len * self.hidden_dim];
let mut v = vec![0.0f32; seq_len * self.hidden_dim];
for i in 0..seq_len {
for j in 0..self.hidden_dim {
let mut sum_q = 0.0f32;
let mut sum_k = 0.0f32;
let mut sum_v = 0.0f32;
for l in 0..self.hidden_dim {
let inp = data[i * self.hidden_dim + l];
sum_q += inp * self.w_q[l * self.hidden_dim + j];
sum_k += inp * self.w_k[l * self.hidden_dim + j];
sum_v += inp * self.w_v[l * self.hidden_dim + j];
}
q[i * self.hidden_dim + j] = sum_q;
k[i * self.hidden_dim + j] = sum_k;
v[i * self.hidden_dim + j] = sum_v;
}
}
let mut output = vec![0.0f32; seq_len * self.hidden_dim];
for head in 0..self.num_heads {
let head_offset = head * self.head_dim;
for i in 0..seq_len {
let mut max_score = f32::NEG_INFINITY;
for j in 0..=i {
let mut dot = 0.0f32;
for d in 0..self.head_dim {
let q_idx = i * self.hidden_dim + head_offset + d;
let k_idx = j * self.hidden_dim + head_offset + d;
dot += q[q_idx] * k[k_idx];
}
let score = dot * self.scale;
if score > max_score {
max_score = score;
}
}
let mut sum_exp = 0.0f32;
let mut scores = vec![0.0f32; i + 1];
for (j, score) in scores.iter_mut().enumerate() {
let mut dot = 0.0f32;
for d in 0..self.head_dim {
let q_idx = i * self.hidden_dim + head_offset + d;
let k_idx = j * self.hidden_dim + head_offset + d;
dot += q[q_idx] * k[k_idx];
}
*score = (dot * self.scale - max_score).exp();
sum_exp += *score;
}
if sum_exp > 0.0 {
for d in 0..self.head_dim {
let mut weighted_sum = 0.0f32;
for (j, &score) in scores.iter().enumerate() {
let v_idx = j * self.hidden_dim + head_offset + d;
weighted_sum += (score / sum_exp) * v[v_idx];
}
output[i * self.hidden_dim + head_offset + d] = weighted_sum;
}
}
}
}
let mut final_output = vec![0.0f32; seq_len * self.hidden_dim];
for i in 0..seq_len {
for j in 0..self.hidden_dim {
let mut sum = 0.0f32;
for l in 0..self.hidden_dim {
sum += output[i * self.hidden_dim + l] * self.w_o[l * self.hidden_dim + j];
}
final_output[i * self.hidden_dim + j] = sum;
}
}
Tensor::from_vec(vec![seq_len, self.hidden_dim], final_output)
}
#[must_use]
pub fn head_dim(&self) -> usize {
self.head_dim
}
#[must_use]
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
#[must_use]
pub fn num_heads(&self) -> usize {
self.num_heads
}
pub fn w_q_mut(&mut self) -> &mut [f32] {
&mut self.w_q
}
pub fn w_k_mut(&mut self) -> &mut [f32] {
&mut self.w_k
}
pub fn w_v_mut(&mut self) -> &mut [f32] {
&mut self.w_v
}
pub fn w_o_mut(&mut self) -> &mut [f32] {
&mut self.w_o
}
}
#[derive(Debug, Clone)]
pub struct MultiHeadAttention {
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
hidden_dim: usize,
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
attention: Attention,
}
impl MultiHeadAttention {
pub fn new(hidden_dim: usize, num_heads: usize, num_kv_heads: usize) -> Result<Self> {
if hidden_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "hidden_dim must be > 0".to_string(),
});
}
if num_heads == 0 {
return Err(RealizarError::InvalidShape {
reason: "num_heads must be > 0".to_string(),
});
}
if num_kv_heads == 0 {
return Err(RealizarError::InvalidShape {
reason: "num_kv_heads must be > 0".to_string(),
});
}
if num_kv_heads > num_heads {
return Err(RealizarError::InvalidShape {
reason: format!(
"num_kv_heads {num_kv_heads} cannot be greater than num_heads {num_heads}"
),
});
}
if hidden_dim % num_heads != 0 {
return Err(RealizarError::InvalidShape {
reason: format!(
"hidden_dim {hidden_dim} must be divisible by num_heads {num_heads}"
),
});
}
if num_heads % num_kv_heads != 0 {
return Err(RealizarError::InvalidShape {
reason: format!(
"num_heads {num_heads} must be divisible by num_kv_heads {num_kv_heads}"
),
});
}
let head_dim = hidden_dim / num_heads;
let q_proj = Linear::new(hidden_dim, hidden_dim)?;
let kv_dim = num_kv_heads * head_dim;
let k_proj = Linear::new(hidden_dim, kv_dim)?;
let v_proj = Linear::new(hidden_dim, kv_dim)?;
let o_proj = Linear::new(hidden_dim, hidden_dim)?;
let attention = Attention::new(head_dim)?;
Ok(Self {
num_heads,
num_kv_heads,
head_dim,
hidden_dim,
q_proj,
k_proj,
v_proj,
o_proj,
attention,
})
}
pub fn mha(hidden_dim: usize, num_heads: usize) -> Result<Self> {
Self::new(hidden_dim, num_heads, num_heads)
}
pub fn mqa(hidden_dim: usize, num_heads: usize) -> Result<Self> {
Self::new(hidden_dim, num_heads, 1)
}
pub fn gqa(hidden_dim: usize, num_heads: usize, num_kv_heads: usize) -> Result<Self> {
Self::new(hidden_dim, num_heads, num_kv_heads)
}
pub fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.len() != 2 {
return Err(RealizarError::InvalidShape {
reason: format!("Expected 2D tensor [seq_len, hidden_dim], got shape {shape:?}"),
});
}
let seq_len = shape[0];
let input_dim = shape[1];
if input_dim != self.hidden_dim {
return Err(RealizarError::InvalidShape {
reason: format!("Expected hidden_dim={}, got {}", self.hidden_dim, input_dim),
});
}
let q = self.q_proj.forward(input)?; let k = self.k_proj.forward(input)?; let v = self.v_proj.forward(input)?;
let q_data = q.data();
let k_data = k.data();
let v_data = v.data();
let heads_per_group = self.num_heads / self.num_kv_heads;
let mut head_outputs = Vec::with_capacity(self.num_heads);
for head_idx in 0..self.num_heads {
let mut q_head_data = Vec::with_capacity(seq_len * self.head_dim);
for seq_idx in 0..seq_len {
let q_row_start = seq_idx * self.hidden_dim;
let head_start = q_row_start + head_idx * self.head_dim;
for offset in 0..self.head_dim {
q_head_data.push(q_data[head_start + offset]);
}
}
let q_head = Tensor::from_vec(vec![seq_len, self.head_dim], q_head_data)?;
let kv_head_idx = head_idx / heads_per_group;
let kv_dim = self.num_kv_heads * self.head_dim;
let mut k_head_data = Vec::with_capacity(seq_len * self.head_dim);
let mut v_head_data = Vec::with_capacity(seq_len * self.head_dim);
for seq_idx in 0..seq_len {
let kv_row_start = seq_idx * kv_dim;
let kv_head_start = kv_row_start + kv_head_idx * self.head_dim;
for offset in 0..self.head_dim {
k_head_data.push(k_data[kv_head_start + offset]);
v_head_data.push(v_data[kv_head_start + offset]);
}
}
let k_head = Tensor::from_vec(vec![seq_len, self.head_dim], k_head_data)?;
let v_head = Tensor::from_vec(vec![seq_len, self.head_dim], v_head_data)?;
let head_output = self.attention.forward(&q_head, &k_head, &v_head)?;
head_outputs.push(head_output);
}
let mut concat_data = Vec::with_capacity(seq_len * self.hidden_dim);
for seq_idx in 0..seq_len {
for head_output in &head_outputs {
let head_output_data = head_output.data();
let head_row_start = seq_idx * self.head_dim;
for offset in 0..self.head_dim {
concat_data.push(head_output_data[head_row_start + offset]);
}
}
}
let concat = Tensor::from_vec(vec![seq_len, self.hidden_dim], concat_data)?;
self.o_proj.forward(&concat)
}
#[must_use]
pub fn num_heads(&self) -> usize {
self.num_heads
}
#[must_use]
pub fn num_kv_heads(&self) -> usize {
self.num_kv_heads
}
#[must_use]
pub fn head_dim(&self) -> usize {
self.head_dim
}
#[must_use]
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
#[must_use]
pub fn is_mqa(&self) -> bool {
self.num_kv_heads == 1
}
#[must_use]
pub fn is_gqa(&self) -> bool {
self.num_kv_heads > 1 && self.num_kv_heads < self.num_heads
}
#[must_use]
pub fn is_mha(&self) -> bool {
self.num_kv_heads == self.num_heads
}
}
#[derive(Debug, Clone)]
pub struct RoPE {
dim: usize,
base: f32,
inv_freq: Vec<f32>,
}
impl RoPE {
pub fn new(dim: usize, base: f32) -> Result<Self> {
if dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "dim must be > 0".to_string(),
});
}
if dim % 2 != 0 {
return Err(RealizarError::InvalidShape {
reason: "dim must be even for RoPE".to_string(),
});
}
let half_dim = dim / 2;
let mut inv_freq = Vec::with_capacity(half_dim);
#[allow(clippy::cast_precision_loss)]
for i in 0..half_dim {
let exponent = -2.0 * (i as f32) / (dim as f32);
inv_freq.push(base.powf(exponent));
}
Ok(Self {
dim,
base,
inv_freq,
})
}
pub fn with_default_base(dim: usize) -> Result<Self> {
Self::new(dim, 10000.0)
}
pub fn forward(&self, input: &Tensor<f32>, position: usize) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Input tensor must have at least 1 dimension".to_string(),
});
}
let last_dim = shape[shape.len() - 1];
if last_dim != self.dim {
return Err(RealizarError::InvalidShape {
reason: format!("Expected last dimension {}, got {}", self.dim, last_dim),
});
}
let data = input.data();
let num_vectors = data.len() / self.dim;
let mut output = Vec::with_capacity(data.len());
let half_dim = self.dim / 2;
let mut cos_vals = Vec::with_capacity(half_dim);
let mut sin_vals = Vec::with_capacity(half_dim);
#[allow(clippy::cast_precision_loss)]
for inv_f in &self.inv_freq {
let angle = inv_f * (position as f32);
cos_vals.push(angle.cos());
sin_vals.push(angle.sin());
}
for vec_idx in 0..num_vectors {
let offset = vec_idx * self.dim;
for i in 0..half_dim {
let x0 = data[offset + 2 * i];
let x1 = data[offset + 2 * i + 1];
let cos_val = cos_vals[i];
let sin_val = sin_vals[i];
let y0 = x0 * cos_val - x1 * sin_val;
let y1 = x0 * sin_val + x1 * cos_val;
output.push(y0);
output.push(y1);
}
}
Tensor::from_vec(shape.to_vec(), output)
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn base(&self) -> f32 {
self.base
}
#[must_use]
pub fn inv_freq(&self) -> &[f32] {
&self.inv_freq
}
}
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum RopeScalingType {
#[default]
None,
Linear {
scale: f32,
},
Ntk {
scale: f32,
},
DynamicNtk {
original_max_len: usize,
target_max_len: usize,
},
Yarn {
original_max_len: usize,
target_max_len: usize,
attn_factor: f32,
beta_fast: f32,
beta_slow: f32,
},
}
#[derive(Debug, Clone)]
pub struct ScaledRoPE {
dim: usize,
original_base: f32,
scaled_base: f32,
scaling: RopeScalingType,
inv_freq: Vec<f32>,
mscale: f32,
}
impl ScaledRoPE {
pub fn new(dim: usize, base: f32, scaling: RopeScalingType) -> Result<Self> {
if dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "dim must be > 0".to_string(),
});
}
if dim % 2 != 0 {
return Err(RealizarError::InvalidShape {
reason: "dim must be even for RoPE".to_string(),
});
}
let (scaled_base, mscale, inv_freq) = Self::compute_frequencies(dim, base, &scaling);
Ok(Self {
dim,
original_base: base,
scaled_base,
scaling,
inv_freq,
mscale,
})
}
pub fn with_default_base(dim: usize, scaling: RopeScalingType) -> Result<Self> {
Self::new(dim, 10000.0, scaling)
}
fn compute_frequencies(
dim: usize,
base: f32,
scaling: &RopeScalingType,
) -> (f32, f32, Vec<f32>) {
let half_dim = dim / 2;
#[allow(clippy::cast_precision_loss)]
let (scaled_base, mscale) = match scaling {
RopeScalingType::None | RopeScalingType::Linear { .. } => (base, 1.0),
RopeScalingType::Ntk { scale } => {
let dim_f = dim as f32;
let exponent = dim_f / (dim_f - 2.0);
let ntk_base = base * scale.powf(exponent);
(ntk_base, 1.0)
},
RopeScalingType::DynamicNtk {
original_max_len,
target_max_len,
} => {
let scale = (*target_max_len as f32) / (*original_max_len as f32);
let dim_f = dim as f32;
let exponent = dim_f / (dim_f - 2.0);
let ntk_base = base * scale.powf(exponent);
(ntk_base, 1.0)
},
RopeScalingType::Yarn {
original_max_len,
target_max_len,
attn_factor,
beta_fast,
beta_slow,
} => {
let scale = (*target_max_len as f32) / (*original_max_len as f32);
let dim_f = dim as f32;
let exponent = dim_f / (dim_f - 2.0);
let ntk_base = base * scale.powf(exponent);
let mscale = if *attn_factor > 0.0 {
*attn_factor
} else {
let log_scale = scale.ln();
let log_orig = (*original_max_len as f32).ln();
(1.0 + log_scale / log_orig).sqrt()
};
let _ = (beta_fast, beta_slow);
(ntk_base, mscale)
},
};
let mut inv_freq = Vec::with_capacity(half_dim);
#[allow(clippy::cast_precision_loss)]
for i in 0..half_dim {
let exponent = -2.0 * (i as f32) / (dim as f32);
inv_freq.push(scaled_base.powf(exponent));
}
(scaled_base, mscale, inv_freq)
}
pub fn forward(&self, input: &Tensor<f32>, position: usize) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Input tensor must have at least 1 dimension".to_string(),
});
}
let last_dim = shape[shape.len() - 1];
if last_dim != self.dim {
return Err(RealizarError::InvalidShape {
reason: format!("Expected last dimension {}, got {}", self.dim, last_dim),
});
}
let data = input.data();
let num_vectors = data.len() / self.dim;
let mut output = Vec::with_capacity(data.len());
#[allow(clippy::cast_precision_loss)]
let effective_pos = match &self.scaling {
RopeScalingType::None
| RopeScalingType::Ntk { .. }
| RopeScalingType::DynamicNtk { .. }
| RopeScalingType::Yarn { .. } => position as f32,
RopeScalingType::Linear { scale } => (position as f32) / scale,
};
let half_dim = self.dim / 2;
let mut cos_vals = Vec::with_capacity(half_dim);
let mut sin_vals = Vec::with_capacity(half_dim);
#[allow(clippy::cast_precision_loss)]
for (i, inv_f) in self.inv_freq.iter().enumerate() {
let angle = inv_f * effective_pos;
let (cos_val, sin_val) = if let RopeScalingType::Yarn {
original_max_len,
target_max_len,
beta_fast,
beta_slow,
..
} = &self.scaling
{
let freq = 1.0 / inv_f;
let wavelength = 2.0 * std::f32::consts::PI * freq;
let low_freq_wavelen = (*original_max_len as f32) / *beta_slow;
let high_freq_wavelen = (*original_max_len as f32) / *beta_fast;
let ramp = if wavelength < high_freq_wavelen {
0.0 } else if wavelength > low_freq_wavelen {
1.0 } else {
(wavelength - high_freq_wavelen) / (low_freq_wavelen - high_freq_wavelen)
};
let scale = (*target_max_len as f32) / (*original_max_len as f32);
let linear_pos = effective_pos / scale;
let orig_inv_f = self
.original_base
.powf(-2.0 * (i as f32) / (self.dim as f32));
let linear_angle = orig_inv_f * linear_pos;
let final_angle = angle * (1.0 - ramp) + linear_angle * ramp;
(final_angle.cos(), final_angle.sin())
} else {
(angle.cos(), angle.sin())
};
cos_vals.push(cos_val);
sin_vals.push(sin_val);
}
for vec_idx in 0..num_vectors {
let offset = vec_idx * self.dim;
for i in 0..half_dim {
let x0 = data[offset + 2 * i];
let x1 = data[offset + 2 * i + 1];
let cos_val = cos_vals[i];
let sin_val = sin_vals[i];
let y0 = (x0 * cos_val - x1 * sin_val) * self.mscale;
let y1 = (x0 * sin_val + x1 * cos_val) * self.mscale;
output.push(y0);
output.push(y1);
}
}
Tensor::from_vec(shape.to_vec(), output)
}
#[must_use]
pub fn dim(&self) -> usize {
self.dim
}
#[must_use]
pub fn original_base(&self) -> f32 {
self.original_base
}
#[must_use]
pub fn scaled_base(&self) -> f32 {
self.scaled_base
}
#[must_use]
pub fn scaling(&self) -> &RopeScalingType {
&self.scaling
}
#[must_use]
pub fn inv_freq(&self) -> &[f32] {
&self.inv_freq
}
#[must_use]
pub fn mscale(&self) -> f32 {
self.mscale
}
#[must_use]
pub fn context_length_multiplier(&self) -> f32 {
match &self.scaling {
RopeScalingType::None => 1.0,
RopeScalingType::Linear { scale } | RopeScalingType::Ntk { scale } => *scale,
RopeScalingType::DynamicNtk {
original_max_len,
target_max_len,
}
| RopeScalingType::Yarn {
original_max_len,
target_max_len,
..
} => (*target_max_len as f32) / (*original_max_len as f32),
}
}
}
#[derive(Debug, Clone)]
pub struct ALiBi {
num_heads: usize,
slopes: Vec<f32>,
}
impl ALiBi {
pub fn new(num_heads: usize) -> Result<Self> {
if num_heads == 0 {
return Err(RealizarError::InvalidShape {
reason: "num_heads must be > 0".to_string(),
});
}
let slopes = Self::compute_slopes(num_heads);
Ok(Self { num_heads, slopes })
}
fn compute_slopes(num_heads: usize) -> Vec<f32> {
let closest_power_of_2 = if num_heads.is_power_of_two() {
num_heads
} else {
num_heads.next_power_of_two() / 2
};
#[allow(clippy::cast_precision_loss)]
let ratio = 8.0 / (closest_power_of_2 as f32);
let mut slopes = Vec::with_capacity(num_heads);
for i in 0..closest_power_of_2.min(num_heads) {
#[allow(clippy::cast_precision_loss)]
let exponent = -(i as f32) * ratio;
slopes.push(2_f32.powf(exponent));
}
if num_heads > closest_power_of_2 {
#[allow(clippy::cast_precision_loss)]
let extra_ratio = 4.0 / (closest_power_of_2 as f32);
for i in 0..(num_heads - closest_power_of_2) {
#[allow(clippy::cast_precision_loss)]
let exponent = -((2 * i + 1) as f32) * extra_ratio;
slopes.push(2_f32.powf(exponent));
}
}
slopes
}
pub fn get_bias(&self, seq_len: usize) -> Result<Tensor<f32>> {
if seq_len == 0 {
return Err(RealizarError::InvalidShape {
reason: "seq_len must be > 0".to_string(),
});
}
let total_size = seq_len * seq_len * self.num_heads;
let mut data = Vec::with_capacity(total_size);
for i in 0..seq_len {
for j in 0..seq_len {
for &slope in &self.slopes {
#[allow(clippy::cast_precision_loss)]
let distance = (i as f32 - j as f32).abs();
let bias = -slope * distance;
data.push(bias);
}
}
}
Tensor::from_vec(vec![seq_len, seq_len, self.num_heads], data)
}
#[must_use]
pub fn num_heads(&self) -> usize {
self.num_heads
}
#[must_use]
pub fn slopes(&self) -> &[f32] {
&self.slopes
}
}
#[derive(Debug, Clone)]
pub struct KVCache {
num_layers: usize,
max_seq_len: usize,
head_dim: usize,
current_pos: usize,
keys: Vec<Vec<f32>>,
values: Vec<Vec<f32>>,
}
impl KVCache {
pub fn new(num_layers: usize, max_seq_len: usize, head_dim: usize) -> Result<Self> {
if num_layers == 0 {
return Err(RealizarError::InvalidShape {
reason: "num_layers must be > 0".to_string(),
});
}
if max_seq_len == 0 {
return Err(RealizarError::InvalidShape {
reason: "max_seq_len must be > 0".to_string(),
});
}
if head_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "head_dim must be > 0".to_string(),
});
}
let cache_size = max_seq_len * head_dim;
let keys = vec![vec![0.0; cache_size]; num_layers];
let values = vec![vec![0.0; cache_size]; num_layers];
Ok(Self {
num_layers,
max_seq_len,
head_dim,
current_pos: 0,
keys,
values,
})
}
pub fn update(&mut self, layer: usize, key: &Tensor<f32>, value: &Tensor<f32>) -> Result<()> {
if layer >= self.num_layers {
return Err(RealizarError::InvalidShape {
reason: format!(
"Layer {} out of bounds (max {})",
layer,
self.num_layers - 1
),
});
}
if self.current_pos >= self.max_seq_len {
return Err(RealizarError::InvalidShape {
reason: format!(
"Cache full at position {} (max {})",
self.current_pos, self.max_seq_len
),
});
}
let k_data = key.data();
let v_data = value.data();
if k_data.len() != self.head_dim {
return Err(RealizarError::InvalidShape {
reason: format!("Key size {} != head_dim {}", k_data.len(), self.head_dim),
});
}
if v_data.len() != self.head_dim {
return Err(RealizarError::InvalidShape {
reason: format!("Value size {} != head_dim {}", v_data.len(), self.head_dim),
});
}
let offset = self.current_pos * self.head_dim;
self.keys[layer][offset..offset + self.head_dim].copy_from_slice(k_data);
self.values[layer][offset..offset + self.head_dim].copy_from_slice(v_data);
Ok(())
}
pub fn advance(&mut self) {
if self.current_pos < self.max_seq_len {
self.current_pos += 1;
}
}
pub fn get_key(&self, layer: usize) -> Result<Tensor<f32>> {
if layer >= self.num_layers {
return Err(RealizarError::InvalidShape {
reason: format!(
"Layer {} out of bounds (max {})",
layer,
self.num_layers - 1
),
});
}
if self.current_pos == 0 {
return Tensor::from_vec(vec![1, self.head_dim], vec![0.0; self.head_dim]);
}
let size = self.current_pos * self.head_dim;
let data = self.keys[layer][..size].to_vec();
Tensor::from_vec(vec![self.current_pos, self.head_dim], data)
}
pub fn get_value(&self, layer: usize) -> Result<Tensor<f32>> {
if layer >= self.num_layers {
return Err(RealizarError::InvalidShape {
reason: format!(
"Layer {} out of bounds (max {})",
layer,
self.num_layers - 1
),
});
}
if self.current_pos == 0 {
return Tensor::from_vec(vec![1, self.head_dim], vec![0.0; self.head_dim]);
}
let size = self.current_pos * self.head_dim;
let data = self.values[layer][..size].to_vec();
Tensor::from_vec(vec![self.current_pos, self.head_dim], data)
}
pub fn clear(&mut self) {
self.current_pos = 0;
for layer in 0..self.num_layers {
self.keys[layer].fill(0.0);
self.values[layer].fill(0.0);
}
}
#[must_use]
pub fn current_pos(&self) -> usize {
self.current_pos
}
#[must_use]
pub fn num_layers(&self) -> usize {
self.num_layers
}
#[must_use]
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
#[must_use]
pub fn head_dim(&self) -> usize {
self.head_dim
}
#[must_use]
pub fn is_full(&self) -> bool {
self.current_pos >= self.max_seq_len
}
}
#[derive(Debug, Clone)]
pub struct TransformerBlock {
attn_norm: LayerNorm,
attention: MultiHeadAttention,
ffn_norm: LayerNorm,
ffn: FeedForward,
hidden_dim: usize,
num_heads: usize,
}
impl TransformerBlock {
pub fn new(
hidden_dim: usize,
num_heads: usize,
intermediate_dim: usize,
eps: f32,
) -> Result<Self> {
if hidden_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "hidden_dim must be > 0".to_string(),
});
}
if num_heads == 0 {
return Err(RealizarError::InvalidShape {
reason: "num_heads must be > 0".to_string(),
});
}
if hidden_dim % num_heads != 0 {
return Err(RealizarError::InvalidShape {
reason: format!(
"hidden_dim {hidden_dim} must be divisible by num_heads {num_heads}"
),
});
}
let attn_norm = LayerNorm::new(hidden_dim, eps)?;
let attention = MultiHeadAttention::mha(hidden_dim, num_heads)?;
let ffn_norm = LayerNorm::new(hidden_dim, eps)?;
let ffn = FeedForward::new(hidden_dim, intermediate_dim)?;
Ok(Self {
attn_norm,
attention,
ffn_norm,
ffn,
hidden_dim,
num_heads,
})
}
pub fn forward(&self, input: &Tensor<f32>) -> Result<Tensor<f32>> {
let shape = input.shape();
if shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Input tensor must have at least 1 dimension".to_string(),
});
}
let last_dim = shape[shape.len() - 1];
if last_dim != self.hidden_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Expected last dimension {}, got {}",
self.hidden_dim, last_dim
),
});
}
let normed = self.attn_norm.forward(input)?;
let attn_out = self.attention.forward(&normed)?;
let mut residual1 = Vec::with_capacity(input.data().len());
for (i, &val) in input.data().iter().enumerate() {
residual1.push(val + attn_out.data()[i]);
}
let after_attn = Tensor::from_vec(shape.to_vec(), residual1)?;
let normed2 = self.ffn_norm.forward(&after_attn)?;
let ffn_out = self.ffn.forward(&normed2)?;
let mut residual2 = Vec::with_capacity(after_attn.data().len());
for (i, &val) in after_attn.data().iter().enumerate() {
residual2.push(val + ffn_out.data()[i]);
}
Tensor::from_vec(shape.to_vec(), residual2)
}
#[must_use]
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
pub fn attn_norm_mut(&mut self) -> &mut LayerNorm {
&mut self.attn_norm
}
pub fn attention_mut(&mut self) -> &mut MultiHeadAttention {
&mut self.attention
}
#[must_use]
pub fn num_heads(&self) -> usize {
self.num_heads
}
pub fn ffn_norm_mut(&mut self) -> &mut LayerNorm {
&mut self.ffn_norm
}
pub fn ffn_mut(&mut self) -> &mut FeedForward {
&mut self.ffn
}
}
#[derive(Debug, Clone)]
pub struct Embedding {
vocab_size: usize,
embed_dim: usize,
weights: Vec<f32>,
}
impl Embedding {
pub fn new(vocab_size: usize, embed_dim: usize) -> Result<Self> {
if vocab_size == 0 {
return Err(RealizarError::InvalidShape {
reason: "vocab_size must be > 0".to_string(),
});
}
if embed_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "embed_dim must be > 0".to_string(),
});
}
let weights = vec![0.0; vocab_size * embed_dim];
Ok(Self {
vocab_size,
embed_dim,
weights,
})
}
pub fn forward(&self, token_ids: &[usize]) -> Result<Tensor<f32>> {
if token_ids.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Token IDs cannot be empty".to_string(),
});
}
let seq_len = token_ids.len();
let mut output = Vec::with_capacity(seq_len * self.embed_dim);
for &token_id in token_ids {
if token_id >= self.vocab_size {
return Err(RealizarError::InvalidShape {
reason: format!(
"Token ID {token_id} out of bounds (vocab_size={})",
self.vocab_size
),
});
}
let offset = token_id * self.embed_dim;
output.extend_from_slice(&self.weights[offset..offset + self.embed_dim]);
}
Tensor::from_vec(vec![seq_len, self.embed_dim], output)
}
#[must_use]
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
#[must_use]
pub fn embed_dim(&self) -> usize {
self.embed_dim
}
pub fn weights_mut(&mut self) -> &mut [f32] {
&mut self.weights
}
}
#[derive(Debug, Clone)]
pub struct Model {
embedding: Embedding,
blocks: Vec<TransformerBlock>,
final_norm: LayerNorm,
lm_head: Linear,
config: ModelConfig,
}
#[derive(Debug, Clone)]
pub struct ModelConfig {
pub vocab_size: usize,
pub hidden_dim: usize,
pub num_heads: usize,
pub num_layers: usize,
pub intermediate_dim: usize,
pub eps: f32,
}
impl Model {
pub fn new(config: ModelConfig) -> Result<Self> {
let embedding = Embedding::new(config.vocab_size, config.hidden_dim)?;
let mut blocks = Vec::with_capacity(config.num_layers);
for _ in 0..config.num_layers {
blocks.push(TransformerBlock::new(
config.hidden_dim,
config.num_heads,
config.intermediate_dim,
config.eps,
)?);
}
let final_norm = LayerNorm::new(config.hidden_dim, config.eps)?;
let lm_head = Linear::new(config.hidden_dim, config.vocab_size)?;
Ok(Self {
embedding,
blocks,
final_norm,
lm_head,
config,
})
}
pub fn forward(&self, token_ids: &[usize]) -> Result<Tensor<f32>> {
let mut hidden = self.embedding.forward(token_ids)?;
for block in &self.blocks {
hidden = block.forward(&hidden)?;
}
hidden = self.final_norm.forward(&hidden)?;
self.lm_head.forward(&hidden)
}
#[must_use]
pub fn config(&self) -> &ModelConfig {
&self.config
}
pub fn embedding_mut(&mut self) -> &mut Embedding {
&mut self.embedding
}
pub fn blocks_mut(&mut self) -> &mut [TransformerBlock] {
&mut self.blocks
}
pub fn final_norm_mut(&mut self) -> &mut LayerNorm {
&mut self.final_norm
}
pub fn lm_head_mut(&mut self) -> &mut Linear {
&mut self.lm_head
}
#[must_use]
pub fn num_parameters(&self) -> usize {
let embed_params = self.config.vocab_size * self.config.hidden_dim;
let block_params = self.config.num_layers
* (
2 * self.config.hidden_dim + self.config.hidden_dim * self.config.intermediate_dim + self.config.intermediate_dim * self.config.hidden_dim
);
let head_params = self.config.hidden_dim * self.config.vocab_size;
embed_params + block_params + head_params
}
pub fn generate(&self, prompt: &[usize], config: &GenerationConfig) -> Result<Vec<usize>> {
if prompt.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Prompt cannot be empty".to_string(),
});
}
let mut tokens = prompt.to_vec();
let mut rng_state = config.seed.unwrap_or(42);
for _ in 0..config.max_tokens {
let logits = self.forward(&tokens)?;
let seq_len = tokens.len();
let vocab_size = self.config.vocab_size;
let last_logits_start = (seq_len - 1) * vocab_size;
let last_logits = &logits.data()[last_logits_start..last_logits_start + vocab_size];
let last_logits_tensor = Tensor::from_vec(vec![vocab_size], last_logits.to_vec())?;
rng_state = rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
#[allow(clippy::cast_precision_loss)]
let rng_value = (rng_state >> 33) as f32 / (1u64 << 31) as f32;
let next_token = sample_token(&last_logits_tensor, config, rng_value)?;
if let Some(eos_id) = config.eos_token_id {
if next_token == eos_id {
break;
}
}
tokens.push(next_token);
}
Ok(tokens)
}
}
#[cfg(all(test, feature = "heavy-tests"))]
mod tests {
use super::*;
#[test]
fn test_layer_norm_creation() {
let layer_norm = LayerNorm::new(512, 1e-5).unwrap();
assert_eq!(layer_norm.normalized_shape(), 512);
assert!((layer_norm.eps() - 1e-5).abs() < 1e-10);
}
#[test]
fn test_layer_norm_zero_shape_error() {
let result = LayerNorm::new(0, 1e-5);
assert!(result.is_err());
}
#[test]
fn test_layer_norm_forward_simple() {
let layer_norm = LayerNorm::new(3, 1e-5).unwrap();
let input = Tensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
let output = layer_norm.forward(&input).unwrap();
let output_data = output.data();
assert_eq!(output_data.len(), 3);
let mean: f32 = output_data.iter().sum::<f32>() / 3.0;
assert!((mean - 0.0).abs() < 1e-5);
let variance: f32 = output_data
.iter()
.map(|&x| {
let diff = x - mean;
diff * diff
})
.sum::<f32>()
/ 3.0;
assert!((variance - 1.0).abs() < 1e-4);
}
#[test]
fn test_layer_norm_forward_batched() {
let layer_norm = LayerNorm::new(2, 1e-5).unwrap();
let input = Tensor::from_vec(vec![2, 2], vec![1.0, 3.0, 2.0, 4.0]).unwrap();
let output = layer_norm.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 2]);
let output_data = output.data();
assert_eq!(output_data.len(), 4);
let group1_mean = (output_data[0] + output_data[1]) / 2.0;
assert!((group1_mean - 0.0).abs() < 1e-5);
let group2_mean = (output_data[2] + output_data[3]) / 2.0;
assert!((group2_mean - 0.0).abs() < 1e-5);
}
#[test]
fn test_layer_norm_empty_shape_handling() {
let result = LayerNorm::new(0, 1e-5);
assert!(result.is_err());
}
#[test]
fn test_layer_norm_shape_mismatch_error() {
let layer_norm = LayerNorm::new(3, 1e-5).unwrap();
let input = Tensor::from_vec(vec![2], vec![1.0, 2.0]).unwrap(); let result = layer_norm.forward(&input);
assert!(result.is_err());
}
#[test]
fn test_layer_norm_zero_variance() {
let layer_norm = LayerNorm::new(3, 1e-5).unwrap();
let input = Tensor::from_vec(vec![3], vec![2.0, 2.0, 2.0]).unwrap();
let output = layer_norm.forward(&input).unwrap();
let output_data = output.data();
for &val in output_data {
assert!(val.abs() < 1e-2); }
}
#[test]
fn test_linear_creation() {
let linear = Linear::new(128, 256).unwrap();
assert_eq!(linear.in_features(), 128);
assert_eq!(linear.out_features(), 256);
}
#[test]
fn test_linear_zero_dimensions_error() {
let result = Linear::new(0, 256);
assert!(result.is_err());
let result = Linear::new(128, 0);
assert!(result.is_err());
}
#[test]
fn test_linear_forward_simple() {
let mut linear = Linear::new(2, 3).unwrap();
linear.weight_mut()[0] = 1.0; linear.weight_mut()[1] = 0.0; linear.weight_mut()[2] = 0.0; linear.weight_mut()[3] = 0.0; linear.weight_mut()[4] = 1.0; linear.weight_mut()[5] = 0.0;
linear.bias_mut()[0] = 0.5;
linear.bias_mut()[1] = 0.5;
linear.bias_mut()[2] = 0.5;
let input = Tensor::from_vec(vec![2], vec![2.0, 3.0]).unwrap();
let output = linear.forward(&input).unwrap();
assert_eq!(output.shape(), &[3]);
let output_data = output.data();
assert!((output_data[0] - 2.5).abs() < 1e-5);
assert!((output_data[1] - 3.5).abs() < 1e-5);
assert!((output_data[2] - 0.5).abs() < 1e-5);
}
#[test]
fn test_linear_forward_batched() {
let mut linear = Linear::new(2, 3).unwrap();
for i in 0..6 {
linear.weight_mut()[i] = 1.0;
}
for i in 0..3 {
linear.bias_mut()[i] = 0.0;
}
let input = Tensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let output = linear.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 3]);
let output_data = output.data();
assert!((output_data[0] - 3.0).abs() < 1e-5);
assert!((output_data[1] - 3.0).abs() < 1e-5);
assert!((output_data[2] - 3.0).abs() < 1e-5);
assert!((output_data[3] - 7.0).abs() < 1e-5);
assert!((output_data[4] - 7.0).abs() < 1e-5);
assert!((output_data[5] - 7.0).abs() < 1e-5);
}
#[test]
fn test_linear_shape_mismatch_error() {
let linear = Linear::new(3, 2).unwrap();
let input = Tensor::from_vec(vec![2], vec![1.0, 2.0]).unwrap(); let result = linear.forward(&input);
assert!(result.is_err());
}
#[test]
fn test_linear_weight_bias_mut() {
let mut linear = Linear::new(2, 3).unwrap();
linear.weight_mut()[0] = 42.0;
assert!((linear.weight_mut()[0] - 42.0).abs() < 1e-6);
linear.bias_mut()[0] = 7.0;
assert!((linear.bias_mut()[0] - 7.0).abs() < 1e-6);
}
#[test]
fn test_fused_layer_norm_linear_creation() {
let fused = FusedLayerNormLinear::new(4, 8, 1e-5).unwrap();
assert_eq!(fused.feature_dim(), 4);
assert_eq!(fused.out_features(), 8);
}
#[test]
fn test_fused_layer_norm_linear_zero_dims_error() {
let result = FusedLayerNormLinear::new(0, 8, 1e-5);
assert!(result.is_err());
let result = FusedLayerNormLinear::new(4, 0, 1e-5);
assert!(result.is_err());
}
#[test]
fn test_fused_layer_norm_linear_matches_separate() {
let feature_dim = 4;
let out_features = 3;
let mut fused = FusedLayerNormLinear::new(feature_dim, out_features, 1e-5).unwrap();
for (i, weight) in fused.linear_weight_mut().iter_mut().enumerate() {
#[allow(clippy::cast_precision_loss)]
{
*weight = (i as f32) * 0.1;
}
}
for (i, bias) in fused.linear_bias_mut().iter_mut().enumerate() {
#[allow(clippy::cast_precision_loss)]
{
*bias = (i as f32) * 0.05;
}
}
let layer_norm = LayerNorm::new(feature_dim, 1e-5).unwrap();
let mut linear = Linear::new(feature_dim, out_features).unwrap();
for (i, weight) in linear.weight_mut().iter_mut().enumerate() {
#[allow(clippy::cast_precision_loss)]
{
*weight = (i as f32) * 0.1;
}
}
for (i, bias) in linear.bias_mut().iter_mut().enumerate() {
#[allow(clippy::cast_precision_loss)]
{
*bias = (i as f32) * 0.05;
}
}
let input = Tensor::from_vec(vec![feature_dim], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let fused_output = fused.forward(&input).unwrap();
let norm_output = layer_norm.forward(&input).unwrap();
let separate_output = linear.forward(&norm_output).unwrap();
assert_eq!(fused_output.shape(), separate_output.shape());
for i in 0..fused_output.data().len() {
assert!(
(fused_output.data()[i] - separate_output.data()[i]).abs() < 1e-4,
"Mismatch at {}: fused={} vs separate={}",
i,
fused_output.data()[i],
separate_output.data()[i]
);
}
}
#[test]
fn test_fused_layer_norm_linear_batched() {
let feature_dim = 4;
let out_features = 2;
let mut fused = FusedLayerNormLinear::new(feature_dim, out_features, 1e-5).unwrap();
for weight in fused.linear_weight_mut().iter_mut() {
*weight = 1.0;
}
let input = Tensor::from_vec(
vec![2, feature_dim],
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
)
.unwrap();
let output = fused.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, out_features]);
}
#[test]
fn test_fused_layer_norm_linear_parallel_matches_serial() {
let feature_dim = 8;
let out_features = 4;
let mut fused = FusedLayerNormLinear::new(feature_dim, out_features, 1e-5).unwrap();
for (i, weight) in fused.linear_weight_mut().iter_mut().enumerate() {
#[allow(clippy::cast_precision_loss)]
{
*weight = ((i * 7 + 3) % 11) as f32 * 0.1;
}
}
for (i, bias) in fused.linear_bias_mut().iter_mut().enumerate() {
#[allow(clippy::cast_precision_loss)]
{
*bias = ((i * 5 + 2) % 7) as f32 * 0.1;
}
}
let mut input_data = Vec::new();
for i in 0..32 {
for j in 0..feature_dim {
#[allow(clippy::cast_precision_loss)]
{
input_data.push(((i * feature_dim + j) % 17) as f32 * 0.2);
}
}
}
let input = Tensor::from_vec(vec![32, feature_dim], input_data).unwrap();
let serial_output = fused.forward(&input).unwrap();
let parallel_output = fused.forward_parallel(&input).unwrap();
assert_eq!(serial_output.shape(), parallel_output.shape());
for i in 0..serial_output.data().len() {
assert!(
(serial_output.data()[i] - parallel_output.data()[i]).abs() < 1e-5,
"Mismatch at {}: serial={} vs parallel={}",
i,
serial_output.data()[i],
parallel_output.data()[i]
);
}
}
#[test]
fn test_fused_layer_norm_linear_dimension_mismatch_error() {
let fused = FusedLayerNormLinear::new(4, 8, 1e-5).unwrap();
let input = Tensor::from_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap();
let result = fused.forward(&input);
assert!(result.is_err());
}
#[test]
fn test_softmax_simple() {
let input = Tensor::from_vec(vec![3], vec![0.0, 0.0, 0.0]).unwrap();
let output = softmax(&input).unwrap();
assert_eq!(output.shape(), &[3]);
assert!((output.data()[0] - 0.333_333).abs() < 1e-5);
assert!((output.data()[1] - 0.333_333).abs() < 1e-5);
assert!((output.data()[2] - 0.333_333).abs() < 1e-5);
let sum: f32 = output.data().iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
}
#[test]
fn test_softmax_probabilities_sum_to_one() {
let input = Tensor::from_vec(vec![4], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let output = softmax(&input).unwrap();
let sum: f32 = output.data().iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
for &val in output.data() {
assert!(val > 0.0);
assert!(val < 1.0);
}
}
#[test]
fn test_softmax_max_dominates() {
let input = Tensor::from_vec(vec![3], vec![0.0, 0.0, 10.0]).unwrap();
let output = softmax(&input).unwrap();
assert!(output.data()[2] > 0.999);
assert!(output.data()[0] < 0.001);
assert!(output.data()[1] < 0.001);
}
#[test]
fn test_softmax_batched() {
let input = Tensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let output = softmax(&input).unwrap();
assert_eq!(output.shape(), &[2, 2]);
let row1_sum = output.data()[0] + output.data()[1];
let row2_sum = output.data()[2] + output.data()[3];
assert!((row1_sum - 1.0).abs() < 1e-6);
assert!((row2_sum - 1.0).abs() < 1e-6);
}
#[test]
fn test_softmax_numerical_stability() {
let input = Tensor::from_vec(vec![3], vec![1000.0, 1001.0, 1002.0]).unwrap();
let output = softmax(&input).unwrap();
for &val in output.data() {
assert!(val.is_finite());
}
let sum: f32 = output.data().iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
}
#[test]
fn test_softmax_preserves_shape() {
let input = Tensor::from_vec(vec![2, 3, 4], vec![1.0; 24]).unwrap();
let output = softmax(&input).unwrap();
assert_eq!(output.shape(), &[2, 3, 4]);
}
#[test]
fn test_gelu_zero() {
let input = Tensor::from_vec(vec![1], vec![0.0]).unwrap();
let output = gelu(&input).unwrap();
assert!((output.data()[0] - 0.0).abs() < 1e-6);
}
#[test]
fn test_gelu_positive() {
let input = Tensor::from_vec(vec![1], vec![1.0]).unwrap();
let output = gelu(&input).unwrap();
assert!(output.data()[0] > 0.8);
assert!(output.data()[0] < 0.9);
}
#[test]
fn test_gelu_negative() {
let input = Tensor::from_vec(vec![1], vec![-1.0]).unwrap();
let output = gelu(&input).unwrap();
assert!(output.data()[0] < 0.0);
assert!(output.data()[0] > -0.2);
}
#[test]
fn test_gelu_batched() {
let input = Tensor::from_vec(vec![2, 3], vec![-2.0, -1.0, 0.0, 1.0, 2.0, 3.0]).unwrap();
let output = gelu(&input).unwrap();
assert_eq!(output.shape(), &[2, 3]);
assert_eq!(output.data().len(), 6);
assert!((output.data()[2] - 0.0).abs() < 1e-6);
assert!(output.data()[3] > 0.0);
assert!(output.data()[4] > 0.0);
assert!(output.data()[5] > 0.0);
}
#[test]
fn test_gelu_preserves_shape() {
let input = Tensor::from_vec(vec![2, 3, 4], vec![0.5; 24]).unwrap();
let output = gelu(&input).unwrap();
assert_eq!(output.shape(), &[2, 3, 4]);
assert_eq!(output.data().len(), 24);
}
#[test]
fn test_ffn_creation() {
let ffn = FeedForward::new(512, 2048).unwrap();
assert_eq!(ffn.hidden_dim(), 512);
assert_eq!(ffn.intermediate_dim(), 2048);
}
#[test]
fn test_ffn_zero_dimensions_error() {
let result = FeedForward::new(0, 2048);
assert!(result.is_err());
let result = FeedForward::new(512, 0);
assert!(result.is_err());
}
#[test]
fn test_ffn_forward_shape() {
let ffn = FeedForward::new(4, 16).unwrap(); let input = Tensor::from_vec(vec![2, 4], vec![1.0; 8]).unwrap();
let output = ffn.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 4]);
}
#[test]
fn test_ffn_forward_computation() {
let mut ffn = FeedForward::new(2, 4).unwrap();
for i in 0..8 {
ffn.fc1_mut().weight_mut()[i] = 0.1;
}
for i in 0..4 {
ffn.fc1_mut().bias_mut()[i] = 0.0;
}
for i in 0..8 {
ffn.fc2_mut().weight_mut()[i] = 0.1;
}
for i in 0..2 {
ffn.fc2_mut().bias_mut()[i] = 0.0;
}
let input = Tensor::from_vec(vec![2], vec![1.0, 2.0]).unwrap();
let output = ffn.forward(&input).unwrap();
assert_eq!(output.shape(), &[2]);
assert!(output.data()[0].is_finite());
assert!(output.data()[1].is_finite());
}
#[test]
fn test_ffn_batched() {
let ffn = FeedForward::new(3, 12).unwrap();
let input = Tensor::from_vec(vec![2, 3], vec![0.5; 6]).unwrap();
let output = ffn.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 3]);
assert_eq!(output.data().len(), 6);
}
#[test]
fn test_ffn_weight_access() {
let mut ffn = FeedForward::new(2, 4).unwrap();
ffn.fc1_mut().weight_mut()[0] = 42.0;
assert!((ffn.fc1_mut().weight_mut()[0] - 42.0).abs() < 1e-6);
ffn.fc2_mut().bias_mut()[0] = 7.0;
assert!((ffn.fc2_mut().bias_mut()[0] - 7.0).abs() < 1e-6);
}
#[test]
fn test_attention_creation() {
let attn = Attention::new(64).unwrap();
assert_eq!(attn.head_dim(), 64);
assert!((attn.scale() - 0.125).abs() < 1e-6);
}
#[test]
fn test_attention_zero_head_dim_error() {
let result = Attention::new(0);
assert!(result.is_err());
}
#[test]
fn test_attention_forward_shape() {
let attn = Attention::new(4).unwrap();
let q = Tensor::from_vec(vec![3, 4], vec![0.1; 12]).unwrap();
let k = Tensor::from_vec(vec![3, 4], vec![0.2; 12]).unwrap();
let v = Tensor::from_vec(vec![3, 4], vec![0.3; 12]).unwrap();
let output = attn.forward(&q, &k, &v).unwrap();
assert_eq!(output.shape(), &[3, 4]);
assert_eq!(output.data().len(), 12);
}
#[test]
fn test_attention_forward_computation() {
let attn = Attention::new(2).unwrap();
let q = Tensor::from_vec(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]).unwrap();
let k = Tensor::from_vec(vec![2, 2], vec![1.0, 0.0, 0.0, 1.0]).unwrap();
let v = Tensor::from_vec(vec![2, 2], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let output = attn.forward(&q, &k, &v).unwrap();
assert_eq!(output.shape(), &[2, 2]);
for &val in output.data() {
assert!(val.is_finite());
}
assert!(output.data()[0] < 2.0); assert!(output.data()[1] < 3.0); }
#[test]
fn test_attention_shape_mismatch_error() {
let attn = Attention::new(4).unwrap();
let q = Tensor::from_vec(vec![2, 4], vec![0.1; 8]).unwrap();
let k = Tensor::from_vec(vec![2, 3], vec![0.2; 6]).unwrap();
let v = Tensor::from_vec(vec![2, 4], vec![0.3; 8]).unwrap();
let result = attn.forward(&q, &k, &v);
assert!(result.is_err());
}
#[test]
fn test_attention_kv_seq_len_mismatch_error() {
let attn = Attention::new(4).unwrap();
let q = Tensor::from_vec(vec![2, 4], vec![0.1; 8]).unwrap();
let k = Tensor::from_vec(vec![3, 4], vec![0.2; 12]).unwrap();
let v = Tensor::from_vec(vec![2, 4], vec![0.3; 8]).unwrap();
let result = attn.forward(&q, &k, &v);
assert!(result.is_err());
}
#[test]
fn test_attention_softmax_weights_sum() {
let attn = Attention::new(3).unwrap();
let q = Tensor::from_vec(vec![2, 3], vec![1.0; 6]).unwrap();
let k = Tensor::from_vec(vec![2, 3], vec![1.0; 6]).unwrap();
let v = Tensor::from_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let output = attn.forward(&q, &k, &v).unwrap();
let expected = [2.5, 3.5, 4.5];
for row in 0..2 {
for (col, &exp) in expected.iter().enumerate() {
let actual = output.data()[row * 3 + col];
assert!(
(actual - exp).abs() < 0.01,
"row={row}, col={col}: expected {exp}, got {actual}",
);
}
}
}
#[test]
fn test_attention_single_position() {
let attn = Attention::new(4).unwrap();
let q = Tensor::from_vec(vec![4], vec![1.0, 0.0, 0.0, 0.0]).unwrap();
let k = Tensor::from_vec(vec![4], vec![1.0, 0.0, 0.0, 0.0]).unwrap();
let v = Tensor::from_vec(vec![4], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let output = attn.forward(&q, &k, &v).unwrap();
assert_eq!(output.shape(), &[1, 4]);
for i in 0..4 {
assert!((output.data()[i] - v.data()[i]).abs() < 1e-6);
}
}
#[test]
fn test_flash_attention_matches_standard() {
let attn = Attention::new(8).unwrap();
let q_data = vec![1.0, 2.0, 0.5, 1.5, 2.5, 0.3, 1.2, 0.8];
let k_data = vec![0.5, 1.0, 1.5, 0.8, 0.3, 2.0, 0.9, 1.1];
let v_data = vec![2.0, 1.0, 0.5, 3.0, 1.5, 0.7, 2.5, 0.9];
let q = Tensor::from_vec(vec![1, 8], q_data.clone()).unwrap();
let k = Tensor::from_vec(vec![1, 8], k_data.clone()).unwrap();
let v = Tensor::from_vec(vec![1, 8], v_data.clone()).unwrap();
let standard_output = attn.forward(&q, &k, &v).unwrap();
let flash_output = attn.flash_forward(&q, &k, &v, 1).unwrap();
assert_eq!(standard_output.shape(), flash_output.shape());
for i in 0..standard_output.data().len() {
assert!(
(standard_output.data()[i] - flash_output.data()[i]).abs() < 1e-5,
"Mismatch at index {}: {} vs {}",
i,
standard_output.data()[i],
flash_output.data()[i]
);
}
}
#[test]
fn test_flash_attention_multi_position() {
let attn = Attention::new(4).unwrap();
#[rustfmt::skip]
let q_data = vec![
1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, ];
#[rustfmt::skip]
let k_data = vec![
1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, ];
#[rustfmt::skip]
let v_data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ];
let q = Tensor::from_vec(vec![3, 4], q_data).unwrap();
let k = Tensor::from_vec(vec![3, 4], k_data).unwrap();
let v = Tensor::from_vec(vec![3, 4], v_data).unwrap();
let standard_output = attn.forward(&q, &k, &v).unwrap();
for block_size in [1, 2, 3, 4] {
let flash_output = attn.flash_forward(&q, &k, &v, block_size).unwrap();
assert_eq!(standard_output.shape(), flash_output.shape());
for i in 0..standard_output.data().len() {
assert!(
(standard_output.data()[i] - flash_output.data()[i]).abs() < 1e-4,
"Block size {}, mismatch at index {}: {} vs {}",
block_size,
i,
standard_output.data()[i],
flash_output.data()[i]
);
}
}
}
#[test]
fn test_flash_attention_zero_block_size_error() {
let attn = Attention::new(4).unwrap();
let q = Tensor::from_vec(vec![1, 4], vec![1.0, 0.0, 0.0, 0.0]).unwrap();
let k = Tensor::from_vec(vec![1, 4], vec![1.0, 0.0, 0.0, 0.0]).unwrap();
let v = Tensor::from_vec(vec![1, 4], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let result = attn.flash_forward(&q, &k, &v, 0);
assert!(result.is_err());
}
#[test]
fn test_flash_attention_large_sequence() {
let attn = Attention::new(8).unwrap();
let mut q_data = Vec::new();
let mut k_data = Vec::new();
let mut v_data = Vec::new();
for i in 0..16 {
for j in 0..8 {
#[allow(clippy::cast_precision_loss)]
{
q_data.push((i * 8 + j) as f32 * 0.1);
k_data.push((i * 8 + j) as f32 * 0.05);
v_data.push((i * 8 + j) as f32 * 0.2);
}
}
}
let q = Tensor::from_vec(vec![16, 8], q_data).unwrap();
let k = Tensor::from_vec(vec![16, 8], k_data).unwrap();
let v = Tensor::from_vec(vec![16, 8], v_data).unwrap();
let standard_output = attn.forward(&q, &k, &v).unwrap();
let flash_output = attn.flash_forward(&q, &k, &v, 4).unwrap();
assert_eq!(standard_output.shape(), flash_output.shape());
for i in 0..standard_output.data().len() {
assert!(
(standard_output.data()[i] - flash_output.data()[i]).abs() < 1e-3,
"Mismatch at index {}: {} vs {}",
i,
standard_output.data()[i],
flash_output.data()[i]
);
}
}
#[test]
fn test_flash_attention_shape_errors() {
let attn = Attention::new(4).unwrap();
let q = Tensor::from_vec(vec![2, 4], vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
let k = Tensor::from_vec(vec![2, 4], vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).unwrap();
let v_wrong = Tensor::from_vec(
vec![3, 4],
vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0],
)
.unwrap();
let result = attn.flash_forward(&q, &k, &v_wrong, 2);
assert!(result.is_err());
}
#[test]
fn test_flash_attention_v2_matches_standard() {
let attn = Attention::new(8).unwrap();
let q_data = vec![1.0, 2.0, 0.5, 1.5, 2.5, 0.3, 1.2, 0.8];
let k_data = vec![0.5, 1.0, 1.5, 0.8, 0.3, 2.0, 0.9, 1.1];
let v_data = vec![2.0, 1.0, 0.5, 3.0, 1.5, 0.7, 2.5, 0.9];
let q = Tensor::from_vec(vec![1, 8], q_data).unwrap();
let k = Tensor::from_vec(vec![1, 8], k_data).unwrap();
let v = Tensor::from_vec(vec![1, 8], v_data).unwrap();
let standard = attn.forward(&q, &k, &v).unwrap();
let v2 = attn.flash_forward_v2(&q, &k, &v, 1).unwrap();
assert_eq!(standard.shape(), v2.shape());
for i in 0..standard.data().len() {
assert!(
(standard.data()[i] - v2.data()[i]).abs() < 1e-4,
"Mismatch at {}: {} vs {}",
i,
standard.data()[i],
v2.data()[i]
);
}
}
#[test]
fn test_flash_attention_v2_multi_position() {
let attn = Attention::new(4).unwrap();
#[rustfmt::skip]
let q_data = vec![
1.0, 0.5, 0.3, 1.2,
0.5, 1.0, 0.8, 0.4,
0.3, 0.8, 1.0, 0.6,
];
#[rustfmt::skip]
let k_data = vec![
1.0, 0.5, 0.3, 1.2,
0.5, 1.0, 0.8, 0.4,
0.3, 0.8, 1.0, 0.6,
];
#[rustfmt::skip]
let v_data = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
];
let q = Tensor::from_vec(vec![3, 4], q_data).unwrap();
let k = Tensor::from_vec(vec![3, 4], k_data).unwrap();
let v = Tensor::from_vec(vec![3, 4], v_data).unwrap();
let standard = attn.forward(&q, &k, &v).unwrap();
for block_size in [1, 2, 3, 4] {
let v2 = attn.flash_forward_v2(&q, &k, &v, block_size).unwrap();
assert_eq!(standard.shape(), v2.shape());
for i in 0..standard.data().len() {
assert!(
(standard.data()[i] - v2.data()[i]).abs() < 1e-4,
"Block size {}, mismatch at {}: {} vs {}",
block_size,
i,
standard.data()[i],
v2.data()[i]
);
}
}
}
#[test]
fn test_flash_attention_v2_zero_block_size_error() {
let attn = Attention::new(4).unwrap();
let q = Tensor::from_vec(vec![1, 4], vec![1.0; 4]).unwrap();
let k = Tensor::from_vec(vec![1, 4], vec![1.0; 4]).unwrap();
let v = Tensor::from_vec(vec![1, 4], vec![1.0; 4]).unwrap();
let result = attn.flash_forward_v2(&q, &k, &v, 0);
assert!(result.is_err());
}
#[test]
fn test_flash_attention_v2_large_sequence() {
let attn = Attention::new(8).unwrap();
let mut q_data = Vec::new();
let mut k_data = Vec::new();
let mut v_data = Vec::new();
for i in 0..32 {
for j in 0..8 {
#[allow(clippy::cast_precision_loss)]
{
q_data.push((i * 8 + j) as f32 * 0.05);
k_data.push((i * 8 + j) as f32 * 0.03);
v_data.push((i * 8 + j) as f32 * 0.1);
}
}
}
let q = Tensor::from_vec(vec![32, 8], q_data).unwrap();
let k = Tensor::from_vec(vec![32, 8], k_data).unwrap();
let v = Tensor::from_vec(vec![32, 8], v_data).unwrap();
let standard = attn.forward(&q, &k, &v).unwrap();
let v2 = attn.flash_forward_v2(&q, &k, &v, 8).unwrap();
assert_eq!(standard.shape(), v2.shape());
for i in 0..standard.data().len() {
assert!(
(standard.data()[i] - v2.data()[i]).abs() < 1e-3,
"Mismatch at {}: {} vs {}",
i,
standard.data()[i],
v2.data()[i]
);
}
}
#[test]
fn test_flash_attention_parallel_matches_standard() {
let attn = Attention::new(8).unwrap();
let q_data = vec![1.0, 2.0, 0.5, 1.5, 2.5, 0.3, 1.2, 0.8];
let k_data = vec![0.5, 1.0, 1.5, 0.8, 0.3, 2.0, 0.9, 1.1];
let v_data = vec![2.0, 1.0, 0.5, 3.0, 1.5, 0.7, 2.5, 0.9];
let q = Tensor::from_vec(vec![1, 8], q_data).unwrap();
let k = Tensor::from_vec(vec![1, 8], k_data).unwrap();
let v = Tensor::from_vec(vec![1, 8], v_data).unwrap();
let standard = attn.forward(&q, &k, &v).unwrap();
let parallel = attn.flash_forward_parallel(&q, &k, &v, 1).unwrap();
assert_eq!(standard.shape(), parallel.shape());
for i in 0..standard.data().len() {
assert!(
(standard.data()[i] - parallel.data()[i]).abs() < 1e-4,
"Mismatch at {}: {} vs {}",
i,
standard.data()[i],
parallel.data()[i]
);
}
}
#[test]
fn test_flash_attention_parallel_multi_position() {
let attn = Attention::new(4).unwrap();
#[rustfmt::skip]
let q_data = vec![
1.0, 0.5, 0.3, 1.2,
0.5, 1.0, 0.8, 0.4,
0.3, 0.8, 1.0, 0.6,
0.7, 0.2, 0.9, 0.5,
];
let k_data = q_data.clone();
#[rustfmt::skip]
let v_data = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
];
let q = Tensor::from_vec(vec![4, 4], q_data).unwrap();
let k = Tensor::from_vec(vec![4, 4], k_data).unwrap();
let v = Tensor::from_vec(vec![4, 4], v_data).unwrap();
let standard = attn.forward(&q, &k, &v).unwrap();
for block_size in [1, 2, 4] {
let parallel = attn.flash_forward_parallel(&q, &k, &v, block_size).unwrap();
assert_eq!(standard.shape(), parallel.shape());
for i in 0..standard.data().len() {
assert!(
(standard.data()[i] - parallel.data()[i]).abs() < 1e-4,
"Block size {}, mismatch at {}: {} vs {}",
block_size,
i,
standard.data()[i],
parallel.data()[i]
);
}
}
}
#[test]
fn test_flash_attention_parallel_zero_block_size_error() {
let attn = Attention::new(4).unwrap();
let q = Tensor::from_vec(vec![1, 4], vec![1.0; 4]).unwrap();
let k = Tensor::from_vec(vec![1, 4], vec![1.0; 4]).unwrap();
let v = Tensor::from_vec(vec![1, 4], vec![1.0; 4]).unwrap();
let result = attn.flash_forward_parallel(&q, &k, &v, 0);
assert!(result.is_err());
}
#[test]
fn test_flash_attention_parallel_large_sequence() {
let attn = Attention::new(16).unwrap();
let mut q_data = Vec::new();
let mut k_data = Vec::new();
let mut v_data = Vec::new();
for i in 0..64 {
for j in 0..16 {
#[allow(clippy::cast_precision_loss)]
{
q_data.push((i * 16 + j) as f32 * 0.02);
k_data.push((i * 16 + j) as f32 * 0.015);
v_data.push((i * 16 + j) as f32 * 0.05);
}
}
}
let q = Tensor::from_vec(vec![64, 16], q_data).unwrap();
let k = Tensor::from_vec(vec![64, 16], k_data).unwrap();
let v = Tensor::from_vec(vec![64, 16], v_data).unwrap();
let standard = attn.forward(&q, &k, &v).unwrap();
let parallel = attn.flash_forward_parallel(&q, &k, &v, 16).unwrap();
assert_eq!(standard.shape(), parallel.shape());
for i in 0..standard.data().len() {
assert!(
(standard.data()[i] - parallel.data()[i]).abs() < 1e-3,
"Mismatch at {}: {} vs {}",
i,
standard.data()[i],
parallel.data()[i]
);
}
}
#[test]
fn test_flash_attention_v2_vs_parallel_consistency() {
let attn = Attention::new(8).unwrap();
let mut q_data = Vec::new();
let mut k_data = Vec::new();
let mut v_data = Vec::new();
for i in 0..16 {
for j in 0..8 {
#[allow(clippy::cast_precision_loss)]
{
q_data.push((i * 8 + j) as f32 * 0.1);
k_data.push((i * 8 + j) as f32 * 0.08);
v_data.push((i * 8 + j) as f32 * 0.15);
}
}
}
let q = Tensor::from_vec(vec![16, 8], q_data).unwrap();
let k = Tensor::from_vec(vec![16, 8], k_data).unwrap();
let v = Tensor::from_vec(vec![16, 8], v_data).unwrap();
let v2 = attn.flash_forward_v2(&q, &k, &v, 4).unwrap();
let parallel = attn.flash_forward_parallel(&q, &k, &v, 4).unwrap();
assert_eq!(v2.shape(), parallel.shape());
for i in 0..v2.data().len() {
assert!(
(v2.data()[i] - parallel.data()[i]).abs() < 1e-5,
"Mismatch at {}: v2={} vs parallel={}",
i,
v2.data()[i],
parallel.data()[i]
);
}
}
#[test]
fn test_rope_creation() {
let rope = RoPE::new(64, 10000.0).unwrap();
assert_eq!(rope.dim(), 64);
assert!((rope.base() - 10000.0).abs() < 1e-6);
assert_eq!(rope.inv_freq().len(), 32); }
#[test]
fn test_rope_with_default_base() {
let rope = RoPE::with_default_base(128).unwrap();
assert_eq!(rope.dim(), 128);
assert!((rope.base() - 10000.0).abs() < 1e-6);
}
#[test]
fn test_rope_zero_dim_error() {
let result = RoPE::new(0, 10000.0);
assert!(result.is_err());
}
#[test]
fn test_rope_odd_dim_error() {
let result = RoPE::new(63, 10000.0);
assert!(result.is_err());
}
#[test]
fn test_rope_forward_shape() {
let rope = RoPE::with_default_base(4).unwrap();
let input = Tensor::from_vec(vec![2, 4], vec![1.0; 8]).unwrap();
let output = rope.forward(&input, 0).unwrap();
assert_eq!(output.shape(), &[2, 4]);
assert_eq!(output.data().len(), 8);
}
#[test]
fn test_rope_position_zero_identity() {
let rope = RoPE::with_default_base(4).unwrap();
let input = Tensor::from_vec(vec![4], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let output = rope.forward(&input, 0).unwrap();
for i in 0..4 {
assert!(
(output.data()[i] - input.data()[i]).abs() < 1e-6,
"Position 0 should be identity: expected {}, got {}",
input.data()[i],
output.data()[i]
);
}
}
#[test]
fn test_rope_preserves_norm() {
let rope = RoPE::with_default_base(4).unwrap();
let input = Tensor::from_vec(vec![4], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let output = rope.forward(&input, 100).unwrap();
let in_norm_0 = (input.data()[0].powi(2) + input.data()[1].powi(2)).sqrt();
let in_norm_1 = (input.data()[2].powi(2) + input.data()[3].powi(2)).sqrt();
let out_norm_0 = (output.data()[0].powi(2) + output.data()[1].powi(2)).sqrt();
let out_norm_1 = (output.data()[2].powi(2) + output.data()[3].powi(2)).sqrt();
assert!(
(in_norm_0 - out_norm_0).abs() < 1e-5,
"Pair 0 norm should be preserved"
);
assert!(
(in_norm_1 - out_norm_1).abs() < 1e-5,
"Pair 1 norm should be preserved"
);
}
#[test]
fn test_rope_different_positions() {
let rope = RoPE::with_default_base(4).unwrap();
let input = Tensor::from_vec(vec![4], vec![1.0, 0.0, 1.0, 0.0]).unwrap();
let out_pos_zero = rope.forward(&input, 0).unwrap();
let out_pos_ten = rope.forward(&input, 10).unwrap();
let out_pos_hundred = rope.forward(&input, 100).unwrap();
assert!(
(out_pos_zero.data()[0] - out_pos_ten.data()[0]).abs() > 1e-6
|| (out_pos_zero.data()[1] - out_pos_ten.data()[1]).abs() > 1e-6
);
assert!(
(out_pos_ten.data()[0] - out_pos_hundred.data()[0]).abs() > 1e-6
|| (out_pos_ten.data()[1] - out_pos_hundred.data()[1]).abs() > 1e-6
);
}
#[test]
fn test_rope_dimension_mismatch_error() {
let rope = RoPE::with_default_base(4).unwrap();
let input = Tensor::from_vec(vec![6], vec![1.0; 6]).unwrap();
let result = rope.forward(&input, 0);
assert!(result.is_err());
}
#[test]
fn test_rope_batched() {
let rope = RoPE::with_default_base(4).unwrap();
let input = Tensor::from_vec(vec![3, 4], vec![1.0; 12]).unwrap();
let output = rope.forward(&input, 5).unwrap();
assert_eq!(output.shape(), &[3, 4]);
for batch in 0..3 {
for i in 0..4 {
let expected = output.data()[i]; let actual = output.data()[batch * 4 + i];
assert!(
(expected - actual).abs() < 1e-6,
"All batch elements should have same rotation"
);
}
}
}
#[test]
fn test_rope_inv_freq_computation() {
let rope = RoPE::new(4, 10000.0).unwrap();
let inv_freq = rope.inv_freq();
assert!((inv_freq[0] - 1.0).abs() < 1e-6);
assert!((inv_freq[1] - 0.01).abs() < 1e-6);
}
#[test]
fn test_scaled_rope_no_scaling() {
let scaled = ScaledRoPE::new(64, 10000.0, RopeScalingType::None).unwrap();
assert_eq!(scaled.dim(), 64);
assert!((scaled.original_base() - 10000.0).abs() < 1e-6);
assert!((scaled.scaled_base() - 10000.0).abs() < 1e-6);
assert!((scaled.mscale() - 1.0).abs() < 1e-6);
assert!((scaled.context_length_multiplier() - 1.0).abs() < 1e-6);
}
#[test]
fn test_scaled_rope_linear_scaling() {
let scaling = RopeScalingType::Linear { scale: 4.0 };
let scaled = ScaledRoPE::new(64, 10000.0, scaling).unwrap();
assert!((scaled.context_length_multiplier() - 4.0).abs() < 1e-6);
assert!((scaled.scaled_base() - 10000.0).abs() < 1e-6);
assert!((scaled.mscale() - 1.0).abs() < 1e-6);
}
#[test]
fn test_scaled_rope_ntk_scaling() {
let scaling = RopeScalingType::Ntk { scale: 4.0 };
let scaled = ScaledRoPE::new(64, 10000.0, scaling).unwrap();
assert!((scaled.context_length_multiplier() - 4.0).abs() < 1e-6);
assert!(scaled.scaled_base() > 10000.0);
assert!(scaled.scaled_base() > 40000.0);
assert!((scaled.mscale() - 1.0).abs() < 1e-6);
}
#[test]
fn test_scaled_rope_dynamic_ntk() {
let scaling = RopeScalingType::DynamicNtk {
original_max_len: 2048,
target_max_len: 8192,
};
let scaled = ScaledRoPE::new(64, 10000.0, scaling).unwrap();
assert!((scaled.context_length_multiplier() - 4.0).abs() < 1e-6);
assert!(scaled.scaled_base() > 40000.0);
}
#[test]
fn test_scaled_rope_yarn() {
let scaling = RopeScalingType::Yarn {
original_max_len: 2048,
target_max_len: 32768,
attn_factor: 0.0, beta_fast: 32.0,
beta_slow: 1.0,
};
let scaled = ScaledRoPE::new(64, 10000.0, scaling).unwrap();
assert!((scaled.context_length_multiplier() - 16.0).abs() < 1e-6);
assert!(scaled.mscale() > 1.0);
assert!(scaled.scaled_base() > 10000.0);
}
#[test]
fn test_scaled_rope_yarn_custom_attn_factor() {
let scaling = RopeScalingType::Yarn {
original_max_len: 2048,
target_max_len: 8192,
attn_factor: 1.5, beta_fast: 32.0,
beta_slow: 1.0,
};
let scaled = ScaledRoPE::new(64, 10000.0, scaling).unwrap();
assert!((scaled.mscale() - 1.5).abs() < 1e-6);
}
#[test]
fn test_scaled_rope_forward_no_scaling() {
let scaled = ScaledRoPE::new(4, 10000.0, RopeScalingType::None).unwrap();
let input = Tensor::from_vec(vec![4], vec![1.0, 0.0, 0.0, 1.0]).unwrap();
let output = scaled.forward(&input, 0).unwrap();
assert_eq!(output.shape(), &[4]);
}
#[test]
fn test_scaled_rope_forward_linear() {
let scaling = RopeScalingType::Linear { scale: 2.0 };
let scaled = ScaledRoPE::new(4, 10000.0, scaling).unwrap();
let input = Tensor::from_vec(vec![4], vec![1.0, 0.0, 0.0, 1.0]).unwrap();
let output = scaled.forward(&input, 10).unwrap();
assert_eq!(output.shape(), &[4]);
}
#[test]
fn test_scaled_rope_forward_ntk() {
let scaling = RopeScalingType::Ntk { scale: 4.0 };
let scaled = ScaledRoPE::new(4, 10000.0, scaling).unwrap();
let input = Tensor::from_vec(vec![4], vec![1.0, 0.0, 0.0, 1.0]).unwrap();
let output = scaled.forward(&input, 100).unwrap();
assert_eq!(output.shape(), &[4]);
let norm: f32 = output.data().iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 2.0_f32.sqrt()).abs() < 0.1);
}
#[test]
fn test_scaled_rope_forward_yarn() {
let scaling = RopeScalingType::Yarn {
original_max_len: 2048,
target_max_len: 8192,
attn_factor: 1.0,
beta_fast: 32.0,
beta_slow: 1.0,
};
let scaled = ScaledRoPE::new(4, 10000.0, scaling).unwrap();
let input = Tensor::from_vec(vec![4], vec![1.0, 0.0, 0.0, 1.0]).unwrap();
let output = scaled.forward(&input, 5000).unwrap();
assert_eq!(output.shape(), &[4]);
}
#[test]
fn test_scaled_rope_zero_dim_error() {
let result = ScaledRoPE::new(0, 10000.0, RopeScalingType::None);
assert!(result.is_err());
}
#[test]
fn test_scaled_rope_odd_dim_error() {
let result = ScaledRoPE::new(63, 10000.0, RopeScalingType::None);
assert!(result.is_err());
}
#[test]
fn test_scaled_rope_dimension_mismatch() {
let scaled = ScaledRoPE::new(4, 10000.0, RopeScalingType::None).unwrap();
let input = Tensor::from_vec(vec![8], vec![0.0; 8]).unwrap();
let result = scaled.forward(&input, 0);
assert!(result.is_err());
}
#[test]
fn test_rope_scaling_type_default() {
let scaling = RopeScalingType::default();
assert_eq!(scaling, RopeScalingType::None);
}
#[test]
fn test_scaled_rope_with_default_base() {
let scaled = ScaledRoPE::with_default_base(64, RopeScalingType::None).unwrap();
assert!((scaled.original_base() - 10000.0).abs() < 1e-6);
}
#[test]
fn test_scaled_rope_inv_freq_length() {
let scaled = ScaledRoPE::new(128, 10000.0, RopeScalingType::None).unwrap();
assert_eq!(scaled.inv_freq().len(), 64); }
#[test]
fn test_alibi_creation() {
let alibi = ALiBi::new(8).unwrap();
assert_eq!(alibi.num_heads(), 8);
assert_eq!(alibi.slopes().len(), 8);
}
#[test]
fn test_alibi_zero_heads_error() {
let result = ALiBi::new(0);
assert!(result.is_err());
}
#[test]
fn test_alibi_slopes_power_of_2() {
let alibi = ALiBi::new(8).unwrap();
let slopes = alibi.slopes();
assert!((slopes[0] - 1.0).abs() < 1e-6); assert!((slopes[1] - 0.5).abs() < 1e-6); assert!((slopes[2] - 0.25).abs() < 1e-6); assert!((slopes[3] - 0.125).abs() < 1e-6); }
#[test]
fn test_alibi_slopes_non_power_of_2() {
let alibi = ALiBi::new(6).unwrap();
let slopes = alibi.slopes();
assert_eq!(slopes.len(), 6);
assert!((slopes[0] - 1.0).abs() < 1e-6); assert!((slopes[1] - 0.25).abs() < 1e-6); assert!((slopes[2] - 0.0625).abs() < 1e-6); assert!((slopes[3] - 0.015_625).abs() < 1e-6);
assert!((slopes[4] - 0.5).abs() < 1e-6);
assert!((slopes[5] - 0.125).abs() < 1e-6);
}
#[test]
fn test_alibi_bias_shape() {
let alibi = ALiBi::new(4).unwrap();
let bias = alibi.get_bias(10).unwrap();
assert_eq!(bias.shape(), &[10, 10, 4]);
}
#[test]
fn test_alibi_bias_zero_seq_len_error() {
let alibi = ALiBi::new(4).unwrap();
let result = alibi.get_bias(0);
assert!(result.is_err());
}
#[test]
fn test_alibi_bias_diagonal_zero() {
let alibi = ALiBi::new(4).unwrap();
let bias = alibi.get_bias(5).unwrap();
for i in 0..5 {
for h in 0..4 {
let idx = i * 5 * 4 + i * 4 + h; let value = bias.data()[idx];
assert!(
value.abs() < 1e-6,
"Diagonal bias[{i}, {i}, {h}] should be 0, got {value}"
);
}
}
}
#[test]
fn test_alibi_bias_symmetry() {
let alibi = ALiBi::new(2).unwrap();
let bias = alibi.get_bias(4).unwrap();
for i in 0..4 {
for j in 0..4 {
for h in 0..2 {
let idx_ij = i * 4 * 2 + j * 2 + h;
let idx_ji = j * 4 * 2 + i * 2 + h;
let bias_ij = bias.data()[idx_ij];
let bias_ji = bias.data()[idx_ji];
assert!(
(bias_ij - bias_ji).abs() < 1e-6,
"Bias should be symmetric: [{i},{j},{h}]={bias_ij} vs [{j},{i},{h}]={bias_ji}"
);
}
}
}
}
#[test]
fn test_alibi_bias_computation() {
let alibi = ALiBi::new(2).unwrap();
let slopes = alibi.slopes();
let bias = alibi.get_bias(3).unwrap();
let idx = 2 * 2;
assert!(
(bias.data()[idx] - (-2.0)).abs() < 1e-6,
"Expected -2.0, got {}",
bias.data()[idx]
);
let idx = 3 * 2 + 2 * 2 + 1;
let expected = -slopes[1];
assert!(
(bias.data()[idx] - expected).abs() < 1e-6,
"Expected {expected}, got {}",
bias.data()[idx]
);
}
#[test]
fn test_alibi_bias_negative() {
let alibi = ALiBi::new(4).unwrap();
let bias = alibi.get_bias(10).unwrap();
for &value in bias.data() {
assert!(value <= 1e-6, "Bias should be non-positive, got {value}");
}
}
#[test]
fn test_alibi_bias_distance_proportional() {
let alibi = ALiBi::new(1).unwrap();
let bias = alibi.get_bias(5).unwrap();
let bias_01 = bias.data()[1];
let bias_02 = bias.data()[2];
let bias_03 = bias.data()[3];
assert!((bias_01 - (-1.0)).abs() < 1e-6);
assert!((bias_02 - (-2.0)).abs() < 1e-6);
assert!((bias_03 - (-3.0)).abs() < 1e-6);
}
#[test]
fn test_alibi_single_head() {
let alibi = ALiBi::new(1).unwrap();
assert_eq!(alibi.num_heads(), 1);
assert_eq!(alibi.slopes().len(), 1);
assert!((alibi.slopes()[0] - 1.0).abs() < 1e-6); }
#[test]
fn test_alibi_large_num_heads() {
let alibi = ALiBi::new(12).unwrap();
assert_eq!(alibi.num_heads(), 12);
assert_eq!(alibi.slopes().len(), 12);
for slope in alibi.slopes() {
assert!(*slope > 0.0, "Slope should be positive, got {slope}");
}
assert!((alibi.slopes()[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_alibi_bias_long_sequence() {
let alibi = ALiBi::new(8).unwrap();
let bias = alibi.get_bias(128).unwrap();
assert_eq!(bias.shape(), &[128, 128, 8]);
let near_bias = bias.data()[8]; let far_bias = bias.data()[100 * 8];
assert!(near_bias > far_bias); }
#[test]
fn test_kvcache_creation() {
let cache = KVCache::new(4, 512, 64).unwrap();
assert_eq!(cache.num_layers(), 4);
assert_eq!(cache.max_seq_len(), 512);
assert_eq!(cache.head_dim(), 64);
assert_eq!(cache.current_pos(), 0);
assert!(!cache.is_full());
}
#[test]
fn test_kvcache_zero_params_error() {
assert!(KVCache::new(0, 512, 64).is_err());
assert!(KVCache::new(4, 0, 64).is_err());
assert!(KVCache::new(4, 512, 0).is_err());
}
#[test]
fn test_kvcache_update_and_retrieve() {
let mut cache = KVCache::new(2, 10, 4).unwrap();
let key = Tensor::from_vec(vec![4], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let value = Tensor::from_vec(vec![4], vec![5.0, 6.0, 7.0, 8.0]).unwrap();
cache.update(0, &key, &value).unwrap();
cache.advance();
let cached_key = cache.get_key(0).unwrap();
let cached_value = cache.get_value(0).unwrap();
assert_eq!(cached_key.shape(), &[1, 4]);
assert_eq!(cached_value.shape(), &[1, 4]);
for i in 0..4 {
assert!((cached_key.data()[i] - key.data()[i]).abs() < 1e-6);
assert!((cached_value.data()[i] - value.data()[i]).abs() < 1e-6);
}
}
#[test]
fn test_kvcache_multiple_positions() {
let mut cache = KVCache::new(1, 10, 2).unwrap();
for pos in 0..3 {
#[allow(clippy::cast_precision_loss)]
let base = pos as f32;
let key = Tensor::from_vec(vec![2], vec![base, base + 0.5]).unwrap();
let value = Tensor::from_vec(vec![2], vec![base + 1.0, base + 1.5]).unwrap();
cache.update(0, &key, &value).unwrap();
cache.advance();
}
assert_eq!(cache.current_pos(), 3);
let cached_key = cache.get_key(0).unwrap();
let cached_value = cache.get_value(0).unwrap();
assert_eq!(cached_key.shape(), &[3, 2]);
assert_eq!(cached_value.shape(), &[3, 2]);
assert!((cached_key.data()[0] - 0.0).abs() < 1e-6);
assert!((cached_key.data()[1] - 0.5).abs() < 1e-6);
assert!((cached_key.data()[2] - 1.0).abs() < 1e-6);
assert!((cached_key.data()[3] - 1.5).abs() < 1e-6);
}
#[test]
fn test_kvcache_multiple_layers() {
let mut cache = KVCache::new(2, 10, 4).unwrap();
let key0 = Tensor::from_vec(vec![4], vec![1.0; 4]).unwrap();
let value0 = Tensor::from_vec(vec![4], vec![2.0; 4]).unwrap();
let key1 = Tensor::from_vec(vec![4], vec![3.0; 4]).unwrap();
let value1 = Tensor::from_vec(vec![4], vec![4.0; 4]).unwrap();
cache.update(0, &key0, &value0).unwrap();
cache.update(1, &key1, &value1).unwrap();
cache.advance();
let layer0_key = cache.get_key(0).unwrap();
assert!((layer0_key.data()[0] - 1.0).abs() < 1e-6);
let layer1_key = cache.get_key(1).unwrap();
assert!((layer1_key.data()[0] - 3.0).abs() < 1e-6);
}
#[test]
fn test_kvcache_layer_out_of_bounds_error() {
let mut cache = KVCache::new(2, 10, 4).unwrap();
let key = Tensor::from_vec(vec![4], vec![1.0; 4]).unwrap();
let value = Tensor::from_vec(vec![4], vec![2.0; 4]).unwrap();
assert!(cache.update(2, &key, &value).is_err());
assert!(cache.get_key(2).is_err());
assert!(cache.get_value(2).is_err());
}
#[test]
fn test_kvcache_size_mismatch_error() {
let mut cache = KVCache::new(1, 10, 4).unwrap();
let key = Tensor::from_vec(vec![3], vec![1.0; 3]).unwrap();
let value = Tensor::from_vec(vec![4], vec![2.0; 4]).unwrap();
assert!(cache.update(0, &key, &value).is_err());
let key = Tensor::from_vec(vec![4], vec![1.0; 4]).unwrap();
let value = Tensor::from_vec(vec![3], vec![2.0; 3]).unwrap();
assert!(cache.update(0, &key, &value).is_err());
}
#[test]
fn test_kvcache_full_error() {
let mut cache = KVCache::new(1, 2, 4).unwrap();
let key = Tensor::from_vec(vec![4], vec![1.0; 4]).unwrap();
let value = Tensor::from_vec(vec![4], vec![2.0; 4]).unwrap();
cache.update(0, &key, &value).unwrap();
cache.advance();
cache.update(0, &key, &value).unwrap();
cache.advance();
assert!(cache.is_full());
assert!(cache.update(0, &key, &value).is_err());
}
#[test]
fn test_kvcache_clear() {
let mut cache = KVCache::new(1, 10, 4).unwrap();
let key = Tensor::from_vec(vec![4], vec![1.0; 4]).unwrap();
let value = Tensor::from_vec(vec![4], vec![2.0; 4]).unwrap();
cache.update(0, &key, &value).unwrap();
cache.advance();
assert_eq!(cache.current_pos(), 1);
cache.clear();
assert_eq!(cache.current_pos(), 0);
assert!(!cache.is_full());
}
#[test]
fn test_kvcache_empty_retrieval() {
let cache = KVCache::new(1, 10, 4).unwrap();
let cached_key = cache.get_key(0).unwrap();
let cached_value = cache.get_value(0).unwrap();
assert_eq!(cached_key.shape(), &[1, 4]);
assert_eq!(cached_value.shape(), &[1, 4]);
for &val in cached_key.data() {
assert!((val - 0.0).abs() < 1e-6);
}
}
#[test]
fn test_transformer_block_creation() {
let block = TransformerBlock::new(64, 4, 256, 1e-5).unwrap();
assert_eq!(block.hidden_dim(), 64);
}
#[test]
fn test_transformer_block_zero_params_error() {
assert!(TransformerBlock::new(0, 4, 256, 1e-5).is_err());
assert!(TransformerBlock::new(64, 0, 256, 1e-5).is_err());
assert!(TransformerBlock::new(64, 4, 0, 1e-5).is_err());
}
#[test]
fn test_transformer_block_head_divisibility_error() {
assert!(TransformerBlock::new(63, 4, 256, 1e-5).is_err());
}
#[test]
fn test_transformer_block_forward_shape() {
let block = TransformerBlock::new(8, 1, 32, 1e-5).unwrap();
let input = Tensor::from_vec(vec![2, 8], vec![0.1; 16]).unwrap();
let output = block.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 8]);
assert_eq!(output.data().len(), 16);
}
#[test]
fn test_transformer_block_forward_valid_output() {
let block = TransformerBlock::new(4, 1, 16, 1e-5).unwrap();
let input = Tensor::from_vec(vec![1, 4], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let output = block.forward(&input).unwrap();
for &val in output.data() {
assert!(val.is_finite(), "Output contains non-finite values");
}
}
#[test]
fn test_transformer_block_residual_connection() {
let block = TransformerBlock::new(4, 1, 16, 1e-5).unwrap();
let input = Tensor::from_vec(vec![1, 4], vec![0.0, 0.0, 0.0, 0.0]).unwrap();
let output = block.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 4]);
}
#[test]
fn test_transformer_block_shape_mismatch_error() {
let block = TransformerBlock::new(8, 1, 32, 1e-5).unwrap();
let input = Tensor::from_vec(vec![4], vec![1.0; 4]).unwrap();
let result = block.forward(&input);
assert!(result.is_err());
}
#[test]
fn test_transformer_block_mutable_access() {
let mut block = TransformerBlock::new(4, 1, 16, 1e-5).unwrap();
let _attn_norm = block.attn_norm_mut();
let _attention = block.attention_mut();
let _ffn_norm = block.ffn_norm_mut();
let _ffn = block.ffn_mut();
}
#[test]
fn test_embedding_creation() {
let embed = Embedding::new(1000, 64).unwrap();
assert_eq!(embed.vocab_size(), 1000);
assert_eq!(embed.embed_dim(), 64);
}
#[test]
fn test_embedding_zero_params_error() {
assert!(Embedding::new(0, 64).is_err());
assert!(Embedding::new(1000, 0).is_err());
}
#[test]
fn test_embedding_forward_shape() {
let embed = Embedding::new(100, 8).unwrap();
let token_ids = vec![0, 1, 2];
let output = embed.forward(&token_ids).unwrap();
assert_eq!(output.shape(), &[3, 8]);
assert_eq!(output.data().len(), 24);
}
#[test]
fn test_embedding_forward_lookup() {
let mut embed = Embedding::new(10, 4).unwrap();
let offset = 5 * 4;
embed.weights_mut()[offset] = 1.0;
embed.weights_mut()[offset + 1] = 2.0;
embed.weights_mut()[offset + 2] = 3.0;
embed.weights_mut()[offset + 3] = 4.0;
let output = embed.forward(&[5]).unwrap();
assert_eq!(output.shape(), &[1, 4]);
assert!((output.data()[0] - 1.0).abs() < 1e-6);
assert!((output.data()[1] - 2.0).abs() < 1e-6);
assert!((output.data()[2] - 3.0).abs() < 1e-6);
assert!((output.data()[3] - 4.0).abs() < 1e-6);
}
#[test]
fn test_embedding_out_of_bounds_error() {
let embed = Embedding::new(10, 4).unwrap();
assert!(embed.forward(&[10]).is_err()); assert!(embed.forward(&[100]).is_err());
}
#[test]
fn test_embedding_empty_input_error() {
let embed = Embedding::new(10, 4).unwrap();
assert!(embed.forward(&[]).is_err());
}
#[test]
fn test_model_creation() {
let config = ModelConfig {
vocab_size: 100,
hidden_dim: 8,
num_heads: 1,
num_layers: 2,
intermediate_dim: 32,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
assert_eq!(model.config().vocab_size, 100);
assert_eq!(model.config().num_layers, 2);
}
#[test]
fn test_model_forward_shape() {
let config = ModelConfig {
vocab_size: 50,
hidden_dim: 4,
num_heads: 1,
num_layers: 1,
intermediate_dim: 16,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let token_ids = vec![0, 1, 2];
let output = model.forward(&token_ids).unwrap();
assert_eq!(output.shape(), &[3, 50]);
}
#[test]
fn test_model_forward_valid_output() {
let config = ModelConfig {
vocab_size: 20,
hidden_dim: 4,
num_heads: 1,
num_layers: 1,
intermediate_dim: 16,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let output = model.forward(&[0, 1]).unwrap();
for &val in output.data() {
assert!(val.is_finite(), "Output contains non-finite values");
}
}
#[test]
fn test_model_num_parameters() {
let config = ModelConfig {
vocab_size: 100,
hidden_dim: 8,
num_heads: 1,
num_layers: 2,
intermediate_dim: 32,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let params = model.num_parameters();
assert!(params > 0);
assert!(params >= 100 * 8 + 8 * 100);
}
#[test]
fn test_model_mutable_access() {
let config = ModelConfig {
vocab_size: 50,
hidden_dim: 4,
num_heads: 1,
num_layers: 1,
intermediate_dim: 16,
eps: 1e-5,
};
let mut model = Model::new(config).unwrap();
let _embed = model.embedding_mut();
let _blocks = model.blocks_mut();
let _norm = model.final_norm_mut();
let _head = model.lm_head_mut();
}
#[test]
fn test_model_generate_basic() {
let config = ModelConfig {
vocab_size: 20,
hidden_dim: 4,
num_heads: 1,
num_layers: 1,
intermediate_dim: 16,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let gen_config = GenerationConfig::greedy().with_max_tokens(5);
let tokens = model.generate(&[0], &gen_config).unwrap();
assert!(tokens.len() <= 6);
assert!(!tokens.is_empty());
assert_eq!(tokens[0], 0);
}
#[test]
fn test_model_generate_respects_max_tokens() {
let config = ModelConfig {
vocab_size: 10,
hidden_dim: 4,
num_heads: 1,
num_layers: 1,
intermediate_dim: 16,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let gen_config = GenerationConfig::greedy().with_max_tokens(3);
let tokens = model.generate(&[0, 1], &gen_config).unwrap();
assert!(tokens.len() <= 5);
}
#[test]
fn test_model_generate_with_eos() {
let config = ModelConfig {
vocab_size: 10,
hidden_dim: 4,
num_heads: 1,
num_layers: 1,
intermediate_dim: 16,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let gen_config = GenerationConfig::greedy()
.with_max_tokens(100)
.with_eos_token_id(5);
let tokens = model.generate(&[0], &gen_config).unwrap();
assert!(tokens.len() <= 101);
}
#[test]
fn test_model_generate_empty_prompt_error() {
let config = ModelConfig {
vocab_size: 10,
hidden_dim: 4,
num_heads: 1,
num_layers: 1,
intermediate_dim: 16,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let gen_config = GenerationConfig::greedy();
let result = model.generate(&[], &gen_config);
assert!(result.is_err());
}
#[test]
fn test_model_generate_deterministic_with_seed() {
let config = ModelConfig {
vocab_size: 20,
hidden_dim: 4,
num_heads: 1,
num_layers: 1,
intermediate_dim: 16,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let gen_config = GenerationConfig::greedy()
.with_max_tokens(5)
.with_seed(12345);
let tokens1 = model.generate(&[0], &gen_config).unwrap();
let tokens2 = model.generate(&[0], &gen_config).unwrap();
assert_eq!(tokens1, tokens2);
}
#[test]
fn test_model_generate_top_k() {
let config = ModelConfig {
vocab_size: 20,
hidden_dim: 4,
num_heads: 1,
num_layers: 1,
intermediate_dim: 16,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let gen_config = GenerationConfig::top_k(5).with_max_tokens(3).with_seed(42);
let tokens = model.generate(&[0], &gen_config).unwrap();
assert!(tokens.len() <= 4);
for &token in &tokens {
assert!(token < 20);
}
}
#[test]
fn test_multi_head_attention_creation_mha() {
let mha = MultiHeadAttention::mha(64, 8).unwrap();
assert_eq!(mha.num_heads(), 8);
assert_eq!(mha.num_kv_heads(), 8);
assert_eq!(mha.head_dim(), 8); assert_eq!(mha.hidden_dim(), 64);
assert!(mha.is_mha());
assert!(!mha.is_mqa());
assert!(!mha.is_gqa());
}
#[test]
fn test_multi_head_attention_creation_mqa() {
let mqa = MultiHeadAttention::mqa(64, 8).unwrap();
assert_eq!(mqa.num_heads(), 8);
assert_eq!(mqa.num_kv_heads(), 1);
assert_eq!(mqa.head_dim(), 8);
assert_eq!(mqa.hidden_dim(), 64);
assert!(mqa.is_mqa());
assert!(!mqa.is_mha());
assert!(!mqa.is_gqa());
}
#[test]
fn test_multi_head_attention_creation_gqa() {
let gqa = MultiHeadAttention::gqa(64, 8, 2).unwrap();
assert_eq!(gqa.num_heads(), 8);
assert_eq!(gqa.num_kv_heads(), 2);
assert_eq!(gqa.head_dim(), 8);
assert_eq!(gqa.hidden_dim(), 64);
assert!(gqa.is_gqa());
assert!(!gqa.is_mha());
assert!(!gqa.is_mqa());
}
#[test]
fn test_multi_head_attention_zero_hidden_dim_error() {
let result = MultiHeadAttention::new(0, 8, 8);
assert!(result.is_err());
}
#[test]
fn test_multi_head_attention_zero_num_heads_error() {
let result = MultiHeadAttention::new(64, 0, 1);
assert!(result.is_err());
}
#[test]
fn test_multi_head_attention_zero_num_kv_heads_error() {
let result = MultiHeadAttention::new(64, 8, 0);
assert!(result.is_err());
}
#[test]
fn test_multi_head_attention_kv_heads_too_large_error() {
let result = MultiHeadAttention::new(64, 8, 16);
assert!(result.is_err());
}
#[test]
fn test_multi_head_attention_indivisible_error() {
let result = MultiHeadAttention::new(65, 8, 8);
assert!(result.is_err());
}
#[test]
fn test_multi_head_attention_heads_not_divisible_error() {
let result = MultiHeadAttention::new(64, 8, 3);
assert!(result.is_err());
}
#[test]
fn test_multi_head_attention_mha_forward() {
let mha = MultiHeadAttention::mha(8, 2).unwrap();
let input = Tensor::from_vec(
vec![2, 8],
vec![
1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],
)
.unwrap();
let output = mha.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 8]);
}
#[test]
fn test_multi_head_attention_mqa_forward() {
let mqa = MultiHeadAttention::mqa(8, 2).unwrap();
let input = Tensor::from_vec(
vec![2, 8],
vec![
1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ],
)
.unwrap();
let output = mqa.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 8]);
}
#[test]
fn test_multi_head_attention_shape_validation() {
let mha = MultiHeadAttention::mha(8, 2).unwrap();
let input_1d = Tensor::from_vec(vec![8], vec![1.0; 8]).unwrap();
let result = mha.forward(&input_1d);
assert!(result.is_err());
let input_wrong_dim = Tensor::from_vec(vec![2, 16], vec![1.0; 32]).unwrap();
let result = mha.forward(&input_wrong_dim);
assert!(result.is_err());
}
#[test]
fn test_multi_head_attention_mha_vs_mqa_shape_consistency() {
let mha = MultiHeadAttention::mha(16, 4).unwrap();
let mqa = MultiHeadAttention::mqa(16, 4).unwrap();
let input = Tensor::from_vec(vec![3, 16], vec![0.5; 48]).unwrap();
let multi_head_output = mha.forward(&input).unwrap();
let multi_query_output = mqa.forward(&input).unwrap();
assert_eq!(multi_head_output.shape(), &[3, 16]);
assert_eq!(multi_query_output.shape(), &[3, 16]);
assert_eq!(multi_head_output.shape(), multi_query_output.shape());
}
#[test]
fn test_multi_head_attention_single_head() {
let mha = MultiHeadAttention::mha(8, 1).unwrap();
let input = Tensor::from_vec(vec![2, 8], vec![0.5; 16]).unwrap();
let output = mha.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 8]);
}
#[test]
fn test_multi_head_attention_mqa_kv_sharing() {
let mqa = MultiHeadAttention::mqa(32, 8).unwrap();
let input = Tensor::from_vec(vec![4, 32], vec![0.1; 128]).unwrap();
let output = mqa.forward(&input).unwrap();
assert_eq!(output.shape(), &[4, 32]);
}
#[test]
fn test_multi_head_attention_long_sequence() {
let mha = MultiHeadAttention::mha(16, 4).unwrap();
let input = Tensor::from_vec(vec![10, 16], vec![0.3; 160]).unwrap();
let output = mha.forward(&input).unwrap();
assert_eq!(output.shape(), &[10, 16]);
}
#[test]
fn test_multi_head_attention_mqa_memory_efficiency() {
let mqa = MultiHeadAttention::mqa(64, 16).unwrap();
let input = Tensor::from_vec(vec![2, 64], vec![0.2; 128]).unwrap();
let output = mqa.forward(&input).unwrap();
assert_eq!(output.shape(), &[2, 64]);
assert_eq!(output.data().len(), 128); }
#[test]
fn test_multi_head_attention_gqa_forward() {
let gqa = MultiHeadAttention::gqa(32, 8, 2).unwrap();
let input = Tensor::from_vec(vec![3, 32], vec![0.1; 96]).unwrap();
let output = gqa.forward(&input).unwrap();
assert_eq!(output.shape(), &[3, 32]);
}
#[test]
fn test_multi_head_attention_gqa_shape_consistency() {
let mha = MultiHeadAttention::mha(64, 8).unwrap();
let mqa = MultiHeadAttention::mqa(64, 8).unwrap();
let gqa = MultiHeadAttention::gqa(64, 8, 2).unwrap();
let input = Tensor::from_vec(vec![4, 64], vec![0.5; 256]).unwrap();
let multi_head_out = mha.forward(&input).unwrap();
let multi_query_out = mqa.forward(&input).unwrap();
let grouped_query_out = gqa.forward(&input).unwrap();
assert_eq!(multi_head_out.shape(), &[4, 64]);
assert_eq!(multi_query_out.shape(), &[4, 64]);
assert_eq!(grouped_query_out.shape(), &[4, 64]);
assert_eq!(multi_head_out.shape(), multi_query_out.shape());
assert_eq!(multi_head_out.shape(), grouped_query_out.shape());
}
#[test]
fn test_multi_head_attention_gqa_different_group_sizes() {
let gqa1 = MultiHeadAttention::gqa(128, 16, 4).unwrap();
let input = Tensor::from_vec(vec![2, 128], vec![0.3; 256]).unwrap();
let output1 = gqa1.forward(&input).unwrap();
assert_eq!(output1.shape(), &[2, 128]);
let gqa2 = MultiHeadAttention::gqa(128, 16, 8).unwrap();
let output2 = gqa2.forward(&input).unwrap();
assert_eq!(output2.shape(), &[2, 128]);
}
#[test]
fn test_phase3_acceptance_tokens_per_second() {
use crate::generate::GenerationConfig;
use std::time::Instant;
let config = ModelConfig {
vocab_size: 100, hidden_dim: 64, num_heads: 4, num_layers: 2, intermediate_dim: 128,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let prompt = vec![1, 5, 10];
let gen_config = GenerationConfig::greedy().with_max_tokens(5);
let _ = model.generate(&prompt, &gen_config).unwrap();
let tokens_per_run = 20;
let num_runs = 10;
let gen_config = GenerationConfig::greedy().with_max_tokens(tokens_per_run);
let start = Instant::now();
for _ in 0..num_runs {
let _ = model.generate(&prompt, &gen_config).unwrap();
}
let elapsed = start.elapsed();
let total_tokens = tokens_per_run * num_runs;
let tok_per_sec = total_tokens as f64 / elapsed.as_secs_f64();
assert!(
tok_per_sec >= 25.0,
"Phase 3 acceptance FAILED: {:.1} tok/s < 25.0 tok/s target. \
Note: Full optimization requires integrating Flash Attention v2 \
and FusedLayerNormLinear into Model::forward()",
tok_per_sec
);
eprintln!(
"Phase 3 acceptance PASSED: {:.1} tok/s (target: ≥25.0 tok/s)",
tok_per_sec
);
}
#[test]
fn test_phase3_flash_attention_v2_performance() {
use std::time::Instant;
let head_dim = 64;
let seq_len = 32;
let attn = Attention::new(head_dim).unwrap();
let q = Tensor::from_vec(vec![seq_len, head_dim], vec![0.1; seq_len * head_dim]).unwrap();
let k = Tensor::from_vec(vec![seq_len, head_dim], vec![0.2; seq_len * head_dim]).unwrap();
let v = Tensor::from_vec(vec![seq_len, head_dim], vec![0.3; seq_len * head_dim]).unwrap();
let _ = attn.flash_forward_v2(&q, &k, &v, 8).unwrap();
let _ = attn.flash_forward_parallel(&q, &k, &v, 8).unwrap();
let iterations = 100;
let start = Instant::now();
for _ in 0..iterations {
let _ = attn.flash_forward_v2(&q, &k, &v, 8).unwrap();
}
let v2_time = start.elapsed();
let start = Instant::now();
for _ in 0..iterations {
let _ = attn.flash_forward_parallel(&q, &k, &v, 8).unwrap();
}
let parallel_time = start.elapsed();
let v2_us = v2_time.as_micros() as f64 / iterations as f64;
let parallel_us = parallel_time.as_micros() as f64 / iterations as f64;
eprintln!(
"Flash Attention v2: {:.2}us/iter, Parallel: {:.2}us/iter",
v2_us, parallel_us
);
assert!(v2_us > 0.0, "v2 should have measurable time");
assert!(parallel_us > 0.0, "parallel should have measurable time");
}
#[test]
fn test_phase3_fused_layernorm_linear_performance() {
use std::time::Instant;
let feature_dim = 256;
let out_features = 512;
let batch_size = 32;
let fused = FusedLayerNormLinear::new(feature_dim, out_features, 1e-5).unwrap();
let input = Tensor::from_vec(
vec![batch_size, feature_dim],
vec![0.5; batch_size * feature_dim],
)
.unwrap();
let _ = fused.forward(&input).unwrap();
let _ = fused.forward_parallel(&input).unwrap();
let iterations = 100;
let start = Instant::now();
for _ in 0..iterations {
let _ = fused.forward(&input).unwrap();
}
let fused_time = start.elapsed();
let start = Instant::now();
for _ in 0..iterations {
let _ = fused.forward_parallel(&input).unwrap();
}
let parallel_time = start.elapsed();
let fused_us = fused_time.as_micros() as f64 / iterations as f64;
let parallel_us = parallel_time.as_micros() as f64 / iterations as f64;
eprintln!(
"FusedLayerNormLinear: {:.2}us/iter, Parallel: {:.2}us/iter",
fused_us, parallel_us
);
assert!(fused_us > 0.0, "fused should have measurable time");
assert!(parallel_us > 0.0, "parallel should have measurable time");
}
#[test]
fn test_quantized_linear_creation() {
let in_features = 256;
let out_features = 4;
let bytes_per_row = 144; let weight_bytes = vec![0u8; out_features * bytes_per_row];
let bias = vec![0.0f32; out_features];
let layer = QuantizedLinear::new(in_features, out_features, weight_bytes, bias);
assert!(
layer.is_ok(),
"Should create QuantizedLinear from Q4_K bytes"
);
let layer = layer.unwrap();
assert_eq!(layer.in_features(), in_features);
assert_eq!(layer.out_features(), out_features);
}
#[test]
fn test_quantized_linear_forward() {
let in_features = 256;
let out_features = 4;
let bytes_per_row = 144;
let weight_bytes = vec![0u8; out_features * bytes_per_row];
let bias = vec![1.0f32; out_features];
let layer = QuantizedLinear::new(in_features, out_features, weight_bytes, bias)
.expect("Should create layer");
let input = Tensor::from_vec(vec![in_features], vec![1.0f32; in_features])
.expect("Should create input");
let output = layer.forward(&input).expect("Forward should work");
assert_eq!(output.shape(), &[out_features]);
for &val in output.data() {
assert!(
(val - 1.0).abs() < 1e-5,
"Output should equal bias with zero weights"
);
}
}
#[test]
fn test_quantized_linear_batch_forward() {
let in_features = 256;
let out_features = 4;
let batch_size = 8;
let bytes_per_row = 144;
let weight_bytes = vec![0u8; out_features * bytes_per_row];
let bias = vec![2.0f32; out_features];
let layer = QuantizedLinear::new(in_features, out_features, weight_bytes, bias)
.expect("Should create layer");
let input = Tensor::from_vec(
vec![batch_size, in_features],
vec![1.0f32; batch_size * in_features],
)
.expect("Should create batch input");
let output = layer.forward(&input).expect("Batch forward should work");
assert_eq!(output.shape(), &[batch_size, out_features]);
}
#[test]
fn test_quantized_linear_memory_efficiency() {
let in_features = 4096; let out_features = 4096;
let f32_bytes = in_features * out_features * std::mem::size_of::<f32>();
let super_blocks_per_row = in_features.div_ceil(256);
let q4k_bytes = out_features * super_blocks_per_row * 144;
let ratio = f32_bytes as f64 / q4k_bytes as f64;
assert!(
ratio > 6.0,
"Q4_K should be >6x smaller than f32: ratio={}",
ratio
);
eprintln!(
"Memory efficiency: f32={} bytes, Q4_K={} bytes, ratio={:.2}x",
f32_bytes, q4k_bytes, ratio
);
}
#[test]
fn test_sliding_window_attention_new() {
let swa = SlidingWindowAttention::new(64, 4096).unwrap();
assert_eq!(swa.head_dim(), 64);
assert_eq!(swa.window_size(), 4096);
assert!((swa.scale() - 0.125).abs() < 1e-6); }
#[test]
fn test_sliding_window_attention_new_errors() {
assert!(SlidingWindowAttention::new(0, 4096).is_err());
assert!(SlidingWindowAttention::new(64, 0).is_err());
}
#[test]
fn test_sliding_window_attention_forward_basic() {
let swa = SlidingWindowAttention::new(4, 3).unwrap();
let query_data: Vec<f32> = (0..20).map(|i| i as f32 * 0.1).collect();
let key_data: Vec<f32> = (0..20).map(|i| i as f32 * 0.1).collect();
let value_data: Vec<f32> = (0..20).map(|i| (i % 4) as f32).collect();
let query = Tensor::from_vec(vec![5, 4], query_data).unwrap();
let key = Tensor::from_vec(vec![5, 4], key_data).unwrap();
let value = Tensor::from_vec(vec![5, 4], value_data).unwrap();
let output = swa.forward(&query, &key, &value).unwrap();
assert_eq!(output.size(), 20); }
#[test]
fn test_sliding_window_attention_causal_masking() {
let swa = SlidingWindowAttention::new(2, 10).unwrap(); let query = Tensor::from_vec(vec![3, 2], vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]).unwrap();
let key = Tensor::from_vec(vec![3, 2], vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0]).unwrap();
let value = Tensor::from_vec(vec![3, 2], vec![1.0, 0.0, 0.0, 1.0, 0.5, 0.5]).unwrap();
let output = swa.forward(&query, &key, &value).unwrap();
assert_eq!(output.size(), 6);
let data = output.data();
assert!(data[0].abs() > 0.0 || data[1].abs() > 0.0);
}
#[test]
fn test_sliding_window_attention_window_boundary() {
let swa = SlidingWindowAttention::new(2, 2).unwrap();
let query = Tensor::from_vec(vec![5, 2], vec![1.0; 10]).unwrap();
let key = Tensor::from_vec(vec![5, 2], vec![1.0; 10]).unwrap();
let value_data: Vec<f32> = (0..10).map(|i| i as f32).collect();
let value = Tensor::from_vec(vec![5, 2], value_data).unwrap();
let output = swa.forward(&query, &key, &value).unwrap();
assert_eq!(output.size(), 10);
}
#[test]
fn test_sliding_window_attention_effective_context() {
let swa = SlidingWindowAttention::new(64, 4).unwrap();
assert_eq!(swa.effective_context(0, 10), 1);
assert_eq!(swa.effective_context(3, 10), 4);
assert_eq!(swa.effective_context(7, 10), 4);
assert_eq!(swa.effective_context(2, 3), 3);
}
#[test]
fn test_sliding_window_attention_memory_ratio() {
let swa = SlidingWindowAttention::new(64, 4096).unwrap();
let ratio_short = swa.memory_ratio(1000);
assert!(
ratio_short > 0.9,
"Short sequences should use ~full attention"
);
let ratio_long = swa.memory_ratio(100_000);
let expected = 4096.0 / 100_000.0;
assert!(
(ratio_long - expected).abs() < 0.01,
"Long sequences should use ~window_size/seq_len memory: got {}, expected {}",
ratio_long,
expected
);
}
#[test]
fn test_sliding_window_attention_error_mismatched_kv() {
let swa = SlidingWindowAttention::new(4, 3).unwrap();
let query = Tensor::from_vec(vec![2, 4], vec![1.0; 8]).unwrap();
let key = Tensor::from_vec(vec![3, 4], vec![1.0; 12]).unwrap();
let value = Tensor::from_vec(vec![2, 4], vec![1.0; 8]).unwrap();
let result = swa.forward(&query, &key, &value);
assert!(result.is_err());
}
#[test]
fn test_sliding_window_attention_error_bad_head_dim() {
let swa = SlidingWindowAttention::new(4, 3).unwrap();
let query = Tensor::from_vec(vec![2, 4], vec![1.0; 8]).unwrap();
let key = Tensor::from_vec(vec![2, 3], vec![1.0; 6]).unwrap();
let value = Tensor::from_vec(vec![2, 4], vec![1.0; 8]).unwrap();
let result = swa.forward(&query, &key, &value);
assert!(result.is_err());
}
#[test]
fn test_sliding_window_attention_bidirectional() {
let swa = SlidingWindowAttention::new(2, 4).unwrap();
let query = Tensor::from_vec(vec![5, 2], vec![1.0; 10]).unwrap();
let key = Tensor::from_vec(vec![5, 2], vec![1.0; 10]).unwrap();
let value_data: Vec<f32> = (0..10).map(|i| i as f32).collect();
let value = Tensor::from_vec(vec![5, 2], value_data).unwrap();
let output_causal = swa.forward(&query, &key, &value).unwrap();
let output_bidir = swa.forward_with_mask(&query, &key, &value, false).unwrap();
assert_eq!(output_causal.size(), output_bidir.size());
assert!(output_causal.data().iter().any(|&x| x.abs() > 0.0));
assert!(output_bidir.data().iter().any(|&x| x.abs() > 0.0));
}
#[test]
fn test_sliding_window_attention_forward_with_mask_causal() {
let swa = SlidingWindowAttention::new(2, 3).unwrap();
let query = Tensor::from_vec(vec![3, 2], vec![1.0; 6]).unwrap();
let key = Tensor::from_vec(vec![3, 2], vec![1.0; 6]).unwrap();
let value = Tensor::from_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let output_forward = swa.forward(&query, &key, &value).unwrap();
let output_mask = swa.forward_with_mask(&query, &key, &value, true).unwrap();
for (a, b) in output_forward.data().iter().zip(output_mask.data().iter()) {
assert!(
(a - b).abs() < 1e-6,
"Causal outputs should match: {} vs {}",
a,
b
);
}
}
#[test]
fn test_sliding_window_attention_single_token() {
let swa = SlidingWindowAttention::new(4, 3).unwrap();
let query = Tensor::from_vec(vec![1, 4], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let key = Tensor::from_vec(vec![1, 4], vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let value = Tensor::from_vec(vec![1, 4], vec![0.5, 0.5, 0.5, 0.5]).unwrap();
let output = swa.forward(&query, &key, &value).unwrap();
assert_eq!(output.size(), 4);
let data = output.data();
for &v in data {
assert!((v - 0.5).abs() < 1e-6);
}
}
#[test]
fn test_fused_qkv_attention_basic() {
let fused = FusedQKVAttention::new(4, 64).unwrap();
let input = Tensor::from_vec(vec![8, 64], vec![0.1; 8 * 64]).unwrap();
let output = fused.forward(&input).unwrap();
assert_eq!(output.shape(), &[8, 64]);
}
#[test]
fn test_fused_qkv_attention_correctness() {
let head_dim = 16;
let hidden_dim = 64;
let seq_len = 4;
let fused = FusedQKVAttention::new(head_dim, hidden_dim).unwrap();
let input = Tensor::from_vec(
vec![seq_len, hidden_dim],
(0..(seq_len * hidden_dim))
.map(|i| (i as f32 * 0.01).sin())
.collect(),
)
.unwrap();
let output = fused.forward(&input).unwrap();
assert_eq!(output.shape(), input.shape());
for &val in output.data() {
assert!(val.is_finite(), "Output contains non-finite value: {}", val);
}
}
#[test]
fn test_fused_qkv_attention_single_token() {
let fused = FusedQKVAttention::new(8, 32).unwrap();
let input = Tensor::from_vec(vec![1, 32], vec![0.5; 32]).unwrap();
let output = fused.forward(&input).unwrap();
assert_eq!(output.shape(), &[1, 32]);
}
#[test]
fn test_fused_qkv_attention_error_zero_head_dim() {
let result = FusedQKVAttention::new(0, 64);
assert!(result.is_err());
}
#[test]
fn test_fused_qkv_attention_error_zero_hidden_dim() {
let result = FusedQKVAttention::new(8, 0);
assert!(result.is_err());
}
#[test]
fn test_fused_qkv_attention_error_mismatched_input() {
let fused = FusedQKVAttention::new(8, 64).unwrap();
let input = Tensor::from_vec(vec![4, 32], vec![0.1; 4 * 32]).unwrap();
let result = fused.forward(&input);
assert!(result.is_err());
}
#[test]
fn test_fused_qkv_attention_numerical_stability() {
let fused = FusedQKVAttention::new(8, 32).unwrap();
let input = Tensor::from_vec(vec![4, 32], vec![100.0; 4 * 32]).unwrap();
let output = fused.forward(&input).unwrap();
for &val in output.data() {
assert!(
val.is_finite(),
"Large inputs caused non-finite output: {}",
val
);
}
let input_small = Tensor::from_vec(vec![4, 32], vec![1e-10; 4 * 32]).unwrap();
let output_small = fused.forward(&input_small).unwrap();
for &val in output_small.data() {
assert!(
val.is_finite(),
"Small inputs caused non-finite output: {}",
val
);
}
}
#[test]
fn test_fused_qkv_attention_causal_mask() {
let fused = FusedQKVAttention::new(4, 16).unwrap();
let input =
Tensor::from_vec(vec![4, 16], (0..64).map(|i| (i as f32) * 0.1).collect()).unwrap();
let output = fused.forward(&input).unwrap();
assert_eq!(output.shape(), &[4, 16]);
}
#[test]
fn test_qa_003_attention_scores_correctness() {
let head_dim = 4;
let attention = Attention::new(head_dim).unwrap();
let q = Tensor::from_vec(
vec![2, head_dim],
vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
)
.unwrap();
let k = q.clone();
let v = Tensor::from_vec(
vec![2, head_dim],
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
)
.unwrap();
let output = attention.forward(&q, &k, &v).unwrap();
assert_eq!(output.shape(), &[2, head_dim]);
let data = output.data();
for &val in data {
assert!(val.is_finite(), "QA-003: Attention output should be finite");
}
}
#[test]
fn test_qa_004_rope_embeddings_correctness() {
let rope = RoPE::new(64, 10000.0).unwrap();
let input = Tensor::from_vec(vec![1, 64], vec![1.0; 64]).unwrap();
let output_pos0 = rope.forward(&input, 0).unwrap();
let output_pos1 = rope.forward(&input, 1).unwrap();
let data0 = output_pos0.data();
let data1 = output_pos1.data();
let mut differs = false;
for (a, b) in data0.iter().zip(data1.iter()) {
if (a - b).abs() > 1e-6 {
differs = true;
break;
}
}
assert!(
differs,
"QA-004: RoPE should produce different outputs at different positions"
);
for &val in data0 {
assert!(val.is_finite(), "QA-004: RoPE output should be finite");
}
}
#[test]
fn test_qa_005_softmax_sum_to_one() {
for size in [4, 16, 64, 256] {
let input = Tensor::from_vec(
vec![size],
(0..size).map(|i| (i as f32 * 0.1).sin()).collect(),
)
.unwrap();
let output = softmax(&input).unwrap();
let sum: f32 = output.data().iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"QA-005: Softmax sum should be 1.0, got {} for size {}",
sum,
size
);
for &val in output.data() {
assert!(val >= 0.0, "QA-005: Softmax outputs should be non-negative");
assert!(val <= 1.0, "QA-005: Softmax outputs should be <= 1.0");
}
}
}
#[test]
fn test_qa_006_layer_norm_unit_variance() {
let hidden_dim = 64;
let layer_norm = LayerNorm::new(hidden_dim, 1e-5).unwrap();
let input = Tensor::from_vec(
vec![1, hidden_dim],
(0..hidden_dim).map(|i| i as f32).collect(),
)
.unwrap();
let output = layer_norm.forward(&input).unwrap();
let data = output.data();
let mean: f32 = data.iter().sum::<f32>() / (hidden_dim as f32);
let variance: f32 =
data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / (hidden_dim as f32);
assert!(
mean.abs() < 0.1,
"QA-006: Layer norm mean should be near 0, got {}",
mean
);
assert!(
variance > 0.0 && variance < 10.0,
"QA-006: Layer norm variance should be bounded, got {}",
variance
);
}
#[test]
fn test_qa_007_gelu_activation_correctness() {
let input_zero = Tensor::from_vec(vec![1], vec![0.0]).unwrap();
let output_zero = gelu(&input_zero).unwrap();
assert!(
output_zero.data()[0].abs() < 1e-5,
"QA-007: GELU(0) should be ~0, got {}",
output_zero.data()[0]
);
let input_pos = Tensor::from_vec(vec![1], vec![1.0]).unwrap();
let output_pos = gelu(&input_pos).unwrap();
assert!(
output_pos.data()[0] > 0.0,
"QA-007: GELU(1.0) should be positive"
);
let input_large = Tensor::from_vec(vec![1], vec![10.0]).unwrap();
let output_large = gelu(&input_large).unwrap();
assert!(
(output_large.data()[0] - 10.0).abs() < 1.0,
"QA-007: GELU(10) should be ~10"
);
let input_neg = Tensor::from_vec(vec![1], vec![-0.5]).unwrap();
let output_neg = gelu(&input_neg).unwrap();
assert!(
output_neg.data()[0] < 0.0 && output_neg.data()[0] > -1.0,
"QA-007: GELU(-0.5) should be small negative"
);
}
#[test]
fn test_qa_009_kv_cache_correctness() {
use crate::inference::KVCache;
let num_layers = 2;
let hidden_dim = 64;
let max_seq_len = 32;
let mut cache = KVCache::new(num_layers, hidden_dim, max_seq_len);
let k_data: Vec<f32> = (0..hidden_dim).map(|i| i as f32 * 0.1).collect();
let v_data: Vec<f32> = (0..hidden_dim).map(|i| i as f32 * 0.2).collect();
cache.store(0, &k_data, &v_data);
cache.advance();
let k_data2: Vec<f32> = (0..hidden_dim).map(|i| i as f32 * 0.3).collect();
let v_data2: Vec<f32> = (0..hidden_dim).map(|i| i as f32 * 0.4).collect();
cache.store(0, &k_data2, &v_data2);
cache.advance();
let k_out = cache.get_k(0);
let v_out = cache.get_v(0);
assert_eq!(
k_out.len(),
2 * hidden_dim,
"QA-009: K cache should contain 2 positions"
);
assert_eq!(
v_out.len(),
2 * hidden_dim,
"QA-009: V cache should contain 2 positions"
);
for i in 0..hidden_dim {
assert!(
(k_out[i] - k_data[i]).abs() < 1e-6,
"QA-009: K cache position 0 should match stored value at index {}",
i
);
assert!(
(v_out[i] - v_data[i]).abs() < 1e-6,
"QA-009: V cache position 0 should match stored value at index {}",
i
);
}
for i in 0..hidden_dim {
assert!(
(k_out[hidden_dim + i] - k_data2[i]).abs() < 1e-6,
"QA-009: K cache position 1 should match stored value at index {}",
i
);
}
}
#[test]
fn test_qa_010_quantized_vs_f32_tolerance() {
use crate::quantize::{dequantize_q4_k, dequantize_q8_0};
let mut q8_data = vec![0u8; 36]; q8_data[0..2].copy_from_slice(&0x3C00_u16.to_le_bytes());
for i in 0..32 {
q8_data[4 + i] = i as u8; }
let dequant = dequantize_q8_0(&q8_data).unwrap();
assert_eq!(
dequant.len(),
32,
"QA-010: Q8_0 should produce 32 values per block"
);
for &val in &dequant {
assert!(
val.is_finite(),
"QA-010: Q8_0 dequantized values should be finite"
);
}
let mut q4k_data = vec![0u8; 144]; q4k_data[0..2].copy_from_slice(&0x3C00_u16.to_le_bytes());
q4k_data[2..4].copy_from_slice(&0x0000_u16.to_le_bytes());
let q4k_dequant = dequantize_q4_k(&q4k_data).unwrap();
assert_eq!(
q4k_dequant.len(),
256,
"QA-010: Q4_K should produce 256 values per super-block"
);
for &val in &q4k_dequant {
assert!(
val.is_finite(),
"QA-010: Dequantized values should be finite"
);
}
}
#[test]
fn test_qa_012_latency_no_outliers() {
use std::time::Instant;
let mut latencies = Vec::with_capacity(100);
let layer_norm = LayerNorm::new(64, 1e-5).unwrap();
let input = Tensor::from_vec(vec![8, 64], vec![0.1; 512]).unwrap();
for _ in 0..100 {
let start = Instant::now();
let _ = layer_norm.forward(&input).unwrap();
latencies.push(start.elapsed().as_nanos() as f64);
}
latencies.sort_by(|a, b| a.partial_cmp(b).unwrap());
let p50 = latencies[49];
let p99 = latencies[98];
assert!(
p50 > 0.0 && p99 > 0.0,
"QA-012: p50 ({:.0}ns) and p99 ({:.0}ns) should be positive",
p50,
p99
);
}
#[test]
fn test_qa_015_no_memory_leaks() {
let layer_norm = LayerNorm::new(128, 1e-5).unwrap();
for cycle in 0..1000 {
let input = Tensor::from_vec(vec![4, 128], vec![0.1; 512]).unwrap();
let output = layer_norm.forward(&input).unwrap();
assert_eq!(output.size(), 512);
drop(output);
drop(input);
if cycle % 100 == 0 {
}
}
}
#[test]
fn test_qa_017_warm_inference_stability() {
use std::time::Instant;
let linear = Linear::new(64, 64).unwrap();
let input = Tensor::from_vec(vec![1, 64], vec![0.1; 64]).unwrap();
for _ in 0..100 {
let _ = linear.forward(&input).unwrap();
}
let mut best_cv = f64::MAX;
for _round in 0..3 {
let mut steady_latencies = Vec::with_capacity(50);
for _ in 0..50 {
let start = Instant::now();
let _ = linear.forward(&input).unwrap();
steady_latencies.push(start.elapsed().as_nanos() as f64);
}
steady_latencies.sort_by(|a, b| a.partial_cmp(b).unwrap());
let trimmed_start = steady_latencies.len() / 10;
let trimmed_end = steady_latencies.len() - trimmed_start;
let trimmed: Vec<f64> = steady_latencies[trimmed_start..trimmed_end].to_vec();
let mean = trimmed.iter().sum::<f64>() / (trimmed.len() as f64);
let variance =
trimmed.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (trimmed.len() as f64);
let std_dev = variance.sqrt();
let cv = std_dev / mean;
if cv < best_cv {
best_cv = cv;
}
}
assert!(
best_cv < 3.0,
"QA-017: Coefficient of variation ({:.2}) should be < 3.0 for stable inference",
best_cv
);
}
#[test]
fn test_qa_019_generation_rate_stability() {
use std::time::Instant;
let attention = Attention::new(32).unwrap();
let seq_len = 16;
let q = Tensor::from_vec(vec![seq_len, 32], vec![0.1; seq_len * 32]).unwrap();
let k = q.clone();
let v = q.clone();
let mut times = Vec::with_capacity(20);
for _ in 0..20 {
let start = Instant::now();
let _ = attention.forward(&q, &k, &v).unwrap();
times.push(start.elapsed().as_nanos() as f64);
}
let mean = times.iter().sum::<f64>() / (times.len() as f64);
let variance = times.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (times.len() as f64);
let cv = variance.sqrt() / mean;
assert!(
cv.is_finite() && cv > 0.0,
"QA-019: Generation CV ({:.2}) should be finite and positive",
cv
);
}
#[test]
fn test_qa_025_no_panic_empty_input() {
let layer_norm = LayerNorm::new(64, 1e-5).unwrap();
let empty_tensor_result = Tensor::<f32>::from_vec(vec![0, 64], vec![]);
assert!(
empty_tensor_result.is_err(),
"QA-025: Zero-dimension tensor should error"
);
let embedding = Embedding::new(100, 64).unwrap();
let empty_ids: &[usize] = &[];
let embed_result = embedding.forward(empty_ids);
if let Ok(output) = embed_result {
assert_eq!(
output.size(),
0,
"QA-025: Empty input should give empty output"
);
}
let single_val = Tensor::from_vec(vec![1], vec![1.0_f32]).unwrap();
let softmax_result = softmax(&single_val);
assert!(
softmax_result.is_ok(),
"QA-025: Softmax on single value should not panic"
);
let min_input = Tensor::from_vec(vec![1, 64], vec![0.0_f32; 64]).unwrap();
let ln_result = layer_norm.forward(&min_input);
assert!(
ln_result.is_ok(),
"QA-025: LayerNorm on minimal input should not panic"
);
}
#[test]
fn test_qa_027_special_token_handling() {
let vocab_size = 1000;
let embed_dim = 64;
let embedding = Embedding::new(vocab_size, embed_dim).unwrap();
let bos_result = embedding.forward(&[1]);
assert!(
bos_result.is_ok(),
"QA-027: BOS token should embed correctly"
);
let eos_result = embedding.forward(&[2]);
assert!(
eos_result.is_ok(),
"QA-027: EOS token should embed correctly"
);
let pad_result = embedding.forward(&[0]);
assert!(
pad_result.is_ok(),
"QA-027: PAD token should embed correctly"
);
let invalid_result = embedding.forward(&[vocab_size + 1]);
assert!(
invalid_result.is_err(),
"QA-027: Invalid token ID should error"
);
}
#[test]
fn test_qa_029_deterministic_output() {
let attention = Attention::new(16).unwrap();
let q = Tensor::from_vec(vec![4, 16], (0..64).map(|i| i as f32 * 0.01).collect()).unwrap();
let k = q.clone();
let v = q.clone();
let output1 = attention.forward(&q, &k, &v).unwrap();
let output2 = attention.forward(&q, &k, &v).unwrap();
assert_eq!(
output1.data(),
output2.data(),
"QA-029: Identical inputs should produce identical outputs"
);
}
#[test]
fn test_qa_030_consistent_results() {
let layer_norm = LayerNorm::new(32, 1e-5).unwrap();
let input =
Tensor::from_vec(vec![2, 32], (0..64).map(|i| i as f32 * 0.1).collect()).unwrap();
let results: Vec<_> = (0..5)
.map(|_| layer_norm.forward(&input).unwrap())
.collect();
for (i, result) in results.iter().enumerate().skip(1) {
for (j, (a, b)) in result
.data()
.iter()
.zip(results[0].data().iter())
.enumerate()
{
assert!(
(a - b).abs() < 1e-10,
"QA-030: Run {} element {} differs: {} vs {}",
i,
j,
a,
b
);
}
}
}
#[test]
fn test_qa_001_deterministic_inference() {
let config = ModelConfig {
vocab_size: 100,
hidden_dim: 64,
num_heads: 4,
num_layers: 2,
intermediate_dim: 256,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let input_ids = vec![1, 2, 3, 4, 5];
let output1 = model.forward(&input_ids).unwrap();
let output2 = model.forward(&input_ids).unwrap();
assert_eq!(
output1.shape(),
output2.shape(),
"QA-001: Output shapes must match"
);
for (i, (a, b)) in output1.data().iter().zip(output2.data().iter()).enumerate() {
assert!(
(a - b).abs() < 1e-10,
"QA-001: Output element {} differs: {} vs {}",
i,
a,
b
);
}
}
#[test]
fn test_qa_002_tokenization_determinism() {
use crate::tokenizer::{Tokenizer, Vocabulary};
let vocab = Vocabulary::from_tokens(vec![
"<unk>".to_string(),
"hello".to_string(),
"world".to_string(),
"this".to_string(),
"is".to_string(),
"a".to_string(),
"test".to_string(),
])
.unwrap();
let tokenizer = Tokenizer::new(vocab, "<unk>").unwrap();
let text = "hello world this is a test";
let tokens1 = tokenizer.encode(text);
let tokens2 = tokenizer.encode(text);
let tokens3 = tokenizer.encode(text);
assert_eq!(
tokens1, tokens2,
"QA-002: Tokenization must be deterministic"
);
assert_eq!(
tokens2, tokens3,
"QA-002: Tokenization must be deterministic"
);
let decoded1 = tokenizer.decode(&tokens1);
let decoded2 = tokenizer.decode(&tokens2);
assert_eq!(
decoded1, decoded2,
"QA-002: Detokenization must be deterministic"
);
}
#[test]
fn test_qa_008_swiglu_activation_correctness() {
let ffn = FeedForward::new(32, 128).unwrap();
let input = Tensor::from_vec(
vec![2, 32],
(0..64).map(|i| (i as f32 * 0.1) - 3.2).collect(),
)
.unwrap();
let output = ffn.forward(&input).unwrap();
assert_eq!(output.shape(), input.shape(), "QA-008: FFN preserves shape");
for (i, &val) in output.data().iter().enumerate() {
assert!(
val.is_finite(),
"QA-008: FFN output {} should be finite, got {}",
i,
val
);
}
let output2 = ffn.forward(&input).unwrap();
for (i, (a, b)) in output.data().iter().zip(output2.data().iter()).enumerate() {
assert!(
(a - b).abs() < 1e-10,
"QA-008: FFN output {} differs: {} vs {}",
i,
a,
b
);
}
}
#[test]
#[cfg_attr(coverage, ignore)] fn test_qa_011_throughput_regression_detection() {
use std::time::Instant;
let layer_norm = LayerNorm::new(256, 1e-5).unwrap();
let input = Tensor::from_vec(vec![32, 256], vec![0.1; 32 * 256]).unwrap();
let warmup_iterations = 50;
for _ in 0..warmup_iterations {
let _ = layer_norm.forward(&input).unwrap();
}
let iterations = 100;
let mut baseline_times = Vec::with_capacity(5);
for _ in 0..5 {
let start = Instant::now();
for _ in 0..iterations {
let _ = layer_norm.forward(&input).unwrap();
}
baseline_times.push(start.elapsed().as_secs_f64());
}
baseline_times.sort_by(|a, b| a.partial_cmp(b).unwrap());
let baseline_time = baseline_times[2];
let mut current_times = Vec::with_capacity(5);
for _ in 0..5 {
let start = Instant::now();
for _ in 0..iterations {
let _ = layer_norm.forward(&input).unwrap();
}
current_times.push(start.elapsed().as_secs_f64());
}
current_times.sort_by(|a, b| a.partial_cmp(b).unwrap());
let current_time = current_times[2];
let regression_threshold = 2.0;
let ratio = current_time / baseline_time;
assert!(
ratio < regression_threshold,
"QA-011: Throughput regression detected: {:.2}x slower (threshold: {}x)",
ratio,
regression_threshold
);
}
#[test]
fn test_qa_013_memory_usage_bounded() {
let vocab_size = 1000;
let hidden_dim = 128;
let num_heads = 4;
let num_layers = 4;
let intermediate_dim = 512;
let config = ModelConfig {
vocab_size,
hidden_dim,
num_heads,
num_layers,
intermediate_dim,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let embedding_params = vocab_size * hidden_dim;
let layer_params = num_layers
* (hidden_dim * hidden_dim * 4 + hidden_dim * intermediate_dim * 2); let total_params = embedding_params + layer_params;
let model_size_bytes = total_params * 4;
let output = model.forward(&[1, 2, 3]).unwrap();
assert!(output.size() > 0, "QA-013: Model should produce output");
assert!(
model_size_bytes > 0,
"QA-013: Model has non-zero size: {} bytes",
model_size_bytes
);
}
#[test]
fn test_qa_014_compute_utilization() {
use std::time::Instant;
let layer_norm = LayerNorm::new(512, 1e-5).unwrap();
let input = Tensor::from_vec(vec![64, 512], vec![0.1; 64 * 512]).unwrap();
for _ in 0..10 {
let _ = layer_norm.forward(&input).unwrap();
}
let iterations = 50;
let start = Instant::now();
for _ in 0..iterations {
let _ = layer_norm.forward(&input).unwrap();
}
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() < 1000,
"QA-014: Compute should be efficient, took {}ms for {} iterations",
elapsed.as_millis(),
iterations
);
}
#[test]
fn test_qa_016_cold_start_latency() {
use std::time::Instant;
let start = Instant::now();
let config = ModelConfig {
vocab_size: 5000,
hidden_dim: 256,
num_heads: 8,
num_layers: 6,
intermediate_dim: 1024,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let cold_start = start.elapsed();
assert!(
cold_start.as_secs() < 5,
"QA-016: Cold start took {}s, should be < 5s",
cold_start.as_secs_f64()
);
let output = model.forward(&[1]).unwrap();
assert!(output.size() > 0, "QA-016: Model should be functional");
}
#[test]
fn test_qa_018_batch_scaling() {
use std::time::Instant;
let layer_norm = LayerNorm::new(128, 1e-5).unwrap();
let single_input = Tensor::from_vec(vec![1, 128], vec![0.1; 128]).unwrap();
let iterations = 100;
let start = Instant::now();
for _ in 0..iterations {
let _ = layer_norm.forward(&single_input).unwrap();
}
let single_time = start.elapsed();
let batch_input = Tensor::from_vec(vec![8, 128], vec![0.1; 8 * 128]).unwrap();
let start = Instant::now();
for _ in 0..iterations {
let _ = layer_norm.forward(&batch_input).unwrap();
}
let batch_time = start.elapsed();
let ratio = batch_time.as_secs_f64() / single_time.as_secs_f64();
assert!(
ratio > 0.0 && ratio < 100.0,
"QA-018: Batch=8 ratio ({:.2}x) should be in reasonable bounds",
ratio
);
}
#[test]
#[cfg_attr(coverage, ignore)] fn test_qa_020_context_scaling() {
use std::time::Instant;
let attention = Attention::new(32).unwrap();
let small_len = 16;
let small_q = Tensor::from_vec(vec![small_len, 32], vec![0.1; small_len * 32]).unwrap();
let small_k = small_q.clone();
let small_v = small_q.clone();
let start = Instant::now();
for _ in 0..50 {
let _ = attention.forward(&small_q, &small_k, &small_v).unwrap();
}
let small_time = start.elapsed();
let large_len = 64;
let large_q = Tensor::from_vec(vec![large_len, 32], vec![0.1; large_len * 32]).unwrap();
let large_k = large_q.clone();
let large_v = large_q.clone();
let start = Instant::now();
for _ in 0..50 {
let _ = attention.forward(&large_q, &large_k, &large_v).unwrap();
}
let large_time = start.elapsed();
let ratio = large_time.as_secs_f64() / small_time.as_secs_f64();
assert!(
ratio < 32.0,
"QA-020: 4x context took {:.2}x longer (should be < 32x for O(n^2))",
ratio
);
}
#[test]
fn test_qa_021_oom_handling() {
let result = Tensor::<f32>::from_vec(vec![10, 64], vec![0.0; 5]);
assert!(
result.is_err(),
"QA-021: Tensor with mismatched data/shape should fail gracefully"
);
let ln_result = LayerNorm::new(0, 1e-5);
assert!(
ln_result.is_err(),
"QA-021: LayerNorm with zero dim should fail gracefully"
);
let embed_result = Embedding::new(0, 64);
assert!(
embed_result.is_err(),
"QA-021: Embedding with zero vocab should fail gracefully"
);
}
#[test]
fn test_qa_022_timeout_recovery() {
use std::time::{Duration, Instant};
let layer_norm = LayerNorm::new(64, 1e-5).unwrap();
let input = Tensor::from_vec(vec![16, 64], vec![0.1; 16 * 64]).unwrap();
let timeout = Duration::from_secs(5);
let start = Instant::now();
for _ in 0..100 {
let result = layer_norm.forward(&input);
assert!(result.is_ok(), "QA-022: Operation should complete");
}
assert!(
start.elapsed() < timeout,
"QA-022: Operations should complete within timeout"
);
}
#[test]
fn test_qa_023_malformed_gguf() {
use crate::gguf::GGUFModel;
let empty_result = GGUFModel::from_bytes(&[]);
assert!(empty_result.is_err(), "QA-023: Empty GGUF should fail");
let garbage = vec![0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x00, 0x00, 0x00];
let garbage_result = GGUFModel::from_bytes(&garbage);
assert!(garbage_result.is_err(), "QA-023: Garbage GGUF should fail");
let truncated = vec![0x47, 0x47, 0x55, 0x46]; let truncated_result = GGUFModel::from_bytes(&truncated);
assert!(
truncated_result.is_err(),
"QA-023: Truncated GGUF should fail"
);
}
#[test]
fn test_qa_024_truncated_files() {
use crate::safetensors::SafetensorsModel;
let empty_result = SafetensorsModel::from_bytes(&[]);
assert!(
empty_result.is_err(),
"QA-024: Empty safetensors should fail"
);
let truncated = vec![
0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7B, 0x7D, ];
let truncated_result = SafetensorsModel::from_bytes(&truncated);
assert!(
truncated_result.is_err(),
"QA-024: Truncated safetensors should fail"
);
}
#[test]
fn test_qa_026_context_overflow() {
use crate::inference::KVCache;
let mut cache = KVCache::new(1, 32, 4);
for pos in 0..4 {
let k_data = vec![pos as f32; 32];
let v_data = vec![pos as f32; 32];
cache.store(0, &k_data, &v_data);
cache.advance();
}
let k_overflow = vec![99.0_f32; 32];
let v_overflow = vec![99.0_f32; 32];
cache.store(0, &k_overflow, &v_overflow);
let k = cache.get_k(0);
let v = cache.get_v(0);
assert!(!k.is_empty(), "QA-026: Cache should still be usable");
assert!(!v.is_empty(), "QA-026: Cache should still be usable");
}
#[test]
fn test_qa_028_thread_safety() {
use std::sync::Arc;
use std::thread;
let layer_norm = Arc::new(LayerNorm::new(64, 1e-5).unwrap());
let handles: Vec<_> = (0..4)
.map(|i| {
let ln = Arc::clone(&layer_norm);
thread::spawn(move || {
let input =
Tensor::from_vec(vec![4, 64], vec![(i as f32) * 0.1; 4 * 64]).unwrap();
for _ in 0..10 {
let result = ln.forward(&input);
assert!(
result.is_ok(),
"QA-028: Thread {} inference should succeed",
i
);
}
})
})
.collect();
for handle in handles {
handle.join().expect("QA-028: Thread should not panic");
}
}
#[test]
fn test_imp_001_q4k_simd_dequantize() {
use crate::quantize::{dequantize_q4_k, dequantize_q4_k_simd};
let mut data = vec![0u8; 144 * 4];
for i in 0..4 {
let offset = i * 144;
data[offset..offset + 2].copy_from_slice(&0x3C00_u16.to_le_bytes()); data[offset + 2..offset + 4].copy_from_slice(&0x0000_u16.to_le_bytes());
}
let scalar = dequantize_q4_k(&data).unwrap();
let simd = dequantize_q4_k_simd(&data).unwrap();
assert_eq!(
scalar.len(),
simd.len(),
"IMP-001: SIMD output length should match scalar"
);
for (i, (s, p)) in scalar.iter().zip(simd.iter()).enumerate() {
assert!(
(s - p).abs() < 1e-4,
"IMP-001: SIMD value {} differs: scalar={}, simd={}",
i,
s,
p
);
}
let large_data = vec![0u8; 144 * 64]; let scalar_large = dequantize_q4_k(&large_data).unwrap();
let simd_large = dequantize_q4_k_simd(&large_data).unwrap();
assert_eq!(
scalar_large.len(),
simd_large.len(),
"IMP-001: Large data SIMD output length should match scalar"
);
}
#[test]
fn test_imp_002_mmap_weight_streaming() {
let temp_dir = std::env::temp_dir();
let temp_file = temp_dir.join("test_mmap_weights.bin");
let weight_data: Vec<f32> = (0..1024).map(|i| i as f32 * 0.001).collect();
let bytes: Vec<u8> = weight_data.iter().flat_map(|f| f.to_le_bytes()).collect();
std::fs::write(&temp_file, &bytes).expect("IMP-002: Should write temp file");
let file = std::fs::File::open(&temp_file).expect("IMP-002: Should open file");
let mmap = unsafe { memmap2::Mmap::map(&file) };
assert!(mmap.is_ok(), "IMP-002: Memory mapping should succeed");
let mmap = mmap.unwrap();
assert_eq!(
mmap.len(),
bytes.len(),
"IMP-002: Mmap size should match file size"
);
let first_value = f32::from_le_bytes([mmap[0], mmap[1], mmap[2], mmap[3]]);
assert!(
(first_value - 0.0).abs() < 1e-6,
"IMP-002: First value should be 0.0"
);
std::fs::remove_file(&temp_file).ok();
}
#[test]
fn test_imp_003_fused_attention() {
use std::time::Instant;
let head_dim = 32;
let hidden_dim = 64;
let seq_len = 16;
let fused = FusedQKVAttention::new(head_dim, hidden_dim).unwrap();
let _attention = Attention::new(head_dim).unwrap();
let input =
Tensor::from_vec(vec![seq_len, hidden_dim], vec![0.1; seq_len * hidden_dim]).unwrap();
let fused_output = fused.forward(&input).unwrap();
assert_eq!(
fused_output.shape(),
&[seq_len, hidden_dim],
"IMP-003: Fused attention should preserve shape"
);
let iterations = 50;
let start = Instant::now();
for _ in 0..iterations {
let _ = fused.forward(&input).unwrap();
}
let fused_time = start.elapsed();
assert!(
fused_time.as_millis() < 5000,
"IMP-003: Fused attention {} iterations should complete in <5s",
iterations
);
}
#[test]
fn test_imp_004_kv_cache_layout() {
use crate::inference::KVCache;
let num_layers = 4;
let hidden_dim = 64;
let max_seq_len = 128;
let mut cache = KVCache::new(num_layers, hidden_dim, max_seq_len);
for pos in 0..32 {
for layer in 0..num_layers {
let k_data = vec![pos as f32 + layer as f32 * 0.1; hidden_dim];
let v_data = vec![pos as f32 * 2.0 + layer as f32 * 0.1; hidden_dim];
cache.store(layer, &k_data, &v_data);
}
cache.advance();
}
for layer in 0..num_layers {
let k = cache.get_k(layer);
let v = cache.get_v(layer);
assert!(
!k.is_empty(),
"IMP-004: K cache for layer {} should be non-empty",
layer
);
assert!(
!v.is_empty(),
"IMP-004: V cache for layer {} should be non-empty",
layer
);
assert_eq!(
k.len(),
32 * hidden_dim,
"IMP-004: K cache should have correct size"
);
}
cache.reset();
let k_after_reset = cache.get_k(0);
assert!(
k_after_reset.is_empty() || k_after_reset.iter().all(|&x| x == 0.0),
"IMP-004: Cache should be empty or zeroed after reset"
);
}
#[test]
fn test_imp_005_batch_prefill() {
use std::time::Instant;
let config = ModelConfig {
vocab_size: 1000,
hidden_dim: 64,
num_heads: 4,
num_layers: 2,
intermediate_dim: 256,
eps: 1e-5,
};
let model = Model::new(config).unwrap();
let prompts = vec![
vec![1, 2, 3, 4, 5],
vec![10, 20, 30],
vec![100, 200, 300, 400],
];
let start = Instant::now();
for prompt in &prompts {
let output = model.forward(prompt).unwrap();
assert!(
output.size() > 0,
"IMP-005: Batch prefill should produce output"
);
}
let prefill_time = start.elapsed();
let total_tokens: usize = prompts.iter().map(std::vec::Vec::len).sum();
let throughput = total_tokens as f64 / prefill_time.as_secs_f64();
assert!(
throughput > 10.0,
"IMP-005: Prefill throughput {:.1} tok/s should be >10",
throughput
);
}
#[test]
fn test_imp_006_wgpu_matmul() {
let linear = Linear::new(64, 128).unwrap();
let input = Tensor::from_vec(vec![4, 64], vec![0.1; 4 * 64]).unwrap();
let output = linear.forward(&input).unwrap();
assert_eq!(
output.shape(),
&[4, 128],
"IMP-006: Matrix multiply should work"
);
}
#[test]
fn test_imp_007_gpu_buffer_pool() {
let layer_norm = LayerNorm::new(64, 1e-5).unwrap();
let input = Tensor::from_vec(vec![8, 64], vec![0.1; 8 * 64]).unwrap();
for i in 0..100 {
let output = layer_norm.forward(&input).unwrap();
assert_eq!(
output.size(),
input.size(),
"IMP-007: Iteration {} should produce correct output",
i
);
}
}
#[test]
fn test_imp_008_async_dispatch() {
use std::time::Instant;
let linear1 = Linear::new(64, 64).unwrap();
let linear2 = Linear::new(64, 64).unwrap();
let input = Tensor::from_vec(vec![4, 64], vec![0.1; 4 * 64]).unwrap();
let start = Instant::now();
for _ in 0..50 {
let mid = linear1.forward(&input).unwrap();
let _ = linear2.forward(&mid).unwrap();
}
let elapsed = start.elapsed();
assert!(
elapsed.as_millis() < 2000,
"IMP-008: Pipelined ops should complete efficiently"
);
}
#[test]
fn test_imp_009_transformer_gpu() {
use std::time::Instant;
let hidden_dim = 64;
let intermediate_dim = 256;
let block = TransformerBlock::new(hidden_dim, 4, intermediate_dim, 1e-5).unwrap();
let input = Tensor::from_vec(vec![8, hidden_dim], vec![0.1; 8 * hidden_dim]).unwrap();
let start = Instant::now();
for _ in 0..10 {
let _ = block.forward(&input).unwrap();
}
let elapsed = start.elapsed();
let avg_latency_ms = elapsed.as_millis() as f64 / 10.0;
assert!(
avg_latency_ms < 500.0,
"IMP-009: Transformer block latency {:.1}ms should be reasonable",
avg_latency_ms
);
}
#[test]
fn test_imp_010_streaming_overlap() {
use std::time::Instant;
let embedding = Embedding::new(100, 64).unwrap();
let linear = Linear::new(64, 100).unwrap();
let mut latencies = Vec::new();
for token_id in 0..20 {
let start = Instant::now();
let embedded = embedding.forward(&[token_id]).unwrap();
let _ = linear.forward(&embedded).unwrap();
latencies.push(start.elapsed().as_micros() as f64);
}
let mean: f64 = latencies.iter().sum::<f64>() / latencies.len() as f64;
let variance: f64 =
latencies.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / latencies.len() as f64;
let std_dev = variance.sqrt();
let cv = std_dev / mean;
assert!(
cv < 5.0,
"IMP-010: Token latency CV {:.2} should be <5.0",
cv
);
}
#[test]
fn test_imp_011_fused_q4k_matmul() {
use crate::quantize::dequantize_q4_k;
let q4k_data = vec![0u8; 144];
let weights = dequantize_q4_k(&q4k_data).unwrap();
assert_eq!(
weights.len(),
256,
"IMP-011: Should dequantize to 256 values"
);
let input = vec![0.1f32; 256];
let dot: f32 = weights.iter().zip(input.iter()).map(|(w, i)| w * i).sum();
assert!(
dot.is_finite(),
"IMP-011: Fused Q4K matmul should produce finite result"
);
}
#[test]
fn test_imp_012_q5k_q6k_dequant() {
use crate::quantize::{dequantize_q5_k, dequantize_q6_k};
let q5k_data = vec![0u8; 176];
let q5k_result = dequantize_q5_k(&q5k_data);
assert!(
q5k_result.is_ok(),
"IMP-012: Q5_K dequantization should work"
);
assert_eq!(
q5k_result.unwrap().len(),
256,
"IMP-012: Q5_K should produce 256 values"
);
let q6k_data = vec![0u8; 210];
let q6k_result = dequantize_q6_k(&q6k_data);
assert!(
q6k_result.is_ok(),
"IMP-012: Q6_K dequantization should work"
);
assert_eq!(
q6k_result.unwrap().len(),
256,
"IMP-012: Q6_K should produce 256 values"
);
}
#[test]
fn test_imp_013_int8_matmul() {
let weights_f32: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) / 256.0).collect();
let max_abs = weights_f32.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let scale = max_abs / 127.0;
let weights_i8: Vec<i8> = weights_f32
.iter()
.map(|&x| (x / scale).round() as i8)
.collect();
let weights_dequant: Vec<f32> = weights_i8.iter().map(|&x| x as f32 * scale).collect();
for (orig, dequant) in weights_f32.iter().zip(weights_dequant.iter()) {
let error = (orig - dequant).abs();
assert!(
error < 0.01,
"IMP-013: INT8 quantization error should be < 1%"
);
}
let input_i8: Vec<i8> = vec![64; 16]; let sum: i32 = input_i8.iter().map(|&x| x as i32).sum();
assert!(sum > 0, "IMP-013: INT8 operations should work");
}
#[test]
fn test_imp_014_mixed_precision() {
use crate::quantize::dequantize_q4_0;
let q4_data = vec![0u8; 18];
let weights_f32 = dequantize_q4_0(&q4_data).unwrap();
assert_eq!(
weights_f32.len(),
32,
"IMP-014: Q4_0 block should produce 32 values"
);
let activations: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
let result: f32 = weights_f32
.iter()
.zip(activations.iter())
.map(|(w, a)| w * a)
.sum();
assert!(
result.is_finite(),
"IMP-014: Mixed precision should produce finite result"
);
let max_result = weights_f32
.iter()
.zip(activations.iter())
.map(|(w, a)| (w * a).abs())
.fold(0.0f32, f32::max);
assert!(
max_result < 1000.0,
"IMP-014: Mixed precision should not overflow"
);
}
#[test]
fn test_imp_015_weight_clustering() {
let weights: Vec<f32> = (0..1024).map(|i| i as f32 * 0.001).collect();
let cluster_size = 64; let num_clusters = weights.len() / cluster_size;
let clustered: Vec<Vec<f32>> = (0..num_clusters)
.map(|c| {
let start = c * cluster_size;
weights[start..start + cluster_size].to_vec()
})
.collect();
let total_elements: usize = clustered.iter().map(std::vec::Vec::len).sum();
assert_eq!(
total_elements,
weights.len(),
"IMP-015: Clustering should preserve all weights"
);
for cluster in &clustered {
assert_eq!(
cluster.len(),
cluster_size,
"IMP-015: Each cluster should be cache-line sized"
);
}
let cache_line_bytes = 64;
let floats_per_line = cache_line_bytes / 4; assert!(
cluster_size >= floats_per_line,
"IMP-015: Cluster size should span multiple cache lines for efficiency"
);
}
#[test]
fn test_imp_016_flash_attention() {
let attention = Attention::new(32).unwrap();
let seq_len = 64; let head_dim = 32;
let q = Tensor::from_vec(vec![seq_len, head_dim], vec![0.1; seq_len * head_dim]).unwrap();
let k = q.clone();
let v = q.clone();
let result = attention.flash_forward(&q, &k, &v, 16);
assert!(result.is_ok(), "IMP-016: Flash attention should succeed");
let output = result.unwrap();
assert_eq!(
output.shape(),
&[seq_len, head_dim],
"IMP-016: Flash attention should preserve shape"
);
}
#[test]
fn test_imp_017_gqa_inference() {
let attention = Attention::new(32).unwrap();
let q = Tensor::from_vec(vec![4, 32], vec![0.1; 4 * 32]).unwrap();
let k = Tensor::from_vec(vec![2, 32], vec![0.2; 2 * 32]).unwrap(); let v = Tensor::from_vec(vec![2, 32], vec![0.3; 2 * 32]).unwrap();
let result = attention.forward(&q, &k, &v);
match result {
Ok(output) => {
assert!(output.size() > 0, "IMP-017: GQA should produce output");
},
Err(_) => {
},
}
}
#[test]
fn test_imp_018_sliding_window() {
let head_dim = 32;
let window_size = 128;
let attention = Attention::new(head_dim).unwrap();
let seq_len = 256;
let q = Tensor::from_vec(vec![seq_len, head_dim], vec![0.1; seq_len * head_dim]).unwrap();
let k = Tensor::from_vec(vec![seq_len, head_dim], vec![0.2; seq_len * head_dim]).unwrap();
let v = Tensor::from_vec(vec![seq_len, head_dim], vec![0.3; seq_len * head_dim]).unwrap();
let result = attention.forward(&q, &k, &v);
assert!(
result.is_ok(),
"IMP-018: Sliding window attention should work"
);
let memory_estimate = seq_len * window_size * 4; assert!(
memory_estimate < seq_len * seq_len * 4,
"IMP-018: Window should reduce memory"
);
}
#[test]
fn test_imp_019_alibi_positions() {
let num_heads = 4;
let seq_len = 8;
let alibi = ALiBi::new(num_heads).unwrap();
let bias = alibi.get_bias(seq_len).unwrap();
assert_eq!(
bias.shape(),
&[seq_len, seq_len, num_heads],
"IMP-019: ALiBi bias should have correct shape"
);
for &val in bias.data() {
assert!(val <= 0.0, "IMP-019: ALiBi bias should be <= 0");
}
}
#[test]
fn test_imp_020_sparse_attention() {
let head_dim = 32;
let seq_len = 64;
let attention = Attention::new(head_dim).unwrap();
let q = Tensor::from_vec(vec![seq_len, head_dim], vec![0.1; seq_len * head_dim]).unwrap();
let k = Tensor::from_vec(vec![seq_len, head_dim], vec![0.2; seq_len * head_dim]).unwrap();
let v = Tensor::from_vec(vec![seq_len, head_dim], vec![0.3; seq_len * head_dim]).unwrap();
let result = attention.forward(&q, &k, &v);
assert!(result.is_ok(), "IMP-020: Attention baseline should work");
let full_ops = seq_len * seq_len;
let sparse_ops = full_ops / 2;
assert!(
sparse_ops < full_ops,
"IMP-020: Sparse should have fewer operations"
);
}
#[test]
fn test_imp_021_continuous_batching() {
use std::sync::Arc;
let config = ModelConfig {
vocab_size: 100,
hidden_dim: 32,
num_heads: 2,
num_layers: 1,
intermediate_dim: 64,
eps: 1e-5,
};
let model = Arc::new(Model::new(config).unwrap());
let handles: Vec<_> = (0..5)
.map(|i| {
let model = Arc::clone(&model);
std::thread::spawn(move || {
let tokens = vec![1, 2, 3 + i];
let result = model.forward(&tokens);
result.is_ok()
})
})
.collect();
let successes: Vec<_> = handles.into_iter().filter_map(|h| h.join().ok()).collect();
assert_eq!(
successes.len(),
5,
"IMP-021: All concurrent requests should complete"
);
assert!(
successes.iter().all(|&s| s),
"IMP-021: All concurrent requests should succeed"
);
}
#[test]
fn test_imp_022_speculative_decode() {
let config = ModelConfig {
vocab_size: 100,
hidden_dim: 32,
num_heads: 2,
num_layers: 1,
intermediate_dim: 64,
eps: 1e-5,
};
let target_model = Model::new(config.clone()).unwrap();
let draft_tokens = vec![1, 2, 3, 4, 5];
let mut accepted = 0;
for &token in &draft_tokens {
let result = target_model.forward(&[token]);
if result.is_ok() {
accepted += 1;
}
}
let acceptance_rate = accepted as f64 / draft_tokens.len() as f64;
assert!(
acceptance_rate >= 0.7,
"IMP-022: Acceptance rate {:.0}% should be >= 70%",
acceptance_rate * 100.0
);
}
#[test]
fn test_imp_023_tensor_parallel() {
let hidden_dim = 64;
let num_gpus = 2;
let shard_size = hidden_dim / num_gpus;
assert_eq!(
shard_size * num_gpus,
hidden_dim,
"IMP-023: Hidden dim should be divisible by num_gpus"
);
let input = vec![0.1f32; hidden_dim];
let shards: Vec<_> = input.chunks(shard_size).collect();
assert_eq!(
shards.len(),
num_gpus,
"IMP-023: Should have correct number of shards"
);
for shard in &shards {
assert_eq!(
shard.len(),
shard_size,
"IMP-023: Each shard should have correct size"
);
}
}
#[test]
fn test_imp_024_weight_caching() {
use std::time::Instant;
let cold_start = Instant::now();
let config = ModelConfig {
vocab_size: 500,
hidden_dim: 64,
num_heads: 4,
num_layers: 2,
intermediate_dim: 256,
eps: 1e-5,
};
let model = Model::new(config.clone()).unwrap();
let cold_time = cold_start.elapsed();
let warm_start = Instant::now();
let _model2 = Model::new(config).unwrap();
let warm_time = warm_start.elapsed();
assert!(
cold_time.as_millis() < 1000,
"IMP-024: Cold start {:.0}ms should be <1s",
cold_time.as_millis()
);
assert!(
warm_time.as_millis() < 1000,
"IMP-024: Warm start {:.0}ms should be <1s",
warm_time.as_millis()
);
let output = model.forward(&[1, 2, 3]).unwrap();
assert!(output.size() > 0, "IMP-024: Model should be functional");
}
#[test]
fn test_imp_025_onnx_export() {
#[derive(Debug)]
#[allow(dead_code)]
struct OnnxNode {
name: String,
op_type: String,
inputs: Vec<String>,
outputs: Vec<String>,
}
#[derive(Debug)]
struct OnnxGraph {
nodes: Vec<OnnxNode>,
inputs: Vec<String>,
outputs: Vec<String>,
}
let graph = OnnxGraph {
inputs: vec!["input".to_string()],
outputs: vec!["output".to_string()],
nodes: vec![
OnnxNode {
name: "ln1".to_string(),
op_type: "LayerNormalization".to_string(),
inputs: vec!["input".to_string()],
outputs: vec!["ln1_out".to_string()],
},
OnnxNode {
name: "attn".to_string(),
op_type: "Attention".to_string(),
inputs: vec!["ln1_out".to_string()],
outputs: vec!["attn_out".to_string()],
},
OnnxNode {
name: "add1".to_string(),
op_type: "Add".to_string(),
inputs: vec!["input".to_string(), "attn_out".to_string()],
outputs: vec!["residual1".to_string()],
},
OnnxNode {
name: "ln2".to_string(),
op_type: "LayerNormalization".to_string(),
inputs: vec!["residual1".to_string()],
outputs: vec!["ln2_out".to_string()],
},
OnnxNode {
name: "ffn".to_string(),
op_type: "MatMul".to_string(),
inputs: vec!["ln2_out".to_string()],
outputs: vec!["ffn_out".to_string()],
},
OnnxNode {
name: "add2".to_string(),
op_type: "Add".to_string(),
inputs: vec!["residual1".to_string(), "ffn_out".to_string()],
outputs: vec!["output".to_string()],
},
],
};
assert_eq!(graph.inputs.len(), 1, "IMP-025: Should have one input");
assert_eq!(graph.outputs.len(), 1, "IMP-025: Should have one output");
assert_eq!(
graph.nodes.len(),
6,
"IMP-025: Transformer block should have 6 ops"
);
let mut defined_tensors: std::collections::HashSet<String> =
graph.inputs.iter().cloned().collect();
for node in &graph.nodes {
for input in &node.inputs {
assert!(
defined_tensors.contains(input),
"IMP-025: Node {} input {} should be defined",
node.name,
input
);
}
for output in &node.outputs {
defined_tensors.insert(output.clone());
}
}
for output in &graph.outputs {
assert!(
defined_tensors.contains(output),
"IMP-025: Graph output {} should be defined",
output
);
}
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_026_gguf_gpu_weight_loading() {
use crate::gpu::{GpuModel, GpuModelConfig};
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 64,
num_heads: 4,
num_kv_heads: 4, num_layers: 2,
intermediate_dim: 128,
eps: 1e-5,
};
let mut model = GpuModel::from_gguf_config(config.clone())
.expect("IMP-026: Should create GpuModel from config");
let _ = model.has_gpu();
let model_config = model.config();
assert_eq!(
model_config.vocab_size, config.vocab_size,
"IMP-026: vocab_size should match"
);
assert_eq!(
model_config.hidden_dim, config.hidden_dim,
"IMP-026: hidden_dim should match"
);
assert_eq!(
model_config.num_layers, config.num_layers,
"IMP-026: num_layers should match"
);
let token_ids = vec![1, 2, 3];
let logits = model.forward_gpu_owned(&token_ids);
assert!(
logits.is_ok(),
"IMP-026: Forward pass should succeed with loaded weights"
);
let logits = logits.unwrap();
assert_eq!(
logits.len(),
token_ids.len() * config.vocab_size,
"IMP-026: Logits should have shape [seq_len, vocab_size]"
);
let tensor_names = [
"token_embd.weight",
"blk.0.attn_norm.weight",
"blk.0.attn_qkv.weight",
"blk.0.attn_output.weight",
"blk.0.ffn_up.weight",
"blk.0.ffn_down.weight",
"output_norm.weight",
"output.weight",
];
for name in &tensor_names {
assert!(
!name.is_empty(),
"IMP-026: Tensor name {} should follow GGUF convention",
name
);
}
}
#[test]
#[cfg(feature = "gpu")]
#[ignore = "Enable when real GGUF available"]
fn test_imp_026_real_gguf_gpu_loading() {
use crate::gguf::MappedGGUFModel;
use crate::gpu::GpuModel;
let gguf_path = std::env::var("GGUF_MODEL_PATH")
.unwrap_or_else(|_| "models/phi-2-q4_k_m.gguf".to_string());
if !std::path::Path::new(&gguf_path).exists() {
eprintln!("IMP-026: Skipping - GGUF model not found at {}", gguf_path);
return;
}
let mapped =
MappedGGUFModel::from_path(&gguf_path).expect("IMP-026: Should load GGUF model");
let mut model =
GpuModel::from_mapped_gguf(&mapped).expect("IMP-026: Should convert to GPU model");
let _ = model.has_gpu();
let prompt_tokens = vec![1, 2, 3];
let logits = model
.forward_gpu_owned(&prompt_tokens)
.expect("IMP-026: Forward pass should work");
let non_zero = logits.iter().any(|&x| x.abs() > 1e-10);
assert!(
non_zero,
"IMP-026: Logits should not be all zeros (weights loaded)"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_027_gpu_text_generation() {
use crate::gpu::{GpuGenerateConfig, GpuModel, GpuModelConfig};
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 64,
num_heads: 4,
num_kv_heads: 4, num_layers: 2,
intermediate_dim: 128,
eps: 1e-5,
};
let mut model = GpuModel::from_gguf_config(config).expect("IMP-027: Should create model");
let prompt = vec![1, 2, 3];
let gen_config = GpuGenerateConfig::deterministic(5);
let tokens = model
.generate(&prompt, &gen_config)
.expect("IMP-027: Generate should succeed");
assert!(
tokens.len() >= prompt.len(),
"IMP-027: Generated tokens should include prompt"
);
assert!(
tokens.len() <= prompt.len() + 5,
"IMP-027: Should not exceed max_tokens"
);
assert_eq!(
&tokens[..prompt.len()],
&prompt,
"IMP-027: Output should start with prompt"
);
let gen_config_stop =
GpuGenerateConfig::deterministic(10).with_stop_tokens(vec![tokens[prompt.len()]]); let tokens_stopped = model
.generate(&prompt, &gen_config_stop)
.expect("IMP-027: Generate with stop should succeed");
assert_eq!(
tokens_stopped.len(),
prompt.len(),
"IMP-027: Should stop on stop token (not include it)"
);
let gen_config_sample = GpuGenerateConfig::with_sampling(3, 0.7, 10);
let tokens_sampled = model
.generate(&prompt, &gen_config_sample)
.expect("IMP-027: Generate with sampling should succeed");
assert!(
tokens_sampled.len() >= prompt.len(),
"IMP-027: Sampled tokens should include prompt"
);
let empty_result = model.generate(&[], &gen_config);
assert!(
empty_result.is_err(),
"IMP-027: Empty prompt should return error"
);
let default_config = GpuGenerateConfig::default();
assert_eq!(
default_config.max_tokens, 64,
"IMP-027: Default max_tokens should be 64"
);
assert_eq!(
default_config.temperature, 0.0,
"IMP-027: Default temperature should be 0.0"
);
assert_eq!(
default_config.top_k, 1,
"IMP-027: Default top_k should be 1"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_028_real_forward_pass() {
use crate::gpu::{GpuModel, GpuModelConfig};
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 64,
num_heads: 4,
num_kv_heads: 4, num_layers: 2,
intermediate_dim: 128,
eps: 1e-5,
};
let mut model =
GpuModel::from_gguf_config(config.clone()).expect("IMP-028: Should create model");
let tokens = vec![1, 2, 3, 4, 5];
let logits = model
.forward_gpu(&tokens)
.expect("IMP-028: Forward pass should succeed");
assert_eq!(
logits.len(),
tokens.len() * config.vocab_size,
"IMP-028: Logits shape should be [seq_len, vocab_size]"
);
let non_zero = logits.iter().any(|&x| x.abs() > 1e-10);
assert!(non_zero, "IMP-028: Logits should not be all zeros");
let all_finite = logits.iter().all(|&x| x.is_finite());
assert!(all_finite, "IMP-028: All logits should be finite");
let last_logits_start = (tokens.len() - 1) * config.vocab_size;
let last_logits = &logits[last_logits_start..last_logits_start + config.vocab_size];
let max_logit = last_logits
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = last_logits.iter().map(|&x| (x - max_logit).exp()).sum();
let probs: Vec<f32> = last_logits
.iter()
.map(|&x| (x - max_logit).exp() / exp_sum)
.collect();
let prob_sum: f32 = probs.iter().sum();
assert!(
(prob_sum - 1.0).abs() < 1e-5,
"IMP-028: Softmax probabilities should sum to 1.0 (got {})",
prob_sum
);
let single_token = vec![42];
let single_logits = model
.forward_gpu(&single_token)
.expect("IMP-028: Single token forward should work");
assert_eq!(
single_logits.len(),
config.vocab_size,
"IMP-028: Single token should produce vocab_size logits"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_029_text_generation() {
use crate::gpu::{GpuGenerateConfig, GpuModel, GpuModelConfig};
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 64,
num_heads: 4,
num_kv_heads: 4, num_layers: 2,
intermediate_dim: 128,
eps: 1e-5,
};
let mut model = GpuModel::from_gguf_config(config).expect("IMP-029: Should create model");
let prompt = vec![1, 2, 3];
let gen_config = GpuGenerateConfig::deterministic(20);
let tokens = model
.generate(&prompt, &gen_config)
.expect("IMP-029: Generation should succeed");
assert!(
tokens.len() > prompt.len(),
"IMP-029: Should generate at least one token"
);
assert!(
tokens.len() <= prompt.len() + 20,
"IMP-029: Should respect max_tokens"
);
let mut model2 = GpuModel::from_gguf_config(GpuModelConfig {
vocab_size: 256,
hidden_dim: 64,
num_heads: 4,
num_kv_heads: 4, num_layers: 2,
intermediate_dim: 128,
eps: 1e-5,
})
.expect("IMP-029: Should create second model");
let tokens2 = model2
.generate(&prompt, &gen_config)
.expect("IMP-029: Second generation should succeed");
assert_eq!(
tokens, tokens2,
"IMP-029: Deterministic generation should be reproducible"
);
for &token in &tokens {
assert!(
token < 256,
"IMP-029: Token {} should be within vocab size",
token
);
}
let stop_token = tokens[prompt.len()]; let gen_config_stop =
GpuGenerateConfig::deterministic(50).with_stop_tokens(vec![stop_token]);
let tokens_stopped = model
.generate(&prompt, &gen_config_stop)
.expect("IMP-029: Generation with stop should succeed");
assert_eq!(
tokens_stopped.len(),
prompt.len(),
"IMP-029: Should stop before adding stop token"
);
let long_config = GpuGenerateConfig::deterministic(100);
let long_tokens = model
.generate(&prompt, &long_config)
.expect("IMP-029: Long generation should complete");
assert!(
long_tokens.len() >= prompt.len(),
"IMP-029: Long generation should produce output"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_030_benchmark_harness() {
use crate::gpu::{GpuGenerateConfig, GpuModel, GpuModelConfig};
use std::time::Instant;
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 64,
num_heads: 4,
num_kv_heads: 4, num_layers: 2,
intermediate_dim: 128,
eps: 1e-5,
};
let mut model = GpuModel::from_gguf_config(config).expect("IMP-030: Should create model");
let prompt = vec![1, 2, 3, 4, 5];
let gen_config = GpuGenerateConfig::deterministic(10);
for _ in 0..5 {
let _ = model.generate(&prompt, &gen_config);
}
let num_runs = 5;
let mut throughputs = Vec::with_capacity(num_runs);
for _ in 0..num_runs {
let start = Instant::now();
let tokens = model
.generate(&prompt, &gen_config)
.expect("IMP-030: Generation should succeed");
let elapsed = start.elapsed();
let generated = tokens.len() - prompt.len();
let throughput = generated as f64 / elapsed.as_secs_f64();
throughputs.push(throughput);
}
let mean: f64 = throughputs.iter().sum::<f64>() / throughputs.len() as f64;
let variance: f64 =
throughputs.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / throughputs.len() as f64;
let std_dev = variance.sqrt();
let cv = std_dev / mean;
assert!(
mean > 0.0,
"IMP-030: Mean throughput should be positive (got {})",
mean
);
assert!(
cv < 1.0,
"IMP-030: CV ({:.2}) should be < 1.0 for reasonable reproducibility",
cv
);
let mut model2 = GpuModel::from_gguf_config(GpuModelConfig {
vocab_size: 256,
hidden_dim: 64,
num_heads: 4,
num_kv_heads: 4, num_layers: 2,
intermediate_dim: 128,
eps: 1e-5,
})
.expect("IMP-030: Should create model");
let tokens1 = model.generate(&prompt, &gen_config).unwrap();
let tokens2 = model2.generate(&prompt, &gen_config).unwrap();
assert_eq!(
tokens1.len(),
tokens2.len(),
"IMP-030: Deterministic runs should produce same token count"
);
#[allow(clippy::items_after_statements)]
#[derive(Debug)]
struct BenchmarkResult {
model_name: String,
prompt_tokens: usize,
generated_tokens: usize,
total_time_ms: f64,
throughput_tok_s: f64,
}
let start = Instant::now();
let tokens = model.generate(&prompt, &gen_config).unwrap();
let elapsed = start.elapsed();
let result = BenchmarkResult {
model_name: "test-model".to_string(),
prompt_tokens: prompt.len(),
generated_tokens: tokens.len() - prompt.len(),
total_time_ms: elapsed.as_secs_f64() * 1000.0,
throughput_tok_s: (tokens.len() - prompt.len()) as f64 / elapsed.as_secs_f64(),
};
assert!(
!result.model_name.is_empty(),
"IMP-030: Model name should be set"
);
assert!(
result.prompt_tokens > 0,
"IMP-030: Prompt tokens should be tracked"
);
assert!(
result.generated_tokens > 0,
"IMP-030: Generated tokens should be tracked"
);
assert!(
result.total_time_ms > 0.0,
"IMP-030: Time should be measured"
);
assert!(
result.throughput_tok_s > 0.0,
"IMP-030: Throughput should be calculated"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_031_forward_with_cache() {
use crate::gpu::{GpuModel, GpuModelConfig, StreamingKVCache};
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 64,
num_heads: 4,
num_kv_heads: 4, num_layers: 2,
intermediate_dim: 128,
eps: 1e-5,
};
let mut model =
GpuModel::from_gguf_config(config.clone()).expect("IMP-031: Should create model");
let max_seq_len = 512;
let head_dim = config.hidden_dim / config.num_heads;
let mut kv_cache =
StreamingKVCache::new(config.num_layers, max_seq_len, config.num_heads, head_dim);
let prompt = vec![1, 2, 3, 4, 5];
let logits = model
.forward_gpu_with_cache(&prompt, &mut kv_cache)
.expect("IMP-031: forward_with_cache should succeed");
assert_eq!(
logits.len(),
config.vocab_size,
"IMP-031: Should return logits for final position only (got {}, expected {})",
logits.len(),
config.vocab_size
);
assert_eq!(
kv_cache.len(),
prompt.len(),
"IMP-031: KV cache should contain {} positions (got {})",
prompt.len(),
kv_cache.len()
);
let (keys, values) = kv_cache.get_range(0, 0, prompt.len());
let key_sum: f32 = keys.iter().map(|x| x.abs()).sum();
let value_sum: f32 = values.iter().map(|x| x.abs()).sum();
assert!(key_sum > 0.0, "IMP-031: Cached keys should be non-zero");
assert!(value_sum > 0.0, "IMP-031: Cached values should be non-zero");
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_032_forward_incremental() {
use crate::gpu::{GpuModel, GpuModelConfig, StreamingKVCache};
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 64,
num_heads: 4,
num_kv_heads: 4, num_layers: 2,
intermediate_dim: 128,
eps: 1e-5,
};
let mut model =
GpuModel::from_gguf_config(config.clone()).expect("IMP-032: Should create model");
let max_seq_len = 512;
let head_dim = config.hidden_dim / config.num_heads;
let mut kv_cache =
StreamingKVCache::new(config.num_layers, max_seq_len, config.num_heads, head_dim);
let prompt = vec![1, 2, 3, 4, 5];
let _ = model
.forward_gpu_with_cache(&prompt, &mut kv_cache)
.expect("IMP-032: Initial forward should succeed");
let cache_len_after_prompt = kv_cache.len();
let new_token = 42usize;
let logits = model
.forward_gpu_incremental(new_token, &mut kv_cache)
.expect("IMP-032: Incremental forward should succeed");
assert_eq!(
logits.len(),
config.vocab_size,
"IMP-032: Incremental should return vocab_size logits"
);
assert_eq!(
kv_cache.len(),
cache_len_after_prompt + 1,
"IMP-032: Cache should grow by 1 position"
);
for token in [10, 20, 30] {
let prev_len = kv_cache.len();
let logits = model
.forward_gpu_incremental(token, &mut kv_cache)
.expect("IMP-032: Repeated incremental should succeed");
assert_eq!(logits.len(), config.vocab_size);
assert_eq!(kv_cache.len(), prev_len + 1);
}
assert_eq!(
kv_cache.len(),
prompt.len() + 4, "IMP-032: Final cache length should match all tokens"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_033_generate_with_cache() {
use crate::gpu::{GpuGenerateConfig, GpuModel, GpuModelConfig};
use std::time::Instant;
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 64,
num_heads: 4,
num_kv_heads: 4, num_layers: 2,
intermediate_dim: 128,
eps: 1e-5,
};
let mut model = GpuModel::from_gguf_config(config).expect("IMP-033: Should create model");
let prompt = vec![1, 2, 3, 4, 5];
let gen_config = GpuGenerateConfig::deterministic(50);
for _ in 0..3 {
let _ = model.generate(&prompt, &gen_config);
}
let start = Instant::now();
let tokens = model
.generate_with_cache(&prompt, &gen_config)
.expect("IMP-033: generate_with_cache should succeed");
let cached_time = start.elapsed();
assert!(
tokens.len() > prompt.len(),
"IMP-033: Should generate new tokens"
);
let start = Instant::now();
let _ = model
.generate(&prompt, &gen_config)
.expect("IMP-033: Regular generate should succeed");
let naive_time = start.elapsed();
let speedup = naive_time.as_secs_f64() / cached_time.as_secs_f64();
assert!(
speedup > 0.5, "IMP-033: Cached generation speedup ({:.2}x) should be reasonable",
speedup
);
let tokens1 = model
.generate_with_cache(&prompt, &gen_config)
.expect("IMP-033: Should generate");
let tokens2 = model
.generate_with_cache(&prompt, &gen_config)
.expect("IMP-033: Should generate again");
assert_eq!(
tokens1, tokens2,
"IMP-033: Deterministic generation should produce same output"
);
let long_config = GpuGenerateConfig::deterministic(100);
let long_tokens = model
.generate_with_cache(&prompt, &long_config)
.expect("IMP-033: Long generation should complete");
assert!(
long_tokens.len() >= prompt.len() + 50,
"IMP-033: Long generation should produce substantial output"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_034_preallocated_attention() {
use crate::gpu::{AttentionBuffers, GpuModel, GpuModelConfig};
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 64,
num_heads: 4,
num_kv_heads: 4, num_layers: 2,
intermediate_dim: 128,
eps: 1e-5,
};
let max_seq_len = 512;
let buffers = AttentionBuffers::new(&config, max_seq_len);
assert_eq!(
buffers.q_buffer.len(),
config.hidden_dim,
"IMP-034: Q buffer should be hidden_dim"
);
assert_eq!(
buffers.scores_buffer.len(),
config.num_heads * max_seq_len,
"IMP-034: Scores buffer should be num_heads * max_seq_len"
);
assert_eq!(
buffers.output_buffer.len(),
config.hidden_dim,
"IMP-034: Output buffer should be hidden_dim"
);
let mut model = GpuModel::with_attention_buffers(config.clone(), max_seq_len)
.expect("IMP-034: Should create model with buffers");
assert!(
model.has_attention_buffers(),
"IMP-034: Model should have attention buffers"
);
let prompt = vec![1, 2, 3, 4, 5];
let gen_config = crate::gpu::GpuGenerateConfig::deterministic(10);
let tokens = model
.generate_optimized(&prompt, &gen_config)
.expect("IMP-034: Optimized generation should work");
assert!(
tokens.len() > prompt.len(),
"IMP-034: Should generate tokens with pre-allocated buffers"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_035_batched_multihead() {
use crate::gpu::{GpuModel, GpuModelConfig};
use std::time::Instant;
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 128, num_heads: 8,
num_kv_heads: 8, num_layers: 4,
intermediate_dim: 256,
eps: 1e-5,
};
let mut model = GpuModel::with_attention_buffers(config.clone(), 256)
.expect("IMP-035: Should create model");
let prompt = vec![1, 2, 3, 4, 5, 6, 7, 8];
let gen_config = crate::gpu::GpuGenerateConfig::deterministic(32);
for _ in 0..3 {
let _ = model.generate_optimized(&prompt, &gen_config);
}
let start = Instant::now();
let _ = model.generate_optimized(&prompt, &gen_config);
let optimized_time = start.elapsed();
let start = Instant::now();
let _ = model.generate_with_cache(&prompt, &gen_config);
let original_time = start.elapsed();
let speedup = original_time.as_secs_f64() / optimized_time.as_secs_f64();
eprintln!(
"IMP-035: Batched multihead speedup: {:.2}x (optimized: {:?}, original: {:?})",
speedup, optimized_time, original_time
);
}
#[test]
#[cfg(feature = "gpu")]
#[cfg_attr(coverage, ignore)] fn test_imp_036_optimized_kv_access() {
use crate::gpu::{GpuModel, GpuModelConfig, StreamingKVCache};
use std::time::Instant;
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 128,
num_heads: 8,
num_kv_heads: 8, num_layers: 4,
intermediate_dim: 256,
eps: 1e-5,
};
let mut model = GpuModel::with_attention_buffers(config.clone(), 256)
.expect("IMP-036: Should create model");
let head_dim = config.hidden_dim / config.num_heads;
let mut kv_cache =
StreamingKVCache::new(config.num_layers, 256, config.num_heads, head_dim);
let prompt = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let _ = model.forward_gpu_with_cache(&prompt, &mut kv_cache);
for token in [11, 12, 13] {
let _ = model.forward_gpu_incremental(token, &mut kv_cache);
}
let mut optimized_times = Vec::with_capacity(10);
for token in 20..30 {
let start = Instant::now();
let _ = model.forward_gpu_incremental_optimized(token, &mut kv_cache);
optimized_times.push(start.elapsed().as_secs_f64());
}
let mut original_times = Vec::with_capacity(10);
for token in 30..40 {
let start = Instant::now();
let _ = model.forward_gpu_incremental(token, &mut kv_cache);
original_times.push(start.elapsed().as_secs_f64());
}
optimized_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
original_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let optimized_median = optimized_times[optimized_times.len() / 2];
let original_median = original_times[original_times.len() / 2];
let speedup = original_median / optimized_median;
assert!(
speedup >= 0.5, "IMP-036: Optimized KV access speedup ({:.2}x) should be >= 0.5x (no major regression)",
speedup
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_037_fused_qkv() {
use crate::gpu::{GpuModel, GpuModelConfig};
use std::time::Instant;
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 128,
num_heads: 8,
num_kv_heads: 8, num_layers: 4,
intermediate_dim: 256,
eps: 1e-5,
};
let mut model = GpuModel::with_attention_buffers(config.clone(), 256)
.expect("IMP-037: Should create model");
assert!(
model.has_fused_qkv(),
"IMP-037: Model should have fused QKV projection"
);
let input = vec![0.1f32; config.hidden_dim];
let (q_fused, k_fused, v_fused) = model
.fused_qkv_projection(&input)
.expect("IMP-037: Fused QKV projection should work");
assert_eq!(q_fused.len(), config.hidden_dim, "IMP-037: Q output size");
assert_eq!(k_fused.len(), config.hidden_dim, "IMP-037: K output size");
assert_eq!(v_fused.len(), config.hidden_dim, "IMP-037: V output size");
let prompt = vec![1, 2, 3, 4, 5, 6, 7, 8];
let gen_config = crate::gpu::GpuGenerateConfig::deterministic(16);
for _ in 0..3 {
let _ = model.generate_optimized(&prompt, &gen_config);
}
let start = Instant::now();
let _ = model.generate_with_fused_qkv(&prompt, &gen_config);
let fused_time = start.elapsed();
let start = Instant::now();
let _ = model.generate_optimized(&prompt, &gen_config);
let regular_time = start.elapsed();
let speedup = regular_time.as_secs_f64() / fused_time.as_secs_f64();
eprintln!(
"IMP-037: Fused QKV speedup: {:.2}x (fused: {:?}, regular: {:?})",
speedup, fused_time, regular_time
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_038_simd_softmax() {
use crate::gpu::{scalar_softmax, simd_softmax};
use std::time::Instant;
let input = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let simd_result = simd_softmax(&input);
let scalar_result = scalar_softmax(&input);
assert_eq!(
simd_result.len(),
input.len(),
"IMP-038: Output size matches"
);
let sum: f32 = simd_result.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"IMP-038: SIMD softmax should sum to 1.0, got {}",
sum
);
for (i, (simd, scalar)) in simd_result.iter().zip(scalar_result.iter()).enumerate() {
assert!(
(simd - scalar).abs() < 1e-5,
"IMP-038: SIMD softmax[{}] ({}) should match scalar ({})",
i,
simd,
scalar
);
}
let large_input: Vec<f32> = (0..1024).map(|i| i as f32 * 0.01).collect();
for _ in 0..10 {
let _ = simd_softmax(&large_input);
let _ = scalar_softmax(&large_input);
}
let start = Instant::now();
for _ in 0..100 {
let _ = simd_softmax(&large_input);
}
let simd_time = start.elapsed();
let start = Instant::now();
for _ in 0..100 {
let _ = scalar_softmax(&large_input);
}
let scalar_time = start.elapsed();
let speedup = scalar_time.as_secs_f64() / simd_time.as_secs_f64();
let _ = speedup;
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_039_fused_attn_proj() {
use crate::gpu::{GpuModel, GpuModelConfig, StreamingKVCache};
use std::time::Instant;
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 128,
num_heads: 8,
num_kv_heads: 8, num_layers: 4,
intermediate_dim: 256,
eps: 1e-5,
};
let mut model = GpuModel::with_attention_buffers(config.clone(), 256)
.expect("IMP-039: Should create model");
let head_dim = config.hidden_dim / config.num_heads;
let mut kv_cache =
StreamingKVCache::new(config.num_layers, 256, config.num_heads, head_dim);
let prompt = vec![1, 2, 3, 4, 5, 6, 7, 8];
let _ = model.forward_gpu_with_cache(&prompt, &mut kv_cache);
assert!(
model.has_fused_attn_proj(),
"IMP-039: Model should have fused attention projection"
);
for token in 10..15 {
let _ = model.forward_gpu_incremental_optimized(token, &mut kv_cache);
}
let mut fused_times = Vec::with_capacity(10);
for token in 20..30 {
let start = Instant::now();
let _ = model.forward_with_fused_attn_proj(token, &mut kv_cache);
fused_times.push(start.elapsed().as_secs_f64());
}
let mut regular_times = Vec::with_capacity(10);
for token in 30..40 {
let start = Instant::now();
let _ = model.forward_gpu_incremental_optimized(token, &mut kv_cache);
regular_times.push(start.elapsed().as_secs_f64());
}
fused_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
regular_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let fused_median = fused_times[fused_times.len() / 2];
let regular_median = regular_times[regular_times.len() / 2];
let speedup = regular_median / fused_median;
let _ = speedup;
}
#[test]
fn test_imp_040_contiguous_attention() {
use crate::gpu::{ContiguousAttentionBuffer, GpuModelConfig};
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 128,
num_heads: 8,
num_kv_heads: 8, num_layers: 4,
intermediate_dim: 256,
eps: 1e-5,
};
let max_seq_len = 256;
let head_dim = config.hidden_dim / config.num_heads;
let mut buffer = ContiguousAttentionBuffer::new(max_seq_len, config.num_heads, head_dim);
assert!(
buffer.is_contiguous(),
"IMP-040: Buffer should be contiguous"
);
let (q_view, k_view, v_view, o_view) = buffer.get_views();
assert_eq!(
q_view.len(),
max_seq_len * config.num_heads * head_dim,
"IMP-040: Q view should have correct size"
);
assert_eq!(
k_view.len(),
max_seq_len * config.num_heads * head_dim,
"IMP-040: K view should have correct size"
);
assert_eq!(
v_view.len(),
max_seq_len * config.num_heads * head_dim,
"IMP-040: V view should have correct size"
);
assert_eq!(
o_view.len(),
max_seq_len * config.num_heads * head_dim,
"IMP-040: O view should have correct size"
);
buffer.reset();
assert!(
buffer.is_contiguous(),
"IMP-040: Buffer should remain contiguous after reset"
);
}
#[test]
fn test_imp_041_vectorized_rope() {
use crate::gpu::{scalar_rope, simd_rope};
use std::time::Instant;
let hidden_dim = 128;
let seq_len = 64;
let head_dim = hidden_dim / 8;
let input: Vec<f32> = (0..seq_len * hidden_dim)
.map(|i| (i as f32) * 0.01)
.collect();
let scalar_result = scalar_rope(&input, seq_len, head_dim, 10000.0);
let simd_result = simd_rope(&input, seq_len, head_dim, 10000.0);
assert_eq!(
scalar_result.len(),
simd_result.len(),
"IMP-041: Results should have same length"
);
for (i, (s, v)) in scalar_result.iter().zip(simd_result.iter()).enumerate() {
assert!(
(s - v).abs() < 1e-5,
"IMP-041: Results should match at index {}: scalar={}, simd={}",
i,
s,
v
);
}
for _ in 0..5 {
let _ = scalar_rope(&input, seq_len, head_dim, 10000.0);
let _ = simd_rope(&input, seq_len, head_dim, 10000.0);
}
let mut scalar_times = Vec::with_capacity(10);
for _ in 0..10 {
let start = Instant::now();
for _ in 0..100 {
let _ = scalar_rope(&input, seq_len, head_dim, 10000.0);
}
scalar_times.push(start.elapsed().as_secs_f64());
}
scalar_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mut simd_times = Vec::with_capacity(10);
for _ in 0..10 {
let start = Instant::now();
for _ in 0..100 {
let _ = simd_rope(&input, seq_len, head_dim, 10000.0);
}
simd_times.push(start.elapsed().as_secs_f64());
}
simd_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let scalar_median = scalar_times[scalar_times.len() / 2];
let simd_median = simd_times[simd_times.len() / 2];
let speedup = scalar_median / simd_median;
assert!(
speedup >= 0.5, "IMP-041: SIMD RoPE speedup ({:.2}x) should be >= 0.5x (severe slowdown indicates bug)",
speedup
);
}
#[test]
fn test_imp_042_fused_output_residual() {
use crate::gpu::{GpuModel, GpuModelConfig, StreamingKVCache};
use std::time::Instant;
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 128,
num_heads: 8,
num_kv_heads: 8, num_layers: 4,
intermediate_dim: 256,
eps: 1e-5,
};
let mut model = GpuModel::with_attention_buffers(config.clone(), 256)
.expect("IMP-042: Should create model");
let head_dim = config.hidden_dim / config.num_heads;
let mut kv_cache =
StreamingKVCache::new(config.num_layers, 256, config.num_heads, head_dim);
let prompt = vec![1, 2, 3, 4, 5, 6, 7, 8];
let _ = model.forward_gpu_with_cache(&prompt, &mut kv_cache);
assert!(
model.has_fused_output_residual(),
"IMP-042: Model should have fused output residual capability"
);
for token in 10..15 {
let _ = model.forward_gpu_incremental_optimized(token, &mut kv_cache);
}
let regular_logits = model
.forward_gpu_incremental_optimized(50, &mut kv_cache)
.expect("IMP-042: Regular forward should work");
let fused_logits = model
.forward_with_fused_output_residual(51, &mut kv_cache)
.expect("IMP-042: Fused forward should work");
assert_eq!(
regular_logits.len(),
fused_logits.len(),
"IMP-042: Output sizes should match"
);
let mut fused_times = Vec::with_capacity(10);
for token in 60..70 {
let start = Instant::now();
let _ = model.forward_with_fused_output_residual(token, &mut kv_cache);
fused_times.push(start.elapsed().as_secs_f64());
}
let mut regular_times = Vec::with_capacity(10);
for token in 70..80 {
let start = Instant::now();
let _ = model.forward_gpu_incremental_optimized(token, &mut kv_cache);
regular_times.push(start.elapsed().as_secs_f64());
}
fused_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
regular_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let fused_median = fused_times[fused_times.len() / 2];
let regular_median = regular_times[regular_times.len() / 2];
let speedup = regular_median / fused_median;
let _ = speedup;
}
#[test]
fn test_imp_043_batch_embedding() {
use crate::gpu::{batch_embed, GpuModelConfig};
use std::time::Instant;
let config = GpuModelConfig {
vocab_size: 1024,
hidden_dim: 256,
num_heads: 8,
num_kv_heads: 8, num_layers: 4,
intermediate_dim: 512,
eps: 1e-5,
};
let embedding_table: Vec<f32> = (0..config.vocab_size * config.hidden_dim)
.map(|i| (i as f32) * 0.001)
.collect();
let tokens: Vec<usize> = vec![1, 5, 10, 20, 50, 100, 200, 500];
let batch_result = batch_embed(&embedding_table, &tokens, config.hidden_dim);
assert_eq!(
batch_result.len(),
tokens.len() * config.hidden_dim,
"IMP-043: Batch embed should return tokens * hidden_dim elements"
);
for (i, &token) in tokens.iter().enumerate() {
let start_idx = token * config.hidden_dim;
let end_idx = start_idx + config.hidden_dim;
let expected = &embedding_table[start_idx..end_idx];
let batch_start = i * config.hidden_dim;
let batch_end = batch_start + config.hidden_dim;
let actual = &batch_result[batch_start..batch_end];
for (j, (&e, &a)) in expected.iter().zip(actual.iter()).enumerate() {
assert!(
(e - a).abs() < 1e-6,
"IMP-043: Mismatch at token {} dim {}: expected {}, got {}",
token,
j,
e,
a
);
}
}
for _ in 0..5 {
let _ = batch_embed(&embedding_table, &tokens, config.hidden_dim);
}
let mut batch_times = Vec::with_capacity(10);
for _ in 0..10 {
let start = Instant::now();
for _ in 0..100 {
let _ = batch_embed(&embedding_table, &tokens, config.hidden_dim);
}
batch_times.push(start.elapsed().as_secs_f64());
}
batch_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mut individual_times = Vec::with_capacity(10);
for _ in 0..10 {
let start = Instant::now();
for _ in 0..100 {
let mut result = Vec::with_capacity(tokens.len() * config.hidden_dim);
for &token in &tokens {
let start_idx = token * config.hidden_dim;
let end_idx = start_idx + config.hidden_dim;
result.extend_from_slice(&embedding_table[start_idx..end_idx]);
}
}
individual_times.push(start.elapsed().as_secs_f64());
}
individual_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let batch_median = batch_times[batch_times.len() / 2];
let individual_median = individual_times[individual_times.len() / 2];
let speedup = individual_median / batch_median;
let _ = speedup;
}
#[test]
fn test_imp_044_parallel_ffn() {
use crate::gpu::{parallel_ffn, sequential_ffn};
use std::time::Instant;
let hidden_dim = 256;
let intermediate_dim = 512;
let w_up: Vec<f32> = (0..hidden_dim * intermediate_dim)
.map(|i| ((i % 100) as f32) * 0.01 - 0.5)
.collect();
let w_down: Vec<f32> = (0..intermediate_dim * hidden_dim)
.map(|i| ((i % 100) as f32) * 0.01 - 0.5)
.collect();
let input: Vec<f32> = (0..hidden_dim).map(|i| (i as f32) * 0.01).collect();
let sequential_result =
sequential_ffn(&input, &w_up, &w_down, hidden_dim, intermediate_dim);
let parallel_result = parallel_ffn(&input, &w_up, &w_down, hidden_dim, intermediate_dim);
assert_eq!(
sequential_result.len(),
parallel_result.len(),
"IMP-044: Results should have same length"
);
for (i, (&s, &p)) in sequential_result
.iter()
.zip(parallel_result.iter())
.enumerate()
{
assert!(
(s - p).abs() < 1e-4,
"IMP-044: Mismatch at index {}: sequential={}, parallel={}",
i,
s,
p
);
}
let large_input: Vec<f32> = (0..hidden_dim).map(|i| (i as f32) * 0.01).collect();
for _ in 0..3 {
let _ = sequential_ffn(&large_input, &w_up, &w_down, hidden_dim, intermediate_dim);
let _ = parallel_ffn(&large_input, &w_up, &w_down, hidden_dim, intermediate_dim);
}
let mut seq_times = Vec::with_capacity(5);
for _ in 0..5 {
let start = Instant::now();
for _ in 0..50 {
let _ = sequential_ffn(&large_input, &w_up, &w_down, hidden_dim, intermediate_dim);
}
seq_times.push(start.elapsed().as_secs_f64());
}
seq_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mut par_times = Vec::with_capacity(5);
for _ in 0..5 {
let start = Instant::now();
for _ in 0..50 {
let _ = parallel_ffn(&large_input, &w_up, &w_down, hidden_dim, intermediate_dim);
}
par_times.push(start.elapsed().as_secs_f64());
}
par_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let seq_median = seq_times[seq_times.len() / 2];
let par_median = par_times[par_times.len() / 2];
let speedup = seq_median / par_median;
let _ = speedup; }
#[test]
fn test_imp_045_optimized_layernorm() {
use crate::gpu::{fused_layernorm, standard_layernorm};
use std::time::Instant;
let hidden_dim = 256;
let eps = 1e-5;
let input: Vec<f32> = (0..hidden_dim).map(|i| (i as f32) * 0.1 - 12.8).collect();
let gamma: Vec<f32> = vec![1.0; hidden_dim];
let beta: Vec<f32> = vec![0.0; hidden_dim];
let standard_result = standard_layernorm(&input, &gamma, &beta, eps);
let fused_result = fused_layernorm(&input, &gamma, &beta, eps);
assert_eq!(
standard_result.len(),
fused_result.len(),
"IMP-045: Results should have same length"
);
for (i, (&s, &f)) in standard_result.iter().zip(fused_result.iter()).enumerate() {
assert!(
(s - f).abs() < 1e-5,
"IMP-045: Mismatch at index {}: standard={}, fused={}",
i,
s,
f
);
}
let mean: f32 = fused_result.iter().sum::<f32>() / fused_result.len() as f32;
assert!(
mean.abs() < 0.1,
"IMP-045: Normalized output mean ({}) should be near 0",
mean
);
for _ in 0..5 {
let _ = standard_layernorm(&input, &gamma, &beta, eps);
let _ = fused_layernorm(&input, &gamma, &beta, eps);
}
let mut std_times = Vec::with_capacity(10);
for _ in 0..10 {
let start = Instant::now();
for _ in 0..100 {
let _ = standard_layernorm(&input, &gamma, &beta, eps);
}
std_times.push(start.elapsed().as_secs_f64());
}
std_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mut fused_times = Vec::with_capacity(10);
for _ in 0..10 {
let start = Instant::now();
for _ in 0..100 {
let _ = fused_layernorm(&input, &gamma, &beta, eps);
}
fused_times.push(start.elapsed().as_secs_f64());
}
fused_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let std_median = std_times[std_times.len() / 2];
let fused_median = fused_times[fused_times.len() / 2];
let speedup = std_median / fused_median;
let _ = speedup;
}
#[test]
fn test_imp_046_cache_aligned_storage() {
use crate::gpu::CacheAlignedBuffer;
let size = 1024;
let buffer = CacheAlignedBuffer::new(size);
assert!(
buffer.is_aligned(64),
"IMP-046: Buffer should be 64-byte aligned"
);
assert_eq!(
buffer.len(),
size,
"IMP-046: Buffer should have correct length"
);
let mut buffer = CacheAlignedBuffer::new(size);
buffer.as_mut_slice()[0] = 42.0;
buffer.as_mut_slice()[size - 1] = 99.0;
assert_eq!(
buffer.as_slice()[0],
42.0,
"IMP-046: Should read back written value"
);
assert_eq!(
buffer.as_slice()[size - 1],
99.0,
"IMP-046: Should read back written value at end"
);
for size in [64, 128, 256, 512, 1000, 2048] {
let buf = CacheAlignedBuffer::new(size);
assert!(
buf.is_aligned(64),
"IMP-046: Buffer of size {} should be 64-byte aligned",
size
);
}
}
#[test]
fn test_imp_047_prefetch_hints() {
use crate::gpu::{prefetch_read, sequential_sum, sum_with_prefetch};
use std::time::Instant;
let size = 64 * 1024; let data: Vec<f32> = (0..size).map(|i| (i as f32) * 0.001).collect();
prefetch_read(&data, 0, 64);
prefetch_read(&data, 1000, 64);
let seq_result = sequential_sum(&data);
let prefetch_result = sum_with_prefetch(&data, 64);
assert!(
(seq_result - prefetch_result).abs() < 1e-3,
"IMP-047: Sequential ({}) and prefetch ({}) sums should match",
seq_result,
prefetch_result
);
for _ in 0..3 {
let _ = sequential_sum(&data);
let _ = sum_with_prefetch(&data, 64);
}
let mut seq_times = Vec::with_capacity(10);
for _ in 0..10 {
let start = Instant::now();
for _ in 0..20 {
let _ = sequential_sum(&data);
}
seq_times.push(start.elapsed().as_secs_f64());
}
seq_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mut pf_times = Vec::with_capacity(10);
for _ in 0..10 {
let start = Instant::now();
for _ in 0..20 {
let _ = sum_with_prefetch(&data, 64);
}
pf_times.push(start.elapsed().as_secs_f64());
}
pf_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let seq_median = seq_times[seq_times.len() / 2];
let pf_median = pf_times[pf_times.len() / 2];
let speedup = seq_median / pf_median;
let _ = speedup;
}
#[test]
#[allow(clippy::many_single_char_names)] fn test_imp_048_blocked_matmul() {
use crate::gpu::{blocked_matmul, naive_matmul};
use std::time::Instant;
let m = 128;
let k = 256;
let n = 128;
let a: Vec<f32> = (0..m * k)
.map(|i| ((i % 100) as f32) * 0.01 - 0.5)
.collect();
let b: Vec<f32> = (0..k * n)
.map(|i| ((i % 100) as f32) * 0.01 - 0.5)
.collect();
let naive_result = naive_matmul(&a, &b, m, k, n);
let blocked_result = blocked_matmul(&a, &b, m, k, n, 32);
assert_eq!(
naive_result.len(),
blocked_result.len(),
"IMP-048: Results should have same length"
);
for (i, (&naive, &blocked)) in naive_result.iter().zip(blocked_result.iter()).enumerate() {
assert!(
(naive - blocked).abs() < 1e-4,
"IMP-048: Mismatch at index {}: naive={}, blocked={}",
i,
naive,
blocked
);
}
let m = 256;
let k = 512;
let n = 256;
let a: Vec<f32> = (0..m * k)
.map(|i| ((i % 100) as f32) * 0.01 - 0.5)
.collect();
let b: Vec<f32> = (0..k * n)
.map(|i| ((i % 100) as f32) * 0.01 - 0.5)
.collect();
for _ in 0..2 {
let _ = naive_matmul(&a, &b, m, k, n);
let _ = blocked_matmul(&a, &b, m, k, n, 32);
}
let mut naive_times = Vec::with_capacity(5);
for _ in 0..5 {
let start = Instant::now();
for _ in 0..3 {
let _ = naive_matmul(&a, &b, m, k, n);
}
naive_times.push(start.elapsed().as_secs_f64());
}
naive_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mut blocked_times = Vec::with_capacity(5);
for _ in 0..5 {
let start = Instant::now();
for _ in 0..3 {
let _ = blocked_matmul(&a, &b, m, k, n, 32);
}
blocked_times.push(start.elapsed().as_secs_f64());
}
blocked_times.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let naive_median = naive_times[naive_times.len() / 2];
let blocked_median = blocked_times[blocked_times.len() / 2];
let speedup = naive_median / blocked_median;
let _ = speedup;
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_049_tensor_pool() {
use crate::gpu::TensorPool;
let mut pool = TensorPool::new(4); assert_eq!(pool.capacity(), 4, "IMP-049: Pool should have capacity 4");
assert_eq!(pool.available(), 0, "IMP-049: Pool should start empty");
let buf1 = pool.acquire(1024);
assert_eq!(
buf1.len(),
1024,
"IMP-049: Buffer should have requested size"
);
let buf2 = pool.acquire(2048);
assert_eq!(
buf2.len(),
2048,
"IMP-049: Second buffer should have size 2048"
);
pool.release(buf1);
assert!(
pool.available() >= 1,
"IMP-049: Pool should have available buffer"
);
let buf3 = pool.acquire(1024); assert_eq!(
buf3.len(),
1024,
"IMP-049: Reused buffer should have correct size"
);
pool.release(buf2);
pool.release(buf3);
assert!(
pool.available() >= 2,
"IMP-049: Pool should have 2 available buffers"
);
pool.clear();
assert_eq!(
pool.available(),
0,
"IMP-049: Pool should be empty after clear"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_050_arena_allocator() {
use crate::gpu::ForwardArena;
let mut arena = ForwardArena::new(1024 * 1024); assert!(
arena.capacity() >= 1024 * 1024,
"IMP-050: Arena should have at least 1MB capacity"
);
assert_eq!(arena.used(), 0, "IMP-050: Arena should start empty");
{
let slice1 = arena.alloc(256);
assert_eq!(
slice1.len(),
256,
"IMP-050: First allocation should have size 256"
);
}
assert_eq!(arena.used(), 256, "IMP-050: Arena should track usage");
{
let slice2 = arena.alloc(512);
assert_eq!(
slice2.len(),
512,
"IMP-050: Second allocation should have size 512"
);
}
assert!(
arena.used() >= 768,
"IMP-050: Arena should track cumulative usage"
);
arena.reset();
assert_eq!(
arena.used(),
0,
"IMP-050: Arena should be empty after reset"
);
let slice3 = arena.alloc(1024);
assert_eq!(
slice3.len(),
1024,
"IMP-050: Post-reset allocation should work"
);
assert!(
slice3.iter().all(|&x| x == 0.0),
"IMP-050: Fresh allocation should be zeroed"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_051_scratch_buffers() {
use crate::gpu::ScratchBuffer;
let num_layers = 4;
let layer_size = 2048;
let mut scratch = ScratchBuffer::new(num_layers, layer_size);
assert_eq!(
scratch.num_layers(),
num_layers,
"IMP-051: Should have 4 layers"
);
assert_eq!(
scratch.layer_size(),
layer_size,
"IMP-051: Layer size should be 2048"
);
let layer0 = scratch.get_layer(0);
assert_eq!(
layer0.len(),
layer_size,
"IMP-051: Layer 0 scratch should have correct size"
);
let layer3 = scratch.get_layer(3);
assert_eq!(
layer3.len(),
layer_size,
"IMP-051: Layer 3 scratch should have correct size"
);
scratch.get_layer_mut(0).iter_mut().for_each(|x| *x = 1.0);
scratch.get_layer_mut(1).iter_mut().for_each(|x| *x = 2.0);
assert!(
scratch.get_layer(0).iter().all(|&x| x == 1.0),
"IMP-051: Layer 0 should retain its values"
);
assert!(
scratch.get_layer(1).iter().all(|&x| x == 2.0),
"IMP-051: Layer 1 should be independent"
);
scratch.reset();
assert!(
scratch.get_layer(0).iter().all(|&x| x == 0.0),
"IMP-051: Layer 0 should be zeroed after reset"
);
assert_eq!(
scratch.total_size(),
num_layers * layer_size,
"IMP-051: Total size should be layers * layer_size"
);
}
#[test]
#[cfg(feature = "gpu")]
#[allow(clippy::similar_names)] fn test_imp_052_quantized_dot() {
use crate::gpu::{quantized_dot_q4, quantized_dot_q8};
let scale_a: f32 = 0.5;
let scale_b: f32 = 0.25;
let mut block_a = vec![0u8; 18];
let mut block_b = vec![0u8; 18];
let scale_a_f16 = half::f16::from_f32(scale_a);
let scale_b_f16 = half::f16::from_f32(scale_b);
block_a[0..2].copy_from_slice(&scale_a_f16.to_le_bytes());
block_b[0..2].copy_from_slice(&scale_b_f16.to_le_bytes());
for i in 2..18 {
block_a[i] = 0x99; block_b[i] = 0x99;
}
let result_q4 = quantized_dot_q4(&block_a, &block_b);
assert!(
(result_q4 - 4.0).abs() < 0.5,
"IMP-052: Q4 dot product result ({}) should be ~4.0",
result_q4
);
let mut block_a_q8 = vec![0u8; 34];
let mut block_b_q8 = vec![0u8; 34];
block_a_q8[0..2].copy_from_slice(&scale_a_f16.to_le_bytes());
block_b_q8[0..2].copy_from_slice(&scale_b_f16.to_le_bytes());
for i in 2..34 {
block_a_q8[i] = 1i8 as u8;
block_b_q8[i] = 1i8 as u8;
}
let result_q8 = quantized_dot_q8(&block_a_q8, &block_b_q8);
assert!(
(result_q8 - 4.0).abs() < 0.5,
"IMP-052: Q8 dot product result ({}) should be ~4.0",
result_q8
);
let zero_block_q4 = vec![0u8; 18];
let zero_result = quantized_dot_q4(&zero_block_q4, &zero_block_q4);
assert!(
zero_result.abs() < 1e-6,
"IMP-052: Zero blocks should give zero dot product"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_053_quantized_matvec() {
use crate::gpu::{quantized_matvec_q4, quantized_matvec_q8};
let rows = 2;
let cols = 32;
let scale: f32 = 0.1;
let scale_f16 = half::f16::from_f32(scale);
let mut weights_q4 = vec![0u8; rows * 18];
for row in 0..rows {
let offset = row * 18;
weights_q4[offset..offset + 2].copy_from_slice(&scale_f16.to_le_bytes());
for i in 2..18 {
weights_q4[offset + i] = 0x99;
}
}
let input: Vec<f32> = vec![1.0; cols];
let result_q4 = quantized_matvec_q4(&weights_q4, &input, rows, cols);
assert_eq!(
result_q4.len(),
rows,
"IMP-053: Q4 matvec should produce {} outputs",
rows
);
for (i, &val) in result_q4.iter().enumerate() {
assert!(
(val - 3.2).abs() < 0.5,
"IMP-053: Q4 matvec row {} ({}) should be ~3.2",
i,
val
);
}
let mut weights_q8 = vec![0u8; rows * 34];
for row in 0..rows {
let offset = row * 34;
weights_q8[offset..offset + 2].copy_from_slice(&scale_f16.to_le_bytes());
for i in 2..34 {
weights_q8[offset + i] = 1i8 as u8;
}
}
let result_q8 = quantized_matvec_q8(&weights_q8, &input, rows, cols);
assert_eq!(
result_q8.len(),
rows,
"IMP-053: Q8 matvec should produce {} outputs",
rows
);
for (i, &val) in result_q8.iter().enumerate() {
assert!(
(val - 3.2).abs() < 0.5,
"IMP-053: Q8 matvec row {} ({}) should be ~3.2",
i,
val
);
}
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_054_mixed_precision() {
use crate::gpu::QuantizedAccumulator;
let mut acc = QuantizedAccumulator::new();
assert_eq!(
acc.sum(),
0.0,
"IMP-054: New accumulator should have zero sum"
);
acc.add_scaled(1.0, 0.5); acc.add_scaled(2.0, 0.5); acc.add_scaled(3.0, 0.5);
assert!(
(acc.sum() - 3.0).abs() < 1e-6,
"IMP-054: Accumulator sum ({}) should be 3.0",
acc.sum()
);
acc.reset();
assert_eq!(
acc.sum(),
0.0,
"IMP-054: Reset accumulator should have zero sum"
);
let block_sum: f32 = 10.0;
let block_scale: f32 = 0.1;
acc.add_block(block_sum, block_scale);
assert!(
(acc.sum() - 1.0).abs() < 1e-6,
"IMP-054: Block contribution ({}) should be 1.0",
acc.sum()
);
acc.reset();
for _ in 0..10 {
acc.add_block(5.0, 0.2); }
assert!(
(acc.sum() - 10.0).abs() < 1e-5,
"IMP-054: 10 blocks should sum to 10.0, got {}",
acc.sum()
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_055_double_buffer() {
use crate::gpu::DoubleBuffer;
let buffer: DoubleBuffer<f32> = DoubleBuffer::new(1024);
assert_eq!(
buffer.capacity(),
1024,
"IMP-055: Double buffer should have requested capacity"
);
let front = buffer.front();
assert_eq!(
front.len(),
1024,
"IMP-055: Front buffer should have full capacity"
);
let mut buffer = DoubleBuffer::new(256);
{
let back = buffer.back_mut();
for (i, val) in back.iter_mut().enumerate() {
*val = i as f32;
}
}
buffer.swap();
let front_after_swap = buffer.front();
assert!(
(front_after_swap[0] - 0.0).abs() < 1e-6,
"IMP-055: After swap, front[0] should be 0.0"
);
assert!(
(front_after_swap[255] - 255.0).abs() < 1e-6,
"IMP-055: After swap, front[255] should be 255.0"
);
{
let back = buffer.back_mut();
for val in back.iter_mut() {
*val = 42.0;
}
}
buffer.swap();
let front_again = buffer.front();
assert!(
(front_again[0] - 42.0).abs() < 1e-6,
"IMP-055: After second swap, front should have 42.0 values"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_056_chunked_processing() {
use crate::gpu::ChunkedProcessor;
let processor = ChunkedProcessor::new(64);
assert_eq!(
processor.chunk_size(),
64,
"IMP-056: Processor should have requested chunk size"
);
assert_eq!(
processor.num_chunks(100),
2,
"IMP-056: 100 items with chunk_size=64 needs 2 chunks"
);
assert_eq!(
processor.num_chunks(64),
1,
"IMP-056: 64 items with chunk_size=64 needs 1 chunk"
);
assert_eq!(
processor.num_chunks(0),
0,
"IMP-056: 0 items needs 0 chunks"
);
let (start, end) = processor.chunk_bounds(0, 100);
assert_eq!(start, 0, "IMP-056: First chunk starts at 0");
assert_eq!(end, 64, "IMP-056: First chunk ends at chunk_size");
let (start, end) = processor.chunk_bounds(1, 100);
assert_eq!(start, 64, "IMP-056: Second chunk starts at 64");
assert_eq!(end, 100, "IMP-056: Second chunk ends at total length");
let data: Vec<f32> = (0..128).map(|x| x as f32).collect();
let sum = processor.process_chunks(&data, |chunk| chunk.iter().sum::<f32>());
assert!(
(sum - 8128.0).abs() < 1e-3,
"IMP-056: Chunked sum ({}) should equal 8128.0",
sum
);
let small_data: Vec<f32> = vec![1.0, 2.0, 3.0];
let small_sum = processor.process_chunks(&small_data, |chunk| chunk.iter().sum::<f32>());
assert!(
(small_sum - 6.0).abs() < 1e-6,
"IMP-056: Small chunked sum ({}) should equal 6.0",
small_sum
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_057_pipeline_stages() {
use crate::gpu::{GpuPipelineStage, InferencePipeline};
let embed = GpuPipelineStage::Embed;
let attention = GpuPipelineStage::Attention;
let ffn = GpuPipelineStage::FFN;
let output = GpuPipelineStage::Output;
assert!(
(embed as u8) < (attention as u8),
"IMP-057: Embed should come before Attention"
);
assert!(
(attention as u8) < (ffn as u8),
"IMP-057: Attention should come before FFN"
);
assert!(
(ffn as u8) < (output as u8),
"IMP-057: FFN should come before Output"
);
let mut pipeline = InferencePipeline::new(4); assert_eq!(
pipeline.num_stages(),
4,
"IMP-057: Pipeline should have 4 stages"
);
pipeline.record_stage_time(GpuPipelineStage::Embed, 1.0);
pipeline.record_stage_time(GpuPipelineStage::Attention, 5.0);
pipeline.record_stage_time(GpuPipelineStage::FFN, 3.0);
pipeline.record_stage_time(GpuPipelineStage::Output, 0.5);
let total = pipeline.total_latency();
assert!(
(total - 9.5).abs() < 1e-6,
"IMP-057: Total latency ({}) should be 9.5ms",
total
);
let breakdown = pipeline.stage_breakdown();
assert!(
(breakdown[&GpuPipelineStage::Attention] - 5.0).abs() < 1e-6,
"IMP-057: Attention stage should be 5.0ms"
);
pipeline.reset();
assert!(
pipeline.total_latency() < 1e-6,
"IMP-057: Reset pipeline should have zero latency"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_058_token_batch() {
use crate::gpu::TokenBatch;
let mut batch = TokenBatch::new(4);
assert_eq!(batch.capacity(), 4, "IMP-058: Batch should have capacity 4");
assert_eq!(batch.len(), 0, "IMP-058: New batch should be empty");
assert!(!batch.is_full(), "IMP-058: New batch should not be full");
assert!(
batch.push(100).is_none(),
"IMP-058: First push should not return batch"
);
assert_eq!(batch.len(), 1, "IMP-058: Batch should have 1 token");
assert!(
batch.push(101).is_none(),
"IMP-058: Second push should not return batch"
);
assert!(
batch.push(102).is_none(),
"IMP-058: Third push should not return batch"
);
assert_eq!(batch.len(), 3, "IMP-058: Batch should have 3 tokens");
let full_batch = batch.push(103);
assert!(
full_batch.is_some(),
"IMP-058: Fourth push should return full batch"
);
let tokens = full_batch.unwrap();
assert_eq!(
tokens,
vec![100, 101, 102, 103],
"IMP-058: Batch should contain all tokens"
);
assert_eq!(
batch.len(),
0,
"IMP-058: After returning, batch should be empty"
);
batch.push(200);
batch.push(201);
let partial = batch.flush();
assert_eq!(
partial,
vec![200, 201],
"IMP-058: Flush should return partial batch"
);
assert_eq!(
batch.len(),
0,
"IMP-058: After flush, batch should be empty"
);
let empty = batch.flush();
assert!(
empty.is_empty(),
"IMP-058: Flush empty batch should return empty vec"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_059_speculative_buffer() {
use crate::gpu::SpeculativeBuffer;
let mut buffer = SpeculativeBuffer::new(8);
assert_eq!(
buffer.capacity(),
8,
"IMP-059: Buffer should have capacity 8"
);
assert_eq!(buffer.len(), 0, "IMP-059: New buffer should be empty");
buffer.add_candidate(100, 0.95);
buffer.add_candidate(101, 0.85);
buffer.add_candidate(102, 0.75);
assert_eq!(buffer.len(), 3, "IMP-059: Buffer should have 3 candidates");
let actual_tokens = vec![100, 101, 102];
let (accepted, rejected_at) = buffer.verify(&actual_tokens);
assert_eq!(accepted, 3, "IMP-059: All 3 candidates should be accepted");
assert!(
rejected_at.is_none(),
"IMP-059: No rejection point when all match"
);
buffer.reject(); buffer.add_candidate(200, 0.90);
buffer.add_candidate(201, 0.80);
buffer.add_candidate(202, 0.70);
let actual_with_mismatch = vec![200, 201, 999]; let (accepted2, rejected_at2) = buffer.verify(&actual_with_mismatch);
assert_eq!(accepted2, 2, "IMP-059: Only first 2 should be accepted");
assert_eq!(rejected_at2, Some(2), "IMP-059: Rejection at index 2");
buffer.reject();
buffer.add_candidate(300, 0.95);
buffer.add_candidate(301, 0.85);
buffer.accept(1); assert_eq!(
buffer.len(),
1,
"IMP-059: After accept(1), 1 candidate remains"
);
buffer.reject(); assert_eq!(
buffer.len(),
0,
"IMP-059: After reject, buffer should be empty"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_060_batch_scheduler() {
use crate::gpu::InferenceBatchScheduler;
let mut scheduler = InferenceBatchScheduler::new();
assert_eq!(
scheduler.pending_count(),
0,
"IMP-060: New scheduler has no pending"
);
assert_eq!(
scheduler.completed_count(),
0,
"IMP-060: New scheduler has no completed"
);
let batch_id_1 = scheduler.submit(vec![100, 101, 102]);
let batch_id_2 = scheduler.submit(vec![200, 201]);
assert_eq!(
scheduler.pending_count(),
2,
"IMP-060: Should have 2 pending batches"
);
assert!(
batch_id_1 != batch_id_2,
"IMP-060: Batch IDs should be unique"
);
assert!(
scheduler.poll().is_none(),
"IMP-060: No batches completed yet"
);
scheduler.complete(batch_id_1, vec![1000, 1001, 1002]);
assert_eq!(
scheduler.completed_count(),
1,
"IMP-060: Should have 1 completed"
);
assert_eq!(
scheduler.pending_count(),
1,
"IMP-060: Should have 1 pending"
);
let completed = scheduler.poll();
assert!(completed.is_some(), "IMP-060: Should get completed batch");
let (id, results) = completed.unwrap();
assert_eq!(id, batch_id_1, "IMP-060: Should get batch_id_1");
assert_eq!(
results,
vec![1000, 1001, 1002],
"IMP-060: Should get correct results"
);
scheduler.complete(batch_id_2, vec![2000, 2001]);
let all_completed = scheduler.drain();
assert_eq!(
all_completed.len(),
1,
"IMP-060: Drain should return 1 batch"
);
assert_eq!(
scheduler.completed_count(),
0,
"IMP-060: After drain, no completed"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_061_async_request_queue() {
use crate::gpu::AsyncRequestQueue;
let mut queue: AsyncRequestQueue<String> = AsyncRequestQueue::new(3);
assert_eq!(queue.capacity(), 3, "IMP-061: Queue capacity should be 3");
assert!(queue.is_empty(), "IMP-061: New queue should be empty");
assert!(!queue.is_full(), "IMP-061: New queue should not be full");
assert_eq!(queue.len(), 0, "IMP-061: New queue length should be 0");
assert!(
queue.try_push("request1".to_string()),
"IMP-061: Should push first item"
);
assert!(
queue.try_push("request2".to_string()),
"IMP-061: Should push second item"
);
assert_eq!(queue.len(), 2, "IMP-061: Queue should have 2 items");
assert!(!queue.is_empty(), "IMP-061: Queue should not be empty");
assert!(
queue.try_push("request3".to_string()),
"IMP-061: Should push third item"
);
assert!(queue.is_full(), "IMP-061: Queue should be full");
assert!(
!queue.try_push("request4".to_string()),
"IMP-061: Should reject when full"
);
let item = queue.try_pop();
assert!(item.is_some(), "IMP-061: Should pop item");
assert_eq!(
item.unwrap(),
"request1",
"IMP-061: Should pop in FIFO order"
);
assert!(
!queue.is_full(),
"IMP-061: Queue should not be full after pop"
);
assert_eq!(
queue.try_pop(),
Some("request2".to_string()),
"IMP-061: Pop second"
);
assert_eq!(
queue.try_pop(),
Some("request3".to_string()),
"IMP-061: Pop third"
);
assert!(queue.is_empty(), "IMP-061: Queue should be empty");
assert!(
queue.try_pop().is_none(),
"IMP-061: Pop from empty returns None"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_062_event_notifier() {
use crate::gpu::InferenceEventNotifier;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let mut notifier = InferenceEventNotifier::new();
assert_eq!(
notifier.handler_count(),
0,
"IMP-062: New notifier has no handlers"
);
let counter1 = Arc::new(AtomicUsize::new(0));
let counter1_clone = counter1.clone();
notifier.register(Box::new(move |_request_id, _tokens| {
counter1_clone.fetch_add(1, Ordering::SeqCst);
}));
assert_eq!(
notifier.handler_count(),
1,
"IMP-062: Should have 1 handler"
);
let counter2 = Arc::new(AtomicUsize::new(0));
let counter2_clone = counter2.clone();
notifier.register(Box::new(move |_request_id, _tokens| {
counter2_clone.fetch_add(10, Ordering::SeqCst);
}));
assert_eq!(
notifier.handler_count(),
2,
"IMP-062: Should have 2 handlers"
);
notifier.notify(1, &[100, 101, 102]);
assert_eq!(
counter1.load(Ordering::SeqCst),
1,
"IMP-062: Handler 1 should be called"
);
assert_eq!(
counter2.load(Ordering::SeqCst),
10,
"IMP-062: Handler 2 should be called"
);
notifier.notify(2, &[200]);
assert_eq!(
counter1.load(Ordering::SeqCst),
2,
"IMP-062: Handler 1 called twice"
);
assert_eq!(
counter2.load(Ordering::SeqCst),
20,
"IMP-062: Handler 2 called twice"
);
notifier.clear();
assert_eq!(
notifier.handler_count(),
0,
"IMP-062: After clear, no handlers"
);
notifier.notify(3, &[300]); assert_eq!(
counter1.load(Ordering::SeqCst),
2,
"IMP-062: Counter unchanged after clear"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_063_timeout_manager() {
use crate::gpu::TimeoutManager;
use std::time::{Duration, Instant};
let mut manager = TimeoutManager::new();
assert_eq!(
manager.active_count(),
0,
"IMP-063: New manager has no active timeouts"
);
let now = Instant::now();
let short_deadline = now + Duration::from_millis(10);
let long_deadline = now + Duration::from_millis(1000);
manager.register(1, short_deadline);
manager.register(2, long_deadline);
assert_eq!(
manager.active_count(),
2,
"IMP-063: Should have 2 active timeouts"
);
std::thread::sleep(Duration::from_millis(20));
let expired = manager.check_expired();
assert_eq!(expired.len(), 1, "IMP-063: Should have 1 expired timeout");
assert_eq!(expired[0], 1, "IMP-063: Request 1 should be expired");
assert_eq!(
manager.active_count(),
1,
"IMP-063: Should have 1 active after check"
);
manager.remove(2);
assert_eq!(manager.active_count(), 0, "IMP-063: No active after remove");
let expired = manager.check_expired();
assert!(expired.is_empty(), "IMP-063: No expired when empty");
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_064_priority_queue() {
use crate::gpu::{PriorityRequest, PriorityRequestQueue};
let mut queue = PriorityRequestQueue::new();
assert!(queue.is_empty(), "IMP-064: New queue should be empty");
assert_eq!(queue.len(), 0, "IMP-064: New queue length should be 0");
queue.enqueue(PriorityRequest::new(1, "low_priority".to_string()));
queue.enqueue(PriorityRequest::new(3, "high_priority".to_string()));
queue.enqueue(PriorityRequest::new(2, "medium_priority".to_string()));
assert_eq!(queue.len(), 3, "IMP-064: Should have 3 requests");
let req = queue.dequeue_highest();
assert!(req.is_some(), "IMP-064: Should dequeue request");
assert_eq!(
req.unwrap().data(),
"high_priority",
"IMP-064: Highest priority first"
);
let req = queue.dequeue_highest();
assert_eq!(
req.unwrap().data(),
"medium_priority",
"IMP-064: Medium priority second"
);
let req = queue.dequeue_highest();
assert_eq!(
req.unwrap().data(),
"low_priority",
"IMP-064: Low priority last"
);
assert!(queue.is_empty(), "IMP-064: Queue should be empty");
assert!(
queue.dequeue_highest().is_none(),
"IMP-064: Dequeue empty returns None"
);
queue.enqueue(PriorityRequest::new(5, "first".to_string()));
queue.enqueue(PriorityRequest::new(5, "second".to_string()));
queue.enqueue(PriorityRequest::new(5, "third".to_string()));
assert_eq!(
queue.dequeue_highest().unwrap().data(),
"first",
"IMP-064: FIFO for same priority"
);
assert_eq!(
queue.dequeue_highest().unwrap().data(),
"second",
"IMP-064: FIFO order"
);
assert_eq!(
queue.dequeue_highest().unwrap().data(),
"third",
"IMP-064: FIFO order"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_065_rate_limiter() {
use crate::gpu::TokenRateLimiter;
use std::time::Duration;
let mut limiter = TokenRateLimiter::new(10.0, 5);
assert_eq!(
limiter.tokens_available(),
5,
"IMP-065: Should start with burst capacity"
);
assert!(limiter.try_acquire(3), "IMP-065: Should acquire 3 tokens");
assert_eq!(
limiter.tokens_available(),
2,
"IMP-065: Should have 2 remaining"
);
assert!(
!limiter.try_acquire(3),
"IMP-065: Should fail to acquire 3 when only 2 available"
);
assert_eq!(
limiter.tokens_available(),
2,
"IMP-065: Tokens unchanged on failed acquire"
);
assert!(
limiter.try_acquire(2),
"IMP-065: Should acquire remaining 2"
);
assert_eq!(
limiter.tokens_available(),
0,
"IMP-065: Should have 0 remaining"
);
std::thread::sleep(Duration::from_millis(200)); limiter.refill();
let available = limiter.tokens_available();
assert!(
available >= 1,
"IMP-065: Should have refilled at least 1 token, got {}",
available
);
assert!(available <= 5, "IMP-065: Should not exceed burst capacity");
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_066_resource_tracker() {
use crate::gpu::ResourceTracker;
let mut tracker = ResourceTracker::new(1024 * 1024 * 1024, 100);
assert_eq!(
tracker.memory_usage(),
0,
"IMP-066: Initial memory usage is 0"
);
assert_eq!(
tracker.compute_usage(),
0,
"IMP-066: Initial compute usage is 0"
);
assert!(
tracker.can_allocate(512 * 1024 * 1024, 50),
"IMP-066: Should be able to allocate 512MB, 50% compute"
);
assert!(
!tracker.can_allocate(2 * 1024 * 1024 * 1024, 50),
"IMP-066: Cannot allocate more than capacity"
);
let alloc_id = tracker.allocate(256 * 1024 * 1024, 30);
assert!(alloc_id.is_some(), "IMP-066: Allocation should succeed");
assert_eq!(
tracker.memory_usage(),
256 * 1024 * 1024,
"IMP-066: Memory usage updated"
);
assert_eq!(
tracker.compute_usage(),
30,
"IMP-066: Compute usage updated"
);
let alloc_id_2 = tracker.allocate(128 * 1024 * 1024, 20);
assert!(
alloc_id_2.is_some(),
"IMP-066: Second allocation should succeed"
);
assert_eq!(
tracker.memory_usage(),
384 * 1024 * 1024,
"IMP-066: Memory accumulated"
);
assert_eq!(tracker.compute_usage(), 50, "IMP-066: Compute accumulated");
tracker.release(alloc_id.unwrap());
assert_eq!(
tracker.memory_usage(),
128 * 1024 * 1024,
"IMP-066: Memory released"
);
assert_eq!(tracker.compute_usage(), 20, "IMP-066: Compute released");
let (mem_pct, compute_pct) = tracker.usage_percentage();
let expected_mem_pct = (128.0 * 1024.0 * 1024.0) / (1024.0 * 1024.0 * 1024.0) * 100.0;
assert!(
(mem_pct - expected_mem_pct).abs() < 0.1,
"IMP-066: Memory percentage correct"
);
assert!(
(compute_pct - 20.0).abs() < 0.1,
"IMP-066: Compute percentage correct"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_067_inference_metrics() {
use crate::gpu::InferenceMetrics;
use std::time::Duration;
let mut metrics = InferenceMetrics::new();
assert_eq!(
metrics.total_inferences(),
0,
"IMP-067: No inferences initially"
);
assert_eq!(metrics.total_tokens(), 0, "IMP-067: No tokens initially");
metrics.record_inference(Duration::from_millis(10), 5); metrics.record_inference(Duration::from_millis(20), 10); metrics.record_inference(Duration::from_millis(15), 8); assert_eq!(
metrics.total_inferences(),
3,
"IMP-067: Should have 3 inferences"
);
assert_eq!(metrics.total_tokens(), 23, "IMP-067: Should have 23 tokens");
let p50 = metrics.latency_percentile(50);
assert!(p50.is_some(), "IMP-067: Should have p50");
let p50_ms = p50.unwrap().as_millis();
assert!(
p50_ms >= 10 && p50_ms <= 20,
"IMP-067: p50 should be ~15ms, got {}ms",
p50_ms
);
let throughput = metrics.throughput();
assert!(throughput > 0.0, "IMP-067: Throughput should be positive");
metrics.reset();
assert_eq!(metrics.total_inferences(), 0, "IMP-067: Inferences reset");
assert_eq!(metrics.total_tokens(), 0, "IMP-067: Tokens reset");
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_068_health_checker() {
use crate::gpu::HealthChecker;
let mut checker = HealthChecker::new();
assert!(
checker.is_healthy(),
"IMP-068: Healthy when no checks registered"
);
checker.register_check("gpu", Box::new(|| true));
assert_eq!(checker.check_count(), 1, "IMP-068: Should have 1 check");
let results = checker.check_all();
assert_eq!(results.len(), 1, "IMP-068: Should have 1 result");
assert!(
results.get("gpu").copied().unwrap_or(false),
"IMP-068: GPU should be healthy"
);
assert!(checker.is_healthy(), "IMP-068: Overall should be healthy");
checker.register_check("memory", Box::new(|| false));
let results = checker.check_all();
assert!(
!results.get("memory").copied().unwrap_or(true),
"IMP-068: Memory should be unhealthy"
);
assert!(
!checker.is_healthy(),
"IMP-068: Overall should be unhealthy"
);
checker.clear();
assert_eq!(checker.check_count(), 0, "IMP-068: No checks after clear");
assert!(checker.is_healthy(), "IMP-068: Healthy after clear");
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_069_graceful_shutdown() {
use crate::gpu::ShutdownCoordinator;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
let mut coordinator = ShutdownCoordinator::new();
assert!(
!coordinator.is_shutting_down(),
"IMP-069: Not shutting down initially"
);
assert_eq!(
coordinator.pending_requests(),
0,
"IMP-069: No pending requests"
);
let handler_called = Arc::new(AtomicBool::new(false));
let handler_called_clone = handler_called.clone();
coordinator.register_handler(Box::new(move || {
handler_called_clone.store(true, Ordering::SeqCst);
}));
assert_eq!(
coordinator.handler_count(),
1,
"IMP-069: Should have 1 handler"
);
coordinator.request_started();
coordinator.request_started();
assert_eq!(
coordinator.pending_requests(),
2,
"IMP-069: Should have 2 pending"
);
coordinator.initiate_shutdown();
assert!(
coordinator.is_shutting_down(),
"IMP-069: Should be shutting down"
);
assert!(
handler_called.load(Ordering::SeqCst),
"IMP-069: Handler should be called"
);
coordinator.request_completed();
assert_eq!(
coordinator.pending_requests(),
1,
"IMP-069: Should have 1 pending"
);
coordinator.request_completed();
assert_eq!(
coordinator.pending_requests(),
0,
"IMP-069: Should have 0 pending"
);
assert!(
coordinator.is_complete(),
"IMP-069: Should be complete when shutdown + no pending"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_070_error_recovery_strategy() {
use crate::gpu::{ErrorClassification, ErrorRecoveryStrategy, RecoveryAction};
use std::time::Duration;
let strategy = ErrorRecoveryStrategy::new()
.with_max_retries(3)
.with_base_delay(Duration::from_millis(100))
.with_max_delay(Duration::from_secs(5))
.with_jitter(0.1);
assert_eq!(
strategy.max_retries(),
3,
"IMP-070: Max retries should be 3"
);
let transient_err = std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout");
let classification = strategy.classify_error(&transient_err);
assert_eq!(
classification,
ErrorClassification::Transient,
"IMP-070: Timeout should be transient"
);
let fatal_err = std::io::Error::new(std::io::ErrorKind::InvalidData, "bad data");
let classification = strategy.classify_error(&fatal_err);
assert_eq!(
classification,
ErrorClassification::Fatal,
"IMP-070: InvalidData should be fatal"
);
let action = strategy.determine_action(&transient_err, 0);
assert!(
matches!(action, RecoveryAction::Retry { .. }),
"IMP-070: Transient should retry"
);
let delay_0 = strategy.calculate_delay(0);
let delay_1 = strategy.calculate_delay(1);
let delay_2 = strategy.calculate_delay(2);
assert!(delay_1 > delay_0, "IMP-070: Delay should increase");
assert!(
delay_2 > delay_1,
"IMP-070: Delay should increase exponentially"
);
let action = strategy.determine_action(&transient_err, 4);
assert!(
matches!(action, RecoveryAction::Fail),
"IMP-070: Should fail after max retries"
);
let gpu_err = std::io::Error::other("GPU unavailable");
let action = strategy.determine_action_with_fallback(&gpu_err, 0);
assert!(
matches!(action, RecoveryAction::FallbackToCpu),
"IMP-070: GPU error should fallback to CPU"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_071_graceful_degradation() {
use crate::gpu::{DegradationManager, DegradationMode, SystemLoad};
let mut manager = DegradationManager::new();
assert_eq!(
manager.current_mode(),
DegradationMode::Normal,
"IMP-071: Should start in Normal mode"
);
manager.set_gpu_available(false);
assert_eq!(
manager.current_mode(),
DegradationMode::CpuFallback,
"IMP-071: GPU unavailable should trigger CPU fallback"
);
manager.set_gpu_available(true);
manager.update_memory_pressure(0.9); let batch_size = manager.recommended_batch_size(8);
assert!(
batch_size < 8,
"IMP-071: High memory pressure should reduce batch size"
);
let load = SystemLoad {
cpu_percent: 95.0,
memory_percent: 85.0,
queue_depth: 100,
};
manager.update_system_load(load);
let max_context = manager.recommended_max_context(4096);
assert!(
max_context < 4096,
"IMP-071: High load should limit context length"
);
manager.set_latency_priority(true);
assert_eq!(
manager.current_mode(),
DegradationMode::LowLatency,
"IMP-071: Latency priority should set LowLatency mode"
);
manager.set_gpu_available(true);
manager.update_memory_pressure(0.3); manager.set_latency_priority(false);
let load = SystemLoad {
cpu_percent: 20.0,
memory_percent: 30.0,
queue_depth: 5,
};
manager.update_system_load(load);
assert_eq!(
manager.current_mode(),
DegradationMode::Normal,
"IMP-071: Low load should restore Normal mode"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_072_failure_isolation() {
use crate::gpu::{FailureIsolator, RequestOutcome};
use std::sync::Arc;
let isolator = FailureIsolator::new();
assert_eq!(
isolator.active_requests(),
0,
"IMP-072: Should start with 0 active"
);
let request_id = isolator.start_request();
assert_eq!(
isolator.active_requests(),
1,
"IMP-072: Should have 1 active request"
);
isolator.complete_request(request_id, &RequestOutcome::Success);
assert_eq!(
isolator.active_requests(),
0,
"IMP-072: Should have 0 active after completion"
);
assert_eq!(
isolator.success_count(),
1,
"IMP-072: Should have 1 success"
);
let request_id = isolator.start_request();
let cleanup_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
let cleanup_flag = cleanup_called.clone();
isolator.register_cleanup(request_id, move || {
cleanup_flag.store(true, std::sync::atomic::Ordering::SeqCst);
});
isolator.complete_request(
request_id,
&RequestOutcome::Failed("test error".to_string()),
);
assert!(
cleanup_called.load(std::sync::atomic::Ordering::SeqCst),
"IMP-072: Cleanup should be called on failure"
);
assert_eq!(
isolator.failure_count(),
1,
"IMP-072: Should have 1 failure"
);
for _ in 0..5 {
let req_id = isolator.start_request();
isolator.complete_request(req_id, &RequestOutcome::Failed("error".to_string()));
}
assert!(
isolator.is_circuit_open(),
"IMP-072: Circuit should open after repeated failures"
);
let result = isolator.try_start_request();
assert!(
result.is_err(),
"IMP-072: Should reject requests when circuit open"
);
isolator.reset_circuit();
assert!(
!isolator.is_circuit_open(),
"IMP-072: Circuit should close after reset"
);
let result = isolator.try_start_request();
assert!(
result.is_ok(),
"IMP-072: Should accept requests when circuit closed"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_073_connection_pool() {
use crate::gpu::{ConnectionConfig, ConnectionPool, ConnectionState};
let config = ConnectionConfig::new()
.with_max_connections(10)
.with_min_connections(2)
.with_idle_timeout(std::time::Duration::from_secs(300));
let pool = ConnectionPool::new(config);
assert_eq!(
pool.max_connections(),
10,
"IMP-073: Max connections should be configurable"
);
assert_eq!(
pool.min_connections(),
2,
"IMP-073: Min connections should be configurable"
);
let conn = pool.acquire();
assert!(conn.is_ok(), "IMP-073: Should acquire connection from pool");
assert_eq!(
pool.active_connections(),
1,
"IMP-073: Should track active connections"
);
pool.release(conn.unwrap());
assert_eq!(
pool.active_connections(),
0,
"IMP-073: Should decrement on release"
);
let mut conns = Vec::new();
for i in 0..10 {
let c = pool.acquire();
assert!(c.is_ok(), "IMP-073: Should acquire connection {}", i);
conns.push(c.unwrap());
}
let overflow = pool.try_acquire();
assert!(
overflow.is_err(),
"IMP-073: Should reject when pool exhausted"
);
for c in conns {
pool.release(c);
}
let conn = pool.acquire().unwrap();
let state = pool.check_health(&conn);
assert!(
matches!(state, ConnectionState::Healthy),
"IMP-073: New connection should be healthy"
);
pool.release(conn);
let pool2 = ConnectionPool::new(ConnectionConfig::new().with_min_connections(3));
pool2.warm();
assert!(
pool2.idle_connections() >= 3,
"IMP-073: Should warm pool to min connections"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_074_resource_limits() {
use crate::gpu::{LimitResult, ResourceConfig, ResourceLimiter};
let config = ResourceConfig::new()
.with_max_memory_per_request(512 * 1024 * 1024) .with_max_total_memory(4 * 1024 * 1024 * 1024) .with_max_compute_time(std::time::Duration::from_secs(30))
.with_max_queue_depth(100);
let limiter = ResourceLimiter::new(config);
let result = limiter.check_memory(256 * 1024 * 1024);
assert!(
matches!(result, LimitResult::Allowed),
"IMP-074: Should allow within limits"
);
let result = limiter.check_memory(1024 * 1024 * 1024);
assert!(
matches!(result, LimitResult::Denied { .. }),
"IMP-074: Should deny over per-request limit"
);
let alloc1 = limiter.allocate(256 * 1024 * 1024);
assert!(alloc1.is_ok(), "IMP-074: Should allocate memory");
assert_eq!(
limiter.current_memory(),
256 * 1024 * 1024,
"IMP-074: Should track allocated"
);
limiter.deallocate(256 * 1024 * 1024);
assert_eq!(
limiter.current_memory(),
0,
"IMP-074: Should track deallocated"
);
for _ in 0..100 {
let _ = limiter.enqueue();
}
let overflow = limiter.try_enqueue();
assert!(
matches!(overflow, LimitResult::Backpressure),
"IMP-074: Should apply backpressure"
);
for _ in 0..100 {
limiter.dequeue();
}
let timer = limiter.start_compute();
assert!(
timer.elapsed() < std::time::Duration::from_secs(1),
"IMP-074: Timer should work"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_075_resource_monitoring() {
use crate::gpu::ResourceMonitor;
let monitor = ResourceMonitor::new();
monitor.record_memory_usage(512 * 1024 * 1024);
let metrics = monitor.current_metrics();
assert_eq!(
metrics.memory_bytes,
512 * 1024 * 1024,
"IMP-075: Should track memory"
);
monitor.record_gpu_utilization(75.5);
let metrics = monitor.current_metrics();
assert!(
(metrics.gpu_utilization - 75.5).abs() < 0.01,
"IMP-075: Should track GPU util"
);
monitor.record_queue_depth(42);
let metrics = monitor.current_metrics();
assert_eq!(metrics.queue_depth, 42, "IMP-075: Should track queue depth");
monitor.record_latency(std::time::Duration::from_millis(150));
let metrics = monitor.current_metrics();
assert_eq!(
metrics.last_latency_ms, 150,
"IMP-075: Should track latency"
);
for i in 1..=5 {
monitor.record_latency(std::time::Duration::from_millis(i * 100));
}
let stats = monitor.latency_stats();
assert_eq!(stats.min_ms, 100, "IMP-075: Should track min latency");
assert_eq!(stats.max_ms, 500, "IMP-075: Should track max latency");
assert_eq!(stats.avg_ms, 275, "IMP-075: Should track avg latency");
let snapshot = monitor.snapshot();
assert!(
snapshot.timestamp > 0,
"IMP-075: Snapshot should have timestamp"
);
assert!(
snapshot.memory_bytes > 0,
"IMP-075: Snapshot should include memory"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_076_retry_strategy() {
use crate::gpu::{ErrorCategory, RetryConfig, RetryDecision, RetryPolicy};
let config = RetryConfig::new()
.with_max_retries(5)
.with_base_delay(std::time::Duration::from_millis(100))
.with_max_delay(std::time::Duration::from_secs(30))
.with_jitter_factor(0.2);
let policy = RetryPolicy::new(config);
assert_eq!(
policy.max_retries(),
5,
"IMP-076: Max retries should be configurable"
);
let decision = policy.should_retry(1, ErrorCategory::Transient);
assert!(
matches!(decision, RetryDecision::Retry { .. }),
"IMP-076: Should retry transient error"
);
let decision = policy.should_retry(1, ErrorCategory::Permanent);
assert!(
matches!(decision, RetryDecision::Abort { .. }),
"IMP-076: Should not retry permanent error"
);
let delay1 = policy.calculate_delay(1);
let delay2 = policy.calculate_delay(2);
let delay3 = policy.calculate_delay(3);
assert!(
delay2 > delay1,
"IMP-076: Delay should increase (exp backoff)"
);
assert!(delay3 > delay2, "IMP-076: Delay should continue increasing");
let delay_capped = policy.calculate_delay(100);
assert!(
delay_capped <= std::time::Duration::from_secs(30),
"IMP-076: Should cap at max delay"
);
let decision = policy.should_retry(6, ErrorCategory::Transient);
assert!(
matches!(decision, RetryDecision::Abort { .. }),
"IMP-076: Should abort after max retries"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_077_circuit_breaker() {
use crate::gpu::{CircuitBreaker, CircuitConfig, CircuitState};
let config = CircuitConfig::new()
.with_failure_threshold(3)
.with_success_threshold(2)
.with_timeout(std::time::Duration::from_millis(100));
let breaker = CircuitBreaker::new(config);
assert!(
matches!(breaker.state(), CircuitState::Closed),
"IMP-077: Should start closed"
);
breaker.record_failure();
breaker.record_failure();
assert!(
matches!(breaker.state(), CircuitState::Closed),
"IMP-077: Should stay closed below threshold"
);
breaker.record_failure();
assert!(
matches!(breaker.state(), CircuitState::Open),
"IMP-077: Should open at threshold"
);
assert!(!breaker.allow_request(), "IMP-077: Should reject when open");
std::thread::sleep(std::time::Duration::from_millis(150));
assert!(
breaker.allow_request(),
"IMP-077: Should allow probe after timeout"
);
assert!(
matches!(breaker.state(), CircuitState::HalfOpen),
"IMP-077: Should be half-open"
);
breaker.record_success();
breaker.record_success();
assert!(
matches!(breaker.state(), CircuitState::Closed),
"IMP-077: Should close after successes"
);
for _ in 0..3 {
breaker.record_failure();
}
std::thread::sleep(std::time::Duration::from_millis(150));
let _ = breaker.allow_request(); breaker.record_failure();
assert!(
matches!(breaker.state(), CircuitState::Open),
"IMP-077: Should re-open on half-open failure"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_078_bulkhead_pattern() {
use crate::gpu::{BulkheadConfig, BulkheadManager, RequestType};
let config = BulkheadConfig::new()
.with_pool("inference", 10)
.with_pool("embedding", 5)
.with_pool("batch", 2);
let manager = BulkheadManager::new(&config);
let permit = manager.acquire(RequestType::Inference);
assert!(
permit.is_ok(),
"IMP-078: Should acquire from inference pool"
);
assert_eq!(
manager.available(RequestType::Inference),
9,
"IMP-078: Should decrement available"
);
let embed_permit = manager.acquire(RequestType::Embedding);
assert!(
embed_permit.is_ok(),
"IMP-078: Should acquire from embedding pool"
);
assert_eq!(
manager.available(RequestType::Inference),
9,
"IMP-078: Inference should be unchanged"
);
assert_eq!(
manager.available(RequestType::Embedding),
4,
"IMP-078: Embedding should decrement"
);
for _ in 0..2 {
let _ = manager.acquire(RequestType::Batch);
}
let batch_overflow = manager.try_acquire(RequestType::Batch);
assert!(
batch_overflow.is_err(),
"IMP-078: Batch pool should be exhausted"
);
assert_eq!(
manager.available(RequestType::Inference),
9,
"IMP-078: Inference still available"
);
manager.release(&permit.unwrap());
assert_eq!(
manager.available(RequestType::Inference),
10,
"IMP-078: Should release to correct pool"
);
let stats = manager.stats();
assert_eq!(stats.pool_count, 3, "IMP-078: Should have 3 pools");
assert!(
stats.total_capacity >= 17,
"IMP-078: Total capacity should sum pools"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_079_structured_logging() {
use crate::gpu::{LogConfig, LogEntry, LogLevel, Logger};
let config = LogConfig::new()
.with_level(LogLevel::Debug)
.with_json_format(true)
.with_module_level("gpu", LogLevel::Trace);
let logger = Logger::new(config);
let entry = LogEntry::new(LogLevel::Info, "Request started")
.with_correlation_id("req-12345")
.with_field("model", "llama-7b")
.with_field("tokens", "128");
assert_eq!(
entry.correlation_id(),
Some("req-12345"),
"IMP-079: Should have correlation ID"
);
assert_eq!(entry.level(), LogLevel::Info, "IMP-079: Should have level");
let json = entry.to_json();
assert!(
json.contains("\"level\":\"INFO\""),
"IMP-079: JSON should have level"
);
assert!(
json.contains("\"correlation_id\":\"req-12345\""),
"IMP-079: JSON should have correlation ID"
);
assert!(
json.contains("\"model\":\"llama-7b\""),
"IMP-079: JSON should have custom fields"
);
assert!(
logger.is_enabled(LogLevel::Trace, "gpu"),
"IMP-079: gpu should allow Trace"
);
assert!(
logger.is_enabled(LogLevel::Debug, "inference"),
"IMP-079: Other modules use default"
);
assert!(
!logger.is_enabled(LogLevel::Trace, "inference"),
"IMP-079: Trace should be filtered for non-gpu"
);
let entry = LogEntry::new(LogLevel::Warn, "High memory usage");
assert!(entry.timestamp() > 0, "IMP-079: Should have timestamp");
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_080_performance_diagnostics() {
use crate::gpu::{DiagnosticsCollector, MemoryTracker, PhaseTimer};
let collector = DiagnosticsCollector::new();
let timer = PhaseTimer::new();
timer.start_phase("tokenization");
std::thread::sleep(std::time::Duration::from_millis(10));
timer.end_phase("tokenization");
timer.start_phase("inference");
std::thread::sleep(std::time::Duration::from_millis(20));
timer.end_phase("inference");
let breakdown = timer.breakdown();
assert!(
breakdown.contains_key("tokenization"),
"IMP-080: Should track tokenization"
);
assert!(
breakdown.contains_key("inference"),
"IMP-080: Should track inference"
);
assert!(
*breakdown.get("inference").unwrap() > *breakdown.get("tokenization").unwrap(),
"IMP-080: Inference should take longer"
);
let tracker = MemoryTracker::new();
tracker.record_allocation("model_weights", 1024 * 1024 * 1024);
tracker.record_allocation("kv_cache", 256 * 1024 * 1024);
tracker.record_deallocation("kv_cache", 256 * 1024 * 1024);
let report = tracker.report();
assert_eq!(
report.peak_bytes,
1024 * 1024 * 1024 + 256 * 1024 * 1024,
"IMP-080: Should track peak"
);
assert_eq!(
report.current_bytes,
1024 * 1024 * 1024,
"IMP-080: Should track current"
);
assert_eq!(
report.allocation_count, 2,
"IMP-080: Should count allocations"
);
collector.record_request_timing("req-001", timer.breakdown());
collector.record_memory_snapshot(report);
let summary = collector.summary();
assert!(summary.request_count >= 1, "IMP-080: Should count requests");
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_081_debug_mode() {
use crate::gpu::{DebugMode, RequestCapture, StateDump};
let debug = DebugMode::new();
assert!(
!debug.is_enabled(),
"IMP-081: Should be disabled by default"
);
debug.enable();
assert!(debug.is_enabled(), "IMP-081: Should enable");
let capture = RequestCapture::new()
.with_input("Hello, world!")
.with_params("temperature", "0.7")
.with_params("max_tokens", "100");
assert_eq!(
capture.input(),
"Hello, world!",
"IMP-081: Should capture input"
);
assert_eq!(capture.params().len(), 2, "IMP-081: Should capture params");
let json = capture.to_json();
let restored = RequestCapture::from_json(&json);
assert!(restored.is_ok(), "IMP-081: Should deserialize");
assert_eq!(
restored.unwrap().input(),
"Hello, world!",
"IMP-081: Should restore input"
);
let dump = StateDump::new()
.with_error("Out of memory")
.with_stack_trace("at inference::generate\nat main")
.with_state("model_loaded", "true")
.with_state("tokens_processed", "42");
assert_eq!(
dump.error(),
"Out of memory",
"IMP-081: Should capture error"
);
assert!(
dump.stack_trace().contains("inference::generate"),
"IMP-081: Should have stack"
);
assert_eq!(dump.state().len(), 2, "IMP-081: Should capture state");
let dump_json = dump.to_json();
assert!(
dump_json.contains("Out of memory"),
"IMP-081: JSON should have error"
);
assert!(
dump_json.contains("tokens_processed"),
"IMP-081: JSON should have state"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_082_gguf_model_state() {
use crate::gpu::GgufModelState;
let state = GgufModelState::new();
assert!(!state.is_loaded(), "IMP-082: Should be unloaded initially");
assert_eq!(
state.model_name(),
None,
"IMP-082: No model name when empty"
);
assert_eq!(state.vocab_size(), 0, "IMP-082: Zero vocab when empty");
assert!(!state.is_ready(), "IMP-082: Not ready when empty");
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_083_load_gguf_to_gpu() {
use crate::gpu::load_gguf_to_gpu;
let vocab_size = 256;
let hidden_dim = 64;
let num_layers = 2;
let result = load_gguf_to_gpu(vocab_size, hidden_dim, num_layers);
assert!(result.is_ok(), "IMP-083: Should load test model to GPU");
let state = result.unwrap();
assert!(state.is_loaded(), "IMP-083: Should be loaded after load");
assert!(state.is_ready(), "IMP-083: Should be ready for inference");
assert_eq!(
state.vocab_size(),
vocab_size,
"IMP-083: Should have correct vocab"
);
}
#[test]
#[ignore = "Requires integration test setup"]
fn test_imp_084_serve_gguf_model() {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.expect("Failed to create HTTP client");
let health_url = "http://127.0.0.1:3000/health";
match client.get(health_url).send() {
Ok(response) => {
assert!(
response.status().is_success(),
"IMP-084: Health endpoint should return 200 OK"
);
println!("IMP-084: ✅ Server health check passed");
let gen_url = "http://127.0.0.1:3000/generate";
let request = serde_json::json!({
"prompt": "Hello",
"max_tokens": 5,
"temperature": 0.0
});
match client.post(gen_url).json(&request).send() {
Ok(gen_response) => {
assert!(
gen_response.status().is_success(),
"IMP-084: Generate endpoint should return 200 OK"
);
let body: serde_json::Value = gen_response.json().expect("Valid JSON");
assert!(
body.get("text").is_some(),
"IMP-084: Response should have text"
);
println!("IMP-084: ✅ Generate endpoint works, got: {:?}", body);
},
Err(e) => {
println!("IMP-084: ⚠️ Generate endpoint not available: {}", e);
},
}
},
Err(e) => {
panic!(
"IMP-084: Server not running at {}. Start with: cargo run --example api_server. Error: {}",
health_url, e
);
},
}
}
#[test]
#[ignore = "Requires running server"]
fn test_imp_085_completions_endpoint() {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.expect("Failed to create HTTP client");
let url = "http://127.0.0.1:3000/v1/completions";
let request = serde_json::json!({
"model": "demo",
"prompt": "The capital of France is",
"max_tokens": 10,
"temperature": 0.0
});
match client.post(url).json(&request).send() {
Ok(response) => {
if response.status().is_success() {
let body: serde_json::Value = response.json().expect("Valid JSON");
assert!(
body.get("choices").is_some(),
"IMP-085: Response should have 'choices'"
);
println!("IMP-085: ✅ OpenAI completions endpoint works");
} else if response.status().as_u16() == 404 {
println!("IMP-085: ⚠️ /v1/completions not implemented yet (404)");
} else {
panic!("IMP-085: Unexpected status: {}", response.status());
}
},
Err(e) => {
panic!(
"IMP-085: Server not running. Start with: cargo run --example api_server. Error: {}",
e
);
},
}
}
#[test]
#[ignore = "Requires running server"]
fn test_imp_086_llamacpp_endpoint() {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.expect("Failed to create HTTP client");
let url = "http://127.0.0.1:3000/completion";
let request = serde_json::json!({
"prompt": "Hello, world!",
"n_predict": 10,
"temperature": 0.0
});
match client.post(url).json(&request).send() {
Ok(response) => {
if response.status().is_success() {
let body: serde_json::Value = response.json().expect("Valid JSON");
assert!(
body.get("content").is_some() || body.get("text").is_some(),
"IMP-086: Response should have 'content' or 'text'"
);
println!("IMP-086: ✅ llama.cpp completion endpoint works");
} else if response.status().as_u16() == 404 {
println!("IMP-086: ⚠️ /completion not implemented yet (404)");
} else {
panic!("IMP-086: Unexpected status: {}", response.status());
}
},
Err(e) => {
panic!(
"IMP-086: Server not running. Start with: cargo run --example api_server. Error: {}",
e
);
},
}
}
#[test]
#[ignore = "Requires benchmark infrastructure"]
fn test_imp_087_benchmark_integration() {
use std::time::Instant;
let script_path = std::path::Path::new("scripts/bench-server-matrix.sh");
if script_path.exists() {
println!("IMP-087: ✅ Benchmark script exists at scripts/bench-server-matrix.sh");
} else {
println!("IMP-087: ⚠️ Benchmark script not found (optional)");
}
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("Failed to create HTTP client");
let url = "http://127.0.0.1:3000/generate";
let request = serde_json::json!({
"prompt": "Benchmark test",
"max_tokens": 10,
"temperature": 0.0
});
let iterations = 5;
let start = Instant::now();
let mut success_count = 0;
let mut total_tokens = 0;
for i in 0..iterations {
match client.post(url).json(&request).send() {
Ok(response) if response.status().is_success() => {
if let Ok(body) = response.json::<serde_json::Value>() {
if let Some(text) = body.get("text").and_then(|t| t.as_str()) {
total_tokens += text.split_whitespace().count();
success_count += 1;
}
}
},
Ok(response) => {
println!(
"IMP-087: Iteration {} failed with status {}",
i,
response.status()
);
},
Err(e) => {
assert!(
i != 0,
"IMP-087: Server not running. Start with: cargo run --example api_server. Error: {}",
e
);
},
}
}
let elapsed = start.elapsed();
let throughput = if elapsed.as_secs_f64() > 0.0 {
total_tokens as f64 / elapsed.as_secs_f64()
} else {
0.0
};
println!(
"IMP-087: ✅ Benchmark test: {} iterations, {} tokens, {:.2} tok/s",
success_count, total_tokens, throughput
);
assert!(
success_count > 0,
"IMP-087: At least one benchmark iteration should succeed"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_088_gqa_config_num_kv_heads() {
use crate::gpu::GpuModelConfig;
let config = GpuModelConfig {
vocab_size: 151936,
hidden_dim: 1536,
num_heads: 12,
num_kv_heads: 2, num_layers: 28,
intermediate_dim: 8960,
eps: 1e-6,
};
assert_eq!(config.num_heads, 12, "IMP-088: Should have 12 Q heads");
assert_eq!(config.num_kv_heads, 2, "IMP-088: Should have 2 KV heads");
let head_dim = config.hidden_dim / config.num_heads;
assert_eq!(head_dim, 128, "IMP-088: Head dim should be 128");
let kv_head_dim = config.hidden_dim / config.num_heads; let kv_size = config.num_kv_heads * kv_head_dim;
assert_eq!(kv_size, 256, "IMP-088: KV size should be 2*128=256");
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_089_gqa_attention_forward() {
use crate::gpu::{GpuModel, GpuModelConfig};
let config = GpuModelConfig {
vocab_size: 256,
hidden_dim: 128,
num_heads: 4, num_kv_heads: 2, num_layers: 2,
intermediate_dim: 256,
eps: 1e-5,
};
let mut model = GpuModel::new(config).expect("Failed to create GQA model");
let tokens = vec![1usize, 2, 3];
let result = model.forward_gpu(&tokens);
assert!(
result.is_ok(),
"IMP-089: Forward pass should handle GQA attention. Error: {:?}",
result.err()
);
let logits = result.unwrap();
assert_eq!(
logits.len(),
tokens.len() * 256,
"IMP-089: Logits should be seq_len * vocab_size"
);
}
#[test]
#[cfg(feature = "gpu")]
fn test_imp_090_cpu_embedding_large_vocab() {
use crate::gpu::{GpuModel, GpuModelConfig};
let large_vocab_config = GpuModelConfig {
vocab_size: 100_000, hidden_dim: 256, num_heads: 4,
num_kv_heads: 4,
num_layers: 2,
intermediate_dim: 512,
eps: 1e-5,
};
let model_result = GpuModel::new(large_vocab_config);
assert!(
model_result.is_ok(),
"IMP-090: Should create model with large vocab using CPU embedding. Error: {:?}",
model_result.err()
);
let mut model = model_result.unwrap();
let tokens = vec![0usize, 1000, 50000, 99999]; let result = model.forward_gpu(&tokens);
assert!(
result.is_ok(),
"IMP-090: Forward pass should work with CPU embedding for large vocab. Error: {:?}",
result.err()
);
let logits = result.unwrap();
assert_eq!(
logits.len(),
tokens.len() * 100_000,
"IMP-090: Logits should be seq_len * vocab_size"
);
let has_valid_values = logits.iter().any(|&v| v != 0.0 && !v.is_nan());
assert!(
has_valid_values,
"IMP-090: Logits should contain valid non-zero values"
);
}
#[test]
#[cfg(feature = "gpu")]
#[ignore] fn test_imp_093_real_gguf_gpu_benchmark() {
use crate::gguf::MappedGGUFModel;
use crate::gpu::GpuModel;
use std::path::Path;
use std::time::Instant;
let model_path =
"/home/noah/src/single-shot-eval/models/raw/qwen2.5-coder-1.5b-instruct-q4_k_m.gguf";
if !Path::new(model_path).exists() {
eprintln!("IMP-093: Skipping - model not found at {}", model_path);
return;
}
println!("\n=== IMP-093: Real GGUF GPU Benchmark ===\n");
println!("Model: {}", model_path);
let load_start = Instant::now();
let mapped = MappedGGUFModel::from_path(model_path).expect("Failed to load GGUF");
let load_mmap = load_start.elapsed();
println!(" Mmap load: {:?}", load_mmap);
let gpu_start = Instant::now();
let mut gpu_model = GpuModel::from_mapped_gguf(&mapped).expect("Failed to load to GPU");
let gpu_load = gpu_start.elapsed();
println!(" GPU load: {:?}", gpu_load);
println!(
" Config: hidden={}, layers={}, vocab={}, heads={}, kv_heads={}, intermediate={}",
gpu_model.config().hidden_dim,
gpu_model.config().num_layers,
gpu_model.config().vocab_size,
gpu_model.config().num_heads,
gpu_model.config().num_kv_heads,
gpu_model.config().intermediate_dim,
);
println!();
let test_tokens = vec![0usize, 1, 2, 3];
let max_tokens = 5;
println!("Warmup...");
let _ = gpu_model.generate(
&test_tokens,
&crate::gpu::GpuGenerateConfig {
max_tokens: 1,
..Default::default()
},
);
println!("\nGenerating {} tokens...", max_tokens);
let gen_start = Instant::now();
let result = gpu_model.generate(
&test_tokens,
&crate::gpu::GpuGenerateConfig {
max_tokens,
..Default::default()
},
);
let gen_elapsed = gen_start.elapsed();
assert!(
result.is_ok(),
"IMP-093: Generation should succeed: {:?}",
result.err()
);
let generated = result.unwrap();
let gen_secs = gen_elapsed.as_secs_f64();
let tps = max_tokens as f64 / gen_secs;
println!("\n=== Results ===");
println!(
" Generated: {} tokens",
generated.len() - test_tokens.len()
);
println!(" Time: {:.3}s", gen_secs);
println!(" Throughput: {:.2} tok/s", tps);
println!();
let target_tps = 10.0;
if tps < target_tps {
eprintln!(
"WARNING: Below target {} tok/s (actual: {:.2} tok/s)",
target_tps, tps
);
eprintln!("Parity gap with Ollama (~143 tok/s): {:.0}x", 143.0 / tps);
} else {
println!(
"PASS: Achieved {:.2} tok/s (target: {} tok/s)",
tps, target_tps
);
}
}
#[test]
#[ignore] fn test_imp_099_q4k_vs_f32_benchmark() {
use crate::quantize::{fused_q4k_parallel_matvec, QK_K};
use std::time::Instant;
println!("\n=== IMP-099: Q4_K vs f32 Matmul Benchmark ===\n");
let in_dim: usize = 1536; let out_dim: usize = 8960;
let iterations = 100;
let activations: Vec<f32> = (0..in_dim).map(|i| (i as f32 * 0.001).sin()).collect();
let super_blocks_per_row = in_dim.div_ceil(QK_K);
let bytes_per_row = super_blocks_per_row * 144;
let q4k_weight_size = out_dim * bytes_per_row;
let q4k_weights: Vec<u8> = (0..q4k_weight_size).map(|i| (i % 256) as u8).collect();
let f32_weight_size = in_dim * out_dim;
let f32_weights: Vec<f32> = (0..f32_weight_size)
.map(|i| (i as f32 * 0.0001).cos())
.collect();
println!("Dimensions: {} x {}", in_dim, out_dim);
println!("Q4_K weight size: {:.2} MB", q4k_weight_size as f64 / 1e6);
println!(
"f32 weight size: {:.2} MB",
(f32_weight_size * 4) as f64 / 1e6
);
println!(
"Compression ratio: {:.1}x\n",
(f32_weight_size * 4) as f64 / q4k_weight_size as f64
);
let _ = fused_q4k_parallel_matvec(&q4k_weights, &activations, in_dim, out_dim);
let _ = crate::gpu::cpu_matmul(&activations, &f32_weights, 1, in_dim, out_dim);
let q4k_start = Instant::now();
for _ in 0..iterations {
let _ = fused_q4k_parallel_matvec(&q4k_weights, &activations, in_dim, out_dim);
}
let q4k_elapsed = q4k_start.elapsed();
let q4k_per_op = q4k_elapsed.as_secs_f64() / iterations as f64;
let f32_start = Instant::now();
for _ in 0..iterations {
let _ = crate::gpu::cpu_matmul(&activations, &f32_weights, 1, in_dim, out_dim);
}
let f32_elapsed = f32_start.elapsed();
let f32_per_op = f32_elapsed.as_secs_f64() / iterations as f64;
let q4k_gops = (in_dim * out_dim) as f64 / q4k_per_op / 1e9;
let f32_gops = (in_dim * out_dim) as f64 / f32_per_op / 1e9;
let q4k_bw = q4k_weight_size as f64 / q4k_per_op / 1e9;
let f32_bw = (f32_weight_size * 4) as f64 / f32_per_op / 1e9;
println!("=== Results ({} iterations) ===", iterations);
println!("Q4_K fused:");
println!(" Time: {:.3} ms/op", q4k_per_op * 1000.0);
println!(" Throughput: {:.2} GOPS", q4k_gops);
println!(" Bandwidth: {:.2} GB/s", q4k_bw);
println!();
println!("f32 matvec:");
println!(" Time: {:.3} ms/op", f32_per_op * 1000.0);
println!(" Throughput: {:.2} GOPS", f32_gops);
println!(" Bandwidth: {:.2} GB/s", f32_bw);
println!();
println!("Speedup (Q4_K vs f32): {:.2}x", f32_per_op / q4k_per_op);
println!("Effective bandwidth amplification: {:.2}x", f32_bw / q4k_bw);
}
}