#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::identity_op,
non_camel_case_types
)]
use prism_tensor::dtype::{
Dtype, TensorDtypeRegistry, BF16, BOOL, C128, C64, F16, F32, F4_E2M1, F64, F8_E4M3,
F8_E4M3_FNUZ, F8_E5M2, F8_E5M2_FNUZ, I16, I32, I4, I64, I8, IQ1_M, IQ1_S, IQ2_S, IQ2_XS,
IQ2_XXS, IQ3_S, IQ3_XXS, IQ4_NL, IQ4_XS, Q2_K, Q3_K, Q4_0, Q4_1, Q4_K, Q5_0, Q5_1, Q5_K, Q6_K,
Q8_0, Q8_1, Q8_K, U16, U32, U4, U64, U8,
};
use uor_foundation::pipeline::shape_iri_registry::ShapeRegistryProvider;
use uor_foundation::pipeline::ConstrainedTypeShape;
const BITS_PER_BYTE: usize = 8;
const QK_K: usize = 256;
const QK4_0: usize = 32;
const QK5_0: usize = 32;
const QK8_0: usize = 32;
const QK4_NL: usize = 32;
const GGML_HALF: usize = 16 / BITS_PER_BYTE;
const GGML_U16: usize = 16 / BITS_PER_BYTE;
const GGML_FLOAT: usize = 32 / BITS_PER_BYTE;
const GGML_I16: usize = 16 / BITS_PER_BYTE;
const K_SCALE_SIZE: usize = 12;
const IQ3S_N_SCALE: usize = 4;
#[test]
fn continuous_float_block_bytes() {
assert_eq!(F32::BLOCK_BYTES, 32 / BITS_PER_BYTE);
assert_eq!(F16::BLOCK_BYTES, 16 / BITS_PER_BYTE);
assert_eq!(BF16::BLOCK_BYTES, 16 / BITS_PER_BYTE);
assert_eq!(F64::BLOCK_BYTES, 64 / BITS_PER_BYTE);
for be in [
F32::BLOCK_ELEMS,
F16::BLOCK_ELEMS,
BF16::BLOCK_ELEMS,
F64::BLOCK_ELEMS,
] {
assert_eq!(be, 1, "continuous floats are one element per block");
}
}
#[test]
fn onnx_float8_block_bytes() {
let one_byte = 8 / BITS_PER_BYTE;
assert_eq!(F8_E4M3::BLOCK_BYTES, one_byte);
assert_eq!(F8_E4M3_FNUZ::BLOCK_BYTES, one_byte);
assert_eq!(F8_E5M2::BLOCK_BYTES, one_byte);
assert_eq!(F8_E5M2_FNUZ::BLOCK_BYTES, one_byte);
}
#[test]
fn onnx_complex_block_bytes() {
assert_eq!(C64::BLOCK_BYTES, 2 * (32 / BITS_PER_BYTE));
assert_eq!(C128::BLOCK_BYTES, 2 * (64 / BITS_PER_BYTE));
}
#[test]
fn integer_block_bytes() {
assert_eq!(I8::BLOCK_BYTES, 1);
assert_eq!(I16::BLOCK_BYTES, 2);
assert_eq!(I32::BLOCK_BYTES, 4);
assert_eq!(I64::BLOCK_BYTES, 8);
assert_eq!(U8::BLOCK_BYTES, 1);
assert_eq!(U16::BLOCK_BYTES, 2);
assert_eq!(U32::BLOCK_BYTES, 4);
assert_eq!(U64::BLOCK_BYTES, 8);
assert_eq!(BOOL::BLOCK_BYTES, 1);
}
#[test]
fn packed_nibble_block_bytes() {
assert_eq!(I4::BLOCK_BYTES, 1);
assert_eq!(U4::BLOCK_BYTES, 1);
assert_eq!(F4_E2M1::BLOCK_BYTES, 1);
assert_eq!(I4::BLOCK_ELEMS, 2);
assert_eq!(U4::BLOCK_ELEMS, 2);
assert_eq!(F4_E2M1::BLOCK_ELEMS, 2);
}
#[test]
fn ggml_legacy_block32_bytes() {
assert_eq!(Q4_0::BLOCK_BYTES, GGML_HALF + QK4_0 / 2);
assert_eq!(Q4_0::BLOCK_BYTES, 18);
assert_eq!(Q4_1::BLOCK_BYTES, 2 * GGML_HALF + QK4_0 / 2);
assert_eq!(Q4_1::BLOCK_BYTES, 20);
assert_eq!(Q5_0::BLOCK_BYTES, GGML_HALF + 2 * GGML_U16 + QK5_0 / 2);
assert_eq!(Q5_0::BLOCK_BYTES, 22);
assert_eq!(Q5_1::BLOCK_BYTES, 2 * GGML_HALF + 2 * GGML_U16 + QK5_0 / 2);
assert_eq!(Q5_1::BLOCK_BYTES, 24);
assert_eq!(Q8_0::BLOCK_BYTES, GGML_HALF + QK8_0);
assert_eq!(Q8_0::BLOCK_BYTES, 34);
assert_eq!(Q8_1::BLOCK_BYTES, 2 * GGML_HALF + QK8_0);
assert_eq!(Q8_1::BLOCK_BYTES, 36);
for be in [
Q4_0::BLOCK_ELEMS,
Q4_1::BLOCK_ELEMS,
Q5_0::BLOCK_ELEMS,
Q5_1::BLOCK_ELEMS,
Q8_0::BLOCK_ELEMS,
Q8_1::BLOCK_ELEMS,
] {
assert_eq!(be, 32, "legacy quantization is block-32");
}
}
#[test]
fn ggml_kseries_block256_bytes() {
assert_eq!(Q2_K::BLOCK_BYTES, 2 * GGML_HALF + QK_K / 16 + QK_K / 4);
assert_eq!(Q2_K::BLOCK_BYTES, 84);
assert_eq!(
Q3_K::BLOCK_BYTES,
GGML_HALF + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE
);
assert_eq!(Q3_K::BLOCK_BYTES, 110);
assert_eq!(Q4_K::BLOCK_BYTES, 2 * GGML_HALF + K_SCALE_SIZE + QK_K / 2);
assert_eq!(Q4_K::BLOCK_BYTES, 144);
assert_eq!(
Q5_K::BLOCK_BYTES,
2 * GGML_HALF + K_SCALE_SIZE + QK_K / 2 + QK_K / 8
);
assert_eq!(Q5_K::BLOCK_BYTES, 176);
assert_eq!(Q6_K::BLOCK_BYTES, GGML_HALF + QK_K / 16 + 3 * QK_K / 4);
assert_eq!(Q6_K::BLOCK_BYTES, 210);
assert_eq!(Q8_K::BLOCK_BYTES, GGML_FLOAT + QK_K + QK_K / 16 * GGML_I16);
assert_eq!(Q8_K::BLOCK_BYTES, 292);
for be in [
Q2_K::BLOCK_ELEMS,
Q3_K::BLOCK_ELEMS,
Q4_K::BLOCK_ELEMS,
Q5_K::BLOCK_ELEMS,
Q6_K::BLOCK_ELEMS,
Q8_K::BLOCK_ELEMS,
] {
assert_eq!(be, QK_K, "K-series quantization is block-256");
}
}
#[test]
fn ggml_iqseries_block_bytes() {
assert_eq!(IQ1_S::BLOCK_BYTES, GGML_HALF + QK_K / 8 + QK_K / 16);
assert_eq!(IQ1_S::BLOCK_BYTES, 50);
assert_eq!(IQ1_M::BLOCK_BYTES, QK_K / 8 + QK_K / 16 + QK_K / 32);
assert_eq!(IQ1_M::BLOCK_BYTES, 56);
assert_eq!(IQ2_XXS::BLOCK_BYTES, GGML_HALF + QK_K / 8 * GGML_U16);
assert_eq!(IQ2_XXS::BLOCK_BYTES, 66);
assert_eq!(
IQ2_XS::BLOCK_BYTES,
GGML_HALF + QK_K / 8 * GGML_U16 + QK_K / 32
);
assert_eq!(IQ2_XS::BLOCK_BYTES, 74);
assert_eq!(IQ2_S::BLOCK_BYTES, GGML_HALF + QK_K / 4 + QK_K / 16);
assert_eq!(IQ2_S::BLOCK_BYTES, 82);
assert_eq!(IQ3_XXS::BLOCK_BYTES, GGML_HALF + 3 * (QK_K / 8));
assert_eq!(IQ3_XXS::BLOCK_BYTES, 98);
assert_eq!(
IQ3_S::BLOCK_BYTES,
GGML_HALF + 13 * (QK_K / 32) + IQ3S_N_SCALE
);
assert_eq!(IQ3_S::BLOCK_BYTES, 110);
assert_eq!(IQ4_NL::BLOCK_BYTES, GGML_HALF + QK4_NL / 2);
assert_eq!(IQ4_NL::BLOCK_BYTES, 18);
assert_eq!(IQ4_NL::BLOCK_ELEMS, QK4_NL);
assert_eq!(
IQ4_XS::BLOCK_BYTES,
GGML_HALF + GGML_U16 + QK_K / 64 + QK_K / 2
);
assert_eq!(IQ4_XS::BLOCK_BYTES, 136);
}
#[test]
fn equal_block_bytes_content_address_identically() {
assert_eq!(
<F16 as ConstrainedTypeShape>::SITE_COUNT,
<BF16 as ConstrainedTypeShape>::SITE_COUNT
);
assert_eq!(
<F16 as ConstrainedTypeShape>::SITE_COUNT,
<I16 as ConstrainedTypeShape>::SITE_COUNT
);
assert_eq!(
<F16 as ConstrainedTypeShape>::SITE_COUNT,
<U16 as ConstrainedTypeShape>::SITE_COUNT
);
assert_eq!(Q4_0::BLOCK_BYTES, IQ4_NL::BLOCK_BYTES);
assert_eq!(
<Q4_0 as ConstrainedTypeShape>::SITE_COUNT,
<IQ4_NL as ConstrainedTypeShape>::SITE_COUNT
);
assert_eq!(Q3_K::BLOCK_BYTES, IQ3_S::BLOCK_BYTES);
assert_eq!(
<Q3_K as ConstrainedTypeShape>::SITE_COUNT,
<IQ3_S as ConstrainedTypeShape>::SITE_COUNT
);
assert_eq!(
<F32 as ConstrainedTypeShape>::IRI,
"https://uor.foundation/type/ConstrainedType"
);
assert_eq!(
<Q4_0 as ConstrainedTypeShape>::IRI,
<IQ4_NL as ConstrainedTypeShape>::IRI
);
}
#[test]
fn site_count_equals_block_bytes() {
assert_eq!(<F32 as ConstrainedTypeShape>::SITE_COUNT, F32::BLOCK_BYTES);
assert_eq!(
<Q6_K as ConstrainedTypeShape>::SITE_COUNT,
Q6_K::BLOCK_BYTES
);
assert_eq!(
<IQ4_XS as ConstrainedTypeShape>::SITE_COUNT,
IQ4_XS::BLOCK_BYTES
);
}
#[test]
fn registry_carries_every_dtype() {
const EXPECTED: usize = 4 + 4 + 2 + 4 + 4 + 1 + 3 + 6 + 6 + 9;
assert_eq!(EXPECTED, 43);
let registry = <TensorDtypeRegistry as ShapeRegistryProvider>::REGISTRY;
assert_eq!(registry.len(), EXPECTED);
}
#[test]
fn name_constants_match_struct_names() {
assert_eq!(F32::NAME, "F32");
assert_eq!(BF16::NAME, "BF16");
assert_eq!(Q4_0::NAME, "Q4_0");
assert_eq!(IQ2_XXS::NAME, "IQ2_XXS");
}