use crate::dtype::Dtype;
use crate::error::{Result, TensogramError};
use crate::types::{DataObjectDescriptor, MaskDescriptor};
use tensogram_encodings::ByteOrder;
use tensogram_encodings::bitmask;
pub(crate) fn restore_non_finite_into(
decoded_payload: &mut [u8],
descriptor: &DataObjectDescriptor,
mask_region: &[u8],
output_byte_order: ByteOrder,
) -> Result<()> {
let Some(masks) = descriptor.masks.as_ref() else {
return Ok(());
};
let n_elements = element_count(descriptor)?;
let elem_size = descriptor.dtype.byte_width();
if elem_size == 0 {
return Err(TensogramError::Framing(
"bitmask-companion masks cannot be restored on bitmask-dtype payloads".to_string(),
));
}
let expected_len = n_elements.checked_mul(elem_size).ok_or_else(|| {
TensogramError::Metadata("n_elements * elem_size overflows usize".to_string())
})?;
if decoded_payload.len() != expected_len {
return Err(TensogramError::Framing(format!(
"decoded payload length {} does not match descriptor n_elements * elem_size ({} * {} = {})",
decoded_payload.len(),
n_elements,
elem_size,
expected_len,
)));
}
let mask_region_base = smallest_mask_offset(masks);
for (md, kind) in each_mask_kind(masks) {
let bits = decode_one_mask_at(md, mask_region, mask_region_base, n_elements)?;
write_canonical_non_finite(
decoded_payload,
descriptor.dtype,
output_byte_order,
&bits,
kind,
);
}
Ok(())
}
fn each_mask_kind(
masks: &crate::types::MasksMetadata,
) -> impl Iterator<Item = (&MaskDescriptor, Kind)> {
[
(masks.nan.as_ref(), Kind::Nan),
(masks.pos_inf.as_ref(), Kind::PosInf),
(masks.neg_inf.as_ref(), Kind::NegInf),
]
.into_iter()
.filter_map(|(md, kind)| md.map(|m| (m, kind)))
}
fn element_count(desc: &DataObjectDescriptor) -> Result<usize> {
let product = desc
.shape
.iter()
.try_fold(1u64, |acc, &x| acc.checked_mul(x))
.ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
usize::try_from(product)
.map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))
}
fn decode_one_mask_at(
md: &MaskDescriptor,
mask_region: &[u8],
mask_region_base: u64,
n_elements: usize,
) -> Result<Vec<bool>> {
let offset = u64_to_usize(md.offset, "mask.offset")?;
let length = u64_to_usize(md.length, "mask.length")?;
let base = u64_to_usize(mask_region_base, "mask_region_base")?;
let start = offset.checked_sub(base).ok_or_else(|| {
TensogramError::Framing(format!(
"mask.offset {} less than mask_region_base {}",
md.offset, mask_region_base
))
})?;
let end = start.checked_add(length).ok_or_else(|| {
TensogramError::Framing(format!(
"mask.offset + length overflow (offset={offset}, length={length})"
))
})?;
if end > mask_region.len() {
return Err(TensogramError::Framing(format!(
"mask slice end {end} exceeds mask_region length {}",
mask_region.len()
)));
}
let blob = &mask_region[start..end];
decode_blob(&md.method, blob, n_elements).map_err(|e| match e {
bitmask::MaskError::UnknownMethod(_) => TensogramError::Encoding(e.to_string()),
other => TensogramError::Encoding(format!("bitmask decode ({}): {other}", md.method)),
})
}
fn u64_to_usize(v: u64, name: &str) -> Result<usize> {
usize::try_from(v).map_err(|_| TensogramError::Framing(format!("{name} {v} overflows usize")))
}
fn decode_blob(
method: &str,
blob: &[u8],
n_elements: usize,
) -> std::result::Result<Vec<bool>, bitmask::MaskError> {
match method {
"none" => bitmask::codecs::decode_none(blob, n_elements),
"rle" => bitmask::rle::decode(blob, n_elements),
"roaring" => bitmask::roaring::decode(blob, n_elements),
"lz4" => bitmask::codecs::decode_lz4(blob, n_elements),
"zstd" => bitmask::codecs::decode_zstd(blob, n_elements),
#[cfg(feature = "blosc2")]
"blosc2" => bitmask::codecs::decode_blosc2(blob, n_elements),
#[cfg(not(feature = "blosc2"))]
"blosc2" => Err(bitmask::MaskError::FeatureDisabled { method: "blosc2" }),
other => Err(bitmask::MaskError::UnknownMethod(other.to_string())),
}
}
#[derive(Debug, Clone, Copy)]
enum Kind {
Nan,
PosInf,
NegInf,
}
fn write_canonical_non_finite(
buf: &mut [u8],
dtype: Dtype,
byte_order: ByteOrder,
bits: &[bool],
kind: Kind,
) {
let elem_size = dtype.byte_width();
debug_assert!(
buf.len() >= bits.len() * elem_size,
"write_canonical_non_finite: buf {} < bits {} * elem_size {}",
buf.len(),
bits.len(),
elem_size,
);
let Some(pattern) = CanonicalPattern::new(dtype, kind, byte_order) else {
return;
};
let element_bytes = pattern.as_slice();
for (i, &is_set) in bits.iter().enumerate() {
if !is_set {
continue;
}
let start = i * elem_size;
if start + elem_size > buf.len() {
break;
}
buf[start..start + elem_size].copy_from_slice(element_bytes);
}
}
struct CanonicalPattern {
bytes: [u8; 16],
len: usize,
}
impl CanonicalPattern {
fn new(dtype: Dtype, kind: Kind, byte_order: ByteOrder) -> Option<Self> {
let mut bytes = [0u8; 16];
let len = match dtype {
Dtype::Float32 => {
let scalar = f32_canonical(kind, byte_order);
bytes[..4].copy_from_slice(&scalar);
4
}
Dtype::Float64 => {
let scalar = f64_canonical(kind, byte_order);
bytes[..8].copy_from_slice(&scalar);
8
}
Dtype::Float16 => {
let half = half_canonical(kind, byte_order, Half::F16);
bytes[..2].copy_from_slice(&half);
2
}
Dtype::Bfloat16 => {
let half = half_canonical(kind, byte_order, Half::BF16);
bytes[..2].copy_from_slice(&half);
2
}
Dtype::Complex64 => {
let comp = f32_canonical(kind, byte_order);
bytes[..4].copy_from_slice(&comp);
bytes[4..8].copy_from_slice(&comp);
8
}
Dtype::Complex128 => {
let comp = f64_canonical(kind, byte_order);
bytes[..8].copy_from_slice(&comp);
bytes[8..16].copy_from_slice(&comp);
16
}
Dtype::Int8
| Dtype::Int16
| Dtype::Int32
| Dtype::Int64
| Dtype::Uint8
| Dtype::Uint16
| Dtype::Uint32
| Dtype::Uint64
| Dtype::Bitmask => return None,
};
Some(Self { bytes, len })
}
fn as_slice(&self) -> &[u8] {
&self.bytes[..self.len]
}
}
#[derive(Clone, Copy)]
enum Half {
F16,
BF16,
}
fn f32_canonical(kind: Kind, byte_order: ByteOrder) -> [u8; 4] {
let bits: u32 = match kind {
Kind::Nan => 0x7FC0_0000,
Kind::PosInf => 0x7F80_0000,
Kind::NegInf => 0xFF80_0000,
};
match byte_order {
ByteOrder::Big => bits.to_be_bytes(),
ByteOrder::Little => bits.to_le_bytes(),
}
}
fn f64_canonical(kind: Kind, byte_order: ByteOrder) -> [u8; 8] {
let bits: u64 = match kind {
Kind::Nan => 0x7FF8_0000_0000_0000,
Kind::PosInf => 0x7FF0_0000_0000_0000,
Kind::NegInf => 0xFFF0_0000_0000_0000,
};
match byte_order {
ByteOrder::Big => bits.to_be_bytes(),
ByteOrder::Little => bits.to_le_bytes(),
}
}
fn half_canonical(kind: Kind, byte_order: ByteOrder, half: Half) -> [u8; 2] {
let bits: u16 = match (half, kind) {
(Half::F16, Kind::Nan) => 0x7E00,
(Half::F16, Kind::PosInf) => 0x7C00,
(Half::F16, Kind::NegInf) => 0xFC00,
(Half::BF16, Kind::Nan) => 0x7FC0,
(Half::BF16, Kind::PosInf) => 0x7F80,
(Half::BF16, Kind::NegInf) => 0xFF80,
};
match byte_order {
ByteOrder::Big => bits.to_be_bytes(),
ByteOrder::Little => bits.to_le_bytes(),
}
}
#[derive(Debug, Clone)]
pub struct DecodedObjectWithMasks {
pub descriptor: DataObjectDescriptor,
pub payload: Vec<u8>,
pub masks: DecodedMaskSet,
}
#[derive(Debug, Clone, Default)]
pub struct DecodedMaskSet {
pub nan: Option<Vec<bool>>,
pub pos_inf: Option<Vec<bool>>,
pub neg_inf: Option<Vec<bool>>,
}
impl DecodedMaskSet {
pub fn is_empty(&self) -> bool {
self.nan.is_none() && self.pos_inf.is_none() && self.neg_inf.is_none()
}
fn slot_for(&mut self, kind: Kind) -> &mut Option<Vec<bool>> {
match kind {
Kind::Nan => &mut self.nan,
Kind::PosInf => &mut self.pos_inf,
Kind::NegInf => &mut self.neg_inf,
}
}
}
pub(crate) fn restore_non_finite_into_ranges(
parts: &mut [Vec<u8>],
descriptor: &DataObjectDescriptor,
ranges: &[(u64, u64)],
mask_set: &DecodedMaskSet,
output_byte_order: ByteOrder,
) -> Result<()> {
if descriptor.masks.is_none() || mask_set.is_empty() {
return Ok(());
}
if parts.len() != ranges.len() {
return Err(TensogramError::Framing(format!(
"range count mismatch: parts.len()={} but ranges.len()={}",
parts.len(),
ranges.len()
)));
}
let elem_size = descriptor.dtype.byte_width();
if elem_size == 0 {
return Err(TensogramError::Framing(
"bitmask-companion masks cannot be restored on bitmask-dtype payloads".to_string(),
));
}
for (part, &(offset, count)) in parts.iter_mut().zip(ranges.iter()) {
let start = u64_to_usize(offset, "range.offset")?;
let count = u64_to_usize(count, "range.count")?;
let end = start
.checked_add(count)
.ok_or_else(|| TensogramError::Framing("range offset+count overflow".to_string()))?;
let expected_part_len = count.checked_mul(elem_size).ok_or_else(|| {
TensogramError::Framing("range count * elem_size overflows usize".to_string())
})?;
if part.len() != expected_part_len {
return Err(TensogramError::Framing(format!(
"range part length {} does not match count * elem_size ({} * {} = {})",
part.len(),
count,
elem_size,
expected_part_len,
)));
}
for (kind_bits, kind) in [
(mask_set.nan.as_ref(), Kind::Nan),
(mask_set.pos_inf.as_ref(), Kind::PosInf),
(mask_set.neg_inf.as_ref(), Kind::NegInf),
] {
let Some(bits) = kind_bits else { continue };
if end > bits.len() {
return Err(TensogramError::Framing(format!(
"range end {end} exceeds mask length {} for descriptor shape",
bits.len()
)));
}
let sliced = &bits[start..end];
write_canonical_non_finite(part, descriptor.dtype, output_byte_order, sliced, kind);
}
}
Ok(())
}
pub(crate) fn decode_mask_set(
descriptor: &DataObjectDescriptor,
mask_region: &[u8],
) -> Result<DecodedMaskSet> {
let Some(masks) = descriptor.masks.as_ref() else {
return Ok(DecodedMaskSet::default());
};
let n_elements = element_count(descriptor)?;
let mask_region_base = smallest_mask_offset(masks);
let mut out = DecodedMaskSet::default();
for (md, kind) in each_mask_kind(masks) {
let bits = decode_one_mask_at(md, mask_region, mask_region_base, n_elements)?;
*out.slot_for(kind) = Some(bits);
}
Ok(out)
}
fn smallest_mask_offset(masks: &crate::types::MasksMetadata) -> u64 {
let mut smallest = u64::MAX;
for md in [
masks.nan.as_ref(),
masks.pos_inf.as_ref(),
masks.neg_inf.as_ref(),
]
.into_iter()
.flatten()
{
if md.offset < smallest {
smallest = md.offset;
}
}
smallest
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{MaskDescriptor, MasksMetadata};
use std::collections::BTreeMap;
fn make_descriptor(shape: Vec<u64>, dtype: Dtype) -> DataObjectDescriptor {
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape: shape.clone(),
strides: {
let mut s = vec![1u64; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
s[i] = s[i + 1] * shape[i + 1];
}
s
},
dtype,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
masks: None,
params: BTreeMap::new(),
}
}
#[test]
fn write_canonical_f64_nan_at_marked_positions() {
let mut buf = vec![0u8; 8 * 4];
let bits = vec![false, true, false, true];
let dtype = Dtype::Float64;
write_canonical_non_finite(&mut buf, dtype, ByteOrder::native(), &bits, Kind::Nan);
let mut doubles = [0f64; 4];
for i in 0..4 {
let bytes = [
buf[i * 8],
buf[i * 8 + 1],
buf[i * 8 + 2],
buf[i * 8 + 3],
buf[i * 8 + 4],
buf[i * 8 + 5],
buf[i * 8 + 6],
buf[i * 8 + 7],
];
doubles[i] = f64::from_ne_bytes(bytes);
}
assert_eq!(doubles[0], 0.0);
assert!(doubles[1].is_nan());
assert_eq!(doubles[2], 0.0);
assert!(doubles[3].is_nan());
}
#[test]
fn write_canonical_f32_neg_inf() {
let mut buf = vec![0u8; 4 * 3];
let bits = vec![false, true, false];
write_canonical_non_finite(
&mut buf,
Dtype::Float32,
ByteOrder::native(),
&bits,
Kind::NegInf,
);
let v = f32::from_ne_bytes([buf[4], buf[5], buf[6], buf[7]]);
assert!(v.is_infinite() && v.is_sign_negative());
}
#[test]
fn write_canonical_complex64_writes_both_components() {
let mut buf = vec![0u8; 8];
let bits = vec![true];
write_canonical_non_finite(
&mut buf,
Dtype::Complex64,
ByteOrder::native(),
&bits,
Kind::Nan,
);
let real = f32::from_ne_bytes([buf[0], buf[1], buf[2], buf[3]]);
let imag = f32::from_ne_bytes([buf[4], buf[5], buf[6], buf[7]]);
assert!(real.is_nan());
assert!(imag.is_nan());
}
#[test]
fn restore_non_finite_into_no_op_without_masks() {
let desc = make_descriptor(vec![4], Dtype::Float64);
let mut payload = vec![0u8; 32];
let original = payload.clone();
restore_non_finite_into(&mut payload, &desc, &[], ByteOrder::native()).unwrap();
assert_eq!(payload, original);
}
#[test]
fn decode_mask_set_empty_when_no_masks() {
let desc = make_descriptor(vec![4], Dtype::Float64);
let set = decode_mask_set(&desc, &[]).unwrap();
assert!(set.is_empty());
}
#[test]
fn restore_non_finite_into_rejects_bitmask_dtype() {
let mut desc = make_descriptor(vec![4], Dtype::Bitmask);
desc.masks = Some(MasksMetadata {
nan: Some(MaskDescriptor {
method: "none".to_string(),
offset: 0,
length: 1,
params: BTreeMap::new(),
}),
..Default::default()
});
let mut payload = vec![0u8; 1];
let err = restore_non_finite_into(&mut payload, &desc, &[0u8; 1], ByteOrder::native())
.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("bitmask"), "got: {msg}");
}
#[test]
fn restore_non_finite_into_rejects_wrong_payload_length() {
let mut desc = make_descriptor(vec![4], Dtype::Float64);
desc.masks = Some(MasksMetadata {
nan: Some(MaskDescriptor {
method: "none".to_string(),
offset: 32, length: 1,
params: BTreeMap::new(),
}),
..Default::default()
});
let mut short_payload = vec![0u8; 16]; let mask_region: Vec<u8> = bitmask::codecs::encode_none(&[false, true, false, true]);
let err =
restore_non_finite_into(&mut short_payload, &desc, &mask_region, ByteOrder::native())
.unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("decoded payload length"),
"expected length-mismatch error, got: {msg}"
);
}
#[test]
fn decode_one_mask_at_rejects_offset_below_region_base() {
let md = MaskDescriptor {
method: "none".to_string(),
offset: 4, length: 1,
params: BTreeMap::new(),
};
let mask_region = vec![0u8; 1];
let mask_region_base = 10; let err = decode_one_mask_at(&md, &mask_region, mask_region_base, 4).unwrap_err();
assert!(
err.to_string().contains("mask.offset"),
"expected offset-below-base error, got: {err}"
);
}
#[test]
fn restore_non_finite_into_ranges_rejects_bitmask_dtype() {
let mut desc = make_descriptor(vec![4], Dtype::Bitmask);
desc.masks = Some(MasksMetadata {
nan: Some(MaskDescriptor {
method: "none".to_string(),
offset: 0,
length: 1,
params: BTreeMap::new(),
}),
..Default::default()
});
let mask_set = DecodedMaskSet {
nan: Some(vec![true; 4]),
pos_inf: None,
neg_inf: None,
};
let mut parts = vec![vec![0u8; 1]];
let err = restore_non_finite_into_ranges(
&mut parts,
&desc,
&[(0, 1)],
&mask_set,
ByteOrder::native(),
)
.unwrap_err();
assert!(err.to_string().contains("bitmask"), "got: {err}");
}
#[test]
fn restore_non_finite_into_ranges_rejects_parts_len_mismatch() {
let mut desc = make_descriptor(vec![8], Dtype::Float64);
desc.masks = Some(MasksMetadata {
nan: Some(MaskDescriptor {
method: "none".to_string(),
offset: 64,
length: 1,
params: BTreeMap::new(),
}),
..Default::default()
});
let mask_set = DecodedMaskSet {
nan: Some(vec![true; 8]),
pos_inf: None,
neg_inf: None,
};
let mut parts = vec![vec![0u8; 24]]; let err = restore_non_finite_into_ranges(
&mut parts,
&desc,
&[(0, 3), (5, 3)], &mask_set,
ByteOrder::native(),
)
.unwrap_err();
assert!(
err.to_string().contains("range count mismatch"),
"got: {err}"
);
}
#[test]
fn restore_non_finite_into_ranges_rejects_range_past_mask_end() {
let mut desc = make_descriptor(vec![8], Dtype::Float64);
desc.masks = Some(MasksMetadata {
nan: Some(MaskDescriptor {
method: "none".to_string(),
offset: 64,
length: 1,
params: BTreeMap::new(),
}),
..Default::default()
});
let mask_set = DecodedMaskSet {
nan: Some(vec![true; 4]), pos_inf: None,
neg_inf: None,
};
let mut parts = vec![vec![0u8; 48]]; let err = restore_non_finite_into_ranges(
&mut parts,
&desc,
&[(0, 6)], &mask_set,
ByteOrder::native(),
)
.unwrap_err();
assert!(err.to_string().contains("range end"), "got: {err}");
}
#[test]
fn restore_non_finite_into_ranges_rejects_wrong_part_length() {
let mut desc = make_descriptor(vec![8], Dtype::Float64);
desc.masks = Some(MasksMetadata {
nan: Some(MaskDescriptor {
method: "none".to_string(),
offset: 64,
length: 1,
params: BTreeMap::new(),
}),
..Default::default()
});
let mask_set = DecodedMaskSet {
nan: Some(vec![true; 8]),
pos_inf: None,
neg_inf: None,
};
let mut parts = vec![vec![0u8; 24]];
let err = restore_non_finite_into_ranges(
&mut parts,
&desc,
&[(0, 4)],
&mask_set,
ByteOrder::native(),
)
.unwrap_err();
assert!(err.to_string().contains("range part length"), "got: {err}");
}
#[test]
fn decode_one_mask_at_unknown_method_error_is_not_duplicated() {
let md = MaskDescriptor {
method: "bogus".to_string(),
offset: 0,
length: 1,
params: BTreeMap::new(),
};
let mask_region = vec![0u8; 1];
let err = decode_one_mask_at(&md, &mask_region, 0, 4).unwrap_err();
let msg = err.to_string();
let occurrences = msg.matches("bogus").count();
assert_eq!(
occurrences, 1,
"message should name method once, got: {msg}"
);
}
#[test]
fn decode_mask_set_restores_offsets_relative_to_region_base() {
let mut masks = MasksMetadata::default();
let bits = vec![true, false, true, false];
let blob = bitmask::roaring::encode(&bits).unwrap();
masks.nan = Some(MaskDescriptor {
method: "roaring".to_string(),
offset: 32, length: blob.len() as u64,
params: BTreeMap::new(),
});
let mut desc = make_descriptor(vec![4], Dtype::Float64);
desc.masks = Some(masks);
let got = decode_mask_set(&desc, &blob).unwrap();
assert_eq!(got.nan.unwrap(), bits);
}
}