#![allow(clippy::too_many_arguments)]
#![allow(clippy::similar_names)]
use super::{AprKVCache, AprTransformerConfig};
use crate::error::{RealizarError, Result};
#[derive(Debug, Clone)]
pub struct QuantizedAprTensorQ4 {
pub data: Vec<u8>,
pub in_dim: usize,
pub out_dim: usize,
}
impl QuantizedAprTensorQ4 {
#[must_use]
pub fn new(data: Vec<u8>, in_dim: usize, out_dim: usize) -> Self {
Self {
data,
in_dim,
out_dim,
}
}
#[must_use]
pub fn zeros(in_dim: usize, out_dim: usize) -> Self {
const Q4_0_BLOCK_BYTES: usize = 18;
const Q4_0_BLOCK_SIZE: usize = 32;
let num_elements = in_dim * out_dim;
let num_blocks = num_elements.div_ceil(Q4_0_BLOCK_SIZE);
let data = vec![0u8; num_blocks * Q4_0_BLOCK_BYTES];
Self {
data,
in_dim,
out_dim,
}
}
#[must_use]
pub fn expected_bytes(num_elements: usize) -> usize {
const Q4_0_BLOCK_BYTES: usize = 18;
const Q4_0_BLOCK_SIZE: usize = 32;
let num_blocks = num_elements.div_ceil(Q4_0_BLOCK_SIZE);
num_blocks * Q4_0_BLOCK_BYTES
}
}
#[derive(Debug, Clone)]
pub struct QuantizedAprLayerQ4 {
pub attn_norm_weight: Vec<f32>,
pub qkv_weight: QuantizedAprTensorQ4,
pub attn_output_weight: QuantizedAprTensorQ4,
pub ffn_up_weight: QuantizedAprTensorQ4,
pub ffn_down_weight: QuantizedAprTensorQ4,
pub ffn_gate_weight: Option<QuantizedAprTensorQ4>,
pub ffn_norm_weight: Option<Vec<f32>>,
}
#[derive(Debug, Clone)]
pub struct QuantizedAprTransformerQ4 {
pub config: AprTransformerConfig,
pub token_embedding: Vec<f32>,
pub layers: Vec<QuantizedAprLayerQ4>,
pub output_norm_weight: Vec<f32>,
pub lm_head_weight: QuantizedAprTensorQ4,
}
#[derive(Debug)]
pub struct AprInferenceScratch {
pub hidden: Vec<f32>,
pub normed: Vec<f32>,
pub qkv_out: Vec<f32>,
pub q: Vec<f32>,
pub k: Vec<f32>,
pub v: Vec<f32>,
pub attn_out: Vec<f32>,
pub ffn_input: Vec<f32>,
pub ffn_up: Vec<f32>,
pub ffn_gate: Vec<f32>,
pub ffn_out: Vec<f32>,
}
impl AprInferenceScratch {
#[must_use]
pub fn from_config(config: &AprTransformerConfig) -> Self {
let hidden_dim = config.hidden_dim;
let qkv_dim = hidden_dim * 3; let intermediate_dim = config.intermediate_dim;
Self {
hidden: vec![0.0; hidden_dim],
normed: vec![0.0; hidden_dim],
qkv_out: vec![0.0; qkv_dim],
q: vec![0.0; hidden_dim],
k: vec![0.0; hidden_dim],
v: vec![0.0; hidden_dim],
attn_out: vec![0.0; hidden_dim],
ffn_input: vec![0.0; hidden_dim],
ffn_up: vec![0.0; intermediate_dim],
ffn_gate: vec![0.0; intermediate_dim],
ffn_out: vec![0.0; hidden_dim],
}
}
pub fn clear(&mut self) {
self.hidden.fill(0.0);
self.normed.fill(0.0);
self.qkv_out.fill(0.0);
self.q.fill(0.0);
self.k.fill(0.0);
self.v.fill(0.0);
self.attn_out.fill(0.0);
self.ffn_input.fill(0.0);
self.ffn_up.fill(0.0);
self.ffn_gate.fill(0.0);
self.ffn_out.fill(0.0);
}
}
include!("attention_kernels.rs");
include!("q4_simd_tests.rs");