use super::format_trait::QuantBlockFormat;
use crate::error::{RealizarError, Result};
use std::borrow::Cow;
const PARALLEL_THRESHOLD: usize = 256;
const MIDI_TILE_M: usize = 64;
#[inline]
fn pad_activations_generic(activations: &[f32], padded_len: usize) -> Cow<'_, [f32]> {
if activations.len() == padded_len {
Cow::Borrowed(activations)
} else {
let mut padded = vec![0.0f32; padded_len];
padded[..activations.len()].copy_from_slice(activations);
Cow::Owned(padded)
}
}
pub type FusedDotFn = fn(&[u8], &[f32]) -> Result<f32>;
#[allow(clippy::similar_names)]
pub fn generic_parallel_matvec_into<F: QuantBlockFormat>(
weight_data: &[u8],
activations: &[f32],
in_dim: usize,
out_dim: usize,
output: &mut [f32],
dot_fn: FusedDotFn,
) -> Result<()> {
let super_blocks_per_row = in_dim.div_ceil(F::ELEMENTS_PER_SUPERBLOCK);
let bytes_per_row = super_blocks_per_row * F::SUPERBLOCK_BYTES;
let expected_weight_bytes = out_dim * bytes_per_row;
if weight_data.len() < expected_weight_bytes {
return Err(RealizarError::InvalidShape {
reason: format!(
"{} weight data too small: need {} bytes for {}x{}, have {}",
F::FORMAT_ID,
expected_weight_bytes,
out_dim,
in_dim,
weight_data.len()
),
});
}
if activations.len() != in_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Activation length {} doesn't match in_dim {}",
activations.len(),
in_dim
),
});
}
if output.len() < out_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Output buffer too small: need {}, have {}",
out_dim,
output.len()
),
});
}
let padded_in_dim = super_blocks_per_row * F::ELEMENTS_PER_SUPERBLOCK;
let acts = pad_activations_generic(activations, padded_in_dim);
if out_dim < PARALLEL_THRESHOLD {
for o in 0..out_dim {
let row_start = o * bytes_per_row;
let row_end = row_start + bytes_per_row;
let row_data = &weight_data[row_start..row_end];
output[o] = dot_fn(row_data, &acts).unwrap_or(0.0);
}
} else {
use rayon::prelude::*;
output[..out_dim]
.par_chunks_mut(MIDI_TILE_M)
.enumerate()
.for_each(|(midi_idx, midi_chunk)| {
let midi_start = midi_idx * MIDI_TILE_M;
for (local_idx, out) in midi_chunk.iter_mut().enumerate() {
let row = midi_start + local_idx;
let row_start = row * bytes_per_row;
let row_end = row_start + bytes_per_row;
let row_data = &weight_data[row_start..row_end];
*out = dot_fn(row_data, &acts).unwrap_or(0.0);
}
});
}
Ok(())
}
#[allow(clippy::similar_names)]
pub fn generic_parallel_matvec<F: QuantBlockFormat>(
weight_data: &[u8],
activations: &[f32],
in_dim: usize,
out_dim: usize,
dot_fn: FusedDotFn,
) -> Result<Vec<f32>> {
let mut output = vec![0.0f32; out_dim];
generic_parallel_matvec_into::<F>(
weight_data,
activations,
in_dim,
out_dim,
&mut output,
dot_fn,
)?;
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::quantize::format_trait::{Q4K, Q6K};
use crate::quantize::generic_dot::generic_fused_dot_scalar;
fn create_q4k_test_weights(out_dim: usize, in_dim: usize) -> Vec<u8> {
let super_blocks_per_row = in_dim.div_ceil(256);
let bytes_per_row = super_blocks_per_row * 144;
vec![0u8; out_dim * bytes_per_row]
}
fn q4k_scalar_dot(data: &[u8], acts: &[f32]) -> Result<f32> {
generic_fused_dot_scalar::<Q4K>(data, acts)
}
fn q6k_scalar_dot(data: &[u8], acts: &[f32]) -> Result<f32> {
generic_fused_dot_scalar::<Q6K>(data, acts)
}
#[test]
fn test_generic_matvec_q4k_basic() {
let in_dim = 256;
let out_dim = 64;
let weights = create_q4k_test_weights(out_dim, in_dim);
let acts = vec![1.0f32; in_dim];
let mut output = vec![0.0f32; out_dim];
let result = generic_parallel_matvec_into::<Q4K>(
&weights,
&acts,
in_dim,
out_dim,
&mut output,
q4k_scalar_dot,
);
assert!(result.is_ok());
assert!(output.iter().all(|&v| v == 0.0));
}
#[test]
fn test_generic_matvec_q6k_basic() {
let in_dim: usize = 256;
let out_dim: usize = 32;
let super_blocks_per_row = in_dim.div_ceil(256);
let bytes_per_row = super_blocks_per_row * 210;
let weights = vec![0u8; out_dim * bytes_per_row];
let acts = vec![1.0f32; in_dim];
let mut output = vec![0.0f32; out_dim];
let result = generic_parallel_matvec_into::<Q6K>(
&weights,
&acts,
in_dim,
out_dim,
&mut output,
q6k_scalar_dot,
);
assert!(result.is_ok());
}
#[test]
fn test_generic_matvec_weight_too_small() {
let weights = vec![0u8; 100]; let acts = vec![1.0f32; 256];
let mut output = vec![0.0f32; 64];
let result = generic_parallel_matvec_into::<Q4K>(
&weights,
&acts,
256,
64,
&mut output,
q4k_scalar_dot,
);
assert!(result.is_err());
}
#[test]
fn test_generic_matvec_activation_mismatch() {
let weights = create_q4k_test_weights(64, 256);
let acts = vec![1.0f32; 128]; let mut output = vec![0.0f32; 64];
let result = generic_parallel_matvec_into::<Q4K>(
&weights,
&acts,
256,
64,
&mut output,
q4k_scalar_dot,
);
assert!(result.is_err());
}
#[test]
fn test_generic_matvec_output_too_small() {
let weights = create_q4k_test_weights(64, 256);
let acts = vec![1.0f32; 256];
let mut output = vec![0.0f32; 32];
let result = generic_parallel_matvec_into::<Q4K>(
&weights,
&acts,
256,
64,
&mut output,
q4k_scalar_dot,
);
assert!(result.is_err());
}
#[test]
fn test_generic_matvec_allocating_variant() {
let weights = create_q4k_test_weights(64, 256);
let acts = vec![1.0f32; 256];
let result = generic_parallel_matvec::<Q4K>(&weights, &acts, 256, 64, q4k_scalar_dot);
assert!(result.is_ok());
assert_eq!(result.expect("should succeed").len(), 64);
}
#[test]
fn test_generic_matvec_parallel_threshold() {
let in_dim = 256;
let out_dim = 512;
let weights = create_q4k_test_weights(out_dim, in_dim);
let acts = vec![1.0f32; in_dim];
let mut output = vec![0.0f32; out_dim];
let result = generic_parallel_matvec_into::<Q4K>(
&weights,
&acts,
in_dim,
out_dim,
&mut output,
q4k_scalar_dot,
);
assert!(result.is_ok());
}
#[test]
fn test_generic_matvec_padding() {
let in_dim: usize = 200;
let out_dim: usize = 16;
let super_blocks_per_row = in_dim.div_ceil(256);
let bytes_per_row = super_blocks_per_row * 144;
let weights = vec![0u8; out_dim * bytes_per_row];
let acts = vec![1.0f32; in_dim];
let mut output = vec![0.0f32; out_dim];
let result = generic_parallel_matvec_into::<Q4K>(
&weights,
&acts,
in_dim,
out_dim,
&mut output,
q4k_scalar_dot,
);
assert!(result.is_ok());
}
}