use std::borrow::Cow;
use crate::dtype::Dtype;
use crate::error::{Result, TensogramError};
use tensogram_encodings::ByteOrder;
const PAR_CHUNK_BYTES: usize = 64 * 1024;
#[derive(Debug, Clone, Default)]
pub struct MaskSet {
pub nan: Option<Vec<bool>>,
pub pos_inf: Option<Vec<bool>>,
pub neg_inf: Option<Vec<bool>>,
pub n_elements: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MaskKind {
Nan,
PosInf,
NegInf,
}
impl MaskSet {
pub fn empty(n_elements: usize) -> Self {
Self {
nan: None,
pos_inf: None,
neg_inf: None,
n_elements,
}
}
pub fn is_empty(&self) -> bool {
self.nan.is_none() && self.pos_inf.is_none() && self.neg_inf.is_none()
}
fn slot_for(&mut self, kind: MaskKind) -> &mut Option<Vec<bool>> {
match kind {
MaskKind::Nan => &mut self.nan,
MaskKind::PosInf => &mut self.pos_inf,
MaskKind::NegInf => &mut self.neg_inf,
}
}
fn set_bit(&mut self, kind: MaskKind, index: usize) {
let n = self.n_elements;
let slot = self.slot_for(kind);
let mask = slot.get_or_insert_with(|| vec![false; n]);
mask[index] = true;
}
#[cfg(feature = "threads")]
fn merge_chunk(&mut self, other: MaskSet, base: usize) {
for (kind, bits) in [
(MaskKind::Nan, other.nan),
(MaskKind::PosInf, other.pos_inf),
(MaskKind::NegInf, other.neg_inf),
] {
let Some(bits) = bits else { continue };
for (i, &b) in bits.iter().enumerate() {
if b {
self.set_bit(kind, base + i);
}
}
}
}
fn iter_bits(&self) -> impl Iterator<Item = &[bool]> {
[
self.nan.as_ref(),
self.pos_inf.as_ref(),
self.neg_inf.as_ref(),
]
.into_iter()
.filter_map(|m| m.map(|v| v.as_slice()))
}
}
pub(crate) fn substitute_and_mask<'a>(
data: &'a [u8],
dtype: Dtype,
byte_order: ByteOrder,
allow_nan: bool,
allow_inf: bool,
parallel: bool,
) -> Result<(Cow<'a, [u8]>, MaskSet)> {
let elem_size = match dtype {
Dtype::Float16 | Dtype::Bfloat16 => 2,
Dtype::Float32 => 4,
Dtype::Float64 | Dtype::Complex64 => 8,
Dtype::Complex128 => 16,
Dtype::Int8
| Dtype::Int16
| Dtype::Int32
| Dtype::Int64
| Dtype::Uint8
| Dtype::Uint16
| Dtype::Uint32
| Dtype::Uint64
| Dtype::Bitmask => {
return Ok((Cow::Borrowed(data), MaskSet::empty(0)));
}
};
if data.is_empty() {
return Ok((Cow::Borrowed(data), MaskSet::empty(0)));
}
let n_elements = data.len() / elem_size;
let go_parallel = parallel && data.len() >= PAR_CHUNK_BYTES;
if !go_parallel {
return run_sequential(
data, dtype, byte_order, allow_nan, allow_inf, elem_size, n_elements, 0,
);
}
#[cfg(feature = "threads")]
{
run_parallel(
data, dtype, byte_order, allow_nan, allow_inf, elem_size, n_elements,
)
}
#[cfg(not(feature = "threads"))]
{
run_sequential(
data, dtype, byte_order, allow_nan, allow_inf, elem_size, n_elements, 0,
)
}
}
#[allow(clippy::too_many_arguments)]
fn run_sequential<'a>(
data: &'a [u8],
dtype: Dtype,
byte_order: ByteOrder,
allow_nan: bool,
allow_inf: bool,
elem_size: usize,
n_elements: usize,
base_index: usize,
) -> Result<(Cow<'a, [u8]>, MaskSet)> {
let mut masks = MaskSet::empty(n_elements);
let mut saw_any = false;
match dtype {
Dtype::Float32 => scan_f32(
data,
byte_order,
base_index,
allow_nan,
allow_inf,
&mut masks,
&mut saw_any,
)?,
Dtype::Float64 => scan_f64(
data,
byte_order,
base_index,
allow_nan,
allow_inf,
&mut masks,
&mut saw_any,
)?,
Dtype::Float16 => scan_half(
data,
byte_order,
base_index,
allow_nan,
allow_inf,
&mut masks,
&mut saw_any,
&F16_LAYOUT,
)?,
Dtype::Bfloat16 => scan_half(
data,
byte_order,
base_index,
allow_nan,
allow_inf,
&mut masks,
&mut saw_any,
&BF16_LAYOUT,
)?,
Dtype::Complex64 => scan_complex64(
data,
byte_order,
base_index,
allow_nan,
allow_inf,
&mut masks,
&mut saw_any,
)?,
Dtype::Complex128 => scan_complex128(
data,
byte_order,
base_index,
allow_nan,
allow_inf,
&mut masks,
&mut saw_any,
)?,
Dtype::Int8
| Dtype::Int16
| Dtype::Int32
| Dtype::Int64
| Dtype::Uint8
| Dtype::Uint16
| Dtype::Uint32
| Dtype::Uint64
| Dtype::Bitmask => {
debug_assert!(false, "non-float dtype {dtype:?} reached run_sequential");
return Ok((Cow::Borrowed(data), MaskSet::empty(n_elements)));
}
}
if !saw_any {
return Ok((Cow::Borrowed(data), MaskSet::empty(n_elements)));
}
let mut out = data.to_vec();
zero_at_mask_positions(&mut out, elem_size, &masks);
Ok((Cow::Owned(out), masks))
}
#[cfg(feature = "threads")]
#[allow(clippy::too_many_arguments)]
fn run_parallel<'a>(
data: &'a [u8],
dtype: Dtype,
byte_order: ByteOrder,
allow_nan: bool,
allow_inf: bool,
elem_size: usize,
n_elements: usize,
) -> Result<(Cow<'a, [u8]>, MaskSet)> {
use rayon::prelude::*;
let chunk_elements = PAR_CHUNK_BYTES / elem_size;
let per_chunk: Vec<MaskSet> = data
.par_chunks(PAR_CHUNK_BYTES)
.enumerate()
.map(|(chunk_idx, chunk)| -> Result<MaskSet> {
let base = chunk_idx * chunk_elements;
let (_, masks) = run_sequential(
chunk,
dtype,
byte_order,
allow_nan,
allow_inf,
elem_size,
chunk.len() / elem_size,
base,
)?;
Ok(masks)
})
.collect::<Result<Vec<_>>>()?;
let mut masks = MaskSet::empty(n_elements);
let mut saw_any = false;
for (chunk_idx, chunk_masks) in per_chunk.into_iter().enumerate() {
if !chunk_masks.is_empty() {
saw_any = true;
let base = chunk_idx * chunk_elements;
masks.merge_chunk(chunk_masks, base);
}
}
if !saw_any {
return Ok((Cow::Borrowed(data), MaskSet::empty(n_elements)));
}
let mut out = data.to_vec();
zero_at_mask_positions(&mut out, elem_size, &masks);
Ok((Cow::Owned(out), masks))
}
fn zero_at_mask_positions(buf: &mut [u8], elem_size: usize, masks: &MaskSet) {
for bits in masks.iter_bits() {
for (i, &is_set) in bits.iter().enumerate() {
if is_set {
let start = i * elem_size;
buf[start..start + elem_size].fill(0);
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn scan_f32(
chunk: &[u8],
byte_order: ByteOrder,
base_index: usize,
allow_nan: bool,
allow_inf: bool,
masks: &mut MaskSet,
saw_any: &mut bool,
) -> Result<()> {
for (i, c) in chunk.chunks_exact(4).enumerate() {
let &[b0, b1, b2, b3] = c else { continue };
let val = match byte_order {
ByteOrder::Big => f32::from_be_bytes([b0, b1, b2, b3]),
ByteOrder::Little => f32::from_le_bytes([b0, b1, b2, b3]),
};
classify_scalar(val.is_nan(), val.is_infinite(), val.is_sign_positive()).dispatch(
i,
base_index + i,
"float32",
None,
allow_nan,
allow_inf,
masks,
saw_any,
)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn scan_f64(
chunk: &[u8],
byte_order: ByteOrder,
base_index: usize,
allow_nan: bool,
allow_inf: bool,
masks: &mut MaskSet,
saw_any: &mut bool,
) -> Result<()> {
for (i, c) in chunk.chunks_exact(8).enumerate() {
let &[b0, b1, b2, b3, b4, b5, b6, b7] = c else {
continue;
};
let val = match byte_order {
ByteOrder::Big => f64::from_be_bytes([b0, b1, b2, b3, b4, b5, b6, b7]),
ByteOrder::Little => f64::from_le_bytes([b0, b1, b2, b3, b4, b5, b6, b7]),
};
classify_scalar(val.is_nan(), val.is_infinite(), val.is_sign_positive()).dispatch(
i,
base_index + i,
"float64",
None,
allow_nan,
allow_inf,
masks,
saw_any,
)?;
}
Ok(())
}
struct HalfLayout {
exp_shift: u16,
exp_all_ones: u16,
mantissa_mask: u16,
dtype_name: &'static str,
}
const F16_LAYOUT: HalfLayout = HalfLayout {
exp_shift: 10,
exp_all_ones: 0x1F,
mantissa_mask: 0x03FF,
dtype_name: "float16",
};
const BF16_LAYOUT: HalfLayout = HalfLayout {
exp_shift: 7,
exp_all_ones: 0xFF,
mantissa_mask: 0x7F,
dtype_name: "bfloat16",
};
fn classify_half(bits: u16, layout: &HalfLayout) -> Classification {
let exp = (bits >> layout.exp_shift) & layout.exp_all_ones;
if exp != layout.exp_all_ones {
return Classification::Finite;
}
let mantissa = bits & layout.mantissa_mask;
let positive = (bits & 0x8000) == 0;
if mantissa != 0 {
Classification::Nan
} else if positive {
Classification::PosInf
} else {
Classification::NegInf
}
}
#[allow(clippy::too_many_arguments)]
fn scan_half(
chunk: &[u8],
byte_order: ByteOrder,
base_index: usize,
allow_nan: bool,
allow_inf: bool,
masks: &mut MaskSet,
saw_any: &mut bool,
layout: &HalfLayout,
) -> Result<()> {
for (i, c) in chunk.chunks_exact(2).enumerate() {
let &[b0, b1] = c else { continue };
let bits = match byte_order {
ByteOrder::Big => u16::from_be_bytes([b0, b1]),
ByteOrder::Little => u16::from_le_bytes([b0, b1]),
};
classify_half(bits, layout).dispatch(
i,
base_index + i,
layout.dtype_name,
None,
allow_nan,
allow_inf,
masks,
saw_any,
)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn scan_complex64(
chunk: &[u8],
byte_order: ByteOrder,
base_index: usize,
allow_nan: bool,
allow_inf: bool,
masks: &mut MaskSet,
saw_any: &mut bool,
) -> Result<()> {
for (i, c) in chunk.chunks_exact(8).enumerate() {
let &[r0, r1, r2, r3, im0, im1, im2, im3] = c else {
continue;
};
let (real, imag) = match byte_order {
ByteOrder::Big => (
f32::from_be_bytes([r0, r1, r2, r3]),
f32::from_be_bytes([im0, im1, im2, im3]),
),
ByteOrder::Little => (
f32::from_le_bytes([r0, r1, r2, r3]),
f32::from_le_bytes([im0, im1, im2, im3]),
),
};
let (kind, component) = classify_complex(real, imag);
kind.dispatch(
i,
base_index + i,
"complex64",
component,
allow_nan,
allow_inf,
masks,
saw_any,
)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn scan_complex128(
chunk: &[u8],
byte_order: ByteOrder,
base_index: usize,
allow_nan: bool,
allow_inf: bool,
masks: &mut MaskSet,
saw_any: &mut bool,
) -> Result<()> {
for (i, c) in chunk.chunks_exact(16).enumerate() {
let &[
r0,
r1,
r2,
r3,
r4,
r5,
r6,
r7,
im0,
im1,
im2,
im3,
im4,
im5,
im6,
im7,
] = c
else {
continue;
};
let (real, imag) = match byte_order {
ByteOrder::Big => (
f64::from_be_bytes([r0, r1, r2, r3, r4, r5, r6, r7]),
f64::from_be_bytes([im0, im1, im2, im3, im4, im5, im6, im7]),
),
ByteOrder::Little => (
f64::from_le_bytes([r0, r1, r2, r3, r4, r5, r6, r7]),
f64::from_le_bytes([im0, im1, im2, im3, im4, im5, im6, im7]),
),
};
let (kind, component) = classify_complex(real, imag);
kind.dispatch(
i,
base_index + i,
"complex128",
component,
allow_nan,
allow_inf,
masks,
saw_any,
)?;
}
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Classification {
Finite,
Nan,
PosInf,
NegInf,
}
impl Classification {
fn mask_kind(self) -> Option<MaskKind> {
match self {
Classification::Finite => None,
Classification::Nan => Some(MaskKind::Nan),
Classification::PosInf => Some(MaskKind::PosInf),
Classification::NegInf => Some(MaskKind::NegInf),
}
}
#[allow(clippy::too_many_arguments)]
fn dispatch(
self,
local_index: usize,
global_index: usize,
dtype_name: &'static str,
component: Option<&'static str>,
allow_nan: bool,
allow_inf: bool,
masks: &mut MaskSet,
saw_any: &mut bool,
) -> Result<()> {
let Some(kind) = self.mask_kind() else {
return Ok(()); };
let allowed = match kind {
MaskKind::Nan => allow_nan,
MaskKind::PosInf | MaskKind::NegInf => allow_inf,
};
if !allowed {
return Err(match kind {
MaskKind::Nan => nan_err(global_index, dtype_name, component),
MaskKind::PosInf => inf_err(global_index, dtype_name, component, true),
MaskKind::NegInf => inf_err(global_index, dtype_name, component, false),
});
}
masks.set_bit(kind, local_index);
*saw_any = true;
Ok(())
}
}
fn classify_scalar(is_nan: bool, is_infinite: bool, is_sign_positive: bool) -> Classification {
if is_nan {
Classification::Nan
} else if is_infinite {
if is_sign_positive {
Classification::PosInf
} else {
Classification::NegInf
}
} else {
Classification::Finite
}
}
trait ComplexComponent: Copy {
fn is_nan(self) -> bool;
fn is_infinite(self) -> bool;
fn is_sign_positive(self) -> bool;
}
impl ComplexComponent for f32 {
fn is_nan(self) -> bool {
f32::is_nan(self)
}
fn is_infinite(self) -> bool {
f32::is_infinite(self)
}
fn is_sign_positive(self) -> bool {
f32::is_sign_positive(self)
}
}
impl ComplexComponent for f64 {
fn is_nan(self) -> bool {
f64::is_nan(self)
}
fn is_infinite(self) -> bool {
f64::is_infinite(self)
}
fn is_sign_positive(self) -> bool {
f64::is_sign_positive(self)
}
}
fn classify_complex<F: ComplexComponent>(
real: F,
imag: F,
) -> (Classification, Option<&'static str>) {
if real.is_nan() {
return (Classification::Nan, Some("real"));
}
if imag.is_nan() {
return (Classification::Nan, Some("imaginary"));
}
if real.is_infinite() && real.is_sign_positive() {
return (Classification::PosInf, Some("real"));
}
if imag.is_infinite() && imag.is_sign_positive() {
return (Classification::PosInf, Some("imaginary"));
}
if real.is_infinite() {
return (Classification::NegInf, Some("real"));
}
if imag.is_infinite() {
return (Classification::NegInf, Some("imaginary"));
}
(Classification::Finite, None)
}
fn nan_err(index: usize, dtype: &str, component: Option<&str>) -> TensogramError {
let suffix = component
.map(|c| format!(" ({c} component)"))
.unwrap_or_default();
TensogramError::Encoding(format!(
"strict-NaN check: NaN at element {index} of {dtype} array{suffix}; \
pass allow_nan=true to substitute with 0.0 and record positions in a mask"
))
}
fn inf_err(index: usize, dtype: &str, component: Option<&str>, positive: bool) -> TensogramError {
let sign = if positive { "+Inf" } else { "-Inf" };
let suffix = component
.map(|c| format!(" ({c} component)"))
.unwrap_or_default();
TensogramError::Encoding(format!(
"strict-Inf check: {sign} at element {index} of {dtype} array{suffix}; \
pass allow_inf=true to substitute with 0.0 and record positions in a mask"
))
}
#[cfg(test)]
mod tests {
use super::*;
fn f32_bytes(values: &[f32]) -> Vec<u8> {
values.iter().flat_map(|v| v.to_ne_bytes()).collect()
}
fn f64_bytes(values: &[f64]) -> Vec<u8> {
values.iter().flat_map(|v| v.to_ne_bytes()).collect()
}
#[test]
fn non_float_dtype_zero_cost_borrowed() {
let data = vec![0xFFu8; 64];
let (out, masks) =
substitute_and_mask(&data, Dtype::Uint32, ByteOrder::native(), true, true, false)
.unwrap();
assert!(matches!(out, Cow::Borrowed(_)));
assert!(masks.is_empty());
}
#[test]
fn finite_float_input_zero_cost_borrowed() {
let data = f64_bytes(&[1.0, 2.0, 3.0, 4.0, -1.0]);
let (out, masks) = substitute_and_mask(
&data,
Dtype::Float64,
ByteOrder::native(),
true,
true,
false,
)
.unwrap();
assert!(matches!(out, Cow::Borrowed(_)));
assert!(masks.is_empty());
}
#[test]
fn disallowed_nan_errors_with_hint() {
let data = f32_bytes(&[1.0, f32::NAN]);
let err = substitute_and_mask(
&data,
Dtype::Float32,
ByteOrder::native(),
false,
false,
false,
)
.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("element 1"), "got: {msg}");
assert!(msg.contains("allow_nan=true"), "got: {msg}");
}
#[test]
fn disallowed_pos_inf_errors_with_hint() {
let data = f64_bytes(&[1.0, f64::INFINITY]);
let err = substitute_and_mask(
&data,
Dtype::Float64,
ByteOrder::native(),
false,
false,
false,
)
.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("+Inf"), "got: {msg}");
assert!(msg.contains("allow_inf=true"), "got: {msg}");
}
#[test]
fn allow_nan_produces_mask_and_substitutes_with_zero() {
let data = f64_bytes(&[1.0, f64::NAN, 3.0, f64::NAN]);
let (out, masks) = substitute_and_mask(
&data,
Dtype::Float64,
ByteOrder::native(),
true,
false,
false,
)
.unwrap();
let got = out.as_ref();
let expected = f64_bytes(&[1.0, 0.0, 3.0, 0.0]);
assert_eq!(got, expected.as_slice());
let nan = masks.nan.unwrap();
assert_eq!(nan, vec![false, true, false, true]);
assert!(masks.pos_inf.is_none());
assert!(masks.neg_inf.is_none());
}
#[test]
fn allow_inf_masks_both_signs_separately() {
let data = f64_bytes(&[0.0, f64::INFINITY, f64::NEG_INFINITY, 2.0]);
let (_, masks) = substitute_and_mask(
&data,
Dtype::Float64,
ByteOrder::native(),
false,
true,
false,
)
.unwrap();
assert_eq!(masks.pos_inf.unwrap(), vec![false, true, false, false]);
assert_eq!(masks.neg_inf.unwrap(), vec![false, false, true, false]);
assert!(masks.nan.is_none());
}
#[test]
fn complex64_nan_wins_over_inf_with_priority_rule() {
let data: Vec<u8> = [f32::NAN, f32::INFINITY]
.iter()
.flat_map(|v| v.to_ne_bytes())
.collect();
let (_, masks) = substitute_and_mask(
&data,
Dtype::Complex64,
ByteOrder::native(),
true,
true,
false,
)
.unwrap();
assert_eq!(masks.nan.unwrap(), vec![true]);
assert!(masks.pos_inf.is_none());
assert!(masks.neg_inf.is_none());
}
#[test]
fn complex64_pos_inf_wins_over_neg_inf() {
let data: Vec<u8> = [f32::NEG_INFINITY, f32::INFINITY]
.iter()
.flat_map(|v| v.to_ne_bytes())
.collect();
let (_, masks) = substitute_and_mask(
&data,
Dtype::Complex64,
ByteOrder::native(),
true,
true,
false,
)
.unwrap();
assert_eq!(masks.pos_inf.unwrap(), vec![true]);
assert!(masks.nan.is_none());
assert!(masks.neg_inf.is_none());
}
#[test]
fn complex64_substitution_zeroes_both_components() {
let data: Vec<u8> = [1.0_f32, f32::NAN]
.iter()
.flat_map(|v| v.to_ne_bytes())
.collect();
let (out, _) = substitute_and_mask(
&data,
Dtype::Complex64,
ByteOrder::native(),
true,
false,
false,
)
.unwrap();
let got = out.as_ref();
let expected: Vec<u8> = [0.0_f32, 0.0_f32]
.iter()
.flat_map(|v| v.to_ne_bytes())
.collect();
assert_eq!(got, expected.as_slice());
}
#[test]
fn float16_nan_detected_and_substituted() {
let bits: u16 = 0x7E00;
let data = bits.to_ne_bytes().to_vec();
let (out, masks) = substitute_and_mask(
&data,
Dtype::Float16,
ByteOrder::native(),
true,
false,
false,
)
.unwrap();
assert_eq!(out.as_ref(), &[0u8, 0u8][..]);
assert_eq!(masks.nan.unwrap(), vec![true]);
}
#[test]
fn bfloat16_neg_inf_detected_with_sign() {
let bits: u16 = 0xFF80;
let data = bits.to_ne_bytes().to_vec();
let (_, masks) = substitute_and_mask(
&data,
Dtype::Bfloat16,
ByteOrder::native(),
false,
true,
false,
)
.unwrap();
assert_eq!(masks.neg_inf.unwrap(), vec![true]);
assert!(masks.pos_inf.is_none());
}
#[test]
fn mixed_nan_and_inf_build_separate_masks() {
let data = f64_bytes(&[
1.0,
f64::NAN,
2.0,
f64::INFINITY,
3.0,
f64::NEG_INFINITY,
4.0,
f64::NAN,
]);
let (out, masks) = substitute_and_mask(
&data,
Dtype::Float64,
ByteOrder::native(),
true,
true,
false,
)
.unwrap();
let expected = f64_bytes(&[1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0]);
assert_eq!(out.as_ref(), expected.as_slice());
assert_eq!(
masks.nan.unwrap(),
vec![false, true, false, false, false, false, false, true]
);
assert_eq!(
masks.pos_inf.unwrap(),
vec![false, false, false, true, false, false, false, false]
);
assert_eq!(
masks.neg_inf.unwrap(),
vec![false, false, false, false, false, true, false, false]
);
}
#[cfg(feature = "threads")]
#[test]
fn parallel_path_matches_sequential() {
let mut values: Vec<f64> = (0..16_384).map(|i| i as f64).collect();
values[9_001] = f64::NAN;
values[9_002] = f64::INFINITY;
values[16_000] = f64::NEG_INFINITY;
let data = f64_bytes(&values);
let (seq_out, seq_masks) = substitute_and_mask(
&data,
Dtype::Float64,
ByteOrder::native(),
true,
true,
false,
)
.unwrap();
let (par_out, par_masks) =
substitute_and_mask(&data, Dtype::Float64, ByteOrder::native(), true, true, true)
.unwrap();
assert_eq!(seq_out.as_ref(), par_out.as_ref());
assert_eq!(seq_masks.nan, par_masks.nan);
assert_eq!(seq_masks.pos_inf, par_masks.pos_inf);
assert_eq!(seq_masks.neg_inf, par_masks.neg_inf);
}
#[test]
fn empty_input_passes_through() {
let data: Vec<u8> = vec![];
let (out, masks) = substitute_and_mask(
&data,
Dtype::Float64,
ByteOrder::native(),
true,
true,
false,
)
.unwrap();
assert!(matches!(out, Cow::Borrowed(_)));
assert!(masks.is_empty());
}
}