burn_std/tensor/
slice.rs

1//! Tensor slice utilities.
2
3use crate::Shape;
4use crate::indexing::AsIndex;
5use alloc::format;
6use alloc::vec::Vec;
7use core::fmt::{Display, Formatter};
8use core::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
9use core::str::FromStr;
10
11/// Trait for slice arguments that can be converted into an array of slices.
12/// This allows the `slice` method to accept both single slices (from `s![..]`)
13/// and arrays of slices (from `s![.., ..]` or `[0..5, 1..3]`).
14pub trait SliceArg {
15    /// Convert to an vec of slices with clamping to shape dimensions.
16    ///
17    /// Returns a [Slice] for each dimension in `shape`.
18    fn into_slices(self, shape: &Shape) -> Vec<Slice>;
19}
20
21impl<S: Into<Slice> + Clone> SliceArg for &[S] {
22    fn into_slices(self, shape: &Shape) -> Vec<Slice> {
23        assert!(
24            self.len() <= shape.num_dims(),
25            "Too many slices provided for shape, got {} but expected at most {}",
26            self.len(),
27            shape.num_dims()
28        );
29
30        shape
31            .iter()
32            .enumerate()
33            .map(|(i, dim_size)| {
34                let slice = if i >= self.len() {
35                    Slice::full()
36                } else {
37                    self[i].clone().into()
38                };
39                // Apply shape clamping by converting to range and back
40                let clamped_range = slice.to_range(*dim_size);
41                Slice::new(
42                    clamped_range.start as isize,
43                    Some(clamped_range.end as isize),
44                    slice.step(),
45                )
46            })
47            .collect::<Vec<_>>()
48    }
49}
50
51impl SliceArg for &Vec<Slice> {
52    fn into_slices(self, shape: &Shape) -> Vec<Slice> {
53        self.as_slice().into_slices(shape)
54    }
55}
56
57impl<const R: usize, T> SliceArg for [T; R]
58where
59    T: Into<Slice> + Clone,
60{
61    fn into_slices(self, shape: &Shape) -> Vec<Slice> {
62        self.as_slice().into_slices(shape)
63    }
64}
65
66impl<T> SliceArg for T
67where
68    T: Into<Slice>,
69{
70    fn into_slices(self, shape: &Shape) -> Vec<Slice> {
71        let slice: Slice = self.into();
72        [slice].as_slice().into_slices(shape)
73    }
74}
75
76/// Slice argument constructor for tensor indexing.
77///
78/// The `s![]` macro is used to create multi-dimensional slice specifications for tensors.
79/// It converts various range syntax forms into a `&[Slice]` that can be used with
80/// `tensor.slice()` and `tensor.slice_assign()` operations.
81///
82/// # Syntax Overview
83///
84/// ## Basic Forms
85///
86/// * **`s![index]`** - Index a single element (produces a subview with that axis removed)
87/// * **`s![range]`** - Slice a range of elements
88/// * **`s![range;step]`** - Slice a range with a custom step
89/// * **`s![dim1, dim2, ...]`** - Multiple dimensions, each can be any of the above forms
90///
91/// ## Range Types
92///
93/// All standard Rust range types are supported:
94/// * **`a..b`** - From `a` (inclusive) to `b` (exclusive)
95/// * **`a..=b`** - From `a` to `b` (both inclusive)
96/// * **`a..`** - From `a` to the end
97/// * **`..b`** - From the beginning to `b` (exclusive)
98/// * **`..=b`** - From the beginning to `b` (inclusive)
99/// * **`..`** - The full range (all elements)
100///
101/// ## Negative Indices
102///
103/// Negative indices count from the end of the axis:
104/// * **`-1`** refers to the last element
105/// * **`-2`** refers to the second-to-last element
106/// * And so on...
107///
108/// This works in all range forms: `s![-3..-1]`, `s![-2..]`, `s![..-1]`
109///
110/// ## Step Syntax
111///
112/// Steps control the stride between selected elements:
113/// * **`;step`** after a range specifies the step
114/// * **Positive steps** select every nth element going forward
115/// * **Negative steps** select every nth element going backward
116/// * Default step is `1` when not specified
117/// * Step cannot be `0`
118///
119/// ### Negative Step Behavior
120///
121/// With negative steps, the range bounds still specify *which* elements to include,
122/// but the traversal order is reversed:
123///
124/// * `s![0..5;-1]` selects indices `[4, 3, 2, 1, 0]` (not `[0, 1, 2, 3, 4]`)
125/// * `s![2..8;-2]` selects indices `[7, 5, 3]` (starting from 7, going backward by 2)
126/// * `s![..;-1]` reverses the entire axis
127///
128/// This matches the semantics of NumPy and the ndarray crate.
129///
130/// # Examples
131///
132/// ## Basic Slicing
133///
134/// ```rust,ignore
135/// use burn_tensor::{Tensor, s};
136///
137/// # fn example<B: Backend>(tensor: Tensor<B, 3>) {
138/// // Select rows 0-5 (exclusive)
139/// let subset = tensor.slice(s![0..5, .., ..]);
140///
141/// // Select the last row
142/// let last_row = tensor.slice(s![-1, .., ..]);
143///
144/// // Select columns 2, 3, 4
145/// let cols = tensor.slice(s![.., 2..5, ..]);
146///
147/// // Select a single element at position [1, 2, 3]
148/// let element = tensor.slice(s![1, 2, 3]);
149/// # }
150/// ```
151///
152/// ## Slicing with Steps
153///
154/// ```rust,ignore
155/// use burn_tensor::{Tensor, s};
156///
157/// # fn example<B: Backend>(tensor: Tensor<B, 2>) {
158/// // Select every 2nd row
159/// let even_rows = tensor.slice(s![0..10;2, ..]);
160///
161/// // Select every 3rd column
162/// let cols = tensor.slice(s![.., 0..9;3]);
163///
164/// // Select every 2nd element in reverse order
165/// let reversed_even = tensor.slice(s![10..0;-2, ..]);
166/// # }
167/// ```
168///
169/// ## Reversing Dimensions
170///
171/// ```rust,ignore
172/// use burn_tensor::{Tensor, s};
173///
174/// # fn example<B: Backend>(tensor: Tensor<B, 2>) {
175/// // Reverse the first dimension
176/// let reversed = tensor.slice(s![..;-1, ..]);
177///
178/// // Reverse both dimensions
179/// let fully_reversed = tensor.slice(s![..;-1, ..;-1]);
180///
181/// // Reverse a specific range
182/// let range_reversed = tensor.slice(s![2..8;-1, ..]);
183/// # }
184/// ```
185///
186/// ## Complex Multi-dimensional Slicing
187///
188/// ```rust,ignore
189/// use burn_tensor::{Tensor, s};
190///
191/// # fn example<B: Backend>(tensor: Tensor<B, 4>) {
192/// // Mix of different slice types
193/// let complex = tensor.slice(s![
194///     0..10;2,    // Every 2nd element from 0 to 10
195///     ..,         // All elements in dimension 1
196///     5..15;-3,   // Every 3rd element from 14 down to 5
197///     -1          // Last element in dimension 3
198/// ]);
199///
200/// // Using inclusive ranges
201/// let inclusive = tensor.slice(s![2..=5, 1..=3, .., ..]);
202///
203/// // Negative indices with steps
204/// let from_end = tensor.slice(s![-5..-1;2, .., .., ..]);
205/// # }
206/// ```
207///
208/// ## Slice Assignment
209///
210/// ```rust,ignore
211/// use burn_tensor::{Tensor, s};
212///
213/// # fn example<B: Backend>(tensor: Tensor<B, 2>, values: Tensor<B, 2>) {
214/// // Assign to every 2nd row
215/// let tensor = tensor.slice_assign(s![0..10;2, ..], values);
216///
217/// // Assign to a reversed slice
218/// let tensor = tensor.slice_assign(s![..;-1, 0..5], values);
219/// # }
220/// ```
221#[macro_export]
222macro_rules! s {
223    // Empty - should not happen
224    [] => {
225        compile_error!("Empty slice specification")
226    };
227
228    // Single expression with step
229    [$range:expr; $step:expr] => {
230        {
231            #[allow(clippy::reversed_empty_ranges)]
232            {
233                $crate::tensor::Slice::from_range_stepped($range, $step)
234            }
235        }
236    };
237
238    // Single expression without step (no comma after)
239    [$range:expr] => {
240        {
241            #[allow(clippy::reversed_empty_ranges)]
242            {
243                $crate::tensor::Slice::from($range)
244            }
245        }
246    };
247
248    // Two or more expressions with first having step
249    [$range:expr; $step:expr, $($rest:tt)*] => {
250        {
251            #[allow(clippy::reversed_empty_ranges)]
252            {
253                $crate::s!(@internal [$crate::tensor::Slice::from_range_stepped($range, $step)] $($rest)*)
254            }
255        }
256    };
257
258    // Two or more expressions with first not having step
259    [$range:expr, $($rest:tt)*] => {
260        {
261            #[allow(clippy::reversed_empty_ranges)]
262            {
263                $crate::s!(@internal [$crate::tensor::Slice::from($range)] $($rest)*)
264            }
265        }
266    };
267
268    // Internal: finished parsing
269    (@internal [$($acc:expr),*]) => {
270        [$($acc),*]
271    };
272
273    // Internal: parse range with step followed by comma
274    (@internal [$($acc:expr),*] $range:expr; $step:expr, $($rest:tt)*) => {
275        $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)] $($rest)*)
276    };
277
278    // Internal: parse range with step at end
279    (@internal [$($acc:expr),*] $range:expr; $step:expr) => {
280        $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)])
281    };
282
283    // Internal: parse range without step followed by comma
284    (@internal [$($acc:expr),*] $range:expr, $($rest:tt)*) => {
285        $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)] $($rest)*)
286    };
287
288    // Internal: parse range without step at end
289    (@internal [$($acc:expr),*] $range:expr) => {
290        $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)])
291    };
292}
293
294/// A slice specification for a single tensor dimension.
295///
296/// This struct represents a range with an optional step, used for advanced indexing
297/// operations on tensors. It is typically created using the [`s!`] macro rather than
298/// constructed directly.
299///
300/// # Fields
301///
302/// * `start` - The starting index (inclusive). Negative values count from the end.
303/// * `end` - The ending index (exclusive). `None` means to the end of the dimension.
304/// * `step` - The stride between elements. Must be non-zero.
305///
306/// # Index Interpretation
307///
308/// - **Positive indices**: Count from the beginning (0-based)
309/// - **Negative indices**: Count from the end (-1 is the last element)
310/// - **Bounds checking**: Indices are clamped to valid ranges
311///
312/// # Step Behavior
313///
314/// - **Positive step**: Traverse forward through the range
315/// - **Negative step**: Traverse backward through the range
316/// - **Step size**: Determines how many elements to skip
317///
318/// # Examples
319///
320/// While you typically use the [`s!`] macro, you can also construct slices directly:
321///
322/// ```rust,ignore
323/// use burn_tensor::Slice;
324///
325/// // Equivalent to s![2..8]
326/// let slice1 = Slice::new(2, Some(8), 1);
327///
328/// // Equivalent to s![0..10;2]
329/// let slice2 = Slice::new(0, Some(10), 2);
330///
331/// // Equivalent to s![..;-1] (reverse)
332/// let slice3 = Slice::new(0, None, -1);
333/// ```
334///
335/// See also the [`s!`] macro for the preferred way to create slices.
336#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
337pub struct Slice {
338    /// Slice start index.
339    pub start: isize,
340    /// Slice end index (exclusive).
341    pub end: Option<isize>,
342    /// Step between elements (default: 1).
343    pub step: isize,
344}
345
346/// Defines an [`Iterator`] over a [`Slice`].
347#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
348pub struct SliceIter {
349    slice: Slice,
350    current: isize,
351}
352
353impl Iterator for SliceIter {
354    type Item = isize;
355
356    fn next(&mut self) -> Option<Self::Item> {
357        let next = self.current;
358        self.current += self.slice.step;
359
360        if let Some(end) = self.slice.end {
361            if self.slice.is_reversed() {
362                if next <= end {
363                    return None;
364                }
365            } else if next >= end {
366                return None;
367            }
368        }
369
370        Some(next)
371    }
372}
373
374/// Note: Unbounded [`Slice`]s produce infinite iterators.
375impl IntoIterator for Slice {
376    type Item = isize;
377    type IntoIter = SliceIter;
378
379    fn into_iter(self) -> Self::IntoIter {
380        SliceIter {
381            slice: self,
382            current: self.start,
383        }
384    }
385}
386
387impl Default for Slice {
388    fn default() -> Self {
389        Self::full()
390    }
391}
392
393impl Slice {
394    /// Creates a new slice with start, end, and step
395    pub const fn new(start: isize, end: Option<isize>, step: isize) -> Self {
396        assert!(step != 0, "Step cannot be zero");
397        Self { start, end, step }
398    }
399
400    /// Creates a slice that represents the full range.
401    pub const fn full() -> Self {
402        Self::new(0, None, 1)
403    }
404
405    /// Creates a slice that represents a single index
406    pub fn index(idx: isize) -> Self {
407        Self {
408            start: idx,
409            end: handle_signed_inclusive_end(idx),
410            step: 1,
411        }
412    }
413
414    /// Converts the slice to a vector.
415    pub fn into_vec(self) -> Vec<isize> {
416        assert!(
417            self.end.is_some(),
418            "Slice must have an end to convert to a vector: {self:?}"
419        );
420        self.into_iter().collect()
421    }
422
423    /// Clips the slice to a maximum size.
424    ///
425    /// # Example
426    ///
427    /// ```rust,ignore
428    /// assert_eq!(
429    ///     Slice::new(0, None, 1).bound_to(10),
430    ///     Slice::new(0, Some(10), 1));
431    /// assert_eq!(
432    ///     Slice::new(0, Some(5), 1).bound_to(10),
433    ///     Slice::new(0, Some(5), 1));
434    /// assert_eq!(
435    ///     Slice::new(0, None, -1).bound_to(10),
436    ///     Slice::new(0, Some(-11), -1));
437    /// assert_eq!(
438    ///     Slice::new(0, Some(-5), -1).bound_to(10),
439    ///     Slice::new(0, Some(-5), -1));
440    /// ```
441    pub fn bound_to(self, size: usize) -> Self {
442        let mut bounds = size as isize;
443
444        if let Some(end) = self.end {
445            if end > 0 {
446                bounds = end.min(bounds);
447            } else {
448                bounds = end.max(-(bounds + 1));
449            }
450        } else if self.is_reversed() {
451            bounds = -(bounds + 1);
452        }
453
454        Self {
455            end: Some(bounds),
456            ..self
457        }
458    }
459
460    /// Creates a slice with a custom step
461    pub fn with_step(start: isize, end: Option<isize>, step: isize) -> Self {
462        assert!(step != 0, "Step cannot be zero");
463        Self { start, end, step }
464    }
465
466    /// Creates a slice from a range with a specified step
467    pub fn from_range_stepped<R: Into<Slice>>(range: R, step: isize) -> Self {
468        assert!(step != 0, "Step cannot be zero");
469        let mut slice = range.into();
470        slice.step = step;
471        slice
472    }
473
474    /// Returns the step of the slice
475    pub fn step(&self) -> isize {
476        self.step
477    }
478
479    /// Returns the range for this slice given a dimension size
480    pub fn range(&self, size: usize) -> Range<usize> {
481        self.to_range(size)
482    }
483
484    /// Convert this slice to a range for a dimension of the given size.
485    ///
486    /// # Arguments
487    ///
488    /// * `size` - The size of the dimension to slice.
489    ///
490    /// # Returns
491    ///
492    /// A `Range<usize>` representing the slice bounds.
493    pub fn to_range(&self, size: usize) -> Range<usize> {
494        // Always return a valid range with start <= end
495        // The step information will be handled separately
496        let start = convert_signed_index(self.start, size);
497        let end = match self.end {
498            Some(end) => convert_signed_index(end, size),
499            None => size,
500        };
501        start..end
502    }
503
504    /// Converts the slice into a range and step tuple
505    pub fn to_range_and_step(&self, size: usize) -> (Range<usize>, isize) {
506        let range = self.to_range(size);
507        (range, self.step)
508    }
509
510    /// Returns true if the step is negative
511    pub fn is_reversed(&self) -> bool {
512        self.step < 0
513    }
514
515    /// Calculates the output size for this slice operation
516    pub fn output_size(&self, dim_size: usize) -> usize {
517        let range = self.to_range(dim_size);
518        // Handle empty slices (start >= end)
519        if range.start >= range.end {
520            return 0;
521        }
522        let len = range.end - range.start;
523        if self.step.unsigned_abs() == 1 {
524            len
525        } else {
526            len.div_ceil(self.step.unsigned_abs())
527        }
528    }
529}
530
531fn convert_signed_index(index: isize, size: usize) -> usize {
532    if index < 0 {
533        (size as isize + index).max(0) as usize
534    } else {
535        (index as usize).min(size)
536    }
537}
538
539fn handle_signed_inclusive_end(end: isize) -> Option<isize> {
540    match end {
541        -1 => None,
542        end => Some(end + 1),
543    }
544}
545
546impl<I: AsIndex> From<Range<I>> for Slice {
547    fn from(r: Range<I>) -> Self {
548        Self {
549            start: r.start.index(),
550            end: Some(r.end.index()),
551            step: 1,
552        }
553    }
554}
555
556impl<I: AsIndex + Copy> From<RangeInclusive<I>> for Slice {
557    fn from(r: RangeInclusive<I>) -> Self {
558        Self {
559            start: (*r.start()).index(),
560            end: handle_signed_inclusive_end((*r.end()).index()),
561            step: 1,
562        }
563    }
564}
565
566impl<I: AsIndex> From<RangeFrom<I>> for Slice {
567    fn from(r: RangeFrom<I>) -> Self {
568        Self {
569            start: r.start.index(),
570            end: None,
571            step: 1,
572        }
573    }
574}
575
576impl<I: AsIndex> From<RangeTo<I>> for Slice {
577    fn from(r: RangeTo<I>) -> Self {
578        Self {
579            start: 0,
580            end: Some(r.end.index()),
581            step: 1,
582        }
583    }
584}
585
586impl<I: AsIndex> From<RangeToInclusive<I>> for Slice {
587    fn from(r: RangeToInclusive<I>) -> Self {
588        Self {
589            start: 0,
590            end: handle_signed_inclusive_end(r.end.index()),
591            step: 1,
592        }
593    }
594}
595
596impl From<RangeFull> for Slice {
597    fn from(_: RangeFull) -> Self {
598        Self {
599            start: 0,
600            end: None,
601            step: 1,
602        }
603    }
604}
605
606impl From<usize> for Slice {
607    fn from(i: usize) -> Self {
608        Slice::index(i as isize)
609    }
610}
611
612impl From<isize> for Slice {
613    fn from(i: isize) -> Self {
614        Slice::index(i)
615    }
616}
617
618impl From<i32> for Slice {
619    fn from(i: i32) -> Self {
620        Slice::index(i as isize)
621    }
622}
623
624impl From<i64> for Slice {
625    fn from(i: i64) -> Self {
626        Slice::index(i as isize)
627    }
628}
629
630impl Display for Slice {
631    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
632        if self.step == 1
633            && let Some(end) = self.end
634            && self.start == end - 1
635        {
636            f.write_fmt(format_args!("{}", self.start))
637        } else {
638            if self.start != 0 {
639                f.write_fmt(format_args!("{}", self.start))?;
640            }
641            f.write_str("..")?;
642            if let Some(end) = self.end {
643                f.write_fmt(format_args!("{}", end))?;
644            }
645            if self.step != 1 {
646                f.write_fmt(format_args!(";{}", self.step))?;
647            }
648            Ok(())
649        }
650    }
651}
652
653impl FromStr for Slice {
654    type Err = crate::ExpressionError;
655
656    fn from_str(source: &str) -> Result<Self, Self::Err> {
657        let mut s = source.trim();
658
659        let parse_int = |v: &str| -> Result<isize, Self::Err> {
660            v.parse::<isize>().map_err(|e| {
661                crate::ExpressionError::parse_error(
662                    format!("Invalid integer: '{v}': {}", e),
663                    source,
664                )
665            })
666        };
667
668        let mut start: isize = 0;
669        let mut end: Option<isize> = None;
670        let mut step: isize = 1;
671
672        if let Some((head, tail)) = s.split_once(";") {
673            step = parse_int(tail)?;
674            s = head;
675        }
676
677        if s.is_empty() {
678            return Err(crate::ExpressionError::parse_error(
679                "Empty expression",
680                source,
681            ));
682        }
683
684        if let Some((start_s, end_s)) = s.split_once("..") {
685            if !start_s.is_empty() {
686                start = parse_int(start_s)?;
687            }
688            if !end_s.is_empty() {
689                if let Some(end_s) = end_s.strip_prefix('=') {
690                    end = Some(parse_int(end_s)? + 1);
691                } else {
692                    end = Some(parse_int(end_s)?);
693                }
694            }
695        } else {
696            start = parse_int(s)?;
697            end = Some(start + 1);
698        }
699
700        if step == 0 {
701            return Err(crate::ExpressionError::invalid_expression(
702                "Step cannot be zero",
703                source,
704            ));
705        }
706
707        Ok(Slice::new(start, end, step))
708    }
709}
710
711#[cfg(test)]
712mod tests {
713    use super::*;
714    use alloc::string::ToString;
715    use alloc::vec;
716
717    #[test]
718    fn test_slice_to_str() {
719        assert_eq!(Slice::new(0, None, 1).to_string(), "..");
720
721        assert_eq!(Slice::new(0, Some(1), 1).to_string(), "0");
722
723        assert_eq!(Slice::new(0, Some(10), 1).to_string(), "..10");
724        assert_eq!(Slice::new(1, Some(10), 1).to_string(), "1..10");
725
726        assert_eq!(Slice::new(-3, Some(10), -2).to_string(), "-3..10;-2");
727    }
728
729    #[test]
730    fn test_slice_from_str() {
731        assert_eq!("1".parse::<Slice>(), Ok(Slice::new(1, Some(2), 1)));
732        assert_eq!("..".parse::<Slice>(), Ok(Slice::new(0, None, 1)));
733        assert_eq!("..3".parse::<Slice>(), Ok(Slice::new(0, Some(3), 1)));
734        assert_eq!("..=3".parse::<Slice>(), Ok(Slice::new(0, Some(4), 1)));
735
736        assert_eq!("-12..3".parse::<Slice>(), Ok(Slice::new(-12, Some(3), 1)));
737        assert_eq!("..;-1".parse::<Slice>(), Ok(Slice::new(0, None, -1)));
738
739        assert_eq!("..=3;-2".parse::<Slice>(), Ok(Slice::new(0, Some(4), -2)));
740
741        assert_eq!(
742            "..;0".parse::<Slice>(),
743            Err(crate::ExpressionError::invalid_expression(
744                "Step cannot be zero",
745                "..;0"
746            ))
747        );
748
749        assert_eq!(
750            "".parse::<Slice>(),
751            Err(crate::ExpressionError::parse_error("Empty expression", ""))
752        );
753        assert_eq!(
754            "a".parse::<Slice>(),
755            Err(crate::ExpressionError::parse_error(
756                "Invalid integer: 'a': invalid digit found in string",
757                "a"
758            ))
759        );
760        assert_eq!(
761            "..a".parse::<Slice>(),
762            Err(crate::ExpressionError::parse_error(
763                "Invalid integer: 'a': invalid digit found in string",
764                "..a"
765            ))
766        );
767        assert_eq!(
768            "a:b:c".parse::<Slice>(),
769            Err(crate::ExpressionError::parse_error(
770                "Invalid integer: 'a:b:c': invalid digit found in string",
771                "a:b:c"
772            ))
773        );
774    }
775
776    #[test]
777    fn test_slice_output_size() {
778        // Test the output_size method directly
779        assert_eq!(Slice::new(0, Some(10), 1).output_size(10), 10);
780        assert_eq!(Slice::new(0, Some(10), 2).output_size(10), 5);
781        assert_eq!(Slice::new(0, Some(10), 3).output_size(10), 4); // ceil(10/3)
782        assert_eq!(Slice::new(0, Some(10), -1).output_size(10), 10);
783        assert_eq!(Slice::new(0, Some(10), -2).output_size(10), 5);
784        assert_eq!(Slice::new(2, Some(8), -3).output_size(10), 2); // ceil(6/3)
785        assert_eq!(Slice::new(5, Some(5), 1).output_size(10), 0); // empty range
786    }
787
788    #[test]
789    fn test_bound_to() {
790        assert_eq!(
791            Slice::new(0, None, 1).bound_to(10),
792            Slice::new(0, Some(10), 1)
793        );
794        assert_eq!(
795            Slice::new(0, Some(5), 1).bound_to(10),
796            Slice::new(0, Some(5), 1)
797        );
798
799        assert_eq!(
800            Slice::new(0, None, -1).bound_to(10),
801            Slice::new(0, Some(-11), -1)
802        );
803        assert_eq!(
804            Slice::new(0, Some(-5), -1).bound_to(10),
805            Slice::new(0, Some(-5), -1)
806        );
807    }
808
809    #[test]
810    fn test_slice_iter() {
811        assert_eq!(
812            Slice::new(2, Some(3), 1).into_iter().collect::<Vec<_>>(),
813            vec![2]
814        );
815        assert_eq!(
816            Slice::new(3, Some(-1), -1).into_iter().collect::<Vec<_>>(),
817            vec![3, 2, 1, 0]
818        );
819
820        assert_eq!(Slice::new(3, Some(-1), -1).into_vec(), vec![3, 2, 1, 0]);
821
822        assert_eq!(
823            Slice::new(3, None, 2)
824                .into_iter()
825                .take(3)
826                .collect::<Vec<_>>(),
827            vec![3, 5, 7]
828        );
829        assert_eq!(
830            Slice::new(3, None, 2)
831                .bound_to(8)
832                .into_iter()
833                .collect::<Vec<_>>(),
834            vec![3, 5, 7]
835        );
836    }
837
838    #[test]
839    #[should_panic(
840        expected = "Slice must have an end to convert to a vector: Slice { start: 0, end: None, step: 1 }"
841    )]
842    fn test_unbound_slice_into_vec() {
843        Slice::new(0, None, 1).into_vec();
844    }
845
846    #[test]
847    fn into_slices_should_return_for_all_shape_dims() {
848        let slice = s![1];
849        let shape = Shape::new([2, 3, 1]);
850
851        let slices = slice.into_slices(&shape);
852
853        assert_eq!(slices.len(), shape.len());
854
855        assert_eq!(slices[0], Slice::new(1, Some(2), 1));
856        assert_eq!(slices[1], Slice::new(0, Some(3), 1));
857        assert_eq!(slices[2], Slice::new(0, Some(1), 1));
858
859        let slice = s![1, 0..2];
860        let slices = slice.into_slices(&shape);
861
862        assert_eq!(slices.len(), shape.len());
863
864        assert_eq!(slices[0], Slice::new(1, Some(2), 1));
865        assert_eq!(slices[1], Slice::new(0, Some(2), 1));
866        assert_eq!(slices[2], Slice::new(0, Some(1), 1));
867
868        let slice = s![..];
869        let slices = slice.into_slices(&shape);
870
871        assert_eq!(slices.len(), shape.len());
872
873        assert_eq!(slices[0], Slice::new(0, Some(2), 1));
874        assert_eq!(slices[1], Slice::new(0, Some(3), 1));
875        assert_eq!(slices[2], Slice::new(0, Some(1), 1));
876    }
877
878    #[test]
879    fn into_slices_all_dimensions() {
880        let slice = s![1, ..2, ..];
881        let shape = Shape::new([2, 3, 1]);
882
883        let slices = slice.into_slices(&shape);
884
885        assert_eq!(slices.len(), shape.len());
886
887        assert_eq!(slices[0], Slice::new(1, Some(2), 1));
888        assert_eq!(slices[1], Slice::new(0, Some(2), 1));
889        assert_eq!(slices[2], Slice::new(0, Some(1), 1));
890    }
891
892    #[test]
893    fn into_slices_supports_empty_dimensions() {
894        let slice = s![.., 1, ..];
895        let shape = Shape::new([0, 3, 1]);
896
897        let slices = slice.into_slices(&shape);
898
899        assert_eq!(slices.len(), shape.len());
900
901        assert_eq!(slices[0], Slice::new(0, Some(0), 1));
902        assert_eq!(slices[1], Slice::new(1, Some(2), 1));
903        assert_eq!(slices[2], Slice::new(0, Some(1), 1));
904    }
905
906    #[test]
907    #[should_panic = "Too many slices provided for shape"]
908    fn into_slices_should_match_shape_rank() {
909        let slice = s![.., 1, ..];
910        let shape = Shape::new([3, 1]);
911
912        let _ = slice.into_slices(&shape);
913    }
914}