use crate::format::f16_safety::F16_MIN_NORMAL;
use crate::format::gguf::dequant::{dequantize_q4_k, dequantize_q6_k};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::{Read, Write};
fn crc32(data: &[u8]) -> u32 {
const TABLE: [u32; 256] = {
let mut table = [0u32; 256];
let mut i = 0;
while i < 256 {
let mut crc = i as u32;
let mut j = 0;
while j < 8 {
if crc & 1 != 0 {
crc = (crc >> 1) ^ 0xEDB8_8320;
} else {
crc >>= 1;
}
j += 1;
}
table[i] = crc;
i += 1;
}
table
};
let mut crc = 0xFFFF_FFFF_u32;
for &byte in data {
let idx = ((crc ^ u32::from(byte)) & 0xFF) as usize;
crc = (crc >> 8) ^ TABLE[idx];
}
!crc
}
fn f32_to_f16(value: f32) -> u16 {
trueno::f32_to_f16(value)
}
fn f16_to_f32(bits: u16) -> f32 {
trueno::f16_to_f32(bits)
}
fn dequantize_q4(data: &[u8], element_count: usize) -> Vec<f32> {
const BLOCK_SIZE: usize = 32;
let mut result = Vec::with_capacity(element_count);
let mut pos = 0;
let mut remaining = element_count;
while remaining > 0 && pos + 2 <= data.len() {
let scale_bits = u16::from_le_bytes([data[pos], data[pos + 1]]);
let scale_raw = f16_to_f32(scale_bits);
let scale =
if scale_raw.is_nan() || scale_raw.is_infinite() || scale_raw.abs() < F16_MIN_NORMAL {
0.0
} else {
scale_raw
};
pos += 2;
let values_in_block = remaining.min(BLOCK_SIZE);
for i in 0..values_in_block {
let byte_idx = pos + i / 2;
if byte_idx >= data.len() {
break;
}
let byte = data[byte_idx];
let nibble = if i % 2 == 0 { byte & 0x0F } else { byte >> 4 };
let q = (nibble as i8) - 8;
result.push(f32::from(q) * scale);
}
pos += 16;
remaining = remaining.saturating_sub(BLOCK_SIZE);
}
result.resize(element_count, 0.0);
result
}
pub const MAGIC_V2: [u8; 4] = [0x41, 0x50, 0x52, 0x00];
pub const VERSION_V2: (u8, u8) = (2, 0);
pub const HEADER_SIZE_V2: usize = 64;
pub const ALIGNMENT: usize = 64;
pub const LZ4_BLOCK_SIZE: usize = 64 * 1024;
pub const MAX_METADATA_SIZE: usize = 16 * 1024 * 1024;
pub const MAX_TENSOR_NAME_LEN: usize = 256;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct AprV2Flags(u16);
impl AprV2Flags {
pub const LZ4_COMPRESSED: u16 = 0b0000_0000_0000_0001;
pub const ZSTD_COMPRESSED: u16 = 0b0000_0000_0000_0010;
pub const ENCRYPTED: u16 = 0b0000_0000_0000_0100;
pub const SIGNED: u16 = 0b0000_0000_0000_1000;
pub const SHARDED: u16 = 0b0000_0000_0001_0000;
pub const QUANTIZED: u16 = 0b0000_0000_0010_0000;
pub const HAS_FILTERBANK: u16 = 0b0000_0000_0100_0000;
pub const HAS_MODEL_CARD: u16 = 0b0000_0000_1000_0000;
pub const STREAMING: u16 = 0b0000_0001_0000_0000;
pub const HAS_VOCAB: u16 = 0b0000_0010_0000_0000;
pub const LAYOUT_ROW_MAJOR: u16 = 0b0000_0100_0000_0000;
pub const LAYOUT_COLUMN_MAJOR: u16 = 0b0000_1000_0000_0000;
#[must_use]
pub const fn new() -> Self {
Self(0)
}
#[must_use]
pub const fn from_bits(bits: u16) -> Self {
Self(bits)
}
#[must_use]
pub const fn bits(self) -> u16 {
self.0
}
#[must_use]
pub const fn contains(self, flag: u16) -> bool {
(self.0 & flag) == flag
}
#[must_use]
pub const fn with(self, flag: u16) -> Self {
Self(self.0 | flag)
}
#[must_use]
pub const fn without(self, flag: u16) -> Self {
Self(self.0 & !flag)
}
#[must_use]
pub const fn is_lz4_compressed(self) -> bool {
self.contains(Self::LZ4_COMPRESSED)
}
#[must_use]
pub const fn is_zstd_compressed(self) -> bool {
self.contains(Self::ZSTD_COMPRESSED)
}
#[must_use]
pub const fn is_encrypted(self) -> bool {
self.contains(Self::ENCRYPTED)
}
#[must_use]
pub const fn is_sharded(self) -> bool {
self.contains(Self::SHARDED)
}
#[must_use]
pub const fn is_quantized(self) -> bool {
self.contains(Self::QUANTIZED)
}
#[must_use]
pub const fn is_row_major(self) -> bool {
self.contains(Self::LAYOUT_ROW_MAJOR)
}
#[must_use]
pub const fn is_column_major(self) -> bool {
self.contains(Self::LAYOUT_COLUMN_MAJOR)
}
#[must_use]
pub const fn is_layout_valid(self) -> bool {
!self.is_column_major()
}
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct AprV2Header {
pub magic: [u8; 4],
pub version: (u8, u8),
pub flags: AprV2Flags,
pub tensor_count: u32,
pub metadata_offset: u64,
pub metadata_size: u32,
pub tensor_index_offset: u64,
pub data_offset: u64,
pub checksum: u32,
pub reserved: [u8; 20],
}
impl Default for AprV2Header {
fn default() -> Self {
Self::new()
}
}
include!("header_impl.rs");
include!("tensor_index_impl.rs");
include!("writer.rs");
include!("streaming_writer.rs");
include!("reader_impl.rs");
include!("v2format_error.rs");