use super::format_trait::QuantBlockFormat;
use crate::error::{RealizarError, Result};
pub fn generic_fused_dot_scalar<F: QuantBlockFormat>(
weight_data: &[u8],
activations: &[f32],
) -> Result<f32> {
let num_superblocks = F::validate_data_length(weight_data)?;
let expected_values = num_superblocks * F::ELEMENTS_PER_SUPERBLOCK;
if activations.len() < expected_values {
return Err(RealizarError::InvalidShape {
reason: format!(
"{}: activation length {} is less than expected {} values ({} super-blocks × {})",
F::FORMAT_ID,
activations.len(),
expected_values,
num_superblocks,
F::ELEMENTS_PER_SUPERBLOCK,
),
});
}
let mut acc = 0.0f32;
for sb_idx in 0..num_superblocks {
let sb_start = sb_idx * F::SUPERBLOCK_BYTES;
let sb_end = sb_start + F::SUPERBLOCK_BYTES;
let superblock = &weight_data[sb_start..sb_end];
let act_start = sb_idx * F::ELEMENTS_PER_SUPERBLOCK;
for i in 0..F::ELEMENTS_PER_SUPERBLOCK {
let w = F::dequant_value(superblock, i);
acc += w * activations[act_start + i];
}
}
Ok(acc)
}
#[must_use]
pub fn compute_bsums(activations: &[f32], elements_per_subblock: usize) -> Vec<f32> {
activations
.chunks(elements_per_subblock)
.map(|chunk| chunk.iter().sum())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::quantize::format_trait::{Q4_0Fmt, Q8_0Fmt, Q4K, Q5K, Q6K};
#[test]
fn test_generic_dot_q4k_zero_data() {
let data = vec![0u8; 144]; let acts = vec![1.0f32; 256];
let result = generic_fused_dot_scalar::<Q4K>(&data, &acts);
assert!(result.is_ok());
assert_eq!(result.expect("should succeed"), 0.0);
}
#[test]
fn test_generic_dot_q6k_zero_data() {
let data = vec![0u8; 210];
let acts = vec![1.0f32; 256];
let result = generic_fused_dot_scalar::<Q6K>(&data, &acts);
assert!(result.is_ok());
assert_eq!(result.expect("should succeed"), 0.0);
}
#[test]
fn test_generic_dot_q4_0_zero_data() {
let data = vec![0u8; 18];
let acts = vec![1.0f32; 32];
let result = generic_fused_dot_scalar::<Q4_0Fmt>(&data, &acts);
assert!(result.is_ok());
assert_eq!(result.expect("should succeed"), 0.0);
}
#[test]
fn test_generic_dot_q8_0_zero_data() {
let data = vec![0u8; 34];
let acts = vec![1.0f32; 32];
let result = generic_fused_dot_scalar::<Q8_0Fmt>(&data, &acts);
assert!(result.is_ok());
assert_eq!(result.expect("should succeed"), 0.0);
}
#[test]
fn test_generic_dot_invalid_data_length() {
let data = vec![0u8; 100]; let acts = vec![0.0f32; 256];
assert!(generic_fused_dot_scalar::<Q4K>(&data, &acts).is_err());
}
#[test]
fn test_generic_dot_empty_data() {
let data = vec![];
let acts = vec![0.0f32; 256];
assert!(generic_fused_dot_scalar::<Q4K>(&data, &acts).is_err());
}
#[test]
fn test_generic_dot_activations_too_short() {
let data = vec![0u8; 144];
let acts = vec![0.0f32; 100]; assert!(generic_fused_dot_scalar::<Q4K>(&data, &acts).is_err());
}
#[test]
fn test_generic_dot_multiple_superblocks() {
let data = vec![0u8; 144 * 3]; let acts = vec![0.0f32; 256 * 3];
let result = generic_fused_dot_scalar::<Q4K>(&data, &acts);
assert!(result.is_ok());
}
#[test]
fn test_compute_bsums_q4k_subblock_size() {
let acts: Vec<f32> = (0..256).map(|i| i as f32).collect();
let bsums = compute_bsums(&acts, 32);
assert_eq!(bsums.len(), 8);
assert!((bsums[0] - 496.0).abs() < 0.001);
}
#[test]
fn test_compute_bsums_uniform() {
let acts = vec![1.0f32; 256];
let bsums = compute_bsums(&acts, 32);
assert_eq!(bsums.len(), 8);
for &b in &bsums {
assert!((b - 32.0).abs() < 0.001);
}
}
#[test]
fn test_generic_dot_q8_0_known_values() {
let mut data = [0u8; 34];
data[0] = 0x00;
data[1] = 0x3C;
for i in 0..32 {
data[2 + i] = 1;
}
let acts = vec![1.0f32; 32];
let result = generic_fused_dot_scalar::<Q8_0Fmt>(&data, &acts);
assert!(result.is_ok());
let val = result.expect("should succeed");
assert!((val - 32.0).abs() < 0.1, "Expected ~32.0, got {val}");
}
#[test]
fn test_generic_dot_q5k_zero_data() {
let data = vec![0u8; 176];
let acts = vec![1.0f32; 256];
let result = generic_fused_dot_scalar::<Q5K>(&data, &acts);
assert!(result.is_ok());
assert_eq!(result.expect("should succeed"), 0.0);
}
}