#![allow(clippy::too_many_arguments)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_possible_wrap)]
#![allow(non_camel_case_types)]
use std::fs::File;
use std::path::Path;
use memmap2::Mmap;
use serde::{Deserialize, Serialize};
use crate::apr::MAGIC;
use crate::error::{RealizarError, Result};
use super::{AprKVCache, AprTransformer, AprTransformerConfig};
pub const APR_TRANSFORMER_HEADER_SIZE: usize = 64;
#[derive(Debug)]
pub struct MmapAprTransformer {
mmap: Mmap,
pub config: AprTransformerConfig,
tensor_data_offset: usize,
is_mmap: bool,
}
impl MmapAprTransformer {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = File::open(path.as_ref()).map_err(|e| RealizarError::IoError {
message: format!("Failed to open APR file: {e}"),
})?;
let mmap = unsafe {
Mmap::map(&file).map_err(|e| RealizarError::IoError {
message: format!("Failed to mmap APR file: {e}"),
})?
};
if mmap.len() < APR_TRANSFORMER_HEADER_SIZE {
return Err(RealizarError::FormatError {
reason: format!(
"APR file too small: {} bytes (need at least {})",
mmap.len(),
APR_TRANSFORMER_HEADER_SIZE
),
});
}
let header_bytes = &mmap[..APR_TRANSFORMER_HEADER_SIZE];
let magic = header_bytes
.get(0..4)
.expect("APR header validated to have required size above");
if magic != MAGIC {
return Err(RealizarError::FormatError {
reason: format!("Invalid APR magic: expected {:?}, got {:?}", MAGIC, magic),
});
}
let version = u32::from_le_bytes([
header_bytes[4],
header_bytes[5],
header_bytes[6],
header_bytes[7],
]);
if version > 1 {
return Err(RealizarError::FormatError {
reason: format!("Unsupported APR version: {version}"),
});
}
let hidden_dim = u32::from_le_bytes([
header_bytes[8],
header_bytes[9],
header_bytes[10],
header_bytes[11],
]) as usize;
let num_layers = u32::from_le_bytes([
header_bytes[12],
header_bytes[13],
header_bytes[14],
header_bytes[15],
]) as usize;
let num_heads = u32::from_le_bytes([
header_bytes[16],
header_bytes[17],
header_bytes[18],
header_bytes[19],
]) as usize;
let num_kv_heads = u32::from_le_bytes([
header_bytes[20],
header_bytes[21],
header_bytes[22],
header_bytes[23],
]) as usize;
let vocab_size = u32::from_le_bytes([
header_bytes[24],
header_bytes[25],
header_bytes[26],
header_bytes[27],
]) as usize;
let intermediate_dim = u32::from_le_bytes([
header_bytes[28],
header_bytes[29],
header_bytes[30],
header_bytes[31],
]) as usize;
let context_length = u32::from_le_bytes([
header_bytes[32],
header_bytes[33],
header_bytes[34],
header_bytes[35],
]) as usize;
let rope_theta = f32::from_le_bytes([
header_bytes[36],
header_bytes[37],
header_bytes[38],
header_bytes[39],
]);
let eps = f32::from_le_bytes([
header_bytes[40],
header_bytes[41],
header_bytes[42],
header_bytes[43],
]);
let tensor_data_offset = u32::from_le_bytes([
header_bytes[44],
header_bytes[45],
header_bytes[46],
header_bytes[47],
]) as usize;
let config = AprTransformerConfig {
architecture: "apr".to_string(),
hidden_dim,
num_layers,
num_heads,
num_kv_heads,
vocab_size,
intermediate_dim,
context_length,
rope_theta,
eps,
eos_token_id: None, ..Default::default()
};
Ok(Self {
mmap,
config,
tensor_data_offset,
is_mmap: true,
})
}
#[must_use]
pub fn is_mmap(&self) -> bool {
self.is_mmap
}
pub fn get_tensor_bytes(&self, offset: usize, len: usize) -> Result<&[u8]> {
let start = self.tensor_data_offset + offset;
let end = start + len;
if end > self.mmap.len() {
return Err(RealizarError::FormatError {
reason: format!(
"Tensor access out of bounds: offset={offset}, len={len}, file_size={}",
self.mmap.len()
),
});
}
Ok(&self.mmap[start..end])
}
pub fn get_tensor_f32(&self, offset: usize, num_elements: usize) -> Result<Vec<f32>> {
let bytes = self.get_tensor_bytes(offset, num_elements * 4)?;
let floats: Vec<f32> = bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Ok(floats)
}
#[must_use]
pub fn file_size(&self) -> usize {
self.mmap.len()
}
#[must_use]
pub fn num_parameters(&self) -> usize {
let hidden = self.config.hidden_dim;
let vocab = self.config.vocab_size;
let layers = self.config.num_layers;
let intermediate = self.config.intermediate_dim;
let embed_params = vocab * hidden * 2;
let layer_params = hidden
+ (hidden * 3 * hidden)
+ (hidden * hidden)
+ (hidden * intermediate)
+ (intermediate * hidden);
let norm_params = hidden;
embed_params + (layers * layer_params) + norm_params
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[allow(non_camel_case_types)] pub enum AprQuantizationType {
#[default]
F32,
Q4_K,
Q8_0,
}
impl AprQuantizationType {
#[must_use]
pub fn bits_per_weight(&self) -> f64 {
match self {
Self::F32 => 32.0,
Self::Q4_K => 4.5, Self::Q8_0 => 8.0, }
}
#[must_use]
pub fn bytes_per_block(&self) -> usize {
match self {
Self::F32 => 4, Self::Q4_K => 144, Self::Q8_0 => 36, }
}
#[must_use]
pub fn values_per_block(&self) -> usize {
match self {
Self::F32 => 1,
Self::Q4_K => 256,
Self::Q8_0 => 32,
}
}
#[must_use]
pub fn to_byte(&self) -> u8 {
match self {
Self::F32 => 0,
Self::Q4_K => 1,
Self::Q8_0 => 2,
}
}
#[must_use]
pub fn from_byte(byte: u8) -> Option<Self> {
match byte {
0 => Some(Self::F32),
1 => Some(Self::Q4_K),
2 => Some(Self::Q8_0),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct QuantizedAprTransformer {
config: AprTransformerConfig,
quant_type: AprQuantizationType,
token_embedding: Vec<f32>,
layer_weights: Vec<Vec<u8>>,
output_norm_weight: Vec<f32>,
lm_head_weight: Vec<u8>,
}
include!("quantized_transformer.rs");
include!("loader_tests.rs");