use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
#[derive(Debug, Clone)]
pub struct PackedSequence<T: Float> {
pub data: Tensor<T>,
pub batch_sizes: Vec<usize>,
pub sorted_indices: Vec<usize>,
pub unsorted_indices: Vec<usize>,
}
pub fn pack_padded_sequence<T: Float>(
input: &Tensor<T>,
lengths: &[usize],
batch_first: bool,
enforce_sorted: bool,
) -> FerrotorchResult<PackedSequence<T>> {
if input.ndim() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"pack_padded_sequence: expected 3-D input, got {}-D",
input.ndim()
),
});
}
let (batch, max_seq_len, features) = if batch_first {
(input.shape()[0], input.shape()[1], input.shape()[2])
} else {
(input.shape()[1], input.shape()[0], input.shape()[2])
};
if lengths.len() != batch {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"pack_padded_sequence: lengths.len() ({}) != batch size ({})",
lengths.len(),
batch,
),
});
}
if batch == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "pack_padded_sequence: batch size must be >= 1".into(),
});
}
for (i, &len) in lengths.iter().enumerate() {
if len == 0 || len > max_seq_len {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"pack_padded_sequence: lengths[{i}] = {len} is invalid \
(must be in [1, {max_seq_len}])",
),
});
}
}
let mut sorted_indices: Vec<usize> = (0..batch).collect();
sorted_indices.sort_by(|&a, &b| lengths[b].cmp(&lengths[a]));
let sorted_lengths: Vec<usize> = sorted_indices.iter().map(|&i| lengths[i]).collect();
if enforce_sorted {
for w in lengths.windows(2) {
if w[0] < w[1] {
return Err(FerrotorchError::InvalidArgument {
message: "pack_padded_sequence: lengths must be sorted in \
descending order when enforce_sorted=true"
.into(),
});
}
}
}
let mut unsorted_indices = vec![0usize; batch];
for (new_pos, &orig_idx) in sorted_indices.iter().enumerate() {
unsorted_indices[orig_idx] = new_pos;
}
let max_len = sorted_lengths[0]; let mut batch_sizes: Vec<usize> = Vec::with_capacity(max_len);
for t in 0..max_len {
let count = sorted_lengths.iter().filter(|&&l| l > t).count();
batch_sizes.push(count);
}
let total_elements: usize = batch_sizes.iter().sum();
let input_data = input.data()?;
let mut packed_data: Vec<T> = Vec::with_capacity(total_elements * features);
for (t, &bs) in batch_sizes.iter().enumerate() {
for &orig_batch_idx in &sorted_indices[..bs] {
let offset = if batch_first {
orig_batch_idx * max_seq_len * features + t * features
} else {
t * batch * features + orig_batch_idx * features
};
packed_data.extend_from_slice(&input_data[offset..offset + features]);
}
}
let data = Tensor::from_storage(
TensorStorage::cpu(packed_data),
vec![total_elements, features],
input.requires_grad(),
)?;
Ok(PackedSequence {
data,
batch_sizes,
sorted_indices,
unsorted_indices,
})
}
pub fn pad_packed_sequence<T: Float>(
packed: &PackedSequence<T>,
batch_first: bool,
padding_value: T,
) -> FerrotorchResult<(Tensor<T>, Vec<usize>)> {
let batch = packed.batch_sizes.first().copied().unwrap_or(0);
if batch == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "pad_packed_sequence: empty PackedSequence".into(),
});
}
let max_seq_len = packed.batch_sizes.len();
let packed_data = packed.data.data()?;
let total_elements = packed.data.shape()[0];
let features = if packed.data.ndim() == 2 {
packed.data.shape()[1]
} else {
return Err(FerrotorchError::InvalidArgument {
message: "pad_packed_sequence: packed data must be 2-D [total, features]".into(),
});
};
let expected_total: usize = packed.batch_sizes.iter().sum();
if total_elements != expected_total {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"pad_packed_sequence: total elements {} != sum of batch_sizes {}",
total_elements, expected_total,
),
});
}
let mut sorted_lengths = vec![0usize; batch];
for &bs in &packed.batch_sizes {
for sl in sorted_lengths[..bs].iter_mut() {
*sl += 1;
}
}
let numel = batch * max_seq_len * features;
let mut output_data = vec![padding_value; numel];
let mut data_offset = 0;
for t in 0..max_seq_len {
let bs = packed.batch_sizes[t];
for s in 0..bs {
let orig_batch_idx = packed.sorted_indices[s];
let out_offset = if batch_first {
orig_batch_idx * max_seq_len * features + t * features
} else {
t * batch * features + orig_batch_idx * features
};
output_data[out_offset..out_offset + features]
.copy_from_slice(&packed_data[data_offset..data_offset + features]);
data_offset += features;
}
}
let out_shape = if batch_first {
vec![batch, max_seq_len, features]
} else {
vec![max_seq_len, batch, features]
};
let tensor = Tensor::from_storage(
TensorStorage::cpu(output_data),
out_shape,
packed.data.requires_grad(),
)?;
let mut original_lengths = vec![0usize; batch];
for (sorted_pos, &orig_idx) in packed.sorted_indices.iter().enumerate() {
original_lengths[orig_idx] = sorted_lengths[sorted_pos];
}
Ok((tensor, original_lengths))
}
#[cfg(test)]
mod tests {
use super::*;
fn make_test_input(batch: usize, max_seq_len: usize, features: usize) -> Tensor<f32> {
let mut data = Vec::with_capacity(batch * max_seq_len * features);
for b in 0..batch {
for t in 0..max_seq_len {
for f in 0..features {
data.push((b * 100 + t * 10 + f) as f32);
}
}
}
Tensor::from_storage(
TensorStorage::cpu(data),
vec![batch, max_seq_len, features],
false,
)
.unwrap()
}
#[test]
fn test_batch_sizes_5_3_2() {
let input = make_test_input(3, 5, 4);
let packed = pack_padded_sequence(&input, &[5, 3, 2], true, true).unwrap();
assert_eq!(packed.batch_sizes, vec![3, 3, 2, 1, 1]);
}
#[test]
fn test_round_trip_batch_first() {
let batch = 3;
let max_seq = 5;
let feat = 4;
let lengths = [5, 3, 2];
let input = make_test_input(batch, max_seq, feat);
let input_data = input.data().unwrap().to_vec();
let packed = pack_padded_sequence(&input, &lengths, true, true).unwrap();
let (output, out_lengths) = pad_packed_sequence(&packed, true, 0.0f32).unwrap();
assert_eq!(out_lengths, &[5, 3, 2]);
assert_eq!(output.shape(), &[batch, max_seq, feat]);
let output_data = output.data().unwrap();
for b in 0..batch {
for t in 0..max_seq {
for f in 0..feat {
let idx = b * max_seq * feat + t * feat + f;
if t < lengths[b] {
assert_eq!(
output_data[idx], input_data[idx],
"mismatch at b={b} t={t} f={f}"
);
} else {
assert_eq!(
output_data[idx], 0.0,
"expected padding=0.0 at b={b} t={t} f={f}"
);
}
}
}
}
}
#[test]
fn test_round_trip_seq_first() {
let batch = 3;
let max_seq = 4;
let feat = 2;
let lengths = [4, 2, 1];
let mut data = Vec::with_capacity(max_seq * batch * feat);
for t in 0..max_seq {
for b in 0..batch {
for f in 0..feat {
data.push((t * 100 + b * 10 + f) as f32);
}
}
}
let input = Tensor::from_storage(
TensorStorage::cpu(data.clone()),
vec![max_seq, batch, feat],
false,
)
.unwrap();
let packed = pack_padded_sequence(&input, &lengths, false, true).unwrap();
let (output, out_lengths) = pad_packed_sequence(&packed, false, -1.0f32).unwrap();
assert_eq!(out_lengths, &lengths);
assert_eq!(output.shape(), &[max_seq, batch, feat]);
let output_data = output.data().unwrap();
for t in 0..max_seq {
for b in 0..batch {
for f in 0..feat {
let idx = t * batch * feat + b * feat + f;
if t < lengths[b] {
assert_eq!(output_data[idx], data[idx], "mismatch at t={t} b={b} f={f}");
} else {
assert_eq!(
output_data[idx], -1.0,
"expected padding=-1.0 at t={t} b={b} f={f}"
);
}
}
}
}
}
#[test]
fn test_enforce_sorted_rejects_unsorted() {
let input = make_test_input(3, 5, 2);
let result = pack_padded_sequence(&input, &[2, 5, 3], true, true);
assert!(result.is_err(), "should reject unsorted lengths");
let err_msg = format!("{}", result.unwrap_err());
assert!(
err_msg.contains("descending"),
"error should mention descending: {err_msg}"
);
}
#[test]
fn test_enforce_sorted_false_accepts_unsorted() {
let input = make_test_input(3, 5, 2);
let packed = pack_padded_sequence(&input, &[2, 5, 3], true, false).unwrap();
assert_eq!(packed.batch_sizes, vec![3, 3, 2, 1, 1]);
assert_eq!(packed.sorted_indices, vec![1, 2, 0]);
}
#[test]
fn test_single_sequence() {
let input = make_test_input(1, 3, 2);
let packed = pack_padded_sequence(&input, &[3], true, true).unwrap();
assert_eq!(packed.batch_sizes, vec![1, 1, 1]);
assert_eq!(packed.sorted_indices, vec![0]);
assert_eq!(packed.unsorted_indices, vec![0]);
let packed_flat = packed.data.data().unwrap();
let input_flat = input.data().unwrap();
assert_eq!(packed_flat, input_flat);
let (output, lens) = pad_packed_sequence(&packed, true, 0.0f32).unwrap();
assert_eq!(lens, vec![3]);
assert_eq!(output.data().unwrap(), input.data().unwrap());
}
#[test]
fn test_all_same_length() {
let batch = 4;
let seq_len = 3;
let feat = 2;
let lengths = [3, 3, 3, 3];
let input = make_test_input(batch, seq_len, feat);
let packed = pack_padded_sequence(&input, &lengths, true, true).unwrap();
assert_eq!(packed.batch_sizes, vec![4, 4, 4]);
let (output, out_lengths) = pad_packed_sequence(&packed, true, 0.0f32).unwrap();
assert_eq!(out_lengths, &[3, 3, 3, 3]);
assert_eq!(output.data().unwrap(), input.data().unwrap());
}
#[test]
fn test_packed_data_order() {
let data = vec![10.0f32, 20.0, 30.0, 40.0, 50.0, 0.0];
let input = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 3, 1], false).unwrap();
let packed = pack_padded_sequence(&input, &[3, 2], true, true).unwrap();
assert_eq!(packed.batch_sizes, vec![2, 2, 1]);
let packed_flat = packed.data.data().unwrap();
assert_eq!(packed_flat, &[10.0, 40.0, 20.0, 50.0, 30.0]);
}
#[test]
fn test_error_lengths_mismatch_batch() {
let input = make_test_input(3, 5, 2);
let result = pack_padded_sequence(&input, &[5, 3], true, false);
assert!(result.is_err());
}
#[test]
fn test_error_zero_length() {
let input = make_test_input(2, 3, 2);
let result = pack_padded_sequence(&input, &[3, 0], true, false);
assert!(result.is_err());
}
#[test]
fn test_error_length_exceeds_max() {
let input = make_test_input(2, 3, 2);
let result = pack_padded_sequence(&input, &[3, 4], true, false);
assert!(result.is_err());
}
#[test]
fn test_error_non_3d_input() {
let input = Tensor::from_storage(
TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
vec![2, 2],
false,
)
.unwrap();
let result = pack_padded_sequence(&input, &[2, 1], true, false);
assert!(result.is_err());
}
#[test]
fn test_unsorted_round_trip() {
let batch = 3;
let max_seq = 5;
let feat = 2;
let lengths = [3, 5, 2];
let input = make_test_input(batch, max_seq, feat);
let input_data = input.data().unwrap().to_vec();
let packed = pack_padded_sequence(&input, &lengths, true, false).unwrap();
let (output, out_lengths) = pad_packed_sequence(&packed, true, 0.0f32).unwrap();
assert_eq!(out_lengths, &[3, 5, 2]);
let output_data = output.data().unwrap();
for b in 0..batch {
for t in 0..max_seq {
for f in 0..feat {
let idx = b * max_seq * feat + t * feat + f;
if t < lengths[b] {
assert_eq!(
output_data[idx], input_data[idx],
"mismatch at b={b} t={t} f={f}"
);
} else {
assert_eq!(
output_data[idx], 0.0,
"expected padding at b={b} t={t} f={f}"
);
}
}
}
}
}
#[test]
fn test_f64_pack_unpack() {
let data = vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let input = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 3, 1], false).unwrap();
let packed = pack_padded_sequence(&input, &[3, 2], true, true).unwrap();
let (output, lens) = pad_packed_sequence(&packed, true, 0.0f64).unwrap();
assert_eq!(lens, &[3, 2]);
let out = output.data().unwrap();
assert_eq!(out[0], 1.0);
assert_eq!(out[1], 2.0);
assert_eq!(out[2], 3.0);
assert_eq!(out[3], 4.0);
assert_eq!(out[4], 5.0);
assert_eq!(out[5], 0.0);
}
}