#[allow(clippy::similar_names)]
pub fn fused_q5k_parallel_matvec_into(
weight_data: &[u8],
activations: &[f32],
in_dim: usize,
out_dim: usize,
output: &mut [f32],
) -> Result<()> {
generic_parallel_matvec_into::<Q5K>(
weight_data,
activations,
in_dim,
out_dim,
output,
fused_q5k_dot_simd,
)
}
#[allow(clippy::similar_names)]
pub fn fused_q6k_parallel_matvec(
weight_data: &[u8],
activations: &[f32],
in_dim: usize,
out_dim: usize,
) -> Result<Vec<f32>> {
generic_parallel_matvec::<Q6K>(weight_data, activations, in_dim, out_dim, fused_q6k_dot_simd)
}
#[allow(clippy::similar_names)]
pub fn fused_q6k_parallel_matvec_into(
weight_data: &[u8],
activations: &[f32],
in_dim: usize,
out_dim: usize,
output: &mut [f32],
) -> Result<()> {
generic_parallel_matvec_into::<Q6K>(
weight_data,
activations,
in_dim,
out_dim,
output,
fused_q6k_dot_simd,
)
}
pub fn fused_q4k_q8k_parallel_matvec_into(
weight_data: &[u8],
q8k_scales: &[f32],
q8k_quants: &[i8],
in_dim: usize,
out_dim: usize,
output: &mut [f32],
) -> Result<()> {
use rayon::prelude::*;
const SUPER_BLOCK_BYTES: usize = 144;
let super_blocks_per_row = in_dim.div_ceil(QK_K);
let bytes_per_row = super_blocks_per_row * SUPER_BLOCK_BYTES;
let expected_weight_bytes = out_dim * bytes_per_row;
if weight_data.len() < expected_weight_bytes {
return Err(RealizarError::InvalidShape {
reason: format!(
"Weight data too small: need {} bytes, have {}",
expected_weight_bytes,
weight_data.len()
),
});
}
if output.len() < out_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Output buffer too small: need {}, have {}",
out_dim,
output.len()
),
});
}
let bsums = precompute_q8k_bsums(q8k_quants, super_blocks_per_row).ok();
const MIDI_TILE_M: usize = 64;
const MICRO_TILE_M: usize = 4;
#[cfg(target_arch = "x86_64")]
let use_4row_kernel =
is_x86_feature_detected!("avx512f") && is_x86_feature_detected!("avx512vnni");
#[cfg(not(target_arch = "x86_64"))]
let use_4row_kernel = false;
if use_4row_kernel && out_dim >= MICRO_TILE_M {
output[..out_dim]
.par_chunks_mut(MIDI_TILE_M)
.enumerate()
.for_each(|(midi_idx, midi_chunk)| {
let midi_start = midi_idx * MIDI_TILE_M;
let midi_rows = midi_chunk.len();
let full_micro_tiles = midi_rows / MICRO_TILE_M;
let remainder = midi_rows % MICRO_TILE_M;
for micro_idx in 0..full_micro_tiles {
let row_base = midi_start + micro_idx * MICRO_TILE_M;
let row_ptrs: [*const u8; 4] = [
weight_data.as_ptr().wrapping_add(row_base * bytes_per_row),
weight_data
.as_ptr()
.wrapping_add((row_base + 1) * bytes_per_row),
weight_data
.as_ptr()
.wrapping_add((row_base + 2) * bytes_per_row),
weight_data
.as_ptr()
.wrapping_add((row_base + 3) * bytes_per_row),
];
#[cfg(target_arch = "x86_64")]
let outputs = unsafe {
fused_q4k_q8k_dot_4rows_avx512vnni(
row_ptrs,
bytes_per_row,
q8k_scales,
q8k_quants,
)
};
#[cfg(not(target_arch = "x86_64"))]
let outputs = [0.0f32; 4];
let local_base = micro_idx * MICRO_TILE_M;
midi_chunk[local_base] = outputs[0];
midi_chunk[local_base + 1] = outputs[1];
midi_chunk[local_base + 2] = outputs[2];
midi_chunk[local_base + 3] = outputs[3];
}
for r in 0..remainder {
let row = midi_start + full_micro_tiles * MICRO_TILE_M + r;
let row_start = row * bytes_per_row;
let row_data = &weight_data[row_start..row_start + bytes_per_row];
let local_idx = full_micro_tiles * MICRO_TILE_M + r;
midi_chunk[local_idx] = if let Some(ref bs) = bsums {
fused_q4k_q8k_dot_with_bsums_simd(
row_data, q8k_scales, q8k_quants, bs,
)
.unwrap_or(0.0)
} else {
fused_q4k_q8k_dot_simd(row_data, q8k_scales, q8k_quants).unwrap_or(0.0)
};
}
});
} else {
output[..out_dim]
.par_chunks_mut(MIDI_TILE_M)
.enumerate()
.for_each(|(midi_idx, midi_chunk)| {
let midi_start = midi_idx * MIDI_TILE_M;
for (local_idx, out) in midi_chunk.iter_mut().enumerate() {
let row = midi_start + local_idx;
let row_start = row * bytes_per_row;
let row_data = &weight_data[row_start..row_start + bytes_per_row];
*out = if let Some(ref bs) = bsums {
fused_q4k_q8k_dot_with_bsums_simd(
row_data, q8k_scales, q8k_quants, bs,
)
.unwrap_or(0.0)
} else {
fused_q4k_q8k_dot_simd(row_data, q8k_scales, q8k_quants).unwrap_or(0.0)
};
}
});
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn fused_q4k_q8k_ffn_up_gate_into(
up_weight: &[u8],
gate_weight: &[u8],
q8k_scales: &[f32],
q8k_quants: &[i8],
in_dim: usize,
out_dim: usize,
up_output: &mut [f32],
gate_output: &mut [f32],
) -> Result<()> {
use rayon::prelude::*;
const SUPER_BLOCK_BYTES: usize = 144;
const MIDI_TILE_M: usize = 64;
let super_blocks_per_row = in_dim.div_ceil(QK_K);
let bytes_per_row = super_blocks_per_row * SUPER_BLOCK_BYTES;
let expected_weight_bytes = out_dim * bytes_per_row;
if up_weight.len() < expected_weight_bytes || gate_weight.len() < expected_weight_bytes {
return Err(RealizarError::InvalidShape {
reason: format!(
"Weight data too small: need {} bytes",
expected_weight_bytes
),
});
}
if up_output.len() < out_dim || gate_output.len() < out_dim {
return Err(RealizarError::InvalidShape {
reason: format!("Output buffers too small: need {}", out_dim),
});
}
let bsums = precompute_q8k_bsums(q8k_quants, super_blocks_per_row).ok();
up_output[..out_dim]
.par_chunks_mut(MIDI_TILE_M)
.zip(gate_output[..out_dim].par_chunks_mut(MIDI_TILE_M))
.enumerate()
.for_each(|(midi_idx, (up_chunk, gate_chunk))| {
let midi_start = midi_idx * MIDI_TILE_M;
for (local_row, (up_out, gate_out)) in
up_chunk.iter_mut().zip(gate_chunk.iter_mut()).enumerate()
{
let row = midi_start + local_row;
let row_start = row * bytes_per_row;
let up_row = &up_weight[row_start..row_start + bytes_per_row];
*up_out = if let Some(ref bs) = bsums {
fused_q4k_q8k_dot_with_bsums_simd(up_row, q8k_scales, q8k_quants, bs)
.unwrap_or(0.0)
} else {
fused_q4k_q8k_dot_simd(up_row, q8k_scales, q8k_quants).unwrap_or(0.0)
};
let gate_row = &gate_weight[row_start..row_start + bytes_per_row];
*gate_out = if let Some(ref bs) = bsums {
fused_q4k_q8k_dot_with_bsums_simd(gate_row, q8k_scales, q8k_quants, bs)
.unwrap_or(0.0)
} else {
fused_q4k_q8k_dot_simd(gate_row, q8k_scales, q8k_quants).unwrap_or(0.0)
};
}
});
Ok(())
}