use std::collections::BTreeMap;
use tensogram::*;
fn make_global_meta() -> GlobalMetadata {
GlobalMetadata {
..Default::default()
}
}
fn make_descriptor(shape: Vec<u64>, dtype: Dtype) -> DataObjectDescriptor {
let strides = if shape.is_empty() {
vec![]
} else {
let mut s = vec![1u64; shape.len()];
for i in (0..shape.len() - 1).rev() {
s[i] = s[i + 1] * shape[i + 1];
}
s
};
DataObjectDescriptor {
obj_type: "ntensor".to_string(),
ndim: shape.len() as u64,
shape,
strides,
dtype,
byte_order: ByteOrder::native(),
encoding: "none".to_string(),
filter: "none".to_string(),
compression: "none".to_string(),
masks: None,
params: BTreeMap::new(),
}
}
fn f64_bytes(values: &[f64]) -> Vec<u8> {
values.iter().flat_map(|v| v.to_ne_bytes()).collect()
}
fn decode_f64_vec(bytes: &[u8]) -> Vec<f64> {
bytes
.chunks_exact(8)
.map(|c| f64::from_ne_bytes(c.try_into().unwrap()))
.collect()
}
#[test]
fn decode_range_restores_nan_when_position_falls_inside_range() {
let data = f64_bytes(&[1.0, f64::NAN, 3.0, f64::NAN, 5.0, 6.0, f64::NAN, 8.0]);
let desc = make_descriptor(vec![8], Dtype::Float64);
let options = EncodeOptions {
allow_nan: true,
hash_algorithm: None,
small_mask_threshold_bytes: 0,
..Default::default()
};
let msg = encode(&make_global_meta(), &[(&desc, &data)], &options).unwrap();
let (_desc_out, parts) = decode_range(&msg, 0, &[(2, 4)], &DecodeOptions::default()).unwrap();
let got = decode_f64_vec(&parts[0]);
assert_eq!(got.len(), 4);
assert_eq!(got[0], 3.0);
assert!(got[1].is_nan());
assert_eq!(got[2], 5.0);
assert_eq!(got[3], 6.0);
}
#[test]
fn decode_range_without_any_masked_positions_returns_finite_values() {
let data = f64_bytes(&[1.0, f64::NAN, 3.0, 4.0, 5.0, 6.0, f64::NAN, 8.0]);
let desc = make_descriptor(vec![8], Dtype::Float64);
let options = EncodeOptions {
allow_nan: true,
hash_algorithm: None,
small_mask_threshold_bytes: 0,
..Default::default()
};
let msg = encode(&make_global_meta(), &[(&desc, &data)], &options).unwrap();
let (_, parts) = decode_range(&msg, 0, &[(2, 4)], &DecodeOptions::default()).unwrap();
let got = decode_f64_vec(&parts[0]);
for (i, v) in got.iter().enumerate() {
assert!(!v.is_nan(), "pos {i} must be finite, got {v}");
}
assert_eq!(got, vec![3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn decode_range_multiple_ranges_each_gets_correct_restoration() {
let data = f64_bytes(&[
f64::NAN,
1.0,
2.0,
f64::INFINITY,
4.0,
5.0,
f64::NEG_INFINITY,
7.0,
]);
let desc = make_descriptor(vec![8], Dtype::Float64);
let options = EncodeOptions {
allow_nan: true,
allow_inf: true,
hash_algorithm: None,
small_mask_threshold_bytes: 0,
..Default::default()
};
let msg = encode(&make_global_meta(), &[(&desc, &data)], &options).unwrap();
let (_, parts) = decode_range(
&msg,
0,
&[(0, 2), (2, 3), (5, 3)],
&DecodeOptions::default(),
)
.unwrap();
assert_eq!(parts.len(), 3);
let got0 = decode_f64_vec(&parts[0]);
assert!(got0[0].is_nan());
assert_eq!(got0[1], 1.0);
let got1 = decode_f64_vec(&parts[1]);
assert_eq!(got1[0], 2.0);
assert!(got1[1].is_infinite() && got1[1].is_sign_positive());
assert_eq!(got1[2], 4.0);
let got2 = decode_f64_vec(&parts[2]);
assert_eq!(got2[0], 5.0);
assert!(got2[1].is_infinite() && got2[1].is_sign_negative());
assert_eq!(got2[2], 7.0);
}
#[test]
fn decode_range_restore_off_returns_substituted_zeros() {
let data = f64_bytes(&[1.0, f64::NAN, 3.0, 4.0]);
let desc = make_descriptor(vec![4], Dtype::Float64);
let options = EncodeOptions {
allow_nan: true,
hash_algorithm: None,
small_mask_threshold_bytes: 0,
..Default::default()
};
let msg = encode(&make_global_meta(), &[(&desc, &data)], &options).unwrap();
let decode_opts = DecodeOptions {
restore_non_finite: false,
..Default::default()
};
let (_, parts) = decode_range(&msg, 0, &[(0, 4)], &decode_opts).unwrap();
let got = decode_f64_vec(&parts[0]);
assert_eq!(got[1], 0.0, "element 1 must stay zero with restore off");
}
#[test]
fn decode_range_large_sparse_payload() {
let n = 5_000;
let mut values: Vec<f64> = (0..n).map(|i| i as f64).collect();
for i in (50..n).step_by(101) {
values[i] = f64::NAN;
}
let data = f64_bytes(&values);
let desc = make_descriptor(vec![n as u64], Dtype::Float64);
let options = EncodeOptions {
allow_nan: true,
hash_algorithm: None,
..Default::default()
};
let msg = encode(&make_global_meta(), &[(&desc, &data)], &options).unwrap();
let ranges = [(100u64, 500u64), (2000, 300), (4000, 1000)];
let (_, parts) = decode_range(&msg, 0, &ranges, &DecodeOptions::default()).unwrap();
for (i, &(offset, count)) in ranges.iter().enumerate() {
let got = decode_f64_vec(&parts[i]);
assert_eq!(got.len(), count as usize);
for (j, &v) in got.iter().enumerate() {
let global = offset as usize + j;
let expected_nan = (global >= 50) && (global - 50).is_multiple_of(101);
if expected_nan {
assert!(
v.is_nan(),
"range {i} pos {j} (global {global}) should be NaN"
);
} else {
assert_eq!(v, global as f64, "finite mismatch at range {i} pos {j}");
}
}
}
}
#[test]
fn decode_range_without_masks_returns_expected_values() {
let data = f64_bytes(&[1.0, 2.0, 3.0, 4.0, 5.0]);
let desc = make_descriptor(vec![5], Dtype::Float64);
let msg = encode(
&make_global_meta(),
&[(&desc, &data)],
&EncodeOptions {
hash_algorithm: None,
..Default::default()
},
)
.unwrap();
let (_, parts) = decode_range(&msg, 0, &[(1, 3)], &DecodeOptions::default()).unwrap();
let got = decode_f64_vec(&parts[0]);
assert_eq!(got, vec![2.0, 3.0, 4.0]);
}