use serde::Serialize;
use std::collections::BTreeMap;
pub use super::test_factory::{
build_minimal_llama_gguf, create_f32_embedding_data, create_f32_norm_weights, create_q4_0_data,
create_q8_0_data, GGUFBuilder,
};
#[derive(Debug, Clone, Serialize)]
struct SafetensorsTensorMeta {
dtype: String,
shape: Vec<usize>,
data_offsets: [usize; 2],
}
pub struct SafetensorsBuilder {
tensors: Vec<(String, String, Vec<usize>, Vec<u8>)>, }
impl Default for SafetensorsBuilder {
fn default() -> Self {
Self::new()
}
}
impl SafetensorsBuilder {
#[must_use]
pub fn new() -> Self {
Self {
tensors: Vec::new(),
}
}
#[must_use]
pub fn add_f32_tensor(mut self, name: &str, shape: &[usize], data: &[f32]) -> Self {
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
self.tensors
.push((name.to_string(), "F32".to_string(), shape.to_vec(), bytes));
self
}
#[must_use]
pub fn add_f16_tensor(mut self, name: &str, shape: &[usize], data: &[u8]) -> Self {
self.tensors.push((
name.to_string(),
"F16".to_string(),
shape.to_vec(),
data.to_vec(),
));
self
}
#[must_use]
pub fn add_bf16_tensor(mut self, name: &str, shape: &[usize], data: &[u8]) -> Self {
self.tensors.push((
name.to_string(),
"BF16".to_string(),
shape.to_vec(),
data.to_vec(),
));
self
}
#[must_use]
pub fn build(self) -> Vec<u8> {
let mut metadata: BTreeMap<String, SafetensorsTensorMeta> = BTreeMap::new();
let mut current_offset = 0usize;
for (name, dtype, shape, data) in &self.tensors {
let end_offset = current_offset + data.len();
metadata.insert(
name.clone(),
SafetensorsTensorMeta {
dtype: dtype.clone(),
shape: shape.clone(),
data_offsets: [current_offset, end_offset],
},
);
current_offset = end_offset;
}
let json = serde_json::to_string(&metadata).expect("JSON serialization");
let json_bytes = json.as_bytes();
let mut data = Vec::new();
data.extend_from_slice(&(json_bytes.len() as u64).to_le_bytes());
data.extend_from_slice(json_bytes);
for (_, _, _, tensor_data) in &self.tensors {
data.extend_from_slice(tensor_data);
}
data
}
#[must_use]
pub fn minimal_model(vocab_size: usize, hidden_dim: usize) -> Vec<u8> {
let embed_data = create_f32_embedding_data(vocab_size, hidden_dim);
let norm_data = create_f32_norm_weights(hidden_dim);
Self::new()
.add_f32_tensor(
"model.embed_tokens.weight",
&[vocab_size, hidden_dim],
&embed_data,
)
.add_f32_tensor("model.norm.weight", &[hidden_dim], &norm_data)
.build()
}
}
const APR_MAGIC: &[u8; 4] = b"APR\0";
const APR_VERSION_MAJOR: u8 = 2;
const APR_VERSION_MINOR: u8 = 0;
const APR_HEADER_SIZE: usize = 64;
const APR_ALIGNMENT: usize = 64;
pub struct AprBuilder {
metadata: BTreeMap<String, serde_json::Value>,
tensors: Vec<(String, Vec<usize>, u32, Vec<u8>)>, }
impl Default for AprBuilder {
fn default() -> Self {
Self::new()
}
}
pub const APR_DTYPE_F32: u32 = 0;
pub const APR_DTYPE_F16: u32 = 1;
pub const APR_DTYPE_Q4_0: u32 = 2;
pub const APR_DTYPE_Q8_0: u32 = 8;
impl AprBuilder {
#[must_use]
pub fn new() -> Self {
Self {
metadata: BTreeMap::new(),
tensors: Vec::new(),
}
}
#[must_use]
pub fn architecture(mut self, arch: &str) -> Self {
self.metadata
.insert("architecture".to_string(), serde_json::json!(arch));
self
}
#[must_use]
pub fn hidden_dim(mut self, dim: usize) -> Self {
self.metadata
.insert("hidden_dim".to_string(), serde_json::json!(dim));
self
}
#[must_use]
pub fn num_layers(mut self, count: usize) -> Self {
self.metadata
.insert("num_layers".to_string(), serde_json::json!(count));
self
}
#[must_use]
pub fn add_f32_tensor(mut self, name: &str, shape: &[usize], data: &[f32]) -> Self {
let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
self.tensors
.push((name.to_string(), shape.to_vec(), APR_DTYPE_F32, bytes));
self
}
#[must_use]
pub fn add_q4_0_tensor(mut self, name: &str, shape: &[usize], data: &[u8]) -> Self {
self.tensors.push((
name.to_string(),
shape.to_vec(),
APR_DTYPE_Q4_0,
data.to_vec(),
));
self
}
#[must_use]
pub fn add_q8_0_tensor(mut self, name: &str, shape: &[usize], data: &[u8]) -> Self {
self.tensors.push((
name.to_string(),
shape.to_vec(),
APR_DTYPE_Q8_0,
data.to_vec(),
));
self
}
#[must_use]
pub fn build(self) -> Vec<u8> {
let mut data = Vec::new();
let json = serde_json::to_string(&self.metadata).expect("JSON serialization");
let json_bytes = json.as_bytes();
let json_padded_len = json_bytes.len().div_ceil(APR_ALIGNMENT) * APR_ALIGNMENT;
let mut tensor_index = Vec::new();
let mut tensor_data_offset = 0u64;
for (name, shape, dtype, tensor_bytes) in &self.tensors {
let name_bytes = name.as_bytes();
tensor_index.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
tensor_index.extend_from_slice(name_bytes);
tensor_index.extend_from_slice(&(shape.len() as u32).to_le_bytes());
for dim in shape {
tensor_index.extend_from_slice(&(*dim as u64).to_le_bytes());
}
tensor_index.extend_from_slice(&dtype.to_le_bytes());
tensor_index.extend_from_slice(&tensor_data_offset.to_le_bytes());
tensor_index.extend_from_slice(&(tensor_bytes.len() as u64).to_le_bytes());
let aligned_size = tensor_bytes.len().div_ceil(APR_ALIGNMENT) * APR_ALIGNMENT;
tensor_data_offset += aligned_size as u64;
}
let index_padded_len = tensor_index.len().div_ceil(APR_ALIGNMENT) * APR_ALIGNMENT;
let metadata_offset = APR_HEADER_SIZE as u64;
let tensor_index_offset = metadata_offset + json_padded_len as u64;
let data_offset = tensor_index_offset + index_padded_len as u64;
data.extend_from_slice(APR_MAGIC);
data.push(APR_VERSION_MAJOR);
data.push(APR_VERSION_MINOR);
data.extend_from_slice(&0u16.to_le_bytes()); data.extend_from_slice(&(self.tensors.len() as u32).to_le_bytes());
data.extend_from_slice(&metadata_offset.to_le_bytes());
data.extend_from_slice(&(json_bytes.len() as u32).to_le_bytes());
data.extend_from_slice(&tensor_index_offset.to_le_bytes());
data.extend_from_slice(&data_offset.to_le_bytes());
data.extend_from_slice(&0u32.to_le_bytes()); data.extend([0u8; 20]);
assert_eq!(data.len(), APR_HEADER_SIZE);
data.extend_from_slice(json_bytes);
data.resize(APR_HEADER_SIZE + json_padded_len, 0);
data.extend_from_slice(&tensor_index);
data.resize(APR_HEADER_SIZE + json_padded_len + index_padded_len, 0);
for (_, _, _, tensor_bytes) in &self.tensors {
let start = data.len();
data.extend_from_slice(tensor_bytes);
let aligned_end = start + tensor_bytes.len().div_ceil(APR_ALIGNMENT) * APR_ALIGNMENT;
data.resize(aligned_end, 0);
}
data
}
#[must_use]
pub fn minimal_model(vocab_size: usize, hidden_dim: usize) -> Vec<u8> {
let embed_data = create_f32_embedding_data(vocab_size, hidden_dim);
let norm_data = create_f32_norm_weights(hidden_dim);
Self::new()
.architecture("llama")
.hidden_dim(hidden_dim)
.num_layers(1)
.add_f32_tensor("token_embd.weight", &[vocab_size, hidden_dim], &embed_data)
.add_f32_tensor("output_norm.weight", &[hidden_dim], &norm_data)
.build()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FormatType {
Gguf,
SafeTensors,
Apr,
Unknown,
}
impl FormatType {
#[must_use]
pub fn from_magic(data: &[u8]) -> Self {
if data.len() < 8 {
return Self::Unknown;
}
if data.get(0..4).expect("len >= 8 checked above") == b"GGUF" {
return Self::Gguf;
}
let magic4 = data.get(0..4).expect("len >= 8 checked above");
if magic4 == b"APR\0" || magic4 == b"APR2" {
return Self::Apr;
}
if data.len() >= 10 {
let header_len = u64::from_le_bytes(
data.get(0..8)
.expect("len >= 10 checked above")
.try_into()
.unwrap_or([0; 8]),
);
if header_len < 100_000_000
&& data.get(8..10).expect("len >= 10 checked above") == b"{\""
{
return Self::SafeTensors;
}
}
Self::Unknown
}
}
include!("format_factory_safetensors_builder.rs");