relearn/torch/
packed.rs

1//! Packed Tensors
2use super::serialize::TensorDef;
3use crate::torch::tensors::ExclusiveTensor;
4use crate::utils::sequence::Sequence;
5use ndarray::{azip, ArrayViewMut, Axis, IxDyn, Slice};
6use once_cell::sync::OnceCell;
7use serde::{Deserialize, Serialize};
8use serde_with::serde_as;
9use std::iter;
10use std::iter::{Fuse, FusedIterator};
11use std::ops::{AddAssign, Bound, Mul};
12use std::rc::Rc;
13use tch::{kind::Element, Device, IndexOp, Kind, Tensor};
14use thiserror::Error;
15
16/// Error involving packing data.
17#[derive(Error, Debug, Copy, Clone, PartialEq, Eq, Hash)]
18pub enum PackingError {
19    #[error("sequences lengths or batch sizes increased; should be monotonic decreasing")]
20    Increasing,
21    #[error("input tensor has < {expected} dimensions")]
22    TooFewDimensions { expected: u8 },
23}
24
25/// A packed tensor.
26///
27/// A packed tensor represents a set of heterogeneous-length sequences.
28/// The sequences are arranged along the first dimension of the tensor and are stored interleaved:
29/// the first steps from all sequences followed by the second steps, etc.
30///
31/// The sequences are packed in order from longest to shortest.
32///
33/// For example, the sequences `[0, 1, 2, 3]`, `[10, 11]`, `[100, 101]` are packed as
34/// `[0, 10, 100, 1, 11, 101, 2, 3]`.
35#[must_use]
36#[serde_as]
37#[derive(Debug, PartialEq, Serialize, Deserialize)]
38pub struct PackedTensor {
39    /// The packed tensor data. Must have at least one dimension.
40    #[serde_as(as = "TensorDef")]
41    tensor: Tensor,
42    /// The packed structure of `tensor`.
43    structure: PackedStructure,
44}
45
46impl Clone for PackedTensor {
47    fn clone(&self) -> Self {
48        Self {
49            tensor: self.tensor.shallow_clone(),
50            structure: self.structure.clone(),
51        }
52    }
53}
54
55impl PackedTensor {
56    /// Construct from a packed data [`Tensor`] and a `PackedStructure` describing it.
57    ///
58    /// # Panics
59    /// * If the tensor is 0-dimensional.
60    /// * If the length of the first dimension does not match `structure.len()`.
61    #[inline]
62    pub fn from_parts(tensor: Tensor, structure: PackedStructure) -> Self {
63        assert_eq!(
64            structure.len() as i64,
65            *tensor
66                .size()
67                .first()
68                .expect("tensor must have at least 1 dimension"),
69            "structure length does not match tensor first dimension size"
70        );
71        Self { tensor, structure }
72    }
73
74    /// Construct from an aligned tensor with equal-length sequences.
75    ///
76    /// # Inputs
77    /// * `tensor`: A tensor with shape `[SEQUENCE_LEN, NUM_SEQUENCES, ...]`.
78    ///
79    /// # Returns
80    /// Returns an error if the input tensor has less than 2 dimensions.
81    pub fn from_aligned_tensor(tensor: &Tensor) -> Result<Self, PackingError> {
82        let mut size = tensor.size();
83        if size.len() < 2 {
84            return Err(PackingError::TooFewDimensions { expected: 2 });
85        }
86        let sequence_length = size.remove(0);
87        let batch_size = size[0];
88
89        size[0] *= sequence_length;
90        Ok(Self {
91            tensor: tensor.reshape(&size),
92            structure: PackedStructure::Aligned {
93                sequence_length: sequence_length.try_into().unwrap(),
94                batch_size: batch_size.try_into().unwrap(),
95            },
96        })
97    }
98
99    /// Construct a 1D packed tensor from slices sorted in monotonic decreasing order of length.
100    ///
101    /// Returns an error if any slice is longer than the sequence before it.
102    #[inline]
103    pub fn from_sorted_sequences<'a, I, E>(slices: I) -> Result<Self, PackingError>
104    where
105        I: IntoIterator<Item = &'a [E]>,
106        I::IntoIter: Clone,
107        E: 'a + tch::kind::Element + Copy,
108    {
109        let sequences = slices.into_iter();
110        let structure =
111            PackedStructure::from_sorted_sequence_lengths(sequences.clone().map(<[E]>::len))?;
112        let data: Vec<_> = PackedSeqIter::from_sorted(sequences).copied().collect();
113        let tensor = Tensor::of_slice(&data);
114        Ok(Self { tensor, structure })
115    }
116
117    /// Convert into the underlying packed [`Tensor`] object.
118    #[allow(clippy::missing_const_for_fn)] // false positive
119    #[inline]
120    pub fn into_tensor(self) -> Tensor {
121        self.tensor
122    }
123
124    /// Reference the underlying packed [`Tensor`] object.
125    #[inline]
126    pub const fn tensor(&self) -> &Tensor {
127        &self.tensor
128    }
129
130    /// Mutably reference the underlying packed [`Tensor`] object.
131    #[inline]
132    pub fn tensor_mut(&mut self) -> &mut Tensor {
133        &mut self.tensor
134    }
135
136    /// Reference the packed structure.
137    #[must_use]
138    #[inline]
139    pub const fn structure(&self) -> &PackedStructure {
140        &self.structure
141    }
142
143    /// The tensor [`Kind`] (data type).
144    #[must_use]
145    pub fn kind(&self) -> Kind {
146        self.tensor.kind()
147    }
148
149    /// The tensor [`Device`].
150    #[must_use]
151    pub fn device(&self) -> Device {
152        self.tensor.device()
153    }
154
155    /// A [`Tensor`] with the packed batch sizes.
156    ///
157    /// Has type `i64` and is on the CPU device.
158    pub fn batch_sizes_tensor(&self) -> Tensor {
159        self.structure.batch_sizes_tensor()
160    }
161
162    /// Batch size of the first step if any. The largest batch size.
163    #[must_use]
164    pub fn first_batch_size(&self) -> Option<i64> {
165        self.structure.first_batch_size()
166    }
167
168    /// Transform the stored tensor with a function that preserves the packed sequence structure.
169    ///
170    /// The sequence structure is stored along the first dimension of the tensor. This dimension
171    /// must be preserved; its length must not change and its semantics should be preserved.
172    /// Other methods may panic if the function changes the sequence structure.
173    #[inline]
174    pub fn batch_map<F: FnOnce(Tensor) -> Tensor>(self, f: F) -> Self {
175        Self {
176            tensor: f(self.tensor),
177            structure: self.structure,
178        }
179    }
180
181    /// Map the stored tensor with a function that preserves the packed sequence structure.
182    ///
183    /// The sequence structure is stored along the first dimension of the tensor. This dimension
184    /// must be preserved; its length must not change and its semantics should be preserved.
185    /// Other methods may panic if the function changes the sequence structure.
186    #[inline]
187    pub fn batch_map_ref<'a, F: FnOnce(&'a Tensor) -> Tensor>(&'a self, f: F) -> Self {
188        Self {
189            tensor: f(&self.tensor),
190            structure: self.structure.clone(),
191        }
192    }
193
194    /// View the packed tensor with the first `n` items removed from each sequence.
195    pub fn view_trim_start(&self, n: usize) -> Self {
196        let (to_remove, structure) = match &self.structure {
197            PackedStructure::Aligned {
198                sequence_length,
199                batch_size,
200            } => {
201                let n = n.min(*sequence_length);
202                let to_remove = n * *batch_size;
203                let new_structure = PackedStructure::Aligned {
204                    sequence_length: *sequence_length - n,
205                    batch_size: *batch_size,
206                };
207                (to_remove as i64, new_structure)
208            }
209            PackedStructure::Ragged(batch_sizes) => {
210                let to_remove = batch_sizes.as_slice()[..n].iter().copied().sum();
211                let new_structure = PackedStructure::Ragged(batch_sizes.clone().trim(n));
212                (to_remove, new_structure)
213            }
214        };
215        let tensor = self.tensor.i(to_remove..);
216
217        Self { tensor, structure }
218    }
219
220    /// Copy a packed tensor without the last `n` elements of each sequence.
221    ///
222    /// Sequences with `n` or fewer elements are removed.
223    pub fn trim_end(&self, n: usize) -> Self {
224        match &self.structure {
225            PackedStructure::Aligned {
226                sequence_length,
227                batch_size,
228            } => {
229                let n = n.min(*sequence_length);
230                let tensor = self.tensor.i(..(n * *batch_size) as i64);
231                let structure = PackedStructure::Aligned {
232                    sequence_length: *sequence_length - n,
233                    batch_size: *batch_size,
234                };
235                Self { tensor, structure }
236            }
237            PackedStructure::Ragged(batch_sizes) => {
238                let new_batch_sizes = batch_sizes.clone().trim(n);
239                let (old_group_sizes, new_group_sizes): (Vec<_>, Vec<_>) =
240                    GroupBatchesForResize::new(
241                        batch_sizes.as_slice().iter().copied(),
242                        new_batch_sizes.as_slice().iter().copied(),
243                    )
244                    .unzip();
245
246                // Split the tensor into groups based on the old sizes
247                let groups = self.tensor.split_with_sizes(&old_group_sizes, 0);
248
249                // Resize each group into the new group size
250                // If batch_sizes is monotonic decreasing (which it should be) then each new group
251                // size will be less than or equal to the old group size.
252                let new_groups: Vec<_> = groups
253                    .iter()
254                    .zip(new_group_sizes)
255                    .map(|(group, new_size)| group.i(..new_size))
256                    .collect();
257
258                // Collect the groups back into a single tensor
259                let new_tensor = Tensor::cat(&new_groups, 0);
260
261                Self {
262                    tensor: new_tensor,
263                    structure: PackedStructure::Ragged(new_batch_sizes),
264                }
265            }
266        }
267    }
268
269    /// Discounted cumulative sum from sequence end to start for a tensor.
270    ///
271    /// For each element `x[i]` in the sequence `x[0] ... x[N]`,
272    /// returns `y[i] = sum_{j in i..N} discount ** (j - i) * x[j]`
273    ///
274    /// # Warning
275    /// Does not preserve gradients.
276    ///
277    /// # Panics
278    /// If the type of `discount` does not match the tensor data type.
279    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
280    pub fn discounted_cumsum_from_end<T>(&self, discount: T) -> Self
281    where
282        T: Mul + AddAssign<<T as Mul>::Output> + Copy + Element,
283    {
284        let mut out = ExclusiveTensor::<T, _>::try_copy_from(self.tensor()).unwrap();
285        match &self.structure {
286            PackedStructure::Ragged(batch_sizes) => {
287                inplace_discounted_cumsum_from_end(
288                    out.array_view_mut(),
289                    discount,
290                    batch_sizes.as_slice().iter().map(|b| *b as usize).rev(),
291                );
292            }
293            PackedStructure::Aligned {
294                sequence_length,
295                batch_size,
296            } => {
297                inplace_discounted_cumsum_from_end(
298                    out.array_view_mut(),
299                    discount,
300                    iter::repeat(*batch_size).take(*sequence_length),
301                );
302            }
303        }
304        Self {
305            tensor: out.into_tensor().to_device(self.tensor.device()),
306            structure: self.structure.clone(),
307        }
308    }
309}
310
311#[allow(clippy::cast_possible_wrap)]
312fn inplace_discounted_cumsum_from_end<I, T>(
313    mut array: ArrayViewMut<T, IxDyn>,
314    discount: T,
315    rev_batch_sizes: I, // Batch sizes in reverse order
316) where
317    I: IntoIterator<Item = usize>,
318    T: Mul + AddAssign<<T as Mul>::Output> + Copy,
319{
320    // Everything to the right of offset is complete, to the left is incomplete
321    let mut offset = array.shape()[0]; // Panics if array is 0-dimensional
322    for batch_size in rev_batch_sizes {
323        let (left, prev_batch) = array.split_at(Axis(0), offset);
324        array = left;
325        offset -= batch_size;
326
327        let prev_batch_size = prev_batch.shape()[0];
328        let batch_part = array.slice_axis_mut(
329            Axis(0),
330            Slice {
331                start: offset as isize,
332                end: Some((offset + prev_batch_size) as isize),
333                step: 1,
334            },
335        );
336        azip!((a in batch_part, b in &prev_batch) *a += *b * discount);
337    }
338    assert_eq!(
339        offset, 0,
340        "batch sizes do not match array first dimension length"
341    );
342}
343
344/// Information about a packed tensor structure.
345#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
346pub enum PackedStructure {
347    /// Heterogeneous batch sizes lengths
348    Ragged(SharedBatchSizes),
349    /// All sequences have the same length.
350    Aligned {
351        sequence_length: usize,
352        /// Number of sequences
353        batch_size: usize,
354    },
355}
356
357impl PackedStructure {
358    /// Construct from an iterator of monotonic decreasing batch sizes.
359    ///
360    /// Returns an error if any batch size is greater than the previous batch size.
361    pub fn from_batch_sizes<I: IntoIterator<Item = usize>>(
362        batch_sizes: I,
363    ) -> Result<Self, PackingError> {
364        Ok(Self::Ragged(SharedBatchSizes::from_batch_sizes(
365            batch_sizes,
366        )?))
367    }
368
369    /// Construct from an iterator of monotonic decreasing sequence lengths.
370    ///
371    /// Returns an error if any length is greater than the previous length.
372    pub fn from_sorted_sequence_lengths<I: IntoIterator<Item = usize>>(
373        lengths: I,
374    ) -> Result<Self, PackingError> {
375        Ok(Self::Ragged(
376            SharedBatchSizes::from_sorted_sequence_lengths(lengths)?,
377        ))
378    }
379
380    /// A [`Tensor`] with the packed batch sizes.
381    ///
382    /// Has type `i64` and is on the CPU device.
383    pub fn batch_sizes_tensor(&self) -> Tensor {
384        match self {
385            Self::Ragged(batch_sizes) => batch_sizes.tensor(),
386            Self::Aligned {
387                sequence_length,
388                batch_size,
389            } => Tensor::full(
390                &[*sequence_length as i64],
391                *batch_size as i64,
392                (Kind::Int64, Device::Cpu),
393            ),
394        }
395    }
396
397    /// Batch size of the first step if any. The largest batch size.
398    #[must_use]
399    pub fn first_batch_size(&self) -> Option<i64> {
400        match self {
401            Self::Ragged(batch_sizes) => batch_sizes.as_slice().first().copied(),
402            Self::Aligned {
403                sequence_length,
404                batch_size,
405            } => {
406                if *sequence_length > 0 {
407                    Some(*batch_size as _)
408                } else {
409                    None
410                }
411            }
412        }
413    }
414
415    /// The total number of elements across all sequences represented by this structure.
416    #[must_use]
417    pub fn len(&self) -> usize {
418        match self {
419            Self::Ragged(batch_sizes) => batch_sizes.len(),
420            Self::Aligned {
421                sequence_length,
422                batch_size,
423            } => sequence_length * batch_size,
424        }
425    }
426
427    /// Whether the total number of elements across all sequences is zero.
428    #[must_use]
429    pub fn is_empty(&self) -> bool {
430        match self {
431            Self::Ragged(batch_sizes) => batch_sizes.is_empty(),
432            Self::Aligned {
433                sequence_length,
434                batch_size,
435            } => *sequence_length == 0 || *batch_size == 0,
436        }
437    }
438
439    /// Structure resulting from removing `n` items from each sequence.
440    ///
441    /// Any sequences with length less than `n` are reduced to length `0`.
442    #[allow(clippy::missing_const_for_fn)] // false positive
443    #[must_use]
444    pub fn trim(self, n: usize) -> Self {
445        match self {
446            Self::Ragged(batch_sizes) => Self::Ragged(batch_sizes.trim(n)),
447            Self::Aligned {
448                sequence_length,
449                batch_size,
450            } => Self::Aligned {
451                sequence_length: sequence_length.saturating_sub(n),
452                batch_size,
453            },
454        }
455    }
456}
457
458/// Slice of a reference-counted batch sizes vector.
459///
460/// The value of the `i`th index of the batch sizes vector is the number of sequences with length
461/// at least `i + 1`. It is the number of index-`i` steps that appear in the packed tensor.
462#[derive(Debug, Clone, Serialize, Deserialize)]
463pub struct SharedBatchSizes {
464    root: Rc<BatchSizes>,
465    start: usize,       // inclusive
466    end: Option<usize>, // exclusive
467}
468
469impl AsRef<[i64]> for SharedBatchSizes {
470    #[inline]
471    fn as_ref(&self) -> &[i64] {
472        self.as_slice()
473    }
474}
475
476impl<T: AsRef<[i64]>> PartialEq<T> for SharedBatchSizes {
477    #[inline]
478    fn eq(&self, other: &T) -> bool {
479        self.as_ref() == other.as_ref()
480    }
481}
482
483impl Eq for SharedBatchSizes {}
484
485impl SharedBatchSizes {
486    /// Construct from an iterator of monotonic decreasing batch sizes.
487    ///
488    /// Returns an error if any batch size is greater than the previous batch size.
489    pub fn from_batch_sizes<I: IntoIterator<Item = usize>>(
490        batch_sizes: I,
491    ) -> Result<Self, PackingError> {
492        Ok(Self {
493            root: Rc::new(BatchSizes::from_batch_sizes(batch_sizes)?),
494            start: 0,
495            end: None,
496        })
497    }
498
499    /// Construct from an iterator of monotonic decreasing sequence lengths.
500    ///
501    /// Returns an error if any length is greater than the previous length.
502    pub fn from_sorted_sequence_lengths<I: IntoIterator<Item = usize>>(
503        lengths: I,
504    ) -> Result<Self, PackingError> {
505        Ok(Self {
506            root: Rc::new(BatchSizes::from_sorted_sequence_lengths(lengths)?),
507            start: 0,
508            end: None,
509        })
510    }
511
512    /// View batch sizes as a slice
513    #[inline]
514    pub fn as_slice(&self) -> &[i64] {
515        let start = Bound::Included(self.start);
516        let end = self.end.map_or(Bound::Unbounded, Bound::Excluded);
517        &self.root.as_slice()[(start, end)]
518    }
519
520    /// Batch size as an `i64` `[Tensor]`. The underlying data is cached and shared between calls.
521    #[inline]
522    pub fn tensor(&self) -> Tensor {
523        let root_tensor = self.root.as_tensor();
524
525        if self.start == 0 && self.end.is_none() {
526            root_tensor.shallow_clone()
527        } else {
528            let end = self.end.map(|i| i as i64);
529            root_tensor.slice(0, self.start as i64, end, 1)
530        }
531    }
532
533    /// The total number of elements across all sequences represented by this structure.
534    #[must_use]
535    pub fn len(&self) -> usize {
536        self.as_slice()
537            .iter()
538            // Batch sizes are initalized from usize and should still fit
539            .map(|x| usize::try_from(*x).unwrap())
540            .sum()
541    }
542
543    /// Whether the total number of elements across all sequences is zero.
544    #[must_use]
545    pub fn is_empty(&self) -> bool {
546        self.as_slice().iter().all(|x| *x == 0)
547    }
548
549    /// Batch sizes resulting from removing `n` values from each sequence.
550    #[must_use]
551    pub const fn trim(mut self, n: usize) -> Self {
552        self.start += n;
553        self
554    }
555}
556
557#[derive(Debug, PartialEq, Serialize, Deserialize)]
558pub struct BatchSizes {
559    /// Batch sizes. Non-negative and monotonic decreasing.
560    ///
561    /// These would be `usize` except `tch` expects `i64`.
562    batch_sizes: Vec<i64>,
563
564    /// Cached `batch_sizes` as a non-negative i64 tensor on the CPU device.
565    #[serde(skip)]
566    batch_sizes_tensor: OnceCell<Tensor>,
567}
568
569impl AsRef<[i64]> for BatchSizes {
570    #[inline]
571    fn as_ref(&self) -> &[i64] {
572        self.as_slice()
573    }
574}
575
576impl BatchSizes {
577    /// Construct from an iterator of monotonic decreasing batch sizes.
578    ///
579    /// Returns an error if any batch size is greater than the previous batch size.
580    pub fn from_batch_sizes<I: IntoIterator<Item = usize>>(
581        batch_sizes: I,
582    ) -> Result<Self, PackingError> {
583        let mut prev = usize::MAX;
584        let batch_sizes: Vec<_> = batch_sizes
585            .into_iter()
586            .map(|x| {
587                if x > prev {
588                    Err(PackingError::Increasing)
589                } else {
590                    prev = x;
591                    Ok(x as i64)
592                }
593            })
594            .collect::<Result<_, _>>()?;
595        Ok(Self {
596            batch_sizes,
597            batch_sizes_tensor: OnceCell::new(),
598        })
599    }
600
601    /// Construct from an iterator of monotonic decreasing sequence lengths.
602    ///
603    /// Returns an an error if any length is greater than the previous length.
604    pub fn from_sorted_sequence_lengths<I: IntoIterator<Item = usize>>(
605        lengths: I,
606    ) -> Result<Self, PackingError> {
607        let mut lengths = lengths.into_iter().enumerate().peekable();
608
609        let (_, max_seq_len) = lengths.peek().copied().unwrap_or((0, 0));
610        let mut batch_sizes = vec![0; max_seq_len];
611
612        while let Some((i, seq_len)) = lengths.next() {
613            // `batch_size = i + 1` sequences have length at least `seq_len`.
614            // Record this as the batch size for all lengths down to the length of the next seq
615            let (_, next_len) = lengths.peek().copied().unwrap_or((0, 0));
616            if next_len > seq_len {
617                return Err(PackingError::Increasing);
618            }
619            batch_sizes[next_len..seq_len].fill((i + 1) as i64);
620        }
621        Ok(Self {
622            batch_sizes,
623            batch_sizes_tensor: OnceCell::new(),
624        })
625    }
626
627    /// View batch sizes as a slice
628    #[inline]
629    pub fn as_slice(&self) -> &[i64] {
630        self.batch_sizes.as_slice()
631    }
632
633    /// View batch sizes as an `i64` `[Tensor]` (cached).
634    #[inline]
635    pub fn as_tensor(&self) -> &Tensor {
636        self.batch_sizes_tensor
637            .get_or_init(|| Tensor::of_slice(&self.batch_sizes))
638    }
639
640    /// The total number of elements across all sequences represented by this structure.
641    #[inline]
642    pub fn len(&self) -> usize {
643        self.batch_sizes
644            .iter()
645            // Batch sizes are initalized from usize and should still fit
646            .map(|x| usize::try_from(*x).unwrap())
647            .sum()
648    }
649
650    /// Whether the total number of elements across all sequences is zero.
651    #[inline]
652    pub fn is_empty(&self) -> bool {
653        self.batch_sizes.iter().all(|x| *x == 0)
654    }
655}
656
657/// Collect batches into groups where only the last batch in a group changes size.
658///
659/// This is designed for resizes that affect only the end of a batch.
660///
661/// If `old_batch_sizes` and `new_batch_sizes` have different lengths then any excess are grouped
662/// together and it is assumed that the corresponding new/old batch size is 0.
663///
664/// # Args
665/// * `old_batch_sizes` - An iterator over the old/original batch sizes.
666/// * `new_batch_sizes` - An iterator over the new batch sizes.
667///
668/// # Items
669/// Yields pairs `(old_group_size, new_group_size)` each representing a group of elements to be
670/// resized only at the end of the group.
671struct GroupBatchesForResize<A, B> {
672    old_batch_sizes: Fuse<A>,
673    new_batch_sizes: Fuse<B>,
674}
675
676impl<A, B> GroupBatchesForResize<A, B>
677where
678    A: Iterator,
679    B: Iterator,
680{
681    pub fn new<IA, IB>(old_batch_sizes: IA, new_batch_sizes: IB) -> Self
682    where
683        IA: IntoIterator<IntoIter = A>,
684        IB: IntoIterator<IntoIter = B>,
685    {
686        Self {
687            old_batch_sizes: old_batch_sizes.into_iter().fuse(),
688            new_batch_sizes: new_batch_sizes.into_iter().fuse(),
689        }
690    }
691}
692
693impl<A, B> Iterator for GroupBatchesForResize<A, B>
694where
695    A: Iterator<Item = i64>,
696    B: Iterator<Item = i64>,
697{
698    type Item = (i64, i64);
699
700    fn next(&mut self) -> Option<Self::Item> {
701        // Accumulated old/new group sizes.
702        //
703        // Invariant: These have the same size unless one iterator has ended.
704        let mut old_group_size = 0;
705        let mut new_group_size = 0;
706        loop {
707            let (old, new, tail) = match (self.old_batch_sizes.next(), self.new_batch_sizes.next())
708            {
709                (Some(old), Some(new)) => (old, new, false),
710                (Some(old), None) => (old, 0, true),
711                (None, Some(new)) => (0, new, true),
712                (None, None) => break,
713            };
714            old_group_size += old;
715            new_group_size += new;
716
717            // Return the group if the sizes differ.
718            // Can merge consecutive tail batches because the whole tail will be added/removed
719            // so it is still a suffix operation on the group, just one that affects multiple
720            // batches.
721            if !tail && old != new {
722                break;
723            }
724        }
725        if (old_group_size, new_group_size) == (0, 0) {
726            None
727        } else {
728            Some((old_group_size, new_group_size))
729        }
730    }
731}
732
733impl<A, B> FusedIterator for GroupBatchesForResize<A, B>
734where
735    A: Iterator<Item = i64>,
736    B: Iterator<Item = i64>,
737{
738}
739
740/// Iterator that packs together the elements of multiple sequences.
741///
742/// Does not allocate any heap memory.
743///
744/// # Example
745/// ```
746/// use relearn::torch::packed::PackedSeqIter;
747///
748/// let sequences: [&[_]; 3] = [&[0, 1, 2, 3], &[10, 11], &[100, 101]];
749/// let packed: Vec<_> = PackedSeqIter::from_sorted(&sequences).copied().collect();
750/// assert_eq!(packed, vec![0, 10, 100, 1, 11, 101, 2, 3]);
751/// ```
752#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
753pub struct PackedSeqIter<I> {
754    /// Initial copy of the sequence iterator. Never modified.
755    sequences: I,
756
757    /// Current offset within the sequences.
758    offset: usize,
759    /// Iterator of sequences for the current offset.
760    sequences_iter: I,
761}
762
763impl<I> PackedSeqIter<I>
764where
765    I: Iterator + Clone,
766    <I as Iterator>::Item: Sequence,
767{
768    /// Initialize from sequences sorted in monotonic decreasing order of length.
769    pub fn from_sorted<T: IntoIterator<IntoIter = I>>(into_sequences: T) -> Self {
770        let sequences = into_sequences.into_iter();
771        assert!(
772            sequences
773                .clone()
774                .zip(sequences.clone().skip(1))
775                .all(|(a, b)| a.len() >= b.len()),
776            "sequences not in monotonic decreasing order of length"
777        );
778        let sequences_iter = sequences.clone();
779        Self {
780            sequences,
781            offset: 0,
782            sequences_iter,
783        }
784    }
785}
786
787impl<I> Iterator for PackedSeqIter<I>
788where
789    I: Iterator + Clone,
790    <I as Iterator>::Item: Sequence,
791{
792    type Item = <I::Item as Sequence>::Item;
793
794    fn next(&mut self) -> Option<Self::Item> {
795        if let Some(value) = self
796            .sequences_iter
797            .next()
798            .and_then(|seq| seq.get(self.offset))
799        {
800            Some(value)
801        } else {
802            // Increment offset and restart the loop
803            self.offset += 1;
804            self.sequences_iter = self.sequences.clone();
805            // If this fails then there are no more items left
806            self.sequences_iter
807                .next()
808                .and_then(|seq| seq.get(self.offset))
809        }
810    }
811
812    fn size_hint(&self) -> (usize, Option<usize>) {
813        // Total size of all elements at index `offset` or later.
814        let level_size: usize = self
815            .sequences
816            .clone()
817            .map(|seq| seq.len().saturating_sub(self.offset))
818            .take_while(|&size| size > 0)
819            .sum();
820        let size = if level_size == 0 {
821            // This is handled because sequences_iter ends up incorrectly indicating that one
822            // element has been emitted when the iterator has been fully exhausted.
823            0
824        } else {
825            // Subtract the number of elements emitted so far at this offset level.
826            // NOTE: Could use size_hint() for these but already iterating over sequences to
827            // calculate level_size.
828            level_size - (self.sequences.clone().count() - self.sequences_iter.clone().count())
829        };
830        (size, Some(size))
831    }
832}
833
834impl<I> ExactSizeIterator for PackedSeqIter<I>
835where
836    I: ExactSizeIterator + Clone,
837    <I as Iterator>::Item: Sequence,
838{
839}
840
841#[cfg(test)]
842mod packed_seq_iter {
843    use super::*;
844
845    #[test]
846    fn iter() {
847        let data = [0, 1, 2, 3, 10, 11, 100, 101];
848        let ranges = [0..4, 4..6, 6..8];
849        let packed: Vec<_> = PackedSeqIter::from_sorted(&ranges)
850            .map(|i| data[i])
851            .collect();
852        let expected = vec![0, 10, 100, 1, 11, 101, 2, 3];
853        assert_eq!(packed, expected);
854    }
855
856    #[test]
857    fn size_hint() {
858        let ranges = [0..4, 4..6, 6..8];
859        let packing_indices = PackedSeqIter::from_sorted(&ranges);
860        assert_eq!(packing_indices.size_hint(), (8, Some(8)));
861    }
862
863    #[test]
864    fn size_hint_after_next() {
865        let ranges = [0..4, 4..6, 6..8];
866        let mut packing_indices = PackedSeqIter::from_sorted(&ranges);
867        let _ = packing_indices.next();
868        assert_eq!(packing_indices.size_hint(), (7, Some(7)));
869        let _ = packing_indices.next();
870        assert_eq!(packing_indices.size_hint(), (6, Some(6)));
871    }
872}
873
874#[cfg(test)]
875mod batch_sizes {
876    use super::*;
877
878    #[test]
879    fn from_sorted() {
880        let batch_sizes = BatchSizes::from_sorted_sequence_lengths([4, 2, 2]).unwrap();
881        assert_eq!(batch_sizes.batch_sizes, [3, 3, 1, 1]);
882    }
883
884    #[test]
885    fn from_increasing() {
886        assert_eq!(
887            BatchSizes::from_sorted_sequence_lengths([4, 5, 2]).unwrap_err(),
888            PackingError::Increasing
889        );
890    }
891}
892
893#[cfg(test)]
894#[allow(clippy::needless_pass_by_value)]
895mod packed_tensor {
896    use super::*;
897    use rstest::{fixture, rstest};
898
899    /// Packed tensor representing sequences: `[0, 1, 2, 3]`, `[10, 11]`, and `[100, 101]`.
900    #[fixture]
901    fn packed_tensor() -> PackedTensor {
902        PackedTensor::from_sorted_sequences([&[0, 1, 2, 3] as &[_], &[10, 11], &[100, 101]])
903            .unwrap()
904    }
905
906    #[test]
907    fn from_sorted_sequences() {
908        let packed_tensor =
909            PackedTensor::from_sorted_sequences([&[0, 1, 2, 3] as &[_], &[10, 11], &[100, 101]])
910                .unwrap();
911        assert_eq!(
912            packed_tensor.tensor(),
913            &Tensor::of_slice(&[0, 10, 100, 1, 11, 101, 2, 3])
914        );
915        assert_eq!(
916            packed_tensor.batch_sizes_tensor(),
917            Tensor::of_slice(&[3, 3, 1, 1])
918        );
919    }
920
921    #[rstest]
922    fn view_trim_start_n1(packed_tensor: PackedTensor) {
923        let actual = packed_tensor.view_trim_start(1);
924        let expected =
925            PackedTensor::from_sorted_sequences([&[1, 2, 3] as &[_], &[11], &[101]]).unwrap();
926        assert_eq!(actual, expected);
927    }
928
929    #[rstest]
930    fn view_trim_start_n3(packed_tensor: PackedTensor) {
931        let actual = packed_tensor.view_trim_start(3);
932        // Sequences: [3]
933        let expected = PackedTensor::from_sorted_sequences([&[3] as &[_]]).unwrap();
934        assert_eq!(actual, expected);
935    }
936
937    #[rstest]
938    fn view_trim_start_is_view(packed_tensor: PackedTensor) {
939        let mut trimmed = packed_tensor.view_trim_start(1);
940        let _ = trimmed.tensor.neg_();
941
942        let expected = PackedTensor::from_sorted_sequences([
943            &[0, -1, -2, -3] as &[_],
944            &[10, -11],
945            &[100, -101],
946        ])
947        .unwrap();
948        assert_eq!(packed_tensor, expected);
949    }
950
951    #[rstest]
952    fn trim_end_n1(packed_tensor: PackedTensor) {
953        let actual = packed_tensor.trim_end(1);
954        let expected =
955            PackedTensor::from_sorted_sequences([&[0, 1, 2] as &[_], &[10], &[100]]).unwrap();
956        assert_eq!(actual, expected);
957    }
958
959    #[rstest]
960    fn trim_end_n3(packed_tensor: PackedTensor) {
961        let actual = packed_tensor.trim_end(3);
962        let expected = PackedTensor::from_sorted_sequences([&[0] as &[_]]).unwrap();
963        assert_eq!(actual, expected);
964    }
965
966    #[rstest]
967    fn trim_end_is_copy(packed_tensor: PackedTensor) {
968        let mut trimmed = packed_tensor.trim_end(1);
969        let _ = trimmed.tensor.neg_();
970
971        // packed_tensor is unchanged
972        let expected =
973            PackedTensor::from_sorted_sequences([&[0, 1, 2, 3] as &[_], &[10, 11], &[100, 101]])
974                .unwrap();
975        assert_eq!(packed_tensor, expected);
976    }
977
978    #[test]
979    fn discounted_cumsum_from_end() {
980        let packed_tensor = PackedTensor::from_sorted_sequences([
981            &[1.0, 2.0, 3.0, 4.0] as &[_],
982            &[5.0, 6.0],
983            &[7.0, 8.0],
984        ])
985        .unwrap();
986
987        let cumsum = packed_tensor.discounted_cumsum_from_end(0.1);
988
989        // Sequences: [1.234, 2.34, 3.4, 4], [5.6, 6], [7.8, 8]
990        let expected = PackedTensor::from_sorted_sequences([
991            &[1.234, 2.34, 3.4, 4.0] as &[_],
992            &[5.6, 6.0],
993            &[7.8, 8.0],
994        ])
995        .unwrap();
996        assert_eq!(cumsum.structure, expected.structure);
997        assert!(
998            bool::from(
999                cumsum
1000                    .tensor
1001                    .isclose(&expected.tensor, 1e-8, 1e-8, false)
1002                    .all()
1003            ),
1004            "result: {:?}\nexpected: {:?}",
1005            cumsum,
1006            expected,
1007        );
1008    }
1009
1010    #[rstest]
1011    fn batch_sizes_tensor_values(packed_tensor: PackedTensor) {
1012        let actual = packed_tensor.structure.batch_sizes_tensor();
1013        let expected = Tensor::of_slice(&[3, 3, 1, 1]);
1014        assert_eq!(actual, expected);
1015    }
1016
1017    #[rstest]
1018    fn batch_sizes_tensor_device_cpu(packed_tensor: PackedTensor) {
1019        let batch_sizes = packed_tensor.structure.batch_sizes_tensor();
1020        assert_eq!(batch_sizes.device(), tch::Device::Cpu);
1021    }
1022}