use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct NestedTensor<T: Float> {
tensors: Vec<Tensor<T>>,
ragged_dim: usize,
}
impl<T: Float> NestedTensor<T> {
pub fn new(tensors: Vec<Tensor<T>>, ragged_dim: usize) -> FerrotorchResult<Self> {
if tensors.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "NestedTensor requires at least one component tensor".into(),
});
}
let ndim = tensors[0].ndim();
if ragged_dim >= ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ragged_dim {} out of range for {}-D tensors",
ragged_dim, ndim
),
});
}
for (i, t) in tensors.iter().enumerate().skip(1) {
if t.ndim() != ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"tensor {} has {} dims but tensor 0 has {} dims",
i,
t.ndim(),
ndim
),
});
}
for d in 0..ndim {
if d != ragged_dim && t.shape()[d] != tensors[0].shape()[d] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"tensor {} has size {} on dim {} but tensor 0 has size {} \
(only dim {} may differ)",
i,
t.shape()[d],
d,
tensors[0].shape()[d],
ragged_dim,
),
});
}
}
}
Ok(Self {
tensors,
ragged_dim,
})
}
#[inline]
pub fn num_components(&self) -> usize {
self.tensors.len()
}
#[inline]
pub fn ragged_dim(&self) -> usize {
self.ragged_dim
}
#[inline]
pub fn tensors(&self) -> &[Tensor<T>] {
&self.tensors
}
#[inline]
pub fn ndim(&self) -> usize {
self.tensors[0].ndim()
}
pub fn consistent_shape(&self) -> Vec<usize> {
self.tensors[0].shape().to_vec()
}
pub fn ragged_lengths(&self) -> Vec<usize> {
self.tensors
.iter()
.map(|t| t.shape()[self.ragged_dim])
.collect()
}
pub fn to_padded(&self, pad_value: T) -> FerrotorchResult<Tensor<T>> {
let batch = self.tensors.len();
let ndim = self.ndim();
let max_len = self
.tensors
.iter()
.map(|t| t.shape()[self.ragged_dim])
.max()
.unwrap_or(0);
let mut out_shape = Vec::with_capacity(ndim + 1);
out_shape.push(batch);
for d in 0..ndim {
if d == self.ragged_dim {
out_shape.push(max_len);
} else {
out_shape.push(self.tensors[0].shape()[d]);
}
}
let numel: usize = out_shape.iter().product();
let mut data = vec![pad_value; numel];
let mut out_strides = vec![0usize; ndim + 1];
out_strides[ndim] = 1;
for d in (0..ndim).rev() {
out_strides[d] = out_strides[d + 1] * out_shape[d + 1];
}
for (b, t) in self.tensors.iter().enumerate() {
let t_data = t.data()?;
let t_shape = t.shape();
let mut t_strides = vec![0usize; ndim];
if ndim > 0 {
t_strides[ndim - 1] = 1;
for d in (0..ndim - 1).rev() {
t_strides[d] = t_strides[d + 1] * t_shape[d + 1];
}
}
let t_numel: usize = t_shape.iter().product();
for (flat, &val) in t_data.iter().enumerate().take(t_numel) {
let mut remaining = flat;
let mut out_flat = b * out_strides[0];
for d in 0..ndim {
let coord = remaining / t_strides[d];
remaining %= t_strides[d];
out_flat += coord * out_strides[d + 1];
}
data[out_flat] = val;
}
}
Tensor::from_storage(TensorStorage::cpu(data), out_shape, false)
}
pub fn from_padded(
tensor: &Tensor<T>,
lengths: &[usize],
ragged_dim: usize,
) -> FerrotorchResult<Self> {
let full_shape = tensor.shape();
if full_shape.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "from_padded requires at least a batch dimension".into(),
});
}
let batch = full_shape[0];
if lengths.len() != batch {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"lengths has {} entries but batch dimension is {}",
lengths.len(),
batch
),
});
}
let comp_ndim = full_shape.len() - 1; if ragged_dim >= comp_ndim {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"ragged_dim {} out of range for {}-D component tensors",
ragged_dim, comp_ndim
),
});
}
let padded_data = tensor.data()?;
let full_ndim = full_shape.len();
let mut full_strides = vec![0usize; full_ndim];
if full_ndim > 0 {
full_strides[full_ndim - 1] = 1;
for d in (0..full_ndim - 1).rev() {
full_strides[d] = full_strides[d + 1] * full_shape[d + 1];
}
}
let mut tensors = Vec::with_capacity(batch);
for (b, &len_b) in lengths.iter().enumerate().take(batch) {
let mut comp_shape = Vec::with_capacity(comp_ndim);
for d in 0..comp_ndim {
if d == ragged_dim {
comp_shape.push(len_b);
} else {
comp_shape.push(full_shape[d + 1]);
}
}
let mut comp_strides = vec![0usize; comp_ndim];
if comp_ndim > 0 {
comp_strides[comp_ndim - 1] = 1;
for d in (0..comp_ndim - 1).rev() {
comp_strides[d] = comp_strides[d + 1] * comp_shape[d + 1];
}
}
let comp_numel: usize = comp_shape.iter().product();
let mut comp_data = Vec::with_capacity(comp_numel);
for flat in 0..comp_numel {
let mut remaining = flat;
let mut full_flat = b * full_strides[0];
for d in 0..comp_ndim {
let coord = if comp_strides[d] > 0 {
remaining / comp_strides[d]
} else {
0
};
if comp_strides[d] > 0 {
remaining %= comp_strides[d];
}
full_flat += coord * full_strides[d + 1];
}
comp_data.push(padded_data[full_flat]);
}
tensors.push(Tensor::from_storage(
TensorStorage::cpu(comp_data),
comp_shape,
false,
)?);
}
Self::new(tensors, ragged_dim)
}
}
fn softmax_rows_inplace<T: Float>(data: &mut [T], rows: usize, cols: usize) {
for r in 0..rows {
let row = &mut data[r * cols..(r + 1) * cols];
let max_val = row
.iter()
.copied()
.fold(<T as num_traits::Float>::neg_infinity(), |a, b| {
if b > a { b } else { a }
});
let mut sum = <T as num_traits::Zero>::zero();
for val in row.iter_mut() {
*val = (*val - max_val).exp();
sum += *val;
}
if sum == <T as num_traits::Zero>::zero() {
for val in row.iter_mut() {
*val = <T as num_traits::Float>::nan();
}
} else {
for val in row.iter_mut() {
*val = *val / sum;
}
}
}
}
pub fn nested_scaled_dot_product_attention<T: Float>(
query: &NestedTensor<T>,
key: &NestedTensor<T>,
value: &NestedTensor<T>,
) -> FerrotorchResult<NestedTensor<T>> {
let n = query.num_components();
if key.num_components() != n || value.num_components() != n {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"query has {} components but key has {} and value has {}",
n,
key.num_components(),
value.num_components()
),
});
}
let mut outputs = Vec::with_capacity(n);
for i in 0..n {
let q = &query.tensors()[i];
let k = &key.tensors()[i];
let v = &value.tensors()[i];
if q.ndim() != 2 || k.ndim() != 2 || v.ndim() != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"attention requires 2-D tensors, component {} has dims ({}, {}, {})",
i,
q.ndim(),
k.ndim(),
v.ndim()
),
});
}
let seq_q = q.shape()[0];
let d_k = q.shape()[1];
let seq_k = k.shape()[0];
let d_k2 = k.shape()[1];
let seq_k2 = v.shape()[0];
let d_v = v.shape()[1];
if d_k != d_k2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!("component {}: query d_k={} but key d_k={}", i, d_k, d_k2),
});
}
if seq_k != seq_k2 {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"component {}: key seq_len={} but value seq_len={}",
i, seq_k, seq_k2
),
});
}
let q_data = q.data()?;
let k_data = k.data()?;
let v_data = v.data()?;
let scale = T::from(d_k).unwrap().sqrt().recip();
let mut scores = vec![<T as num_traits::Zero>::zero(); seq_q * seq_k];
for qi in 0..seq_q {
for ki in 0..seq_k {
let mut dot = <T as num_traits::Zero>::zero();
for di in 0..d_k {
dot += q_data[qi * d_k + di] * k_data[ki * d_k + di];
}
scores[qi * seq_k + ki] = dot * scale;
}
}
softmax_rows_inplace(&mut scores, seq_q, seq_k);
let mut out = vec![<T as num_traits::Zero>::zero(); seq_q * d_v];
for qi in 0..seq_q {
for dvi in 0..d_v {
let mut acc = <T as num_traits::Zero>::zero();
for ki in 0..seq_k {
acc += scores[qi * seq_k + ki] * v_data[ki * d_v + dvi];
}
out[qi * d_v + dvi] = acc;
}
}
outputs.push(Tensor::from_storage(
TensorStorage::cpu(out),
vec![seq_q, d_v],
false,
)?);
}
NestedTensor::new(outputs, query.ragged_dim())
}
#[derive(Debug, Clone)]
pub struct PackedNestedTensor<T: Float> {
data: Vec<T>,
offsets: Vec<usize>,
tail_shape: Vec<usize>,
}
impl<T: Float> PackedNestedTensor<T> {
pub fn from_sequences(
sequences: Vec<Vec<T>>,
lengths: &[usize],
tail_shape: &[usize],
) -> FerrotorchResult<Self> {
if sequences.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "PackedNestedTensor requires at least one sequence".into(),
});
}
if sequences.len() != lengths.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"PackedNestedTensor: sequences has {} entries but lengths has {}",
sequences.len(),
lengths.len()
),
});
}
let tail_numel: usize = tail_shape.iter().product::<usize>().max(1);
let mut total = 0usize;
for (i, seq) in sequences.iter().enumerate() {
let expected = lengths[i]
.checked_mul(tail_numel)
.ok_or_else(|| FerrotorchError::InvalidArgument {
message: format!(
"PackedNestedTensor: length overflow in component {i} \
(length={}, tail_numel={tail_numel})",
lengths[i]
),
})?;
if seq.len() != expected {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"PackedNestedTensor: sequence {i} has {} elements but expected \
lengths[{i}] * tail_numel = {}*{} = {}",
seq.len(),
lengths[i],
tail_numel,
expected
),
});
}
total += expected;
}
let mut data = Vec::with_capacity(total);
let mut offsets = Vec::with_capacity(sequences.len() + 1);
offsets.push(0);
for seq in sequences {
data.extend(seq);
offsets.push(data.len());
}
Ok(Self {
data,
offsets,
tail_shape: tail_shape.to_vec(),
})
}
pub fn from_nested(nested: &NestedTensor<T>) -> FerrotorchResult<Self> {
if nested.ragged_dim() != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"PackedNestedTensor::from_nested requires ragged_dim == 0, got {}",
nested.ragged_dim()
),
});
}
let comps = nested.tensors();
if comps.is_empty() {
return Err(FerrotorchError::InvalidArgument {
message: "PackedNestedTensor::from_nested: no components".into(),
});
}
let tail_shape: Vec<usize> = comps[0].shape()[1..].to_vec();
let lengths: Vec<usize> = comps.iter().map(|t| t.shape()[0]).collect();
let mut sequences: Vec<Vec<T>> = Vec::with_capacity(comps.len());
for t in comps {
sequences.push(t.data()?.to_vec());
}
Self::from_sequences(sequences, &lengths, &tail_shape)
}
pub fn to_nested(&self) -> FerrotorchResult<NestedTensor<T>> {
let n = self.num_components();
let mut tensors = Vec::with_capacity(n);
for i in 0..n {
let len = self.length(i);
let mut shape = vec![len];
shape.extend_from_slice(&self.tail_shape);
let slice = self.component_slice(i).to_vec();
tensors.push(Tensor::from_storage(TensorStorage::cpu(slice), shape, false)?);
}
NestedTensor::new(tensors, 0)
}
#[inline]
pub fn num_components(&self) -> usize {
self.offsets.len().saturating_sub(1)
}
#[inline]
pub fn offsets(&self) -> &[usize] {
&self.offsets
}
#[inline]
pub fn tail_shape(&self) -> &[usize] {
&self.tail_shape
}
#[inline]
pub fn data(&self) -> &[T] {
&self.data
}
#[inline]
pub fn length(&self, i: usize) -> usize {
let tail_numel: usize = self.tail_shape.iter().product::<usize>().max(1);
(self.offsets[i + 1] - self.offsets[i]) / tail_numel
}
#[inline]
pub fn total_numel(&self) -> usize {
self.data.len()
}
pub fn component_slice(&self, i: usize) -> &[T] {
&self.data[self.offsets[i]..self.offsets[i + 1]]
}
pub fn map(&self, f: impl Fn(T) -> T) -> Self {
let data: Vec<T> = self.data.iter().copied().map(f).collect();
Self {
data,
offsets: self.offsets.clone(),
tail_shape: self.tail_shape.clone(),
}
}
pub fn add(&self, other: &Self) -> FerrotorchResult<Self> {
self.zip_with(other, "add", |a, b| a + b)
}
pub fn sub(&self, other: &Self) -> FerrotorchResult<Self> {
self.zip_with(other, "sub", |a, b| a - b)
}
pub fn mul(&self, other: &Self) -> FerrotorchResult<Self> {
self.zip_with(other, "mul", |a, b| a * b)
}
pub fn div(&self, other: &Self) -> FerrotorchResult<Self> {
self.zip_with(other, "div", |a, b| a / b)
}
fn zip_with(
&self,
other: &Self,
op_name: &'static str,
f: impl Fn(T, T) -> T,
) -> FerrotorchResult<Self> {
if self.offsets != other.offsets {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"PackedNestedTensor::{op_name}: offsets mismatch \
({:?} vs {:?})",
self.offsets, other.offsets
),
});
}
if self.tail_shape != other.tail_shape {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"PackedNestedTensor::{op_name}: tail shape mismatch \
({:?} vs {:?})",
self.tail_shape, other.tail_shape
),
});
}
let data: Vec<T> = self
.data
.iter()
.zip(other.data.iter())
.map(|(&a, &b)| f(a, b))
.collect();
Ok(Self {
data,
offsets: self.offsets.clone(),
tail_shape: self.tail_shape.clone(),
})
}
pub fn sum_per_component(&self) -> Vec<T> {
let mut out = Vec::with_capacity(self.num_components());
for i in 0..self.num_components() {
let slice = self.component_slice(i);
let mut acc = <T as num_traits::Zero>::zero();
for &v in slice {
acc += v;
}
out.push(acc);
}
out
}
pub fn mean_per_component(&self) -> Vec<T> {
let mut out = Vec::with_capacity(self.num_components());
for i in 0..self.num_components() {
let slice = self.component_slice(i);
if slice.is_empty() {
out.push(<T as num_traits::Zero>::zero());
continue;
}
let mut acc = <T as num_traits::Zero>::zero();
for &v in slice {
acc += v;
}
let n = T::from(slice.len()).unwrap();
out.push(acc / n);
}
out
}
pub fn to_padded(&self, pad_value: T) -> FerrotorchResult<Tensor<T>> {
let n = self.num_components();
let mut max_len = 0usize;
for i in 0..n {
max_len = max_len.max(self.length(i));
}
let tail_numel: usize = self.tail_shape.iter().product::<usize>().max(1);
let row_stride = max_len * tail_numel;
let mut out = vec![pad_value; n * row_stride];
for i in 0..n {
let dst_base = i * row_stride;
let slice = self.component_slice(i);
out[dst_base..dst_base + slice.len()].copy_from_slice(slice);
}
let mut shape = vec![n, max_len];
shape.extend_from_slice(&self.tail_shape);
Tensor::from_storage(TensorStorage::cpu(out), shape, false)
}
pub fn from_padded(tensor: &Tensor<T>, lengths: &[usize]) -> FerrotorchResult<Self> {
let shape = tensor.shape();
if shape.len() < 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"PackedNestedTensor::from_padded: tensor must have at least \
2 dims (batch, sequence), got {:?}",
shape
),
});
}
let n = shape[0];
let max_len = shape[1];
let tail_shape: Vec<usize> = shape[2..].to_vec();
let tail_numel: usize = tail_shape.iter().product::<usize>().max(1);
if lengths.len() != n {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"PackedNestedTensor::from_padded: lengths has {} entries but \
batch dim is {}",
lengths.len(),
n
),
});
}
for (i, &len) in lengths.iter().enumerate() {
if len > max_len {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"PackedNestedTensor::from_padded: lengths[{i}] = {len} \
exceeds max_len = {max_len}"
),
});
}
}
let padded = tensor.data()?;
let row_stride = max_len * tail_numel;
let mut data = Vec::with_capacity(lengths.iter().sum::<usize>() * tail_numel);
let mut offsets = Vec::with_capacity(n + 1);
offsets.push(0);
for (i, &len) in lengths.iter().enumerate() {
let src_base = i * row_stride;
let src_end = src_base + len * tail_numel;
data.extend_from_slice(&padded[src_base..src_end]);
offsets.push(data.len());
}
Ok(Self {
data,
offsets,
tail_shape,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tensor(data: Vec<f32>, shape: Vec<usize>) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
}
fn make_tensor_f64(data: Vec<f64>, shape: Vec<usize>) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data), shape, false).unwrap()
}
#[test]
fn test_nested_construction() {
let t1 = make_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
let t2 = make_tensor(vec![7.0, 8.0, 9.0, 10.0], vec![2, 2]);
let nt = NestedTensor::new(vec![t1, t2], 0).unwrap();
assert_eq!(nt.num_components(), 2);
assert_eq!(nt.ragged_dim(), 0);
assert_eq!(nt.ndim(), 2);
assert_eq!(nt.ragged_lengths(), vec![3, 2]);
}
#[test]
fn test_nested_rejects_empty() {
let result = NestedTensor::<f32>::new(vec![], 0);
assert!(result.is_err());
}
#[test]
fn test_nested_rejects_shape_mismatch() {
let t1 = make_tensor(vec![1.0; 6], vec![3, 2]);
let t2 = make_tensor(vec![1.0; 6], vec![2, 3]);
let result = NestedTensor::new(vec![t1, t2], 0);
assert!(result.is_err());
}
#[test]
fn test_to_padded_from_padded_ragged_dim_0() {
let t1 = make_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
let t2 = make_tensor(vec![7.0, 8.0, 9.0, 10.0], vec![2, 2]);
let nt = NestedTensor::new(vec![t1, t2], 0).unwrap();
let padded = nt.to_padded(0.0).unwrap();
assert_eq!(padded.shape(), &[2, 3, 2]);
let lengths = nt.ragged_lengths();
let reconstructed = NestedTensor::from_padded(&padded, &lengths, 0).unwrap();
assert_eq!(reconstructed.num_components(), 2);
let r0 = reconstructed.tensors()[0].data().unwrap();
assert_eq!(r0, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let r1 = reconstructed.tensors()[1].data().unwrap();
assert_eq!(r1, &[7.0, 8.0, 9.0, 10.0]);
}
#[test]
fn test_from_padded_round_trip_ragged_dim_1() {
let t1 = make_tensor(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let t2 = make_tensor(vec![7.0, 8.0, 9.0, 10.0], vec![2, 2]);
let nt = NestedTensor::new(vec![t1, t2], 1).unwrap();
let padded = nt.to_padded(0.0).unwrap();
assert_eq!(padded.shape(), &[2, 2, 3]);
let padded_data = padded.data().unwrap();
assert_eq!(padded_data[0], 1.0);
assert_eq!(padded_data[1], 2.0);
assert_eq!(padded_data[2], 3.0);
assert_eq!(padded_data[3], 4.0);
assert_eq!(padded_data[4], 5.0);
assert_eq!(padded_data[5], 6.0);
assert_eq!(padded_data[6], 7.0);
assert_eq!(padded_data[7], 8.0);
assert_eq!(padded_data[8], 0.0); assert_eq!(padded_data[9], 9.0);
assert_eq!(padded_data[10], 10.0);
assert_eq!(padded_data[11], 0.0);
let lengths = nt.ragged_lengths();
assert_eq!(lengths, vec![3, 2]);
let reconstructed = NestedTensor::from_padded(&padded, &lengths, 1).unwrap();
assert_eq!(reconstructed.num_components(), 2);
assert_eq!(reconstructed.tensors()[0].shape(), &[2, 3]);
assert_eq!(reconstructed.tensors()[1].shape(), &[2, 2]);
let r0 = reconstructed.tensors()[0].data().unwrap();
assert_eq!(r0, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let r1 = reconstructed.tensors()[1].data().unwrap();
assert_eq!(r1, &[7.0, 8.0, 9.0, 10.0]);
}
#[test]
fn test_sdpa_basic() {
let q = make_tensor(vec![1.0; 8], vec![2, 4]);
let k = make_tensor(vec![1.0; 12], vec![3, 4]);
let v = make_tensor(vec![1.0; 15], vec![3, 5]);
let qn = NestedTensor::new(vec![q], 0).unwrap();
let kn = NestedTensor::new(vec![k], 0).unwrap();
let vn = NestedTensor::new(vec![v], 0).unwrap();
let result = nested_scaled_dot_product_attention(&qn, &kn, &vn).unwrap();
assert_eq!(result.num_components(), 1);
assert_eq!(result.tensors()[0].shape(), &[2, 5]);
let out = result.tensors()[0].data().unwrap();
for &val in out.iter() {
assert!((val - 1.0).abs() < 1e-5, "expected ~1.0, got {val}");
}
}
#[test]
fn test_sdpa_f64() {
let q = make_tensor_f64(vec![1.0; 8], vec![2, 4]);
let k = make_tensor_f64(vec![1.0; 12], vec![3, 4]);
let v = make_tensor_f64(vec![1.0; 15], vec![3, 5]);
let qn = NestedTensor::new(vec![q], 0).unwrap();
let kn = NestedTensor::new(vec![k], 0).unwrap();
let vn = NestedTensor::new(vec![v], 0).unwrap();
let result = nested_scaled_dot_product_attention(&qn, &kn, &vn).unwrap();
assert_eq!(result.num_components(), 1);
assert_eq!(result.tensors()[0].shape(), &[2, 5]);
}
#[test]
fn test_softmax_all_neg_inf_produces_nan() {
let mut data = vec![f32::NEG_INFINITY; 6];
softmax_rows_inplace(&mut data, 2, 3);
for val in &data {
assert!(val.is_nan(), "expected NaN for all -inf input, got {val}");
}
}
#[test]
fn test_softmax_normal_case() {
let mut data = vec![1.0f32, 2.0, 3.0];
softmax_rows_inplace(&mut data, 1, 3);
let sum: f32 = data.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-6,
"softmax should sum to 1, got {sum}"
);
assert!(data[0] < data[1]);
assert!(data[1] < data[2]);
}
#[test]
fn packed_from_sequences_1d() {
let seqs = vec![
vec![1.0f32, 2.0, 3.0],
vec![4.0, 5.0, 6.0, 7.0, 8.0],
vec![9.0, 10.0],
];
let lengths = vec![3usize, 5, 2];
let pnt = PackedNestedTensor::from_sequences(seqs, &lengths, &[]).unwrap();
assert_eq!(pnt.num_components(), 3);
assert_eq!(pnt.offsets(), &[0, 3, 8, 10]);
assert_eq!(pnt.total_numel(), 10);
assert_eq!(pnt.length(0), 3);
assert_eq!(pnt.length(1), 5);
assert_eq!(pnt.length(2), 2);
assert_eq!(pnt.component_slice(0), &[1.0, 2.0, 3.0]);
assert_eq!(pnt.component_slice(1), &[4.0, 5.0, 6.0, 7.0, 8.0]);
assert_eq!(pnt.component_slice(2), &[9.0, 10.0]);
}
#[test]
fn packed_from_sequences_with_tail_shape() {
let seqs = vec![
vec![
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
], vec![13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0], ];
let lengths = vec![3usize, 2];
let tail = vec![4usize];
let pnt = PackedNestedTensor::from_sequences(seqs, &lengths, &tail).unwrap();
assert_eq!(pnt.num_components(), 2);
assert_eq!(pnt.offsets(), &[0, 12, 20]);
assert_eq!(pnt.length(0), 3);
assert_eq!(pnt.length(1), 2);
assert_eq!(pnt.tail_shape(), &[4]);
}
#[test]
fn packed_rejects_empty_sequences_list() {
let result = PackedNestedTensor::<f32>::from_sequences(vec![], &[], &[]);
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("at least one sequence"));
}
#[test]
fn packed_rejects_mismatched_sequence_length() {
let seqs = vec![vec![1.0f32, 2.0, 3.0]]; let lengths = vec![2usize]; let result = PackedNestedTensor::from_sequences(seqs, &lengths, &[]);
assert!(result.is_err());
}
#[test]
fn packed_rejects_mismatched_sequences_vs_lengths() {
let seqs = vec![vec![1.0f32, 2.0]];
let lengths = vec![2usize, 3];
let result = PackedNestedTensor::from_sequences(seqs, &lengths, &[]);
assert!(result.is_err());
}
#[test]
fn packed_map_applies_fn_to_every_element() {
let pnt = PackedNestedTensor::from_sequences(
vec![vec![1.0f32, -2.0, 3.0], vec![-4.0, 5.0]],
&[3usize, 2],
&[],
)
.unwrap();
let relu = pnt.map(|x: f32| x.max(0.0));
assert_eq!(relu.data(), &[1.0, 0.0, 3.0, 0.0, 5.0]);
assert_eq!(relu.offsets(), pnt.offsets());
}
#[test]
fn packed_add_sub_mul_div() {
let a = PackedNestedTensor::from_sequences(
vec![vec![10.0f32, 20.0, 30.0], vec![40.0, 50.0]],
&[3usize, 2],
&[],
)
.unwrap();
let b = PackedNestedTensor::from_sequences(
vec![vec![1.0f32, 2.0, 3.0], vec![4.0, 5.0]],
&[3usize, 2],
&[],
)
.unwrap();
assert_eq!(a.add(&b).unwrap().data(), &[11.0, 22.0, 33.0, 44.0, 55.0]);
assert_eq!(a.sub(&b).unwrap().data(), &[9.0, 18.0, 27.0, 36.0, 45.0]);
assert_eq!(
a.mul(&b).unwrap().data(),
&[10.0, 40.0, 90.0, 160.0, 250.0]
);
assert_eq!(a.div(&b).unwrap().data(), &[10.0, 10.0, 10.0, 10.0, 10.0]);
}
#[test]
fn packed_add_rejects_mismatched_offsets() {
let a = PackedNestedTensor::from_sequences(
vec![vec![1.0f32, 2.0, 3.0]],
&[3usize],
&[],
)
.unwrap();
let b = PackedNestedTensor::from_sequences(
vec![vec![1.0f32, 2.0]],
&[2usize],
&[],
)
.unwrap();
let result = a.add(&b);
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("offsets mismatch"));
}
#[test]
fn packed_add_rejects_mismatched_tail_shape() {
let a = PackedNestedTensor::from_sequences(
vec![vec![1.0f32, 2.0, 3.0, 4.0]],
&[2usize],
&[2], )
.unwrap();
let b = PackedNestedTensor::from_sequences(
vec![vec![1.0f32, 2.0, 3.0, 4.0]],
&[4usize],
&[], )
.unwrap();
let result = a.add(&b);
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("tail shape mismatch"));
}
#[test]
fn packed_sum_per_component() {
let pnt = PackedNestedTensor::from_sequences(
vec![
vec![1.0f32, 2.0, 3.0], vec![10.0, 20.0, 30.0, 40.0], vec![5.0], ],
&[3usize, 4, 1],
&[],
)
.unwrap();
let sums = pnt.sum_per_component();
assert_eq!(sums, vec![6.0, 100.0, 5.0]);
}
#[test]
fn packed_mean_per_component() {
let pnt = PackedNestedTensor::from_sequences(
vec![
vec![2.0f32, 4.0, 6.0], vec![10.0, 20.0, 30.0, 40.0], vec![7.0], ],
&[3usize, 4, 1],
&[],
)
.unwrap();
let means = pnt.mean_per_component();
assert_eq!(means, vec![4.0, 25.0, 7.0]);
}
#[test]
fn packed_mean_handles_empty_component_as_zero() {
let pnt = PackedNestedTensor::from_sequences(
vec![vec![1.0f32, 2.0], vec![]],
&[2usize, 0],
&[],
)
.unwrap();
let means = pnt.mean_per_component();
assert_eq!(means, vec![1.5, 0.0]);
}
#[test]
fn packed_to_padded_pads_with_value() {
let pnt = PackedNestedTensor::from_sequences(
vec![
vec![1.0f32, 2.0, 3.0],
vec![4.0, 5.0],
vec![6.0, 7.0, 8.0, 9.0],
],
&[3usize, 2, 4],
&[],
)
.unwrap();
let padded = pnt.to_padded(-1.0).unwrap();
assert_eq!(padded.shape(), &[3, 4]);
let data = padded.data().unwrap();
assert_eq!(
data,
&[
1.0, 2.0, 3.0, -1.0, 4.0, 5.0, -1.0, -1.0, 6.0, 7.0, 8.0, 9.0, ]
);
}
#[test]
fn packed_to_padded_with_tail_shape() {
let pnt = PackedNestedTensor::from_sequences(
vec![
vec![1.0f32, 2.0, 3.0, 4.0], vec![5.0, 6.0], ],
&[2usize, 1],
&[2],
)
.unwrap();
let padded = pnt.to_padded(0.0).unwrap();
assert_eq!(padded.shape(), &[2, 2, 2]);
let data = padded.data().unwrap();
assert_eq!(
data,
&[
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, ]
);
}
#[test]
fn packed_from_padded_inverse_of_to_padded() {
let orig = PackedNestedTensor::from_sequences(
vec![
vec![1.0f32, 2.0, 3.0],
vec![4.0, 5.0],
vec![6.0, 7.0, 8.0, 9.0],
],
&[3usize, 2, 4],
&[],
)
.unwrap();
let padded = orig.to_padded(-99.0).unwrap();
let recovered = PackedNestedTensor::from_padded(&padded, &[3, 2, 4]).unwrap();
assert_eq!(recovered.offsets(), orig.offsets());
assert_eq!(recovered.data(), orig.data());
}
#[test]
fn packed_from_padded_rejects_length_exceeding_max_len() {
let data: Vec<f32> = vec![0.0; 12];
let t = make_tensor(data, vec![3, 4]);
let result = PackedNestedTensor::from_padded(&t, &[3, 5, 2]); assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("exceeds max_len"));
}
#[test]
fn packed_from_padded_rejects_lengths_count_mismatch() {
let t = make_tensor(vec![0.0f32; 12], vec![3, 4]);
let result = PackedNestedTensor::from_padded(&t, &[3, 4]); assert!(result.is_err());
}
#[test]
fn packed_from_nested_and_back() {
let t1 = make_tensor(vec![1.0, 2.0, 3.0], vec![3]);
let t2 = make_tensor(vec![4.0, 5.0], vec![2]);
let nested = NestedTensor::new(vec![t1, t2], 0).unwrap();
let packed = PackedNestedTensor::from_nested(&nested).unwrap();
assert_eq!(packed.num_components(), 2);
assert_eq!(packed.data(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
let round_trip = packed.to_nested().unwrap();
assert_eq!(round_trip.num_components(), 2);
assert_eq!(round_trip.tensors()[0].shape(), &[3]);
assert_eq!(round_trip.tensors()[1].shape(), &[2]);
assert_eq!(round_trip.tensors()[0].data().unwrap(), &[1.0, 2.0, 3.0]);
assert_eq!(round_trip.tensors()[1].data().unwrap(), &[4.0, 5.0]);
}
#[test]
fn packed_from_nested_rejects_non_zero_ragged_dim() {
let t1 = make_tensor(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let t2 = make_tensor(vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0], vec![2, 3]);
let nested = NestedTensor::new(vec![t1, t2], 1).unwrap();
let result = PackedNestedTensor::from_nested(&nested);
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("ragged_dim == 0"));
}
#[test]
fn packed_f64_works_like_f32() {
let pnt = PackedNestedTensor::from_sequences(
vec![vec![1.0f64, 2.0], vec![3.0, 4.0, 5.0]],
&[2usize, 3],
&[],
)
.unwrap();
assert_eq!(pnt.sum_per_component(), vec![3.0, 12.0]);
let doubled = pnt.map(|x: f64| x * 2.0);
assert_eq!(doubled.data(), &[2.0, 4.0, 6.0, 8.0, 10.0]);
let dense = make_tensor_f64(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
assert_eq!(dense.shape(), &[2, 2]);
}
}