use crate::error::{AprenderError, Result};
use crate::format::gguf::reader::GgufReader;
use crate::format::rosetta::FormatType;
use crate::format::v2::{AprV2Reader, AprV2ReaderRef, TensorIndexEntry};
use crate::format::HEADER_SIZE;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Read};
use std::path::Path;
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub name: String,
pub shape: Vec<usize>,
pub dtype: String,
pub size_bytes: usize,
pub mean: Option<f32>,
pub std: Option<f32>,
pub min: Option<f32>,
pub max: Option<f32>,
pub nan_count: Option<usize>,
pub inf_count: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct TensorListResult {
pub file: String,
pub format_version: String,
pub tensor_count: usize,
pub total_size_bytes: usize,
pub tensors: Vec<TensorInfo>,
}
#[derive(Debug, Clone)]
pub struct TensorListOptions {
pub compute_stats: bool,
pub filter: Option<String>,
pub limit: usize,
}
impl Default for TensorListOptions {
fn default() -> Self {
Self {
compute_stats: false,
filter: None,
limit: usize::MAX,
}
}
}
impl TensorListOptions {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_stats(mut self) -> Self {
self.compute_stats = true;
self
}
#[must_use]
pub fn with_filter(mut self, pattern: impl Into<String>) -> Self {
self.filter = Some(pattern.into());
self
}
pub fn matches_filter(&self, name: &str) -> bool {
match &self.filter {
None => true,
Some(pattern) => {
if pattern.contains('*') || pattern.contains('?') {
glob_match(pattern, name)
} else {
name.contains(pattern.as_str())
}
}
}
}
#[must_use]
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
}
fn glob_match(pattern: &str, text: &str) -> bool {
let p = pattern.as_bytes();
let t = text.as_bytes();
let (mut pi, mut ti) = (0, 0);
let (mut star_pi, mut star_ti) = (usize::MAX, 0);
while ti < t.len() {
if pi < p.len() && (p[pi] == b'?' || p[pi] == t[ti]) {
pi += 1;
ti += 1;
} else if pi < p.len() && p[pi] == b'*' {
star_pi = pi;
star_ti = ti;
pi += 1;
} else if star_pi != usize::MAX {
pi = star_pi + 1;
star_ti += 1;
ti = star_ti;
} else {
return false;
}
}
while pi < p.len() && p[pi] == b'*' {
pi += 1;
}
pi == p.len()
}
const MAGIC_APRN: [u8; 4] = [0x41, 0x50, 0x52, 0x4E]; const MAGIC_APR1: [u8; 4] = [0x41, 0x50, 0x52, 0x31]; const MAGIC_APR2: [u8; 4] = [0x41, 0x50, 0x52, 0x32]; const MAGIC_APR0: [u8; 4] = [0x41, 0x50, 0x52, 0x00];
fn detect_format(magic: &[u8; 4]) -> Option<&'static str> {
match *magic {
MAGIC_APRN => Some("v1"),
MAGIC_APR1 => Some("v1"),
MAGIC_APR2 => Some("v2"),
MAGIC_APR0 => Some("v2"),
_ => None,
}
}
#[must_use]
pub fn is_valid_apr_magic(magic: &[u8; 4]) -> bool {
detect_format(magic).is_some()
}
pub fn list_tensors_from_bytes(
data: &[u8],
options: TensorListOptions,
) -> Result<TensorListResult> {
if data.len() < 4 {
return Err(AprenderError::FormatError {
message: "File too small to contain model header".to_string(),
});
}
if data.get(0..4) == Some(b"GGUF") {
return list_tensors_gguf(data, options);
}
if data.len() >= 10 {
let header_len = u64::from_le_bytes(
data.get(0..8)
.and_then(|s| s.try_into().ok())
.unwrap_or([0u8; 8]),
);
if header_len < 100_000_000 && data.get(8..10) == Some(b"{\"") {
return list_tensors_safetensors(data, options);
}
}
let magic: [u8; 4] = data[0..4]
.try_into()
.map_err(|_| AprenderError::FormatError {
message: "Failed to read magic bytes".to_string(),
})?;
let format_version = detect_format(&magic).ok_or_else(|| AprenderError::FormatError {
message: format!(
"Unknown model format: magic bytes {:02x}{:02x}{:02x}{:02x}. \
Supported formats: APR (.apr), GGUF (.gguf), SafeTensors (.safetensors)",
magic[0], magic[1], magic[2], magic[3]
),
})?;
match format_version {
"v2" => list_tensors_v2(data, options),
"v1" => list_tensors_v1(data, options),
_ => Err(AprenderError::FormatError {
message: format!("Unsupported format version: {format_version}"),
}),
}
}
fn build_v2_tensor_info(
reader: &AprV2Reader,
name: &str,
entry: &TensorIndexEntry,
compute_stats: bool,
) -> TensorInfo {
let mut info = tensor_info_from_entry(entry);
if compute_stats {
if let Some(data) = reader.get_tensor_as_f32(name) {
compute_tensor_stats(&mut info, &data);
}
}
info
}
fn list_tensors_v2(data: &[u8], options: TensorListOptions) -> Result<TensorListResult> {
let reader = AprV2Reader::from_bytes(data).map_err(|e| AprenderError::FormatError {
message: format!("Failed to parse APR v2: {e}"),
})?;
let mut tensors = Vec::new();
let mut total_size = 0usize;
let mut total_matching = 0usize;
for name in reader.tensor_names() {
if !options.matches_filter(name) {
continue;
}
if let Some(entry) = reader.get_tensor(name) {
total_size += entry.size as usize;
total_matching += 1;
if tensors.len() < options.limit {
tensors.push(build_v2_tensor_info(
&reader,
name,
entry,
options.compute_stats,
));
}
}
}
Ok(TensorListResult {
file: String::new(), format_version: "v2".to_string(),
tensor_count: total_matching,
total_size_bytes: total_size,
tensors,
})
}
fn list_tensors_v2_mmap(data: &[u8], options: TensorListOptions) -> Result<TensorListResult> {
let reader = AprV2ReaderRef::from_bytes(data).map_err(|e| AprenderError::FormatError {
message: format!("Failed to parse APR v2: {e}"),
})?;
let mut tensors = Vec::new();
let mut total_size = 0usize;
let mut total_matching = 0usize;
for name in reader.tensor_names() {
if !options.matches_filter(name) {
continue;
}
if let Some(entry) = reader.get_tensor(name) {
total_size += entry.size as usize;
total_matching += 1;
if tensors.len() < options.limit {
let mut info = tensor_info_from_entry(entry);
if options.compute_stats {
if let Some(data) = reader.get_tensor_as_f32(name) {
compute_tensor_stats(&mut info, &data);
}
}
tensors.push(info);
}
}
}
Ok(TensorListResult {
file: String::new(),
format_version: "v2".to_string(),
tensor_count: total_matching,
total_size_bytes: total_size,
tensors,
})
}
fn parse_shape_array(shape_val: &serde_json::Value) -> Vec<usize> {
shape_val.as_array().map_or(Vec::new(), |arr| {
arr.iter()
.filter_map(|v| v.as_u64().map(|n| n as usize))
.collect()
})
}
fn extract_tensors_from_metadata_with_counts(
metadata: &HashMap<String, serde_json::Value>,
options: &TensorListOptions,
) -> (Vec<TensorInfo>, usize, usize) {
let Some(shapes) = metadata.get("tensor_shapes").and_then(|s| s.as_object()) else {
return (Vec::new(), 0, 0);
};
let mut tensors = Vec::new();
let mut total_matching = 0usize;
let mut total_size = 0usize;
for (name, shape_val) in shapes {
if !options.matches_filter(name) {
continue;
}
let shape = parse_shape_array(shape_val);
let size_bytes = shape.iter().product::<usize>() * 4;
total_size += size_bytes;
total_matching += 1;
if tensors.len() < options.limit {
tensors.push(TensorInfo {
name: name.clone(),
shape,
dtype: "f32".to_string(),
size_bytes,
mean: None,
std: None,
min: None,
max: None,
nan_count: None,
inf_count: None,
});
}
}
(tensors, total_matching, total_size)
}
fn list_tensors_v1(data: &[u8], options: TensorListOptions) -> Result<TensorListResult> {
if data.len() < HEADER_SIZE {
return Err(AprenderError::FormatError {
message: "APR v1 file too small for header".to_string(),
});
}
let metadata_size = u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize;
if data.len() < HEADER_SIZE + metadata_size {
return Err(AprenderError::FormatError {
message: "APR v1 file too small for metadata".to_string(),
});
}
let metadata_bytes = &data[HEADER_SIZE..HEADER_SIZE + metadata_size];
let metadata: HashMap<String, serde_json::Value> = serde_json::from_slice(metadata_bytes)
.or_else(|_| rmp_serde::from_slice(metadata_bytes))
.unwrap_or_default();
let (tensors, total_matching, total_size) =
extract_tensors_from_metadata_with_counts(&metadata, &options);
Ok(TensorListResult {
file: String::new(),
format_version: "v1".to_string(),
tensor_count: total_matching,
total_size_bytes: total_size,
tensors,
})
}
pub(crate) fn ggml_dtype_name(dtype: u32) -> &'static str {
const NAMES: [&str; 31] = [
"F32", "F16", "Q4_0", "Q4_1", "unknown", "unknown", "Q5_0", "Q5_1", "Q8_0", "Q8_1", "Q2_K",
"Q3_K", "Q4_K", "Q5_K", "Q6_K", "Q8_K", "IQ2_XXS", "IQ2_XS", "IQ3_XXS", "IQ1_S", "IQ4_NL",
"IQ3_S", "IQ2_S", "IQ4_XS", "I8", "I16", "BF16", "I32", "I64", "F64", "IQ1_M",
];
NAMES.get(dtype as usize).copied().unwrap_or("unknown")
}
include!("safetensors.rs");
include!("tensors_safetensors.rs");