use std::collections::HashMap;
use std::fmt;
use crate::error::AnamnesisError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum Dtype {
F8E4M3,
F8E5M2,
BF16,
F16,
F32,
F64,
Bool,
U8,
I8,
U16,
I16,
U32,
I32,
U64,
I64,
}
impl Dtype {
#[must_use]
pub const fn byte_size(self) -> usize {
match self {
Self::Bool | Self::U8 | Self::I8 | Self::F8E4M3 | Self::F8E5M2 => 1,
Self::U16 | Self::I16 | Self::F16 | Self::BF16 => 2,
Self::U32 | Self::I32 | Self::F32 => 4,
Self::U64 | Self::I64 | Self::F64 => 8,
}
}
#[must_use]
pub const fn is_quantized(self) -> bool {
matches!(self, Self::F8E4M3 | Self::F8E5M2)
}
#[must_use]
pub const fn is_floating_point(self) -> bool {
matches!(
self,
Self::F8E4M3 | Self::F8E5M2 | Self::BF16 | Self::F16 | Self::F32 | Self::F64
)
}
pub fn to_safetensors_dtype(self) -> crate::Result<safetensors::Dtype> {
match self {
Self::F8E4M3 => Ok(safetensors::Dtype::F8_E4M3),
Self::F8E5M2 => Ok(safetensors::Dtype::F8_E5M2),
Self::BF16 => Ok(safetensors::Dtype::BF16),
Self::F16 => Ok(safetensors::Dtype::F16),
Self::F32 => Ok(safetensors::Dtype::F32),
Self::F64 => Ok(safetensors::Dtype::F64),
Self::Bool => Ok(safetensors::Dtype::BOOL),
Self::U8 => Ok(safetensors::Dtype::U8),
Self::I8 => Ok(safetensors::Dtype::I8),
Self::U16 => Ok(safetensors::Dtype::U16),
Self::I16 => Ok(safetensors::Dtype::I16),
Self::U32 => Ok(safetensors::Dtype::U32),
Self::I32 => Ok(safetensors::Dtype::I32),
Self::U64 => Ok(safetensors::Dtype::U64),
Self::I64 => Ok(safetensors::Dtype::I64),
}
}
}
impl fmt::Display for Dtype {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::F8E4M3 => "F8_E4M3",
Self::F8E5M2 => "F8_E5M2",
Self::BF16 => "BF16",
Self::F16 => "F16",
Self::F32 => "F32",
Self::F64 => "F64",
Self::Bool => "BOOL",
Self::U8 => "U8",
Self::I8 => "I8",
Self::U16 => "U16",
Self::I16 => "I16",
Self::U32 => "U32",
Self::I32 => "I32",
Self::U64 => "U64",
Self::I64 => "I64",
};
f.write_str(s)
}
}
impl TryFrom<safetensors::Dtype> for Dtype {
type Error = AnamnesisError;
fn try_from(st: safetensors::Dtype) -> std::result::Result<Self, Self::Error> {
match st {
safetensors::Dtype::F8_E4M3 => Ok(Self::F8E4M3),
safetensors::Dtype::F8_E5M2 => Ok(Self::F8E5M2),
safetensors::Dtype::BF16 => Ok(Self::BF16),
safetensors::Dtype::F16 => Ok(Self::F16),
safetensors::Dtype::F32 => Ok(Self::F32),
safetensors::Dtype::F64 => Ok(Self::F64),
safetensors::Dtype::BOOL => Ok(Self::Bool),
safetensors::Dtype::U8 => Ok(Self::U8),
safetensors::Dtype::I8 => Ok(Self::I8),
safetensors::Dtype::U16 => Ok(Self::U16),
safetensors::Dtype::I16 => Ok(Self::I16),
safetensors::Dtype::U32 => Ok(Self::U32),
safetensors::Dtype::I32 => Ok(Self::I32),
safetensors::Dtype::U64 => Ok(Self::U64),
safetensors::Dtype::I64 => Ok(Self::I64),
unknown => Err(AnamnesisError::Unsupported {
format: "safetensors".into(),
detail: format!("unknown dtype {unknown:?}"),
}),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum TensorRole {
Quantized,
Scale,
Passthrough,
ZeroPoint,
GroupIndex,
QuantMap,
NestedScale,
QuantState,
}
fn classify_tensor(name: &str, dtype: Dtype) -> TensorRole {
if name.ends_with("_scale_inv") || name.ends_with("_scale") {
return TensorRole::Scale;
}
#[cfg(any(feature = "gptq", feature = "awq"))]
{
if name.ends_with(".qweight") {
return TensorRole::Quantized;
}
if name.ends_with(".qzeros") {
return TensorRole::ZeroPoint;
}
if name.ends_with(".scales") {
return TensorRole::Scale;
}
}
#[cfg(feature = "gptq")]
if name.ends_with(".g_idx") {
return TensorRole::GroupIndex;
}
#[cfg(feature = "bnb")]
{
if name.contains(".quant_state.bitsandbytes__") {
return TensorRole::QuantState;
}
if name.ends_with(".weight.nested_quant_map") || name.ends_with(".weight.quant_map") {
return TensorRole::QuantMap;
}
if name.ends_with(".weight.nested_absmax") {
return TensorRole::NestedScale;
}
if name.ends_with(".weight.absmax") {
return TensorRole::Scale;
}
#[allow(clippy::case_sensitive_file_extension_comparisons)]
if name.ends_with(".SCB") {
return TensorRole::Scale;
}
if dtype == Dtype::U8 && name.ends_with(".weight") {
return TensorRole::Quantized;
}
if dtype == Dtype::I8 && name.ends_with(".weight") {
return TensorRole::Quantized;
}
}
if dtype.is_quantized() {
TensorRole::Quantized
} else {
TensorRole::Passthrough
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum QuantScheme {
FineGrainedFp8,
PerChannelFp8,
PerTensorFp8,
Unquantized,
Gptq,
Awq,
Bnb4,
BnbInt8,
}
impl fmt::Display for QuantScheme {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::FineGrainedFp8 => "Fine-grained FP8 (E4M3), 128x128 blocks",
Self::PerChannelFp8 => "Per-channel FP8 (E4M3), one scale per row",
Self::PerTensorFp8 => "Per-tensor FP8 (E4M3)",
Self::Unquantized => "Unquantized",
Self::Gptq => "GPTQ",
Self::Awq => "AWQ",
Self::Bnb4 => "BitsAndBytes NF4/FP4 (4-bit, per-block absmax)",
Self::BnbInt8 => "BitsAndBytes INT8 (LLM.int8(), per-row absmax)",
};
f.write_str(s)
}
}
fn detect_scheme(entries: &[TensorEntry]) -> QuantScheme {
let has_quantized = entries.iter().any(|e| e.role == TensorRole::Quantized);
if !has_quantized {
return QuantScheme::Unquantized;
}
for entry in entries
.iter()
.filter(|e| e.role == TensorRole::Quantized && e.name.ends_with(".qweight"))
{
let base = entry.name.strip_suffix(".qweight");
if let Some(base) = base {
let scales_name = format!("{base}.scales");
if let Some(scales) = entries.iter().find(|e| e.name == scales_name) {
let qw_cols = entry.shape.last().copied().unwrap_or(0);
let sc_cols = scales.shape.last().copied().unwrap_or(0);
if qw_cols > 0 && sc_cols > 0 && qw_cols == sc_cols {
return QuantScheme::Gptq;
} else if qw_cols > 0 && sc_cols > 0 && qw_cols < sc_cols {
return QuantScheme::Awq;
}
}
}
}
#[cfg(feature = "bnb")]
{
let has_quant_map = entries.iter().any(|e| e.role == TensorRole::QuantMap);
if has_quant_map {
return QuantScheme::Bnb4;
}
let has_scb = entries.iter().any(|e| {
#[allow(clippy::case_sensitive_file_extension_comparisons)]
let is_scb = e.name.ends_with(".SCB");
e.role == TensorRole::Scale && is_scb
});
if has_scb {
return QuantScheme::BnbInt8;
}
}
for entry in entries.iter().filter(|e| e.role == TensorRole::Quantized) {
for suffix in &["_scale_inv", "_scale"] {
let expected = format!("{}{suffix}", entry.name);
if let Some(scale) = entries
.iter()
.find(|s| s.name == expected && s.role == TensorRole::Scale)
{
if scale.shape.len() >= 2 {
if scale.shape.last().copied() > Some(1) {
return QuantScheme::FineGrainedFp8;
}
return QuantScheme::PerChannelFp8;
}
return QuantScheme::PerTensorFp8;
}
}
}
QuantScheme::PerTensorFp8
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct GptqConfig {
pub bits: u8,
pub group_size: usize,
}
#[derive(Debug, Clone)]
pub struct GptqCompanions<'a> {
pub scales: &'a TensorEntry,
pub qzeros: &'a TensorEntry,
pub g_idx: Option<&'a TensorEntry>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct AwqConfig {
pub bits: u8,
pub group_size: usize,
}
#[derive(Debug, Clone)]
pub struct AwqCompanions<'a> {
pub scales: &'a TensorEntry,
pub qzeros: &'a TensorEntry,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct BnbConfig {
pub block_size: usize,
pub double_quant: bool,
}
#[derive(Debug, Clone)]
pub struct Bnb4Companions<'a> {
pub absmax: &'a TensorEntry,
pub quant_map: &'a TensorEntry,
pub nested_absmax: Option<&'a TensorEntry>,
pub nested_quant_map: Option<&'a TensorEntry>,
pub quant_state: Option<&'a TensorEntry>,
}
#[derive(Debug, Clone)]
pub struct TensorEntry {
pub name: String,
pub dtype: Dtype,
pub shape: Vec<usize>,
pub data_offsets: (usize, usize),
pub role: TensorRole,
}
impl TensorEntry {
#[must_use]
pub fn num_elements(&self) -> usize {
self.shape
.iter()
.try_fold(1usize, |acc, &d| acc.checked_mul(d))
.unwrap_or(usize::MAX)
}
#[must_use]
pub fn byte_len(&self) -> usize {
self.data_offsets.1.saturating_sub(self.data_offsets.0)
}
}
#[derive(Debug, Clone)]
pub struct SafetensorsHeader {
pub tensors: Vec<TensorEntry>,
pub scheme: QuantScheme,
pub metadata: Option<HashMap<String, String>>,
pub header_size: usize,
pub gptq_config: Option<GptqConfig>,
pub awq_config: Option<AwqConfig>,
pub bnb_config: Option<BnbConfig>,
}
impl SafetensorsHeader {
pub fn quantized_tensors(&self) -> impl Iterator<Item = &TensorEntry> {
self.tensors
.iter()
.filter(|e| e.role == TensorRole::Quantized)
}
pub fn scale_tensors(&self) -> impl Iterator<Item = &TensorEntry> {
self.tensors.iter().filter(|e| e.role == TensorRole::Scale)
}
pub fn passthrough_tensors(&self) -> impl Iterator<Item = &TensorEntry> {
self.tensors
.iter()
.filter(|e| e.role == TensorRole::Passthrough)
}
#[must_use]
pub fn quantized_count(&self) -> usize {
self.quantized_tensors().count()
}
#[must_use]
pub fn scale_count(&self) -> usize {
self.scale_tensors().count()
}
#[must_use]
pub fn passthrough_count(&self) -> usize {
self.passthrough_tensors().count()
}
#[must_use]
pub fn find_scale_for(&self, weight_name: &str) -> Option<&TensorEntry> {
let scale_inv = format!("{weight_name}_scale_inv");
let scale = format!("{weight_name}_scale");
self.tensors
.iter()
.find(|e| e.name == scale_inv)
.or_else(|| self.tensors.iter().find(|e| e.name == scale))
}
pub fn zeropoint_tensors(&self) -> impl Iterator<Item = &TensorEntry> {
self.tensors
.iter()
.filter(|e| e.role == TensorRole::ZeroPoint)
}
#[must_use]
pub fn zeropoint_count(&self) -> usize {
self.zeropoint_tensors().count()
}
pub fn group_index_tensors(&self) -> impl Iterator<Item = &TensorEntry> {
self.tensors
.iter()
.filter(|e| e.role == TensorRole::GroupIndex)
}
#[must_use]
pub fn group_index_count(&self) -> usize {
self.group_index_tensors().count()
}
#[must_use]
pub fn find_gptq_companions(&self, qweight_name: &str) -> Option<GptqCompanions<'_>> {
let base = qweight_name.strip_suffix(".qweight")?;
let scales_name = format!("{base}.scales");
let qzeros_name = format!("{base}.qzeros");
let g_idx_name = format!("{base}.g_idx");
let scales = self.tensors.iter().find(|e| e.name == scales_name)?;
let qzeros = self.tensors.iter().find(|e| e.name == qzeros_name)?;
let g_idx = self.tensors.iter().find(|e| e.name == g_idx_name);
Some(GptqCompanions {
scales,
qzeros,
g_idx,
})
}
#[must_use]
pub fn find_awq_companions(&self, qweight_name: &str) -> Option<AwqCompanions<'_>> {
let base = qweight_name.strip_suffix(".qweight")?;
let scales_name = format!("{base}.scales");
let qzeros_name = format!("{base}.qzeros");
let scales = self.tensors.iter().find(|e| e.name == scales_name)?;
let qzeros = self.tensors.iter().find(|e| e.name == qzeros_name)?;
Some(AwqCompanions { scales, qzeros })
}
pub fn quant_map_tensors(&self) -> impl Iterator<Item = &TensorEntry> {
self.tensors
.iter()
.filter(|e| e.role == TensorRole::QuantMap)
}
#[must_use]
pub fn quant_map_count(&self) -> usize {
self.quant_map_tensors().count()
}
pub fn nested_scale_tensors(&self) -> impl Iterator<Item = &TensorEntry> {
self.tensors
.iter()
.filter(|e| e.role == TensorRole::NestedScale)
}
#[must_use]
pub fn nested_scale_count(&self) -> usize {
self.nested_scale_tensors().count()
}
#[must_use]
pub fn find_bnb4_companions(&self, weight_name: &str) -> Option<Bnb4Companions<'_>> {
let absmax_name = format!("{weight_name}.absmax");
let quant_map_name = format!("{weight_name}.quant_map");
let nested_absmax_name = format!("{weight_name}.nested_absmax");
let nested_quant_map_name = format!("{weight_name}.nested_quant_map");
let quant_state_prefix = format!("{weight_name}.quant_state.bitsandbytes__");
let absmax = self.tensors.iter().find(|e| e.name == absmax_name)?;
let quant_map = self.tensors.iter().find(|e| e.name == quant_map_name)?;
let nested_absmax = self.tensors.iter().find(|e| e.name == nested_absmax_name);
let nested_quant_map = self
.tensors
.iter()
.find(|e| e.name == nested_quant_map_name);
let quant_state = self
.tensors
.iter()
.find(|e| e.name.starts_with(&quant_state_prefix));
Some(Bnb4Companions {
absmax,
quant_map,
nested_absmax,
nested_quant_map,
quant_state,
})
}
#[must_use]
pub fn find_bnb_int8_scb(&self, weight_name: &str) -> Option<&TensorEntry> {
let base = weight_name.strip_suffix(".weight")?;
let scb_name = format!("{base}.SCB");
self.tensors.iter().find(|e| e.name == scb_name)
}
}
fn infer_awq_config(entries: &[TensorEntry]) -> Option<AwqConfig> {
for entry in entries
.iter()
.filter(|e| e.role == TensorRole::Quantized && e.name.ends_with(".qweight"))
{
let base = entry.name.strip_suffix(".qweight")?;
let scales_name = format!("{base}.scales");
if let Some(scales) = entries.iter().find(|e| e.name == scales_name) {
if entry.shape.len() >= 2 && scales.shape.len() >= 2 {
let in_features = entry.shape.first().copied()?;
let qw_cols = entry.shape.last().copied()?;
let num_groups = scales.shape.first().copied()?;
let out_features = scales.shape.last().copied()?;
if qw_cols == 0 || out_features == 0 || num_groups == 0 || in_features == 0 {
return None;
}
if out_features.is_multiple_of(qw_cols) {
let pack_factor = out_features / qw_cols;
for bits in [4u8, 8] {
#[allow(clippy::as_conversions)]
let expected_pf = 32 / bits as usize;
if pack_factor == expected_pf && in_features.is_multiple_of(num_groups) {
let group_size = in_features / num_groups;
return Some(AwqConfig { bits, group_size });
}
}
}
}
}
}
None
}
fn infer_gptq_config(
entries: &[TensorEntry],
metadata: Option<&HashMap<String, String>>,
) -> Option<GptqConfig> {
if let Some(meta) = metadata {
let bits = meta.get("gptq_bits").and_then(|v| v.parse::<u8>().ok());
let group_size = meta
.get("gptq_group_size")
.and_then(|v| v.parse::<usize>().ok());
if let (Some(bits), Some(group_size)) = (bits, group_size) {
return Some(GptqConfig { bits, group_size });
}
}
for entry in entries
.iter()
.filter(|e| e.role == TensorRole::Quantized && e.name.ends_with(".qweight"))
{
let base = entry.name.strip_suffix(".qweight")?;
let scales_name = format!("{base}.scales");
if let Some(scales) = entries.iter().find(|e| e.name == scales_name) {
if entry.shape.len() >= 2 && scales.shape.len() >= 2 {
let qw_rows = entry.shape.first().copied()?;
let num_groups = scales.shape.first().copied()?;
let out_features = scales.shape.last().copied()?;
if num_groups == 0 || qw_rows == 0 || out_features == 0 {
return None;
}
for bits in [4u8, 8] {
#[allow(clippy::as_conversions)]
let pack_factor = 32 / bits as usize;
let in_features = qw_rows.checked_mul(pack_factor)?;
if in_features.is_multiple_of(num_groups) {
let group_size = in_features / num_groups;
return Some(GptqConfig { bits, group_size });
}
}
}
}
}
None
}
fn infer_bnb_config(entries: &[TensorEntry]) -> Option<BnbConfig> {
for entry in entries
.iter()
.filter(|e| e.role == TensorRole::Quantized && e.dtype == Dtype::U8)
{
let absmax_name = format!("{}.absmax", entry.name);
let nested_name = format!("{}.nested_absmax", entry.name);
if let Some(absmax) = entries.iter().find(|e| e.name == absmax_name) {
let total_elements = entry.byte_len().checked_mul(2)?;
let absmax_count = absmax.num_elements();
if absmax_count == 0 || total_elements % absmax_count != 0 {
return None;
}
let block_size = total_elements / absmax_count;
let double_quant = entries.iter().any(|e| e.name == nested_name);
return Some(BnbConfig {
block_size,
double_quant,
});
}
}
None
}
pub fn parse_safetensors_header(buffer: &[u8]) -> crate::Result<SafetensorsHeader> {
let (header_size, metadata) =
safetensors::SafeTensors::read_metadata(buffer).map_err(AnamnesisError::from)?;
let st_tensors = metadata.tensors();
let mut entries = Vec::with_capacity(st_tensors.len());
for (name, info) in &st_tensors {
let dtype = Dtype::try_from(info.dtype)?;
let role = classify_tensor(name, dtype);
entries.push(TensorEntry {
name: name.clone(),
dtype,
shape: info.shape.clone(),
data_offsets: info.data_offsets,
role,
});
}
entries.sort_by(|a, b| a.name.cmp(&b.name));
let scheme = detect_scheme(&entries);
let file_metadata = metadata.metadata().clone();
let gptq_config = if scheme == QuantScheme::Gptq {
infer_gptq_config(&entries, file_metadata.as_ref())
} else {
None
};
let awq_config = if scheme == QuantScheme::Awq {
infer_awq_config(&entries)
} else {
None
};
let bnb_config = if scheme == QuantScheme::Bnb4 {
infer_bnb_config(&entries)
} else {
None
};
Ok(SafetensorsHeader {
tensors: entries,
scheme,
metadata: file_metadata,
header_size,
gptq_config,
awq_config,
bnb_config,
})
}
#[cfg(test)]
#[allow(clippy::panic, clippy::indexing_slicing)]
mod tests {
use super::*;
#[test]
fn dtype_byte_sizes() {
assert_eq!(Dtype::F8E4M3.byte_size(), 1);
assert_eq!(Dtype::F8E5M2.byte_size(), 1);
assert_eq!(Dtype::U8.byte_size(), 1);
assert_eq!(Dtype::I8.byte_size(), 1);
assert_eq!(Dtype::Bool.byte_size(), 1);
assert_eq!(Dtype::BF16.byte_size(), 2);
assert_eq!(Dtype::F16.byte_size(), 2);
assert_eq!(Dtype::U16.byte_size(), 2);
assert_eq!(Dtype::I16.byte_size(), 2);
assert_eq!(Dtype::F32.byte_size(), 4);
assert_eq!(Dtype::U32.byte_size(), 4);
assert_eq!(Dtype::I32.byte_size(), 4);
assert_eq!(Dtype::F64.byte_size(), 8);
assert_eq!(Dtype::U64.byte_size(), 8);
assert_eq!(Dtype::I64.byte_size(), 8);
}
#[test]
fn dtype_is_quantized() {
assert!(Dtype::F8E4M3.is_quantized());
assert!(Dtype::F8E5M2.is_quantized());
assert!(!Dtype::BF16.is_quantized());
assert!(!Dtype::F32.is_quantized());
assert!(!Dtype::U8.is_quantized());
}
#[test]
fn dtype_is_floating_point() {
assert!(Dtype::F8E4M3.is_floating_point());
assert!(Dtype::BF16.is_floating_point());
assert!(Dtype::F32.is_floating_point());
assert!(Dtype::F64.is_floating_point());
assert!(!Dtype::U8.is_floating_point());
assert!(!Dtype::I32.is_floating_point());
assert!(!Dtype::Bool.is_floating_point());
}
#[test]
fn dtype_display() {
assert_eq!(Dtype::F8E4M3.to_string(), "F8_E4M3");
assert_eq!(Dtype::BF16.to_string(), "BF16");
assert_eq!(Dtype::F32.to_string(), "F32");
}
#[test]
fn dtype_try_from_safetensors() {
assert_eq!(
Dtype::try_from(safetensors::Dtype::F8_E4M3).ok(),
Some(Dtype::F8E4M3)
);
assert_eq!(
Dtype::try_from(safetensors::Dtype::BF16).ok(),
Some(Dtype::BF16)
);
assert_eq!(
Dtype::try_from(safetensors::Dtype::F32).ok(),
Some(Dtype::F32)
);
}
#[test]
fn classify_quantized_weight() {
let role = classify_tensor("model.layers.0.self_attn.q_proj.weight", Dtype::F8E4M3);
assert_eq!(role, TensorRole::Quantized);
}
#[test]
fn classify_scale_inv() {
let role = classify_tensor(
"model.layers.0.self_attn.q_proj.weight_scale_inv",
Dtype::F32,
);
assert_eq!(role, TensorRole::Scale);
}
#[test]
fn classify_scale() {
let role = classify_tensor("model.layers.0.self_attn.q_proj.weight_scale", Dtype::F32);
assert_eq!(role, TensorRole::Scale);
}
#[test]
fn classify_passthrough_norm() {
let role = classify_tensor("model.norm.weight", Dtype::BF16);
assert_eq!(role, TensorRole::Passthrough);
}
#[test]
fn classify_passthrough_embedding() {
let role = classify_tensor("model.embed_tokens.weight", Dtype::BF16);
assert_eq!(role, TensorRole::Passthrough);
}
fn make_entry(name: &str, dtype: Dtype, role: TensorRole) -> TensorEntry {
make_entry_with_shape(name, dtype, role, vec![128, 128])
}
fn make_entry_with_shape(
name: &str,
dtype: Dtype,
role: TensorRole,
shape: Vec<usize>,
) -> TensorEntry {
let num_elements: usize = shape.iter().product();
let byte_len = num_elements * dtype.byte_size();
TensorEntry {
name: name.to_owned(),
dtype,
shape,
data_offsets: (0, byte_len),
role,
}
}
#[test]
fn num_elements_saturates_on_overflow() {
let entry = TensorEntry {
name: "huge".to_owned(),
dtype: Dtype::F32,
shape: vec![usize::MAX, 2],
data_offsets: (0, 0),
role: TensorRole::Passthrough,
};
assert_eq!(entry.num_elements(), usize::MAX);
}
#[test]
fn num_elements_exact_on_normal_shape() {
let entry = TensorEntry {
name: "normal".to_owned(),
dtype: Dtype::F32,
shape: vec![16, 4096, 2048],
data_offsets: (0, 0),
role: TensorRole::Passthrough,
};
assert_eq!(entry.num_elements(), 16 * 4096 * 2048);
}
#[test]
fn num_elements_empty_shape_is_one() {
let entry = TensorEntry {
name: "scalar".to_owned(),
dtype: Dtype::F32,
shape: vec![],
data_offsets: (0, 0),
role: TensorRole::Passthrough,
};
assert_eq!(entry.num_elements(), 1);
}
#[test]
fn detect_unquantized() {
let entries = vec![
make_entry("model.norm.weight", Dtype::BF16, TensorRole::Passthrough),
make_entry("lm_head.weight", Dtype::BF16, TensorRole::Passthrough),
];
assert_eq!(detect_scheme(&entries), QuantScheme::Unquantized);
}
#[test]
fn detect_fine_grained_fp8() {
let entries = vec![
make_entry("layer.0.weight", Dtype::F8E4M3, TensorRole::Quantized),
make_entry("layer.0.weight_scale_inv", Dtype::F32, TensorRole::Scale),
make_entry("model.norm.weight", Dtype::BF16, TensorRole::Passthrough),
];
assert_eq!(detect_scheme(&entries), QuantScheme::FineGrainedFp8);
}
#[test]
fn detect_per_tensor_fp8() {
let entries = vec![
make_entry("layer.0.weight", Dtype::F8E4M3, TensorRole::Quantized),
make_entry("model.norm.weight", Dtype::BF16, TensorRole::Passthrough),
];
assert_eq!(detect_scheme(&entries), QuantScheme::PerTensorFp8);
}
#[test]
fn detect_per_tensor_fp8_with_scalar_scale_inv() {
let entries = vec![
make_entry("layer.0.weight", Dtype::F8E4M3, TensorRole::Quantized),
make_entry_with_shape(
"layer.0.weight_scale_inv",
Dtype::BF16,
TensorRole::Scale,
vec![],
),
make_entry_with_shape(
"layer.0.activation_scale",
Dtype::BF16,
TensorRole::Scale,
vec![],
),
make_entry("model.norm.weight", Dtype::BF16, TensorRole::Passthrough),
];
assert_eq!(detect_scheme(&entries), QuantScheme::PerTensorFp8);
}
#[test]
fn detect_per_tensor_fp8_with_1d_scale_inv() {
let entries = vec![
make_entry("layer.0.weight", Dtype::F8E4M3, TensorRole::Quantized),
make_entry_with_shape(
"layer.0.weight_scale_inv",
Dtype::BF16,
TensorRole::Scale,
vec![1],
),
];
assert_eq!(detect_scheme(&entries), QuantScheme::PerTensorFp8);
}
#[test]
fn detect_fine_grained_fp8_with_2d_scale_inv() {
let entries = vec![
make_entry_with_shape(
"layer.0.weight",
Dtype::F8E4M3,
TensorRole::Quantized,
vec![2048, 4096],
),
make_entry_with_shape(
"layer.0.weight_scale_inv",
Dtype::BF16,
TensorRole::Scale,
vec![16, 32],
),
make_entry("model.norm.weight", Dtype::BF16, TensorRole::Passthrough),
];
assert_eq!(detect_scheme(&entries), QuantScheme::FineGrainedFp8);
}
#[test]
fn find_scale_for_prefers_scale_inv() {
let header = SafetensorsHeader {
tensors: vec![
make_entry("w", Dtype::F8E4M3, TensorRole::Quantized),
make_entry("w_scale", Dtype::F32, TensorRole::Scale),
make_entry("w_scale_inv", Dtype::F32, TensorRole::Scale),
],
scheme: QuantScheme::FineGrainedFp8,
metadata: None,
header_size: 0,
gptq_config: None,
awq_config: None,
bnb_config: None,
};
let found = header.find_scale_for("w");
assert_eq!(found.map(|e| e.name.as_str()), Some("w_scale_inv"));
}
#[test]
fn find_scale_for_falls_back_to_scale() {
let header = SafetensorsHeader {
tensors: vec![
make_entry("w", Dtype::F8E4M3, TensorRole::Quantized),
make_entry("w_scale", Dtype::F32, TensorRole::Scale),
],
scheme: QuantScheme::PerTensorFp8,
metadata: None,
header_size: 0,
gptq_config: None,
awq_config: None,
bnb_config: None,
};
let found = header.find_scale_for("w");
assert_eq!(found.map(|e| e.name.as_str()), Some("w_scale"));
}
#[test]
fn find_scale_for_returns_none_when_missing() {
let header = SafetensorsHeader {
tensors: vec![make_entry("w", Dtype::F8E4M3, TensorRole::Quantized)],
scheme: QuantScheme::PerTensorFp8,
metadata: None,
header_size: 0,
gptq_config: None,
awq_config: None,
bnb_config: None,
};
assert!(header.find_scale_for("w").is_none());
}
#[test]
fn parse_minimal_safetensors() {
use safetensors::tensor::serialize;
let data: Vec<u8> = vec![0; 4]; let tensors = vec![(
"test_tensor",
safetensors::tensor::TensorView::new(safetensors::Dtype::BF16, vec![2], &data)
.unwrap_or_else(|e| panic!("failed to create TensorView: {e}")),
)];
let buffer = serialize(tensors, &None).unwrap_or_else(|e| panic!("serialize: {e}"));
let header = parse_safetensors_header(&buffer).unwrap_or_else(|e| panic!("parse: {e}"));
assert_eq!(header.tensors.len(), 1);
assert_eq!(header.tensors[0].name, "test_tensor"); assert_eq!(header.tensors[0].dtype, Dtype::BF16);
assert_eq!(header.tensors[0].shape, vec![2]);
assert_eq!(header.tensors[0].role, TensorRole::Passthrough);
assert_eq!(header.scheme, QuantScheme::Unquantized);
}
#[test]
fn parse_fp8_with_scale() {
use safetensors::tensor::serialize;
let weight_data: Vec<u8> = vec![0; 4]; let scale_data: Vec<u8> = vec![0; 8];
let tensors = vec![
(
"layer.weight",
safetensors::tensor::TensorView::new(
safetensors::Dtype::F8_E4M3,
vec![2, 2],
&weight_data,
)
.unwrap_or_else(|e| panic!("weight TensorView: {e}")),
),
(
"layer.weight_scale_inv",
safetensors::tensor::TensorView::new(
safetensors::Dtype::F32,
vec![1, 2],
&scale_data,
)
.unwrap_or_else(|e| panic!("scale TensorView: {e}")),
),
];
let buffer = serialize(tensors, &None).unwrap_or_else(|e| panic!("serialize: {e}"));
let header = parse_safetensors_header(&buffer).unwrap_or_else(|e| panic!("parse: {e}"));
assert_eq!(header.tensors.len(), 2);
assert_eq!(header.quantized_count(), 1);
assert_eq!(header.scale_count(), 1);
assert_eq!(header.passthrough_count(), 0);
assert_eq!(header.scheme, QuantScheme::FineGrainedFp8);
let scale = header.find_scale_for("layer.weight");
assert_eq!(
scale.map(|e| e.name.as_str()),
Some("layer.weight_scale_inv")
);
}
}