use crate::dtype::Dtype;
use crate::error::{Result, TensogramError};
use tensogram_encodings::ByteOrder;
const PAR_CHUNK_BYTES: usize = 64 * 1024;
pub(crate) fn scan(
data: &[u8],
dtype: Dtype,
byte_order: ByteOrder,
reject_nan: bool,
reject_inf: bool,
parallel: bool,
) -> Result<()> {
if !reject_nan && !reject_inf {
return Ok(());
}
let go_parallel = parallel && data.len() >= PAR_CHUNK_BYTES;
match dtype {
Dtype::Float32 => scan_dispatch(
data,
4,
byte_order,
reject_nan,
reject_inf,
go_parallel,
scan_f32_seq,
),
Dtype::Float64 => scan_dispatch(
data,
8,
byte_order,
reject_nan,
reject_inf,
go_parallel,
scan_f64_seq,
),
Dtype::Float16 => scan_dispatch(
data,
2,
byte_order,
reject_nan,
reject_inf,
go_parallel,
scan_f16_seq,
),
Dtype::Bfloat16 => scan_dispatch(
data,
2,
byte_order,
reject_nan,
reject_inf,
go_parallel,
scan_bf16_seq,
),
Dtype::Complex64 => scan_dispatch(
data,
8,
byte_order,
reject_nan,
reject_inf,
go_parallel,
scan_complex64_seq,
),
Dtype::Complex128 => scan_dispatch(
data,
16,
byte_order,
reject_nan,
reject_inf,
go_parallel,
scan_complex128_seq,
),
Dtype::Int8
| Dtype::Int16
| Dtype::Int32
| Dtype::Int64
| Dtype::Uint8
| Dtype::Uint16
| Dtype::Uint32
| Dtype::Uint64
| Dtype::Bitmask => Ok(()),
}
}
#[inline]
fn scan_dispatch<F>(
data: &[u8],
element_size: usize,
byte_order: ByteOrder,
reject_nan: bool,
reject_inf: bool,
go_parallel: bool,
seq_scan: F,
) -> Result<()>
where
F: Fn(&[u8], ByteOrder, bool, bool, usize) -> Result<()> + Send + Sync + Copy,
{
if !go_parallel {
return seq_scan(data, byte_order, reject_nan, reject_inf, 0);
}
#[cfg(feature = "threads")]
{
use rayon::prelude::*;
let chunk_elements = PAR_CHUNK_BYTES / element_size;
data.par_chunks(PAR_CHUNK_BYTES)
.enumerate()
.try_for_each(|(chunk_idx, chunk)| {
let base = chunk_idx * chunk_elements;
seq_scan(chunk, byte_order, reject_nan, reject_inf, base)
})
}
#[cfg(not(feature = "threads"))]
{
let _ = element_size; seq_scan(data, byte_order, reject_nan, reject_inf, 0)
}
}
fn scan_f32_seq(
data: &[u8],
byte_order: ByteOrder,
reject_nan: bool,
reject_inf: bool,
base_index: usize,
) -> Result<()> {
for (i, chunk) in data.chunks_exact(4).enumerate() {
let &[b0, b1, b2, b3] = chunk else {
continue;
};
let bytes = [b0, b1, b2, b3];
let val = match byte_order {
ByteOrder::Big => f32::from_be_bytes(bytes),
ByteOrder::Little => f32::from_le_bytes(bytes),
};
if reject_nan && val.is_nan() {
return Err(nan_err(base_index + i, "float32", None));
}
if reject_inf && val.is_infinite() {
return Err(inf_err(
base_index + i,
"float32",
None,
val.is_sign_positive(),
));
}
}
Ok(())
}
fn scan_f64_seq(
data: &[u8],
byte_order: ByteOrder,
reject_nan: bool,
reject_inf: bool,
base_index: usize,
) -> Result<()> {
for (i, chunk) in data.chunks_exact(8).enumerate() {
let &[b0, b1, b2, b3, b4, b5, b6, b7] = chunk else {
continue;
};
let bytes = [b0, b1, b2, b3, b4, b5, b6, b7];
let val = match byte_order {
ByteOrder::Big => f64::from_be_bytes(bytes),
ByteOrder::Little => f64::from_le_bytes(bytes),
};
if reject_nan && val.is_nan() {
return Err(nan_err(base_index + i, "float64", None));
}
if reject_inf && val.is_infinite() {
return Err(inf_err(
base_index + i,
"float64",
None,
val.is_sign_positive(),
));
}
}
Ok(())
}
fn scan_f16_seq(
data: &[u8],
byte_order: ByteOrder,
reject_nan: bool,
reject_inf: bool,
base_index: usize,
) -> Result<()> {
for (i, chunk) in data.chunks_exact(2).enumerate() {
let &[b0, b1] = chunk else { continue };
let bytes = [b0, b1];
let bits = match byte_order {
ByteOrder::Big => u16::from_be_bytes(bytes),
ByteOrder::Little => u16::from_le_bytes(bytes),
};
let exp = (bits >> 10) & 0x1F;
if exp != 0x1F {
continue;
}
let mantissa = bits & 0x03FF;
if mantissa != 0 {
if reject_nan {
return Err(nan_err(base_index + i, "float16", None));
}
} else if reject_inf {
let positive = (bits & 0x8000) == 0;
return Err(inf_err(base_index + i, "float16", None, positive));
}
}
Ok(())
}
fn scan_bf16_seq(
data: &[u8],
byte_order: ByteOrder,
reject_nan: bool,
reject_inf: bool,
base_index: usize,
) -> Result<()> {
for (i, chunk) in data.chunks_exact(2).enumerate() {
let &[b0, b1] = chunk else { continue };
let bytes = [b0, b1];
let bits = match byte_order {
ByteOrder::Big => u16::from_be_bytes(bytes),
ByteOrder::Little => u16::from_le_bytes(bytes),
};
let exp = (bits >> 7) & 0xFF;
if exp != 0xFF {
continue;
}
let mantissa = bits & 0x7F;
if mantissa != 0 {
if reject_nan {
return Err(nan_err(base_index + i, "bfloat16", None));
}
} else if reject_inf {
let positive = (bits & 0x8000) == 0;
return Err(inf_err(base_index + i, "bfloat16", None, positive));
}
}
Ok(())
}
fn scan_complex64_seq(
data: &[u8],
byte_order: ByteOrder,
reject_nan: bool,
reject_inf: bool,
base_index: usize,
) -> Result<()> {
for (i, chunk) in data.chunks_exact(8).enumerate() {
let &[r0, r1, r2, r3, i0, i1, i2, i3] = chunk else {
continue;
};
let real_bytes = [r0, r1, r2, r3];
let imag_bytes = [i0, i1, i2, i3];
let (real, imag) = match byte_order {
ByteOrder::Big => (
f32::from_be_bytes(real_bytes),
f32::from_be_bytes(imag_bytes),
),
ByteOrder::Little => (
f32::from_le_bytes(real_bytes),
f32::from_le_bytes(imag_bytes),
),
};
if reject_nan {
if real.is_nan() {
return Err(nan_err(base_index + i, "complex64", Some("real")));
}
if imag.is_nan() {
return Err(nan_err(base_index + i, "complex64", Some("imaginary")));
}
}
if reject_inf {
if real.is_infinite() {
return Err(inf_err(
base_index + i,
"complex64",
Some("real"),
real.is_sign_positive(),
));
}
if imag.is_infinite() {
return Err(inf_err(
base_index + i,
"complex64",
Some("imaginary"),
imag.is_sign_positive(),
));
}
}
}
Ok(())
}
fn scan_complex128_seq(
data: &[u8],
byte_order: ByteOrder,
reject_nan: bool,
reject_inf: bool,
base_index: usize,
) -> Result<()> {
for (i, chunk) in data.chunks_exact(16).enumerate() {
let &[
r0,
r1,
r2,
r3,
r4,
r5,
r6,
r7,
i0,
i1,
i2,
i3,
i4,
i5,
i6,
i7,
] = chunk
else {
continue;
};
let real_bytes = [r0, r1, r2, r3, r4, r5, r6, r7];
let imag_bytes = [i0, i1, i2, i3, i4, i5, i6, i7];
let (real, imag) = match byte_order {
ByteOrder::Big => (
f64::from_be_bytes(real_bytes),
f64::from_be_bytes(imag_bytes),
),
ByteOrder::Little => (
f64::from_le_bytes(real_bytes),
f64::from_le_bytes(imag_bytes),
),
};
if reject_nan {
if real.is_nan() {
return Err(nan_err(base_index + i, "complex128", Some("real")));
}
if imag.is_nan() {
return Err(nan_err(base_index + i, "complex128", Some("imaginary")));
}
}
if reject_inf {
if real.is_infinite() {
return Err(inf_err(
base_index + i,
"complex128",
Some("real"),
real.is_sign_positive(),
));
}
if imag.is_infinite() {
return Err(inf_err(
base_index + i,
"complex128",
Some("imaginary"),
imag.is_sign_positive(),
));
}
}
}
Ok(())
}
fn nan_err(index: usize, dtype: &str, component: Option<&str>) -> TensogramError {
let suffix = component
.map(|c| format!(" ({c} component)"))
.unwrap_or_default();
TensogramError::Encoding(format!(
"strict-NaN check: NaN at element {index} of {dtype} array{suffix}"
))
}
fn inf_err(index: usize, dtype: &str, component: Option<&str>, positive: bool) -> TensogramError {
let sign = if positive { "+Inf" } else { "-Inf" };
let suffix = component
.map(|c| format!(" ({c} component)"))
.unwrap_or_default();
TensogramError::Encoding(format!(
"strict-Inf check: {sign} at element {index} of {dtype} array{suffix}"
))
}
#[cfg(test)]
mod tests {
use super::*;
fn f32_bytes(values: &[f32]) -> Vec<u8> {
values.iter().flat_map(|v| v.to_ne_bytes()).collect()
}
fn f64_bytes(values: &[f64]) -> Vec<u8> {
values.iter().flat_map(|v| v.to_ne_bytes()).collect()
}
#[test]
fn short_circuits_when_both_flags_off() {
let data = f32_bytes(&[f32::NAN, f32::INFINITY]);
scan(
&data,
Dtype::Float32,
ByteOrder::native(),
false,
false,
false,
)
.expect("both flags off must short-circuit");
}
#[test]
fn integer_dtype_is_zero_cost() {
let data: Vec<u8> = vec![0xFF; 64];
scan(&data, Dtype::Uint32, ByteOrder::native(), true, true, false)
.expect("integer dtypes must short-circuit");
}
#[test]
fn bitmask_dtype_is_zero_cost() {
let data: Vec<u8> = vec![0xFF; 16];
scan(
&data,
Dtype::Bitmask,
ByteOrder::native(),
true,
true,
false,
)
.expect("bitmask dtype must short-circuit");
}
#[test]
fn sequential_f32_reports_global_first_nan() {
let data = f32_bytes(&[1.0, 2.0, 3.0, f32::NAN, 5.0, f32::NAN]);
let err = scan(
&data,
Dtype::Float32,
ByteOrder::native(),
true,
false,
false,
)
.unwrap_err();
assert!(err.to_string().contains("element 3"));
}
#[test]
fn sequential_f64_reports_positive_inf_sign() {
let data = f64_bytes(&[1.0, f64::INFINITY]);
let err = scan(
&data,
Dtype::Float64,
ByteOrder::native(),
false,
true,
false,
)
.unwrap_err();
assert!(err.to_string().contains("+Inf"));
}
#[test]
fn sequential_f64_reports_negative_inf_sign() {
let data = f64_bytes(&[1.0, f64::NEG_INFINITY]);
let err = scan(
&data,
Dtype::Float64,
ByteOrder::native(),
false,
true,
false,
)
.unwrap_err();
assert!(err.to_string().contains("-Inf"));
}
#[test]
fn float16_quiet_nan_detected() {
let bits: u16 = 0x7E00;
let data: Vec<u8> = bits.to_ne_bytes().to_vec();
let err = scan(
&data,
Dtype::Float16,
ByteOrder::native(),
true,
false,
false,
)
.unwrap_err();
assert!(err.to_string().contains("NaN"));
assert!(err.to_string().contains("float16"));
}
#[test]
fn float16_inf_detected() {
let bits: u16 = 0x7C00;
let data: Vec<u8> = bits.to_ne_bytes().to_vec();
let err = scan(
&data,
Dtype::Float16,
ByteOrder::native(),
false,
true,
false,
)
.unwrap_err();
assert!(err.to_string().contains("+Inf"));
}
#[test]
fn bfloat16_quiet_nan_detected() {
let bits: u16 = 0x7FC0;
let data: Vec<u8> = bits.to_ne_bytes().to_vec();
let err = scan(
&data,
Dtype::Bfloat16,
ByteOrder::native(),
true,
false,
false,
)
.unwrap_err();
assert!(err.to_string().contains("NaN"));
}
#[test]
fn complex64_nan_in_imag_detected() {
let data: Vec<u8> = [1.0_f32, f32::NAN]
.iter()
.flat_map(|v| v.to_ne_bytes())
.collect();
let err = scan(
&data,
Dtype::Complex64,
ByteOrder::native(),
true,
false,
false,
)
.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("imaginary"), "got: {msg}");
}
#[test]
fn finite_data_passes_when_both_flags_on() {
let data = f64_bytes(&[0.0, -0.0, 1.0, -1.0, f64::MIN, f64::MAX]);
scan(
&data,
Dtype::Float64,
ByteOrder::native(),
true,
true,
false,
)
.expect("finite data must pass both flags");
}
#[cfg(feature = "threads")]
#[test]
fn parallel_path_rejects_nan() {
let mut values = vec![1.0_f64; 16_384]; values[9_001] = f64::NAN;
let data = f64_bytes(&values);
let err = scan(
&data,
Dtype::Float64,
ByteOrder::native(),
true,
false,
true,
)
.unwrap_err();
assert!(err.to_string().contains("NaN"));
}
}