use crate::error::{RealizarError, Result};
pub mod activation;
pub mod bsum_precompute;
pub mod contract_tests;
pub mod dequant;
pub mod encode;
pub mod format_trait;
pub mod fused_gate_up;
pub mod fused_k;
pub mod fused_q5k_q6k;
pub(crate) mod gemv_pool;
pub mod generic_dot;
pub mod generic_matvec;
pub mod parallel_dequant;
pub mod parallel_k;
pub mod simd;
pub mod types;
pub use types::{
detect_simd_backend, DequantStats, Q4_0Block, Q4_KBlock, Q5_KBlock, Q6_KBlock, Q8KSuperBlock,
Q8_0Block, SimdBackend, BLOCK_SIZE, QK_K,
};
pub use dequant::{
dequantize_f16, dequantize_q2_k, dequantize_q4_0, dequantize_q4_1, dequantize_q4_k,
dequantize_q5_0, dequantize_q5_1, dequantize_q5_k, dequantize_q6_k, dequantize_q8_0,
f16_to_f32,
};
pub use fused_k::{fused_q4k_dot, fused_q4k_dot_simd, fused_q4k_q8k_dot, fused_q4k_q8k_dot_simd};
pub use fused_q5k_q6k::{
fused_q4k_q8_dot, fused_q5k_dot, fused_q5k_dot_simd, fused_q6k_dot, fused_q6k_dot_simd,
};
pub use parallel_k::{
fused_q4k_parallel_matvec, fused_q4k_parallel_matvec_into, fused_q4k_q8k_ffn_up_gate_into,
fused_q4k_q8k_parallel_matvec_into, fused_q4k_tiled_matvec, fused_q5k_parallel_matvec,
fused_q5k_parallel_matvec_into, fused_q6k_parallel_matvec, fused_q6k_parallel_matvec_into,
};
pub use activation::{
fused_rmsnorm_ffn_up_gate, fused_rmsnorm_q4_0_matmul, fused_swiglu_simd,
quantize_activations_q8_0, quantize_rmsnorm_q8_0, quantize_rmsnorm_q8_0_into, softmax_simd,
};
pub use parallel_dequant::{
apply_rope_rotation_simd, dequantize_q4_k_parallel, dequantize_q4_k_simd,
dequantize_q8_0_parallel, dequantize_q8_0_simd,
};
pub use simd::{extract_scale_min, extract_scale_min_from_slice, read_f16};
pub use format_trait::{Q4_0Fmt, Q8_0Fmt, QuantBlockFormat, QuantFamily, Q4K, Q5K, Q6K};
pub use generic_dot::{compute_bsums, generic_fused_dot_scalar};
pub use generic_matvec::{generic_parallel_matvec, generic_parallel_matvec_into};
pub use fused_gate_up::{
fused_gate_up_q4k_into, fused_gate_up_q5k_into, fused_gate_up_q6k_into,
generic_fused_gate_up_matvec_into,
};
pub use bsum_precompute::{fused_q4k_q8k_parallel_matvec_with_bsums_into, precompute_q8k_bsums};
pub use encode::{
dequantize_q4_k_to_f32,
dequantize_q5_k_to_f32,
dequantize_q6_k_to_f32,
quantize_q4_k,
quantize_q4_k_matrix,
quantize_q5_k,
quantize_q5_k_matrix,
quantize_q6_k,
quantize_q6_k_matrix,
transpose_q4k_for_matmul,
transpose_q5k_for_matmul,
transpose_q6k_for_matmul,
F16_MIN_NORMAL,
};
static F16_TO_F32_LUT: std::sync::LazyLock<Box<[f32; 65536]>> = std::sync::LazyLock::new(|| {
let mut lut = Box::new([0.0f32; 65536]);
for i in 0..65536u32 {
lut[i as usize] = half::f16::from_bits(i as u16).to_f32();
}
lut
});
#[inline]
pub(crate) fn f16_to_f32_lut(bits: u16) -> f32 {
F16_TO_F32_LUT[bits as usize]
}
pub fn quantize_activations_q8k_into(
activations: &[f32],
scales: &mut [f32],
quants: &mut [i8],
) -> Result<()> {
if !activations.len().is_multiple_of(256) {
return Err(RealizarError::FormatError {
reason: format!(
"Q8_K quantization requires length multiple of 256, got {}",
activations.len()
),
});
}
let num_superblocks = activations.len() / 256;
if scales.len() < num_superblocks {
return Err(RealizarError::InvalidShape {
reason: format!(
"Scales buffer too small: need {}, have {}",
num_superblocks,
scales.len()
),
});
}
if quants.len() < activations.len() {
return Err(RealizarError::InvalidShape {
reason: format!(
"Quants buffer too small: need {}, have {}",
activations.len(),
quants.len()
),
});
}
for (sb_idx, chunk) in activations.chunks_exact(256).enumerate() {
Q8KSuperBlock::quantize_into(
chunk,
&mut scales[sb_idx],
&mut quants[sb_idx * 256..(sb_idx + 1) * 256],
);
}
Ok(())
}
pub fn quantize_to_q8_blocks(values: &[f32]) -> Result<Vec<Q8_0Block>> {
if !values.len().is_multiple_of(32) {
return Err(RealizarError::FormatError {
reason: format!(
"Q8_0 quantization requires length multiple of 32, got {}",
values.len()
),
});
}
let blocks: Vec<Q8_0Block> = values
.chunks_exact(32)
.map(|chunk| {
let arr: [f32; 32] = chunk.try_into().expect("chunk is exactly 32 elements");
Q8_0Block::quantize(&arr)
})
.collect();
Ok(blocks)
}
pub fn dequantize_q8_blocks(blocks: &[Q8_0Block]) -> Vec<f32> {
let mut output = Vec::with_capacity(blocks.len() * 32);
for block in blocks {
output.extend_from_slice(&block.dequantize());
}
output
}
#[derive(Debug, Clone)]
pub struct InterleavedQ4K {
pub d: Vec<f32>,
pub dmin: Vec<f32>,
pub scales: Vec<u8>,
pub qs: Vec<u8>,
pub num_super_blocks: usize,
}
include!("product.rs");
include!("q4_0.rs");
include!("fused_q4_0_q8_0.rs");
include!("fused_q8_0_q8_0.rs");