use super::config::GGUFConfig;
const CACHE_LINE_BYTES: usize = 64;
const FLOATS_PER_CACHE_LINE: usize = CACHE_LINE_BYTES / std::mem::size_of::<f32>();
#[derive(Debug)]
pub struct InferenceScratchBuffer {
pub hidden: Vec<f32>,
pub normed: Vec<f32>,
pub qkv: Vec<f32>,
pub q: Vec<f32>,
pub k: Vec<f32>,
pub v: Vec<f32>,
pub attn_out: Vec<f32>,
pub attn_proj: Vec<f32>,
pub ffn_up: Vec<f32>,
pub ffn_gate: Vec<f32>,
pub ffn_down: Vec<f32>,
pub logits: Vec<f32>,
pub q8k_hidden_scales: Vec<f32>,
pub q8k_hidden_quants: Vec<i8>,
pub q8k_inter_scales: Vec<f32>,
pub q8k_inter_quants: Vec<i8>,
}
impl InferenceScratchBuffer {
#[must_use]
pub fn from_config(config: &GGUFConfig) -> Self {
let hidden_dim = config.hidden_dim;
let intermediate_dim = config.intermediate_dim;
let vocab_size = config.vocab_size;
let q_dim = config.q_dim();
let kv_dim = config.kv_dim();
let qkv_dim = q_dim + 2 * kv_dim;
const QK_K: usize = 256;
let q8k_attn_dim = q_dim.max(hidden_dim);
let q8k_hidden_padded = q8k_attn_dim.div_ceil(QK_K) * QK_K;
let q8k_inter_padded = intermediate_dim.div_ceil(QK_K) * QK_K;
Self {
hidden: vec![0.0; hidden_dim],
normed: vec![0.0; hidden_dim],
qkv: vec![0.0; qkv_dim],
q: vec![0.0; q_dim],
k: vec![0.0; kv_dim],
v: vec![0.0; kv_dim],
attn_out: vec![0.0; q_dim],
attn_proj: vec![0.0; hidden_dim],
ffn_up: vec![0.0; intermediate_dim],
ffn_gate: vec![0.0; intermediate_dim],
ffn_down: vec![0.0; hidden_dim],
logits: vec![0.0; vocab_size],
q8k_hidden_scales: vec![0.0f32; q8k_hidden_padded / QK_K],
q8k_hidden_quants: vec![0i8; q8k_hidden_padded],
q8k_inter_scales: vec![0.0f32; q8k_inter_padded / QK_K],
q8k_inter_quants: vec![0i8; q8k_inter_padded],
}
}
#[inline]
pub fn reset(&mut self) {
self.hidden.iter_mut().for_each(|x| *x = 0.0);
self.normed.iter_mut().for_each(|x| *x = 0.0);
}
}
#[derive(Debug)]
pub struct OwnedInferenceScratchBuffer {
pub qkv: Vec<f32>,
pub attn_out: Vec<f32>,
pub ffn_up: Vec<f32>,
pub ffn_gate: Vec<f32>,
pub ffn_down: Vec<f32>,
pub expanded_v: Vec<f32>,
pub logits: Vec<f32>,
pub q8_scales: Vec<f32>,
pub q8_quants: Vec<i8>,
pub q8k_hidden_scales: Vec<f32>,
pub q8k_hidden_quants: Vec<i8>,
pub q8k_inter_scales: Vec<f32>,
pub q8k_inter_quants: Vec<i8>,
}
impl OwnedInferenceScratchBuffer {
#[must_use]
pub fn from_config(config: &GGUFConfig) -> Self {
let hidden_dim = config.hidden_dim;
let q_dim = config.q_dim();
let kv_dim = config.kv_dim();
let qkv_dim = q_dim + 2 * kv_dim;
let intermediate_dim = hidden_dim * 6; let num_blocks = q_dim.max(hidden_dim).div_ceil(32);
const QK_K: usize = 256;
let q8k_attn_dim = q_dim.max(hidden_dim);
let q8k_hidden_padded = q8k_attn_dim.div_ceil(QK_K) * QK_K;
let q8k_inter_padded = intermediate_dim.div_ceil(QK_K) * QK_K;
Self {
qkv: vec![0.0f32; qkv_dim],
attn_out: vec![0.0f32; q_dim],
ffn_up: vec![0.0f32; intermediate_dim],
ffn_gate: vec![0.0f32; intermediate_dim],
ffn_down: vec![0.0f32; hidden_dim],
expanded_v: vec![0.0f32; q_dim],
logits: vec![0.0f32; config.vocab_size],
q8_scales: vec![0.0f32; num_blocks],
q8_quants: vec![0i8; num_blocks * 32],
q8k_hidden_scales: vec![0.0f32; q8k_hidden_padded / QK_K],
q8k_hidden_quants: vec![0i8; q8k_hidden_padded],
q8k_inter_scales: vec![0.0f32; q8k_inter_padded / QK_K],
q8k_inter_quants: vec![0i8; q8k_inter_padded],
}
}
pub fn reset(&mut self) {
self.qkv.clear();
self.attn_out.clear();
self.ffn_up.clear();
self.ffn_gate.clear();
self.ffn_down.clear();
self.expanded_v.clear();
self.logits.clear();
self.q8_scales.clear();
self.q8_quants.clear();
self.q8k_hidden_scales.clear();
self.q8k_hidden_quants.clear();
self.q8k_inter_scales.clear();
self.q8k_inter_quants.clear();
}
}
#[derive(Debug)]
pub struct ContiguousKVCache {
num_layers: usize,
hidden_dim: usize,
max_seq_len: usize,
seq_len: usize,
layer_stride: usize,
k_data: Vec<f32>,
v_data: Vec<f32>,
}
impl ContiguousKVCache {
#[must_use]
pub fn new(num_layers: usize, hidden_dim: usize, max_seq_len: usize) -> Self {
let raw_layer_size = max_seq_len * hidden_dim;
let layer_stride = Self::align_to_cache_line(raw_layer_size);
let total_size = num_layers * layer_stride;
Self {
num_layers,
hidden_dim,
max_seq_len,
seq_len: 0,
layer_stride,
k_data: vec![0.0f32; total_size],
v_data: vec![0.0f32; total_size],
}
}
#[inline]
fn align_to_cache_line(size: usize) -> usize {
let remainder = size % FLOATS_PER_CACHE_LINE;
if remainder == 0 {
size
} else {
size + FLOATS_PER_CACHE_LINE - remainder
}
}
#[must_use]
pub fn from_config(config: &GGUFConfig, max_seq_len: usize) -> Self {
Self::new(config.num_layers, config.hidden_dim, max_seq_len)
}
#[must_use]
pub const fn is_contiguous(&self) -> bool {
true
}
#[must_use]
pub fn is_cache_aligned(&self) -> bool {
self.layer_stride.is_multiple_of(FLOATS_PER_CACHE_LINE)
}
#[must_use]
pub fn layer_stride(&self) -> usize {
self.layer_stride
}
#[inline]
fn layer_offset(&self, layer: usize) -> usize {
layer * self.layer_stride
}
pub fn append(&mut self, layer: usize, k: &[f32], v: &[f32]) {
if layer >= self.num_layers || self.seq_len >= self.max_seq_len {
return;
}
let start = self.layer_offset(layer) + self.seq_len * self.hidden_dim;
let end = start + self.hidden_dim;
if end <= self.k_data.len() {
self.k_data[start..end].copy_from_slice(k);
self.v_data[start..end].copy_from_slice(v);
}
}
pub fn advance(&mut self) {
if self.seq_len < self.max_seq_len {
self.seq_len += 1;
}
}
#[must_use]
pub fn get_k(&self, layer: usize) -> &[f32] {
if layer >= self.num_layers {
return &[];
}
let start = self.layer_offset(layer);
&self.k_data[start..start + self.seq_len * self.hidden_dim]
}
#[must_use]
pub fn get_v(&self, layer: usize) -> &[f32] {
if layer >= self.num_layers {
return &[];
}
let start = self.layer_offset(layer);
&self.v_data[start..start + self.seq_len * self.hidden_dim]
}
pub fn get_k_mut(&mut self, layer: usize) -> &mut [f32] {
if layer >= self.num_layers {
return &mut [];
}
let start = self.layer_offset(layer);
let len = self.seq_len * self.hidden_dim;
&mut self.k_data[start..start + len]
}
pub fn get_v_mut(&mut self, layer: usize) -> &mut [f32] {
if layer >= self.num_layers {
return &mut [];
}
let start = self.layer_offset(layer);
let len = self.seq_len * self.hidden_dim;
&mut self.v_data[start..start + len]
}
#[must_use]
pub fn len(&self) -> usize {
self.seq_len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.seq_len == 0
}
pub fn reset(&mut self) {
self.seq_len = 0;
}
pub fn reset_and_zero(&mut self) {
self.seq_len = 0;
self.k_data.fill(0.0);
self.v_data.fill(0.0);
}
#[must_use]
pub fn max_len(&self) -> usize {
self.max_seq_len
}
#[must_use]
pub fn memory_bytes(&self) -> usize {
(self.k_data.len() + self.v_data.len()) * std::mem::size_of::<f32>()
}
#[inline]
pub fn prefetch_k(&self, layer: usize) {
if layer < self.num_layers {
let _ = self.k_data.get(self.layer_offset(layer));
}
}
#[inline]
pub fn prefetch_v(&self, layer: usize) {
if layer < self.num_layers {
let _ = self.v_data.get(self.layer_offset(layer));
}
}
}
#[derive(Debug)]
pub struct DispatchMetrics {
cpu_dispatches: std::sync::atomic::AtomicUsize,
gpu_dispatches: std::sync::atomic::AtomicUsize,
cpu_latency_count: std::sync::atomic::AtomicUsize,
cpu_latency_sum_us: std::sync::atomic::AtomicU64,
gpu_latency_count: std::sync::atomic::AtomicUsize,
gpu_latency_sum_us: std::sync::atomic::AtomicU64,
cpu_latency_buckets: [std::sync::atomic::AtomicUsize; 5],
gpu_latency_buckets: [std::sync::atomic::AtomicUsize; 5],
cpu_latency_min_us: std::sync::atomic::AtomicU64,
cpu_latency_max_us: std::sync::atomic::AtomicU64,
gpu_latency_min_us: std::sync::atomic::AtomicU64,
gpu_latency_max_us: std::sync::atomic::AtomicU64,
cpu_latency_sum_sq_us: std::sync::atomic::AtomicU64,
gpu_latency_sum_sq_us: std::sync::atomic::AtomicU64,
start_time_ms: std::sync::atomic::AtomicU64,
}
include!("latency.rs");
include!("inference_types_config_default.rs");