use super::serialize::TensorDef;
use crate::torch::tensors::ExclusiveTensor;
use crate::utils::sequence::Sequence;
use ndarray::{azip, ArrayViewMut, Axis, IxDyn, Slice};
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use std::iter;
use std::iter::{Fuse, FusedIterator};
use std::ops::{AddAssign, Bound, Mul};
use std::rc::Rc;
use tch::{kind::Element, Device, IndexOp, Kind, Tensor};
use thiserror::Error;
#[derive(Error, Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum PackingError {
#[error("sequences lengths or batch sizes increased; should be monotonic decreasing")]
Increasing,
#[error("input tensor has < {expected} dimensions")]
TooFewDimensions { expected: u8 },
}
#[must_use]
#[serde_as]
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct PackedTensor {
#[serde_as(as = "TensorDef")]
tensor: Tensor,
structure: PackedStructure,
}
impl Clone for PackedTensor {
fn clone(&self) -> Self {
Self {
tensor: self.tensor.shallow_clone(),
structure: self.structure.clone(),
}
}
}
impl PackedTensor {
#[inline]
pub fn from_parts(tensor: Tensor, structure: PackedStructure) -> Self {
assert_eq!(
structure.len() as i64,
*tensor
.size()
.first()
.expect("tensor must have at least 1 dimension"),
"structure length does not match tensor first dimension size"
);
Self { tensor, structure }
}
pub fn from_aligned_tensor(tensor: &Tensor) -> Result<Self, PackingError> {
let mut size = tensor.size();
if size.len() < 2 {
return Err(PackingError::TooFewDimensions { expected: 2 });
}
let sequence_length = size.remove(0);
let batch_size = size[0];
size[0] *= sequence_length;
Ok(Self {
tensor: tensor.reshape(&size),
structure: PackedStructure::Aligned {
sequence_length: sequence_length.try_into().unwrap(),
batch_size: batch_size.try_into().unwrap(),
},
})
}
#[inline]
pub fn from_sorted_sequences<'a, I, E>(slices: I) -> Result<Self, PackingError>
where
I: IntoIterator<Item = &'a [E]>,
I::IntoIter: Clone,
E: 'a + tch::kind::Element + Copy,
{
let sequences = slices.into_iter();
let structure =
PackedStructure::from_sorted_sequence_lengths(sequences.clone().map(<[E]>::len))?;
let data: Vec<_> = PackedSeqIter::from_sorted(sequences).copied().collect();
let tensor = Tensor::of_slice(&data);
Ok(Self { tensor, structure })
}
#[allow(clippy::missing_const_for_fn)] #[inline]
pub fn into_tensor(self) -> Tensor {
self.tensor
}
#[inline]
pub const fn tensor(&self) -> &Tensor {
&self.tensor
}
#[inline]
pub fn tensor_mut(&mut self) -> &mut Tensor {
&mut self.tensor
}
#[must_use]
#[inline]
pub const fn structure(&self) -> &PackedStructure {
&self.structure
}
#[must_use]
pub fn kind(&self) -> Kind {
self.tensor.kind()
}
#[must_use]
pub fn device(&self) -> Device {
self.tensor.device()
}
pub fn batch_sizes_tensor(&self) -> Tensor {
self.structure.batch_sizes_tensor()
}
#[must_use]
pub fn first_batch_size(&self) -> Option<i64> {
self.structure.first_batch_size()
}
#[inline]
pub fn batch_map<F: FnOnce(Tensor) -> Tensor>(self, f: F) -> Self {
Self {
tensor: f(self.tensor),
structure: self.structure,
}
}
#[inline]
pub fn batch_map_ref<'a, F: FnOnce(&'a Tensor) -> Tensor>(&'a self, f: F) -> Self {
Self {
tensor: f(&self.tensor),
structure: self.structure.clone(),
}
}
pub fn view_trim_start(&self, n: usize) -> Self {
let (to_remove, structure) = match &self.structure {
PackedStructure::Aligned {
sequence_length,
batch_size,
} => {
let n = n.min(*sequence_length);
let to_remove = n * *batch_size;
let new_structure = PackedStructure::Aligned {
sequence_length: *sequence_length - n,
batch_size: *batch_size,
};
(to_remove as i64, new_structure)
}
PackedStructure::Ragged(batch_sizes) => {
let to_remove = batch_sizes.as_slice()[..n].iter().copied().sum();
let new_structure = PackedStructure::Ragged(batch_sizes.clone().trim(n));
(to_remove, new_structure)
}
};
let tensor = self.tensor.i(to_remove..);
Self { tensor, structure }
}
pub fn trim_end(&self, n: usize) -> Self {
match &self.structure {
PackedStructure::Aligned {
sequence_length,
batch_size,
} => {
let n = n.min(*sequence_length);
let tensor = self.tensor.i(..(n * *batch_size) as i64);
let structure = PackedStructure::Aligned {
sequence_length: *sequence_length - n,
batch_size: *batch_size,
};
Self { tensor, structure }
}
PackedStructure::Ragged(batch_sizes) => {
let new_batch_sizes = batch_sizes.clone().trim(n);
let (old_group_sizes, new_group_sizes): (Vec<_>, Vec<_>) =
GroupBatchesForResize::new(
batch_sizes.as_slice().iter().copied(),
new_batch_sizes.as_slice().iter().copied(),
)
.unzip();
let groups = self.tensor.split_with_sizes(&old_group_sizes, 0);
let new_groups: Vec<_> = groups
.iter()
.zip(new_group_sizes)
.map(|(group, new_size)| group.i(..new_size))
.collect();
let new_tensor = Tensor::cat(&new_groups, 0);
Self {
tensor: new_tensor,
structure: PackedStructure::Ragged(new_batch_sizes),
}
}
}
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
pub fn discounted_cumsum_from_end<T>(&self, discount: T) -> Self
where
T: Mul + AddAssign<<T as Mul>::Output> + Copy + Element,
{
let mut out = ExclusiveTensor::<T, _>::try_copy_from(self.tensor()).unwrap();
match &self.structure {
PackedStructure::Ragged(batch_sizes) => {
inplace_discounted_cumsum_from_end(
out.array_view_mut(),
discount,
batch_sizes.as_slice().iter().map(|b| *b as usize).rev(),
);
}
PackedStructure::Aligned {
sequence_length,
batch_size,
} => {
inplace_discounted_cumsum_from_end(
out.array_view_mut(),
discount,
iter::repeat(*batch_size).take(*sequence_length),
);
}
}
Self {
tensor: out.into_tensor().to_device(self.tensor.device()),
structure: self.structure.clone(),
}
}
}
#[allow(clippy::cast_possible_wrap)]
fn inplace_discounted_cumsum_from_end<I, T>(
mut array: ArrayViewMut<T, IxDyn>,
discount: T,
rev_batch_sizes: I, ) where
I: IntoIterator<Item = usize>,
T: Mul + AddAssign<<T as Mul>::Output> + Copy,
{
let mut offset = array.shape()[0]; for batch_size in rev_batch_sizes {
let (left, prev_batch) = array.split_at(Axis(0), offset);
array = left;
offset -= batch_size;
let prev_batch_size = prev_batch.shape()[0];
let batch_part = array.slice_axis_mut(
Axis(0),
Slice {
start: offset as isize,
end: Some((offset + prev_batch_size) as isize),
step: 1,
},
);
azip!((a in batch_part, b in &prev_batch) *a += *b * discount);
}
assert_eq!(
offset, 0,
"batch sizes do not match array first dimension length"
);
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum PackedStructure {
Ragged(SharedBatchSizes),
Aligned {
sequence_length: usize,
batch_size: usize,
},
}
impl PackedStructure {
pub fn from_batch_sizes<I: IntoIterator<Item = usize>>(
batch_sizes: I,
) -> Result<Self, PackingError> {
Ok(Self::Ragged(SharedBatchSizes::from_batch_sizes(
batch_sizes,
)?))
}
pub fn from_sorted_sequence_lengths<I: IntoIterator<Item = usize>>(
lengths: I,
) -> Result<Self, PackingError> {
Ok(Self::Ragged(
SharedBatchSizes::from_sorted_sequence_lengths(lengths)?,
))
}
pub fn batch_sizes_tensor(&self) -> Tensor {
match self {
Self::Ragged(batch_sizes) => batch_sizes.tensor(),
Self::Aligned {
sequence_length,
batch_size,
} => Tensor::full(
&[*sequence_length as i64],
*batch_size as i64,
(Kind::Int64, Device::Cpu),
),
}
}
#[must_use]
pub fn first_batch_size(&self) -> Option<i64> {
match self {
Self::Ragged(batch_sizes) => batch_sizes.as_slice().first().copied(),
Self::Aligned {
sequence_length,
batch_size,
} => {
if *sequence_length > 0 {
Some(*batch_size as _)
} else {
None
}
}
}
}
#[must_use]
pub fn len(&self) -> usize {
match self {
Self::Ragged(batch_sizes) => batch_sizes.len(),
Self::Aligned {
sequence_length,
batch_size,
} => sequence_length * batch_size,
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
match self {
Self::Ragged(batch_sizes) => batch_sizes.is_empty(),
Self::Aligned {
sequence_length,
batch_size,
} => *sequence_length == 0 || *batch_size == 0,
}
}
#[allow(clippy::missing_const_for_fn)] #[must_use]
pub fn trim(self, n: usize) -> Self {
match self {
Self::Ragged(batch_sizes) => Self::Ragged(batch_sizes.trim(n)),
Self::Aligned {
sequence_length,
batch_size,
} => Self::Aligned {
sequence_length: sequence_length.saturating_sub(n),
batch_size,
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SharedBatchSizes {
root: Rc<BatchSizes>,
start: usize, end: Option<usize>, }
impl AsRef<[i64]> for SharedBatchSizes {
#[inline]
fn as_ref(&self) -> &[i64] {
self.as_slice()
}
}
impl<T: AsRef<[i64]>> PartialEq<T> for SharedBatchSizes {
#[inline]
fn eq(&self, other: &T) -> bool {
self.as_ref() == other.as_ref()
}
}
impl Eq for SharedBatchSizes {}
impl SharedBatchSizes {
pub fn from_batch_sizes<I: IntoIterator<Item = usize>>(
batch_sizes: I,
) -> Result<Self, PackingError> {
Ok(Self {
root: Rc::new(BatchSizes::from_batch_sizes(batch_sizes)?),
start: 0,
end: None,
})
}
pub fn from_sorted_sequence_lengths<I: IntoIterator<Item = usize>>(
lengths: I,
) -> Result<Self, PackingError> {
Ok(Self {
root: Rc::new(BatchSizes::from_sorted_sequence_lengths(lengths)?),
start: 0,
end: None,
})
}
#[inline]
pub fn as_slice(&self) -> &[i64] {
let start = Bound::Included(self.start);
let end = self.end.map_or(Bound::Unbounded, Bound::Excluded);
&self.root.as_slice()[(start, end)]
}
#[inline]
pub fn tensor(&self) -> Tensor {
let root_tensor = self.root.as_tensor();
if self.start == 0 && self.end.is_none() {
root_tensor.shallow_clone()
} else {
let end = self.end.map(|i| i as i64);
root_tensor.slice(0, self.start as i64, end, 1)
}
}
#[must_use]
pub fn len(&self) -> usize {
self.as_slice()
.iter()
.map(|x| usize::try_from(*x).unwrap())
.sum()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.as_slice().iter().all(|x| *x == 0)
}
#[must_use]
pub const fn trim(mut self, n: usize) -> Self {
self.start += n;
self
}
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
pub struct BatchSizes {
batch_sizes: Vec<i64>,
#[serde(skip)]
batch_sizes_tensor: OnceCell<Tensor>,
}
impl AsRef<[i64]> for BatchSizes {
#[inline]
fn as_ref(&self) -> &[i64] {
self.as_slice()
}
}
impl BatchSizes {
pub fn from_batch_sizes<I: IntoIterator<Item = usize>>(
batch_sizes: I,
) -> Result<Self, PackingError> {
let mut prev = usize::MAX;
let batch_sizes: Vec<_> = batch_sizes
.into_iter()
.map(|x| {
if x > prev {
Err(PackingError::Increasing)
} else {
prev = x;
Ok(x as i64)
}
})
.collect::<Result<_, _>>()?;
Ok(Self {
batch_sizes,
batch_sizes_tensor: OnceCell::new(),
})
}
pub fn from_sorted_sequence_lengths<I: IntoIterator<Item = usize>>(
lengths: I,
) -> Result<Self, PackingError> {
let mut lengths = lengths.into_iter().enumerate().peekable();
let (_, max_seq_len) = lengths.peek().copied().unwrap_or((0, 0));
let mut batch_sizes = vec![0; max_seq_len];
while let Some((i, seq_len)) = lengths.next() {
let (_, next_len) = lengths.peek().copied().unwrap_or((0, 0));
if next_len > seq_len {
return Err(PackingError::Increasing);
}
batch_sizes[next_len..seq_len].fill((i + 1) as i64);
}
Ok(Self {
batch_sizes,
batch_sizes_tensor: OnceCell::new(),
})
}
#[inline]
pub fn as_slice(&self) -> &[i64] {
self.batch_sizes.as_slice()
}
#[inline]
pub fn as_tensor(&self) -> &Tensor {
self.batch_sizes_tensor
.get_or_init(|| Tensor::of_slice(&self.batch_sizes))
}
#[inline]
pub fn len(&self) -> usize {
self.batch_sizes
.iter()
.map(|x| usize::try_from(*x).unwrap())
.sum()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.batch_sizes.iter().all(|x| *x == 0)
}
}
struct GroupBatchesForResize<A, B> {
old_batch_sizes: Fuse<A>,
new_batch_sizes: Fuse<B>,
}
impl<A, B> GroupBatchesForResize<A, B>
where
A: Iterator,
B: Iterator,
{
pub fn new<IA, IB>(old_batch_sizes: IA, new_batch_sizes: IB) -> Self
where
IA: IntoIterator<IntoIter = A>,
IB: IntoIterator<IntoIter = B>,
{
Self {
old_batch_sizes: old_batch_sizes.into_iter().fuse(),
new_batch_sizes: new_batch_sizes.into_iter().fuse(),
}
}
}
impl<A, B> Iterator for GroupBatchesForResize<A, B>
where
A: Iterator<Item = i64>,
B: Iterator<Item = i64>,
{
type Item = (i64, i64);
fn next(&mut self) -> Option<Self::Item> {
let mut old_group_size = 0;
let mut new_group_size = 0;
loop {
let (old, new, tail) = match (self.old_batch_sizes.next(), self.new_batch_sizes.next())
{
(Some(old), Some(new)) => (old, new, false),
(Some(old), None) => (old, 0, true),
(None, Some(new)) => (0, new, true),
(None, None) => break,
};
old_group_size += old;
new_group_size += new;
if !tail && old != new {
break;
}
}
if (old_group_size, new_group_size) == (0, 0) {
None
} else {
Some((old_group_size, new_group_size))
}
}
}
impl<A, B> FusedIterator for GroupBatchesForResize<A, B>
where
A: Iterator<Item = i64>,
B: Iterator<Item = i64>,
{
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PackedSeqIter<I> {
sequences: I,
offset: usize,
sequences_iter: I,
}
impl<I> PackedSeqIter<I>
where
I: Iterator + Clone,
<I as Iterator>::Item: Sequence,
{
pub fn from_sorted<T: IntoIterator<IntoIter = I>>(into_sequences: T) -> Self {
let sequences = into_sequences.into_iter();
assert!(
sequences
.clone()
.zip(sequences.clone().skip(1))
.all(|(a, b)| a.len() >= b.len()),
"sequences not in monotonic decreasing order of length"
);
let sequences_iter = sequences.clone();
Self {
sequences,
offset: 0,
sequences_iter,
}
}
}
impl<I> Iterator for PackedSeqIter<I>
where
I: Iterator + Clone,
<I as Iterator>::Item: Sequence,
{
type Item = <I::Item as Sequence>::Item;
fn next(&mut self) -> Option<Self::Item> {
if let Some(value) = self
.sequences_iter
.next()
.and_then(|seq| seq.get(self.offset))
{
Some(value)
} else {
self.offset += 1;
self.sequences_iter = self.sequences.clone();
self.sequences_iter
.next()
.and_then(|seq| seq.get(self.offset))
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let level_size: usize = self
.sequences
.clone()
.map(|seq| seq.len().saturating_sub(self.offset))
.take_while(|&size| size > 0)
.sum();
let size = if level_size == 0 {
0
} else {
level_size - (self.sequences.clone().count() - self.sequences_iter.clone().count())
};
(size, Some(size))
}
}
impl<I> ExactSizeIterator for PackedSeqIter<I>
where
I: ExactSizeIterator + Clone,
<I as Iterator>::Item: Sequence,
{
}
#[cfg(test)]
mod packed_seq_iter {
use super::*;
#[test]
fn iter() {
let data = [0, 1, 2, 3, 10, 11, 100, 101];
let ranges = [0..4, 4..6, 6..8];
let packed: Vec<_> = PackedSeqIter::from_sorted(&ranges)
.map(|i| data[i])
.collect();
let expected = vec![0, 10, 100, 1, 11, 101, 2, 3];
assert_eq!(packed, expected);
}
#[test]
fn size_hint() {
let ranges = [0..4, 4..6, 6..8];
let packing_indices = PackedSeqIter::from_sorted(&ranges);
assert_eq!(packing_indices.size_hint(), (8, Some(8)));
}
#[test]
fn size_hint_after_next() {
let ranges = [0..4, 4..6, 6..8];
let mut packing_indices = PackedSeqIter::from_sorted(&ranges);
let _ = packing_indices.next();
assert_eq!(packing_indices.size_hint(), (7, Some(7)));
let _ = packing_indices.next();
assert_eq!(packing_indices.size_hint(), (6, Some(6)));
}
}
#[cfg(test)]
mod batch_sizes {
use super::*;
#[test]
fn from_sorted() {
let batch_sizes = BatchSizes::from_sorted_sequence_lengths([4, 2, 2]).unwrap();
assert_eq!(batch_sizes.batch_sizes, [3, 3, 1, 1]);
}
#[test]
fn from_increasing() {
assert_eq!(
BatchSizes::from_sorted_sequence_lengths([4, 5, 2]).unwrap_err(),
PackingError::Increasing
);
}
}
#[cfg(test)]
#[allow(clippy::needless_pass_by_value)]
mod packed_tensor {
use super::*;
use rstest::{fixture, rstest};
#[fixture]
fn packed_tensor() -> PackedTensor {
PackedTensor::from_sorted_sequences([&[0, 1, 2, 3] as &[_], &[10, 11], &[100, 101]])
.unwrap()
}
#[test]
fn from_sorted_sequences() {
let packed_tensor =
PackedTensor::from_sorted_sequences([&[0, 1, 2, 3] as &[_], &[10, 11], &[100, 101]])
.unwrap();
assert_eq!(
packed_tensor.tensor(),
&Tensor::of_slice(&[0, 10, 100, 1, 11, 101, 2, 3])
);
assert_eq!(
packed_tensor.batch_sizes_tensor(),
Tensor::of_slice(&[3, 3, 1, 1])
);
}
#[rstest]
fn view_trim_start_n1(packed_tensor: PackedTensor) {
let actual = packed_tensor.view_trim_start(1);
let expected =
PackedTensor::from_sorted_sequences([&[1, 2, 3] as &[_], &[11], &[101]]).unwrap();
assert_eq!(actual, expected);
}
#[rstest]
fn view_trim_start_n3(packed_tensor: PackedTensor) {
let actual = packed_tensor.view_trim_start(3);
let expected = PackedTensor::from_sorted_sequences([&[3] as &[_]]).unwrap();
assert_eq!(actual, expected);
}
#[rstest]
fn view_trim_start_is_view(packed_tensor: PackedTensor) {
let mut trimmed = packed_tensor.view_trim_start(1);
let _ = trimmed.tensor.neg_();
let expected = PackedTensor::from_sorted_sequences([
&[0, -1, -2, -3] as &[_],
&[10, -11],
&[100, -101],
])
.unwrap();
assert_eq!(packed_tensor, expected);
}
#[rstest]
fn trim_end_n1(packed_tensor: PackedTensor) {
let actual = packed_tensor.trim_end(1);
let expected =
PackedTensor::from_sorted_sequences([&[0, 1, 2] as &[_], &[10], &[100]]).unwrap();
assert_eq!(actual, expected);
}
#[rstest]
fn trim_end_n3(packed_tensor: PackedTensor) {
let actual = packed_tensor.trim_end(3);
let expected = PackedTensor::from_sorted_sequences([&[0] as &[_]]).unwrap();
assert_eq!(actual, expected);
}
#[rstest]
fn trim_end_is_copy(packed_tensor: PackedTensor) {
let mut trimmed = packed_tensor.trim_end(1);
let _ = trimmed.tensor.neg_();
let expected =
PackedTensor::from_sorted_sequences([&[0, 1, 2, 3] as &[_], &[10, 11], &[100, 101]])
.unwrap();
assert_eq!(packed_tensor, expected);
}
#[test]
fn discounted_cumsum_from_end() {
let packed_tensor = PackedTensor::from_sorted_sequences([
&[1.0, 2.0, 3.0, 4.0] as &[_],
&[5.0, 6.0],
&[7.0, 8.0],
])
.unwrap();
let cumsum = packed_tensor.discounted_cumsum_from_end(0.1);
let expected = PackedTensor::from_sorted_sequences([
&[1.234, 2.34, 3.4, 4.0] as &[_],
&[5.6, 6.0],
&[7.8, 8.0],
])
.unwrap();
assert_eq!(cumsum.structure, expected.structure);
assert!(
bool::from(
cumsum
.tensor
.isclose(&expected.tensor, 1e-8, 1e-8, false)
.all()
),
"result: {:?}\nexpected: {:?}",
cumsum,
expected,
);
}
#[rstest]
fn batch_sizes_tensor_values(packed_tensor: PackedTensor) {
let actual = packed_tensor.structure.batch_sizes_tensor();
let expected = Tensor::of_slice(&[3, 3, 1, 1]);
assert_eq!(actual, expected);
}
#[rstest]
fn batch_sizes_tensor_device_cpu(packed_tensor: PackedTensor) {
let batch_sizes = packed_tensor.structure.batch_sizes_tensor();
assert_eq!(batch_sizes.device(), tch::Device::Cpu);
}
}