use std::collections::HashMap;
use std::fs::{self, File};
use std::path::Path;
use serde::{Deserialize, Serialize};
use crate::error::{RealizarError, Result};
use crate::safetensors::find_sibling_file;
use crate::tokenizer::SentencePieceTokenizer;
#[cfg(feature = "cuda")]
mod cuda;
#[cfg(all(test, feature = "cuda"))]
mod cuda_tests;
pub mod dequant;
mod helpers;
mod model_data;
pub(crate) mod tokenizer;
pub use model_data::ModelData;
pub(crate) use dequant::{dequantize_f16, dequantize_q4_k, dequantize_q6_k, dequantize_q8_0};
pub use dequant::{dtype_to_ggml_qtype, f16_to_f32, is_quantized_dtype};
#[cfg(test)]
pub(crate) mod test_factory;
#[cfg(feature = "cuda")]
pub use cuda::AprV2ModelCuda;
#[cfg(feature = "cuda")]
use helpers::transpose_matrix;
use helpers::{apply_rope_norm, matmul, rms_norm, simple_attention};
pub use helpers::{detect_format, is_apr_file, simd_dot};
use tokenizer::bpe_encode;
pub use tokenizer::{byte_to_bpe_char, BpeTokenizer, SimpleTokenizer};
pub const MAGIC_PREFIX: [u8; 3] = [0x41, 0x50, 0x52];
pub const MAGIC: [u8; 4] = [0x41, 0x50, 0x52, 0x00];
pub const HEADER_SIZE: usize = 64;
pub const ALIGNMENT: usize = 64;
#[derive(Debug, Clone, Copy, Default)]
pub struct AprFlags(u16);
impl AprFlags {
pub const LZ4_COMPRESSED: u16 = 0x0001;
pub const ZSTD_COMPRESSED: u16 = 0x0002;
pub const ENCRYPTED: u16 = 0x0004;
pub const SIGNED: u16 = 0x0008;
pub const SHARDED: u16 = 0x0010;
pub const QUANTIZED: u16 = 0x0020;
pub const HAS_VOCAB: u16 = 0x0200;
#[must_use]
pub const fn new(bits: u16) -> Self {
Self(bits)
}
#[must_use]
pub const fn is_compressed(&self) -> bool {
self.0 & (Self::LZ4_COMPRESSED | Self::ZSTD_COMPRESSED) != 0
}
#[must_use]
pub const fn is_lz4(&self) -> bool {
self.0 & Self::LZ4_COMPRESSED != 0
}
#[must_use]
pub const fn is_zstd(&self) -> bool {
self.0 & Self::ZSTD_COMPRESSED != 0
}
#[must_use]
pub const fn is_encrypted(&self) -> bool {
self.0 & Self::ENCRYPTED != 0
}
#[must_use]
pub const fn is_quantized(&self) -> bool {
self.0 & Self::QUANTIZED != 0
}
#[must_use]
pub const fn has_vocab(&self) -> bool {
self.0 & Self::HAS_VOCAB != 0
}
}
#[derive(Debug, Clone)]
pub struct AprHeader {
pub magic: [u8; 4],
pub version: (u8, u8),
pub flags: AprFlags,
pub tensor_count: u32,
pub metadata_offset: u64,
pub metadata_size: u32,
pub tensor_index_offset: u64,
pub data_offset: u64,
pub checksum: u32,
}
impl AprHeader {
pub fn from_bytes(data: &[u8]) -> Result<Self> {
if data.len() < HEADER_SIZE {
return Err(RealizarError::FormatError {
reason: format!(
".apr header too small: {} bytes (need {})",
data.len(),
HEADER_SIZE
),
});
}
let magic: [u8; 4] = data[0..4]
.try_into()
.map_err(|_| RealizarError::FormatError {
reason: "Failed to read magic bytes".to_string(),
})?;
if magic.get(0..3).expect("magic is 4 bytes") != MAGIC_PREFIX {
return Err(RealizarError::FormatError {
reason: format!(
"Invalid .apr magic: expected APR {:?}, got {:?}",
MAGIC_PREFIX,
magic.get(0..3).expect("magic is 4 bytes"),
),
});
}
let version_byte = magic[3];
if version_byte != 0 && version_byte != b'1' && version_byte != b'2' {
return Err(RealizarError::FormatError {
reason: format!(
"Invalid .apr version byte: expected 0, '1', or '2', got {}",
version_byte
),
});
}
if version_byte == b'1' {
return Err(RealizarError::UnsupportedOperation {
operation: "load_apr_v1".to_string(),
reason: "APR v1 format not supported for inference. \
Use 'apr convert model.apr -o model_v2.apr --format apr2' \
to convert to APR v2 format, or use the GGUF version."
.to_string(),
});
}
let version = (data[4], data[5]);
let flags = AprFlags::new(u16::from_le_bytes([data[6], data[7]]));
let tensor_count = u32::from_le_bytes([data[8], data[9], data[10], data[11]]);
let metadata_offset = u64::from_le_bytes([
data[12], data[13], data[14], data[15], data[16], data[17], data[18], data[19],
]);
let metadata_size = u32::from_le_bytes([data[20], data[21], data[22], data[23]]);
let tensor_index_offset = u64::from_le_bytes([
data[24], data[25], data[26], data[27], data[28], data[29], data[30], data[31],
]);
let data_offset = u64::from_le_bytes([
data[32], data[33], data[34], data[35], data[36], data[37], data[38], data[39],
]);
let checksum = u32::from_le_bytes([data[40], data[41], data[42], data[43]]);
Ok(Self {
magic,
version,
flags,
tensor_count,
metadata_offset,
metadata_size,
tensor_index_offset,
data_offset,
checksum,
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorEntry {
pub name: String,
pub dtype: String,
pub shape: Vec<usize>,
pub offset: u64,
pub size: u64,
}
impl TensorEntry {
pub fn from_binary(data: &[u8]) -> Result<(Self, usize)> {
if data.len() < 4 {
return Err(RealizarError::FormatError {
reason: "Tensor entry too short".to_string(),
});
}
let mut pos = 0;
let name_len = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2;
if data.len() < pos + name_len + 2 {
return Err(RealizarError::FormatError {
reason: "Tensor entry truncated at name".to_string(),
});
}
let name = String::from_utf8_lossy(&data[pos..pos + name_len]).to_string();
pos += name_len;
let dtype_byte = data[pos];
pos += 1;
let dtype = match dtype_byte {
128 => "q4".to_string(), 129 => "q8".to_string(), 8 => "q4".to_string(), 9 => "q8".to_string(), _ => crate::gguf::GgmlQuantType::from_id(u32::from(dtype_byte))
.map_or_else(
|| {
eprintln!("WARN: Unknown APR dtype byte {dtype_byte}, treating as F32");
"F32"
},
crate::gguf::GgmlQuantType::as_str,
)
.to_string(),
};
let ndim = data[pos] as usize;
pos += 1;
if data.len() < pos + ndim * 8 + 16 {
return Err(RealizarError::FormatError {
reason: "Tensor entry truncated at shape".to_string(),
});
}
let mut shape = Vec::with_capacity(ndim);
for _ in 0..ndim {
let dim = u64::from_le_bytes([
data[pos],
data[pos + 1],
data[pos + 2],
data[pos + 3],
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]) as usize;
pos += 8;
shape.push(dim);
}
let offset = u64::from_le_bytes([
data[pos],
data[pos + 1],
data[pos + 2],
data[pos + 3],
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]);
pos += 8;
let size = u64::from_le_bytes([
data[pos],
data[pos + 1],
data[pos + 2],
data[pos + 3],
data[pos + 4],
data[pos + 5],
data[pos + 6],
data[pos + 7],
]);
pos += 8;
Ok((
Self {
name,
dtype,
shape,
offset,
size,
},
pos,
))
}
pub fn element_count(&self) -> usize {
self.shape.iter().product()
}
}
include!("metadata.rs");
include!("tokenizer_loading.rs");
include!("special_tokens.rs");