1use super::serialize::TensorDef;
3use crate::torch::tensors::ExclusiveTensor;
4use crate::utils::sequence::Sequence;
5use ndarray::{azip, ArrayViewMut, Axis, IxDyn, Slice};
6use once_cell::sync::OnceCell;
7use serde::{Deserialize, Serialize};
8use serde_with::serde_as;
9use std::iter;
10use std::iter::{Fuse, FusedIterator};
11use std::ops::{AddAssign, Bound, Mul};
12use std::rc::Rc;
13use tch::{kind::Element, Device, IndexOp, Kind, Tensor};
14use thiserror::Error;
15
16#[derive(Error, Debug, Copy, Clone, PartialEq, Eq, Hash)]
18pub enum PackingError {
19 #[error("sequences lengths or batch sizes increased; should be monotonic decreasing")]
20 Increasing,
21 #[error("input tensor has < {expected} dimensions")]
22 TooFewDimensions { expected: u8 },
23}
24
25#[must_use]
36#[serde_as]
37#[derive(Debug, PartialEq, Serialize, Deserialize)]
38pub struct PackedTensor {
39 #[serde_as(as = "TensorDef")]
41 tensor: Tensor,
42 structure: PackedStructure,
44}
45
46impl Clone for PackedTensor {
47 fn clone(&self) -> Self {
48 Self {
49 tensor: self.tensor.shallow_clone(),
50 structure: self.structure.clone(),
51 }
52 }
53}
54
55impl PackedTensor {
56 #[inline]
62 pub fn from_parts(tensor: Tensor, structure: PackedStructure) -> Self {
63 assert_eq!(
64 structure.len() as i64,
65 *tensor
66 .size()
67 .first()
68 .expect("tensor must have at least 1 dimension"),
69 "structure length does not match tensor first dimension size"
70 );
71 Self { tensor, structure }
72 }
73
74 pub fn from_aligned_tensor(tensor: &Tensor) -> Result<Self, PackingError> {
82 let mut size = tensor.size();
83 if size.len() < 2 {
84 return Err(PackingError::TooFewDimensions { expected: 2 });
85 }
86 let sequence_length = size.remove(0);
87 let batch_size = size[0];
88
89 size[0] *= sequence_length;
90 Ok(Self {
91 tensor: tensor.reshape(&size),
92 structure: PackedStructure::Aligned {
93 sequence_length: sequence_length.try_into().unwrap(),
94 batch_size: batch_size.try_into().unwrap(),
95 },
96 })
97 }
98
99 #[inline]
103 pub fn from_sorted_sequences<'a, I, E>(slices: I) -> Result<Self, PackingError>
104 where
105 I: IntoIterator<Item = &'a [E]>,
106 I::IntoIter: Clone,
107 E: 'a + tch::kind::Element + Copy,
108 {
109 let sequences = slices.into_iter();
110 let structure =
111 PackedStructure::from_sorted_sequence_lengths(sequences.clone().map(<[E]>::len))?;
112 let data: Vec<_> = PackedSeqIter::from_sorted(sequences).copied().collect();
113 let tensor = Tensor::of_slice(&data);
114 Ok(Self { tensor, structure })
115 }
116
117 #[allow(clippy::missing_const_for_fn)] #[inline]
120 pub fn into_tensor(self) -> Tensor {
121 self.tensor
122 }
123
124 #[inline]
126 pub const fn tensor(&self) -> &Tensor {
127 &self.tensor
128 }
129
130 #[inline]
132 pub fn tensor_mut(&mut self) -> &mut Tensor {
133 &mut self.tensor
134 }
135
136 #[must_use]
138 #[inline]
139 pub const fn structure(&self) -> &PackedStructure {
140 &self.structure
141 }
142
143 #[must_use]
145 pub fn kind(&self) -> Kind {
146 self.tensor.kind()
147 }
148
149 #[must_use]
151 pub fn device(&self) -> Device {
152 self.tensor.device()
153 }
154
155 pub fn batch_sizes_tensor(&self) -> Tensor {
159 self.structure.batch_sizes_tensor()
160 }
161
162 #[must_use]
164 pub fn first_batch_size(&self) -> Option<i64> {
165 self.structure.first_batch_size()
166 }
167
168 #[inline]
174 pub fn batch_map<F: FnOnce(Tensor) -> Tensor>(self, f: F) -> Self {
175 Self {
176 tensor: f(self.tensor),
177 structure: self.structure,
178 }
179 }
180
181 #[inline]
187 pub fn batch_map_ref<'a, F: FnOnce(&'a Tensor) -> Tensor>(&'a self, f: F) -> Self {
188 Self {
189 tensor: f(&self.tensor),
190 structure: self.structure.clone(),
191 }
192 }
193
194 pub fn view_trim_start(&self, n: usize) -> Self {
196 let (to_remove, structure) = match &self.structure {
197 PackedStructure::Aligned {
198 sequence_length,
199 batch_size,
200 } => {
201 let n = n.min(*sequence_length);
202 let to_remove = n * *batch_size;
203 let new_structure = PackedStructure::Aligned {
204 sequence_length: *sequence_length - n,
205 batch_size: *batch_size,
206 };
207 (to_remove as i64, new_structure)
208 }
209 PackedStructure::Ragged(batch_sizes) => {
210 let to_remove = batch_sizes.as_slice()[..n].iter().copied().sum();
211 let new_structure = PackedStructure::Ragged(batch_sizes.clone().trim(n));
212 (to_remove, new_structure)
213 }
214 };
215 let tensor = self.tensor.i(to_remove..);
216
217 Self { tensor, structure }
218 }
219
220 pub fn trim_end(&self, n: usize) -> Self {
224 match &self.structure {
225 PackedStructure::Aligned {
226 sequence_length,
227 batch_size,
228 } => {
229 let n = n.min(*sequence_length);
230 let tensor = self.tensor.i(..(n * *batch_size) as i64);
231 let structure = PackedStructure::Aligned {
232 sequence_length: *sequence_length - n,
233 batch_size: *batch_size,
234 };
235 Self { tensor, structure }
236 }
237 PackedStructure::Ragged(batch_sizes) => {
238 let new_batch_sizes = batch_sizes.clone().trim(n);
239 let (old_group_sizes, new_group_sizes): (Vec<_>, Vec<_>) =
240 GroupBatchesForResize::new(
241 batch_sizes.as_slice().iter().copied(),
242 new_batch_sizes.as_slice().iter().copied(),
243 )
244 .unzip();
245
246 let groups = self.tensor.split_with_sizes(&old_group_sizes, 0);
248
249 let new_groups: Vec<_> = groups
253 .iter()
254 .zip(new_group_sizes)
255 .map(|(group, new_size)| group.i(..new_size))
256 .collect();
257
258 let new_tensor = Tensor::cat(&new_groups, 0);
260
261 Self {
262 tensor: new_tensor,
263 structure: PackedStructure::Ragged(new_batch_sizes),
264 }
265 }
266 }
267 }
268
269 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
280 pub fn discounted_cumsum_from_end<T>(&self, discount: T) -> Self
281 where
282 T: Mul + AddAssign<<T as Mul>::Output> + Copy + Element,
283 {
284 let mut out = ExclusiveTensor::<T, _>::try_copy_from(self.tensor()).unwrap();
285 match &self.structure {
286 PackedStructure::Ragged(batch_sizes) => {
287 inplace_discounted_cumsum_from_end(
288 out.array_view_mut(),
289 discount,
290 batch_sizes.as_slice().iter().map(|b| *b as usize).rev(),
291 );
292 }
293 PackedStructure::Aligned {
294 sequence_length,
295 batch_size,
296 } => {
297 inplace_discounted_cumsum_from_end(
298 out.array_view_mut(),
299 discount,
300 iter::repeat(*batch_size).take(*sequence_length),
301 );
302 }
303 }
304 Self {
305 tensor: out.into_tensor().to_device(self.tensor.device()),
306 structure: self.structure.clone(),
307 }
308 }
309}
310
311#[allow(clippy::cast_possible_wrap)]
312fn inplace_discounted_cumsum_from_end<I, T>(
313 mut array: ArrayViewMut<T, IxDyn>,
314 discount: T,
315 rev_batch_sizes: I, ) where
317 I: IntoIterator<Item = usize>,
318 T: Mul + AddAssign<<T as Mul>::Output> + Copy,
319{
320 let mut offset = array.shape()[0]; for batch_size in rev_batch_sizes {
323 let (left, prev_batch) = array.split_at(Axis(0), offset);
324 array = left;
325 offset -= batch_size;
326
327 let prev_batch_size = prev_batch.shape()[0];
328 let batch_part = array.slice_axis_mut(
329 Axis(0),
330 Slice {
331 start: offset as isize,
332 end: Some((offset + prev_batch_size) as isize),
333 step: 1,
334 },
335 );
336 azip!((a in batch_part, b in &prev_batch) *a += *b * discount);
337 }
338 assert_eq!(
339 offset, 0,
340 "batch sizes do not match array first dimension length"
341 );
342}
343
344#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
346pub enum PackedStructure {
347 Ragged(SharedBatchSizes),
349 Aligned {
351 sequence_length: usize,
352 batch_size: usize,
354 },
355}
356
357impl PackedStructure {
358 pub fn from_batch_sizes<I: IntoIterator<Item = usize>>(
362 batch_sizes: I,
363 ) -> Result<Self, PackingError> {
364 Ok(Self::Ragged(SharedBatchSizes::from_batch_sizes(
365 batch_sizes,
366 )?))
367 }
368
369 pub fn from_sorted_sequence_lengths<I: IntoIterator<Item = usize>>(
373 lengths: I,
374 ) -> Result<Self, PackingError> {
375 Ok(Self::Ragged(
376 SharedBatchSizes::from_sorted_sequence_lengths(lengths)?,
377 ))
378 }
379
380 pub fn batch_sizes_tensor(&self) -> Tensor {
384 match self {
385 Self::Ragged(batch_sizes) => batch_sizes.tensor(),
386 Self::Aligned {
387 sequence_length,
388 batch_size,
389 } => Tensor::full(
390 &[*sequence_length as i64],
391 *batch_size as i64,
392 (Kind::Int64, Device::Cpu),
393 ),
394 }
395 }
396
397 #[must_use]
399 pub fn first_batch_size(&self) -> Option<i64> {
400 match self {
401 Self::Ragged(batch_sizes) => batch_sizes.as_slice().first().copied(),
402 Self::Aligned {
403 sequence_length,
404 batch_size,
405 } => {
406 if *sequence_length > 0 {
407 Some(*batch_size as _)
408 } else {
409 None
410 }
411 }
412 }
413 }
414
415 #[must_use]
417 pub fn len(&self) -> usize {
418 match self {
419 Self::Ragged(batch_sizes) => batch_sizes.len(),
420 Self::Aligned {
421 sequence_length,
422 batch_size,
423 } => sequence_length * batch_size,
424 }
425 }
426
427 #[must_use]
429 pub fn is_empty(&self) -> bool {
430 match self {
431 Self::Ragged(batch_sizes) => batch_sizes.is_empty(),
432 Self::Aligned {
433 sequence_length,
434 batch_size,
435 } => *sequence_length == 0 || *batch_size == 0,
436 }
437 }
438
439 #[allow(clippy::missing_const_for_fn)] #[must_use]
444 pub fn trim(self, n: usize) -> Self {
445 match self {
446 Self::Ragged(batch_sizes) => Self::Ragged(batch_sizes.trim(n)),
447 Self::Aligned {
448 sequence_length,
449 batch_size,
450 } => Self::Aligned {
451 sequence_length: sequence_length.saturating_sub(n),
452 batch_size,
453 },
454 }
455 }
456}
457
458#[derive(Debug, Clone, Serialize, Deserialize)]
463pub struct SharedBatchSizes {
464 root: Rc<BatchSizes>,
465 start: usize, end: Option<usize>, }
468
469impl AsRef<[i64]> for SharedBatchSizes {
470 #[inline]
471 fn as_ref(&self) -> &[i64] {
472 self.as_slice()
473 }
474}
475
476impl<T: AsRef<[i64]>> PartialEq<T> for SharedBatchSizes {
477 #[inline]
478 fn eq(&self, other: &T) -> bool {
479 self.as_ref() == other.as_ref()
480 }
481}
482
483impl Eq for SharedBatchSizes {}
484
485impl SharedBatchSizes {
486 pub fn from_batch_sizes<I: IntoIterator<Item = usize>>(
490 batch_sizes: I,
491 ) -> Result<Self, PackingError> {
492 Ok(Self {
493 root: Rc::new(BatchSizes::from_batch_sizes(batch_sizes)?),
494 start: 0,
495 end: None,
496 })
497 }
498
499 pub fn from_sorted_sequence_lengths<I: IntoIterator<Item = usize>>(
503 lengths: I,
504 ) -> Result<Self, PackingError> {
505 Ok(Self {
506 root: Rc::new(BatchSizes::from_sorted_sequence_lengths(lengths)?),
507 start: 0,
508 end: None,
509 })
510 }
511
512 #[inline]
514 pub fn as_slice(&self) -> &[i64] {
515 let start = Bound::Included(self.start);
516 let end = self.end.map_or(Bound::Unbounded, Bound::Excluded);
517 &self.root.as_slice()[(start, end)]
518 }
519
520 #[inline]
522 pub fn tensor(&self) -> Tensor {
523 let root_tensor = self.root.as_tensor();
524
525 if self.start == 0 && self.end.is_none() {
526 root_tensor.shallow_clone()
527 } else {
528 let end = self.end.map(|i| i as i64);
529 root_tensor.slice(0, self.start as i64, end, 1)
530 }
531 }
532
533 #[must_use]
535 pub fn len(&self) -> usize {
536 self.as_slice()
537 .iter()
538 .map(|x| usize::try_from(*x).unwrap())
540 .sum()
541 }
542
543 #[must_use]
545 pub fn is_empty(&self) -> bool {
546 self.as_slice().iter().all(|x| *x == 0)
547 }
548
549 #[must_use]
551 pub const fn trim(mut self, n: usize) -> Self {
552 self.start += n;
553 self
554 }
555}
556
557#[derive(Debug, PartialEq, Serialize, Deserialize)]
558pub struct BatchSizes {
559 batch_sizes: Vec<i64>,
563
564 #[serde(skip)]
566 batch_sizes_tensor: OnceCell<Tensor>,
567}
568
569impl AsRef<[i64]> for BatchSizes {
570 #[inline]
571 fn as_ref(&self) -> &[i64] {
572 self.as_slice()
573 }
574}
575
576impl BatchSizes {
577 pub fn from_batch_sizes<I: IntoIterator<Item = usize>>(
581 batch_sizes: I,
582 ) -> Result<Self, PackingError> {
583 let mut prev = usize::MAX;
584 let batch_sizes: Vec<_> = batch_sizes
585 .into_iter()
586 .map(|x| {
587 if x > prev {
588 Err(PackingError::Increasing)
589 } else {
590 prev = x;
591 Ok(x as i64)
592 }
593 })
594 .collect::<Result<_, _>>()?;
595 Ok(Self {
596 batch_sizes,
597 batch_sizes_tensor: OnceCell::new(),
598 })
599 }
600
601 pub fn from_sorted_sequence_lengths<I: IntoIterator<Item = usize>>(
605 lengths: I,
606 ) -> Result<Self, PackingError> {
607 let mut lengths = lengths.into_iter().enumerate().peekable();
608
609 let (_, max_seq_len) = lengths.peek().copied().unwrap_or((0, 0));
610 let mut batch_sizes = vec![0; max_seq_len];
611
612 while let Some((i, seq_len)) = lengths.next() {
613 let (_, next_len) = lengths.peek().copied().unwrap_or((0, 0));
616 if next_len > seq_len {
617 return Err(PackingError::Increasing);
618 }
619 batch_sizes[next_len..seq_len].fill((i + 1) as i64);
620 }
621 Ok(Self {
622 batch_sizes,
623 batch_sizes_tensor: OnceCell::new(),
624 })
625 }
626
627 #[inline]
629 pub fn as_slice(&self) -> &[i64] {
630 self.batch_sizes.as_slice()
631 }
632
633 #[inline]
635 pub fn as_tensor(&self) -> &Tensor {
636 self.batch_sizes_tensor
637 .get_or_init(|| Tensor::of_slice(&self.batch_sizes))
638 }
639
640 #[inline]
642 pub fn len(&self) -> usize {
643 self.batch_sizes
644 .iter()
645 .map(|x| usize::try_from(*x).unwrap())
647 .sum()
648 }
649
650 #[inline]
652 pub fn is_empty(&self) -> bool {
653 self.batch_sizes.iter().all(|x| *x == 0)
654 }
655}
656
657struct GroupBatchesForResize<A, B> {
672 old_batch_sizes: Fuse<A>,
673 new_batch_sizes: Fuse<B>,
674}
675
676impl<A, B> GroupBatchesForResize<A, B>
677where
678 A: Iterator,
679 B: Iterator,
680{
681 pub fn new<IA, IB>(old_batch_sizes: IA, new_batch_sizes: IB) -> Self
682 where
683 IA: IntoIterator<IntoIter = A>,
684 IB: IntoIterator<IntoIter = B>,
685 {
686 Self {
687 old_batch_sizes: old_batch_sizes.into_iter().fuse(),
688 new_batch_sizes: new_batch_sizes.into_iter().fuse(),
689 }
690 }
691}
692
693impl<A, B> Iterator for GroupBatchesForResize<A, B>
694where
695 A: Iterator<Item = i64>,
696 B: Iterator<Item = i64>,
697{
698 type Item = (i64, i64);
699
700 fn next(&mut self) -> Option<Self::Item> {
701 let mut old_group_size = 0;
705 let mut new_group_size = 0;
706 loop {
707 let (old, new, tail) = match (self.old_batch_sizes.next(), self.new_batch_sizes.next())
708 {
709 (Some(old), Some(new)) => (old, new, false),
710 (Some(old), None) => (old, 0, true),
711 (None, Some(new)) => (0, new, true),
712 (None, None) => break,
713 };
714 old_group_size += old;
715 new_group_size += new;
716
717 if !tail && old != new {
722 break;
723 }
724 }
725 if (old_group_size, new_group_size) == (0, 0) {
726 None
727 } else {
728 Some((old_group_size, new_group_size))
729 }
730 }
731}
732
733impl<A, B> FusedIterator for GroupBatchesForResize<A, B>
734where
735 A: Iterator<Item = i64>,
736 B: Iterator<Item = i64>,
737{
738}
739
740#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
753pub struct PackedSeqIter<I> {
754 sequences: I,
756
757 offset: usize,
759 sequences_iter: I,
761}
762
763impl<I> PackedSeqIter<I>
764where
765 I: Iterator + Clone,
766 <I as Iterator>::Item: Sequence,
767{
768 pub fn from_sorted<T: IntoIterator<IntoIter = I>>(into_sequences: T) -> Self {
770 let sequences = into_sequences.into_iter();
771 assert!(
772 sequences
773 .clone()
774 .zip(sequences.clone().skip(1))
775 .all(|(a, b)| a.len() >= b.len()),
776 "sequences not in monotonic decreasing order of length"
777 );
778 let sequences_iter = sequences.clone();
779 Self {
780 sequences,
781 offset: 0,
782 sequences_iter,
783 }
784 }
785}
786
787impl<I> Iterator for PackedSeqIter<I>
788where
789 I: Iterator + Clone,
790 <I as Iterator>::Item: Sequence,
791{
792 type Item = <I::Item as Sequence>::Item;
793
794 fn next(&mut self) -> Option<Self::Item> {
795 if let Some(value) = self
796 .sequences_iter
797 .next()
798 .and_then(|seq| seq.get(self.offset))
799 {
800 Some(value)
801 } else {
802 self.offset += 1;
804 self.sequences_iter = self.sequences.clone();
805 self.sequences_iter
807 .next()
808 .and_then(|seq| seq.get(self.offset))
809 }
810 }
811
812 fn size_hint(&self) -> (usize, Option<usize>) {
813 let level_size: usize = self
815 .sequences
816 .clone()
817 .map(|seq| seq.len().saturating_sub(self.offset))
818 .take_while(|&size| size > 0)
819 .sum();
820 let size = if level_size == 0 {
821 0
824 } else {
825 level_size - (self.sequences.clone().count() - self.sequences_iter.clone().count())
829 };
830 (size, Some(size))
831 }
832}
833
834impl<I> ExactSizeIterator for PackedSeqIter<I>
835where
836 I: ExactSizeIterator + Clone,
837 <I as Iterator>::Item: Sequence,
838{
839}
840
841#[cfg(test)]
842mod packed_seq_iter {
843 use super::*;
844
845 #[test]
846 fn iter() {
847 let data = [0, 1, 2, 3, 10, 11, 100, 101];
848 let ranges = [0..4, 4..6, 6..8];
849 let packed: Vec<_> = PackedSeqIter::from_sorted(&ranges)
850 .map(|i| data[i])
851 .collect();
852 let expected = vec![0, 10, 100, 1, 11, 101, 2, 3];
853 assert_eq!(packed, expected);
854 }
855
856 #[test]
857 fn size_hint() {
858 let ranges = [0..4, 4..6, 6..8];
859 let packing_indices = PackedSeqIter::from_sorted(&ranges);
860 assert_eq!(packing_indices.size_hint(), (8, Some(8)));
861 }
862
863 #[test]
864 fn size_hint_after_next() {
865 let ranges = [0..4, 4..6, 6..8];
866 let mut packing_indices = PackedSeqIter::from_sorted(&ranges);
867 let _ = packing_indices.next();
868 assert_eq!(packing_indices.size_hint(), (7, Some(7)));
869 let _ = packing_indices.next();
870 assert_eq!(packing_indices.size_hint(), (6, Some(6)));
871 }
872}
873
874#[cfg(test)]
875mod batch_sizes {
876 use super::*;
877
878 #[test]
879 fn from_sorted() {
880 let batch_sizes = BatchSizes::from_sorted_sequence_lengths([4, 2, 2]).unwrap();
881 assert_eq!(batch_sizes.batch_sizes, [3, 3, 1, 1]);
882 }
883
884 #[test]
885 fn from_increasing() {
886 assert_eq!(
887 BatchSizes::from_sorted_sequence_lengths([4, 5, 2]).unwrap_err(),
888 PackingError::Increasing
889 );
890 }
891}
892
893#[cfg(test)]
894#[allow(clippy::needless_pass_by_value)]
895mod packed_tensor {
896 use super::*;
897 use rstest::{fixture, rstest};
898
899 #[fixture]
901 fn packed_tensor() -> PackedTensor {
902 PackedTensor::from_sorted_sequences([&[0, 1, 2, 3] as &[_], &[10, 11], &[100, 101]])
903 .unwrap()
904 }
905
906 #[test]
907 fn from_sorted_sequences() {
908 let packed_tensor =
909 PackedTensor::from_sorted_sequences([&[0, 1, 2, 3] as &[_], &[10, 11], &[100, 101]])
910 .unwrap();
911 assert_eq!(
912 packed_tensor.tensor(),
913 &Tensor::of_slice(&[0, 10, 100, 1, 11, 101, 2, 3])
914 );
915 assert_eq!(
916 packed_tensor.batch_sizes_tensor(),
917 Tensor::of_slice(&[3, 3, 1, 1])
918 );
919 }
920
921 #[rstest]
922 fn view_trim_start_n1(packed_tensor: PackedTensor) {
923 let actual = packed_tensor.view_trim_start(1);
924 let expected =
925 PackedTensor::from_sorted_sequences([&[1, 2, 3] as &[_], &[11], &[101]]).unwrap();
926 assert_eq!(actual, expected);
927 }
928
929 #[rstest]
930 fn view_trim_start_n3(packed_tensor: PackedTensor) {
931 let actual = packed_tensor.view_trim_start(3);
932 let expected = PackedTensor::from_sorted_sequences([&[3] as &[_]]).unwrap();
934 assert_eq!(actual, expected);
935 }
936
937 #[rstest]
938 fn view_trim_start_is_view(packed_tensor: PackedTensor) {
939 let mut trimmed = packed_tensor.view_trim_start(1);
940 let _ = trimmed.tensor.neg_();
941
942 let expected = PackedTensor::from_sorted_sequences([
943 &[0, -1, -2, -3] as &[_],
944 &[10, -11],
945 &[100, -101],
946 ])
947 .unwrap();
948 assert_eq!(packed_tensor, expected);
949 }
950
951 #[rstest]
952 fn trim_end_n1(packed_tensor: PackedTensor) {
953 let actual = packed_tensor.trim_end(1);
954 let expected =
955 PackedTensor::from_sorted_sequences([&[0, 1, 2] as &[_], &[10], &[100]]).unwrap();
956 assert_eq!(actual, expected);
957 }
958
959 #[rstest]
960 fn trim_end_n3(packed_tensor: PackedTensor) {
961 let actual = packed_tensor.trim_end(3);
962 let expected = PackedTensor::from_sorted_sequences([&[0] as &[_]]).unwrap();
963 assert_eq!(actual, expected);
964 }
965
966 #[rstest]
967 fn trim_end_is_copy(packed_tensor: PackedTensor) {
968 let mut trimmed = packed_tensor.trim_end(1);
969 let _ = trimmed.tensor.neg_();
970
971 let expected =
973 PackedTensor::from_sorted_sequences([&[0, 1, 2, 3] as &[_], &[10, 11], &[100, 101]])
974 .unwrap();
975 assert_eq!(packed_tensor, expected);
976 }
977
978 #[test]
979 fn discounted_cumsum_from_end() {
980 let packed_tensor = PackedTensor::from_sorted_sequences([
981 &[1.0, 2.0, 3.0, 4.0] as &[_],
982 &[5.0, 6.0],
983 &[7.0, 8.0],
984 ])
985 .unwrap();
986
987 let cumsum = packed_tensor.discounted_cumsum_from_end(0.1);
988
989 let expected = PackedTensor::from_sorted_sequences([
991 &[1.234, 2.34, 3.4, 4.0] as &[_],
992 &[5.6, 6.0],
993 &[7.8, 8.0],
994 ])
995 .unwrap();
996 assert_eq!(cumsum.structure, expected.structure);
997 assert!(
998 bool::from(
999 cumsum
1000 .tensor
1001 .isclose(&expected.tensor, 1e-8, 1e-8, false)
1002 .all()
1003 ),
1004 "result: {:?}\nexpected: {:?}",
1005 cumsum,
1006 expected,
1007 );
1008 }
1009
1010 #[rstest]
1011 fn batch_sizes_tensor_values(packed_tensor: PackedTensor) {
1012 let actual = packed_tensor.structure.batch_sizes_tensor();
1013 let expected = Tensor::of_slice(&[3, 3, 1, 1]);
1014 assert_eq!(actual, expected);
1015 }
1016
1017 #[rstest]
1018 fn batch_sizes_tensor_device_cpu(packed_tensor: PackedTensor) {
1019 let batch_sizes = packed_tensor.structure.batch_sizes_tensor();
1020 assert_eq!(batch_sizes.device(), tch::Device::Cpu);
1021 }
1022}