use std::fs::File;
use std::path::Path;
use serde::{Deserialize, Serialize};
use crate::error::{RealizarError, Result};
mod config;
mod convert;
mod dequant;
mod generation;
mod helpers;
mod loader;
mod q4_simd;
pub use config::{
AprKVCache, AprTransformerConfig, AprTransformerLayer, GenerateConfig, Q4KLayerWeights,
};
use dequant::{
dequantize_apr_q4_native, dequantize_apr_q8_native, dequantize_q4_k_apr, dequantize_q6_k_apr,
dequantize_q8_0_apr, f16_to_f32,
};
use helpers::{matmul_q4k_rowmajor, matmul_q6k_rowmajor, simd_add_weighted, simd_dot_f32};
pub use loader::{
AprQuantizationType, MmapAprTransformer, QuantizedAprTransformer, APR_TRANSFORMER_HEADER_SIZE,
};
pub use q4_simd::{
AprInferenceScratch, QuantizedAprLayerQ4, QuantizedAprTensorQ4, QuantizedAprTransformerQ4,
};
mod benchmark;
pub use benchmark::{
AprBenchmarkResult, AprBenchmarkRunner, AprLoadResult, AprParityComparison, AprPrefillResult,
APR_CPU_DECODE_THRESHOLD_TOK_S, APR_PARITY_THRESHOLD_PCT, APR_PREFILL_THRESHOLD_TOK_S,
};
fn dequant_perrow(
data: &[u8],
dims: &[usize],
block_elems: usize,
block_bytes: usize,
dequant_block: impl Fn(&[u8], &mut [f32]),
) -> Vec<f32> {
let rows = dims[0];
let cols = dims[1];
let blocks_per_row = cols.div_ceil(block_elems);
let bytes_per_row = blocks_per_row * block_bytes;
let mut result = Vec::with_capacity(rows * cols);
for row in 0..rows {
let row_start = row * bytes_per_row;
if row_start + bytes_per_row > data.len() {
result.resize(rows * cols, 0.0);
return result;
}
let mut row_values = vec![0.0f32; blocks_per_row * block_elems];
for b in 0..blocks_per_row {
let block_start = row_start + b * block_bytes;
let out_start = b * block_elems;
dequant_block(
&data[block_start..block_start + block_bytes],
&mut row_values[out_start..out_start + block_elems],
);
}
result.extend_from_slice(&row_values[..cols]);
}
result
}
fn dequant_q6k_block(block: &[u8], out: &mut [f32]) {
let ql = block.get(0..128).expect("Q6K block requires 128 ql bytes");
let qh = block
.get(128..192)
.expect("Q6K block requires 64 qh bytes at offset 128");
let mut scales = [0i8; 16];
#[allow(clippy::cast_possible_wrap)]
for (i, s) in scales.iter_mut().enumerate() {
*s = block[192 + i] as i8;
}
let d = dequant::f16_to_f32(u16::from_le_bytes([block[208], block[209]]));
for n in (0..256).step_by(128) {
let idx = n / 128;
let sc = &scales[8 * idx..];
let ql_s = &ql[64 * idx..];
let qh_s = &qh[32 * idx..];
for l in 0..32 {
let is = l / 16;
let q1 = ((ql_s[l] & 0xF) | ((qh_s[l] & 3) << 4)) as i32 - 32;
let q2 = ((ql_s[l + 32] & 0xF) | (((qh_s[l] >> 2) & 3) << 4)) as i32 - 32;
let q3 = ((ql_s[l] >> 4) | (((qh_s[l] >> 4) & 3) << 4)) as i32 - 32;
let q4 = ((ql_s[l + 32] >> 4) | (((qh_s[l] >> 6) & 3) << 4)) as i32 - 32;
out[n + l] = d * (sc[is] as f32) * (q1 as f32);
out[n + l + 32] = d * (sc[is + 2] as f32) * (q2 as f32);
out[n + l + 64] = d * (sc[is + 4] as f32) * (q3 as f32);
out[n + l + 96] = d * (sc[is + 6] as f32) * (q4 as f32);
}
}
}
fn dequant_q4k_block(block: &[u8], out: &mut [f32]) {
let d = dequant::f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
let dmin = dequant::f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
let scales = block
.get(4..16)
.expect("Q4K block requires 12 scale bytes at offset 4");
let qs = block
.get(16..144)
.expect("Q4K block requires 128 qs bytes at offset 16");
let mut ys_index = 0;
for j in (0..256).step_by(64) {
let q = &qs[j / 2..j / 2 + 32];
let is = j / 32;
let (sc1, m1) = dequant::extract_scale_min_apr(scales, is);
let d1 = d * sc1;
let dm1 = dmin * m1;
let (sc2, m2) = dequant::extract_scale_min_apr(scales, is + 1);
let d2 = d * sc2;
let dm2 = dmin * m2;
for &byte in q {
out[ys_index] = d1 * (byte & 0xF) as f32 - dm1;
ys_index += 1;
}
for &byte in q {
out[ys_index] = d2 * (byte >> 4) as f32 - dm2;
ys_index += 1;
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ActivationStats {
pub min: f32,
pub max: f32,
pub mean: f32,
pub std_dev: f32,
pub nan_count: usize,
pub inf_count: usize,
pub zero_count: usize,
pub count: usize,
}
impl ActivationStats {
#[must_use]
pub fn from_slice(data: &[f32]) -> Self {
if data.is_empty() {
return Self::default();
}
let count = data.len();
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let mut sum = 0.0f64;
let mut nan_count = 0;
let mut inf_count = 0;
let mut zero_count = 0;
for &v in data {
if v.is_nan() {
nan_count += 1;
continue;
}
if v.is_infinite() {
inf_count += 1;
continue;
}
if v == 0.0 {
zero_count += 1;
}
min = min.min(v);
max = max.max(v);
sum += v as f64;
}
let valid_count = count - nan_count - inf_count;
let mean = if valid_count > 0 {
(sum / valid_count as f64) as f32
} else {
0.0
};
let mut var_sum = 0.0f64;
for &v in data {
if !v.is_nan() && !v.is_infinite() {
let diff = v as f64 - mean as f64;
var_sum += diff * diff;
}
}
let std_dev = if valid_count > 1 {
((var_sum / (valid_count - 1) as f64).sqrt()) as f32
} else {
0.0
};
Self {
min,
max,
mean,
std_dev,
nan_count,
inf_count,
zero_count,
count,
}
}
}
#[derive(Debug, Clone)]
pub struct LayerActivation {
pub layer_idx: usize,
pub attn_norm_stats: ActivationStats,
pub qkv_stats: ActivationStats,
pub attn_out_stats: ActivationStats,
pub ffn_norm_stats: ActivationStats,
pub ffn_out_stats: ActivationStats,
pub output_stats: ActivationStats,
}
#[derive(Debug, Clone)]
pub struct ForwardTrace {
pub input_tokens: Vec<u32>,
pub embed_stats: ActivationStats,
pub layer_activations: Vec<LayerActivation>,
pub final_norm_stats: ActivationStats,
pub logits_stats: ActivationStats,
pub logits: Vec<f32>,
}
pub trait TracedForward {
fn forward_traced(&mut self, tokens: &[u32]) -> Result<ForwardTrace>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AprTransformer {
pub config: AprTransformerConfig,
pub token_embedding: Vec<f32>,
pub layers: Vec<AprTransformerLayer>,
pub output_norm_weight: Vec<f32>,
pub output_norm_bias: Option<Vec<f32>>,
pub lm_head_weight: Vec<f32>,
pub lm_head_bias: Option<Vec<f32>>,
#[serde(default)]
pub q4k_layers: Option<Vec<Q4KLayerWeights>>,
#[serde(default)]
pub lm_head_weight_q6k: Option<Vec<u8>>,
#[serde(default)]
pub lm_head_weight_q4k: Option<Vec<u8>>,
}
include!("generation_delegates.rs");
include!("traced_forward.rs");