burn_std/tensor/
slice.rs

1//! Tensor slice utilities.
2
3use crate::Shape;
4use crate::indexing::AsIndex;
5use alloc::vec::Vec;
6use core::fmt::{Display, Formatter};
7use core::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
8use core::str::FromStr;
9
10/// Trait for slice arguments that can be converted into an array of slices.
11/// This allows the `slice` method to accept both single slices (from `s![..]`)
12/// and arrays of slices (from `s![.., ..]` or `[0..5, 1..3]`).
13pub trait SliceArg<const D2: usize> {
14    /// Convert to an array of slices with clamping to shape dimensions
15    fn into_slices(self, shape: Shape) -> [Slice; D2];
16}
17
18impl<const D2: usize, T> SliceArg<D2> for [T; D2]
19where
20    T: Into<Slice>,
21{
22    fn into_slices(self, shape: Shape) -> [Slice; D2] {
23        self.into_iter()
24            .enumerate()
25            .map(|(i, s)| {
26                let slice: Slice = s.into();
27                // Apply shape clamping by converting to range and back
28                let clamped_range = slice.to_range(shape[i]);
29                Slice::new(
30                    clamped_range.start as isize,
31                    Some(clamped_range.end as isize),
32                    slice.step(),
33                )
34            })
35            .collect::<Vec<_>>()
36            .try_into()
37            .unwrap()
38    }
39}
40
41impl<T> SliceArg<1> for T
42where
43    T: Into<Slice>,
44{
45    fn into_slices(self, shape: Shape) -> [Slice; 1] {
46        let slice: Slice = self.into();
47        let clamped_range = slice.to_range(shape[0]);
48        [Slice::new(
49            clamped_range.start as isize,
50            Some(clamped_range.end as isize),
51            slice.step(),
52        )]
53    }
54}
55
56/// Slice argument constructor for tensor indexing.
57///
58/// The `s![]` macro is used to create multi-dimensional slice specifications for tensors.
59/// It converts various range syntax forms into a `&[Slice]` that can be used with
60/// `tensor.slice()` and `tensor.slice_assign()` operations.
61///
62/// # Syntax Overview
63///
64/// ## Basic Forms
65///
66/// * **`s![index]`** - Index a single element (produces a subview with that axis removed)
67/// * **`s![range]`** - Slice a range of elements
68/// * **`s![range;step]`** - Slice a range with a custom step
69/// * **`s![dim1, dim2, ...]`** - Multiple dimensions, each can be any of the above forms
70///
71/// ## Range Types
72///
73/// All standard Rust range types are supported:
74/// * **`a..b`** - From `a` (inclusive) to `b` (exclusive)
75/// * **`a..=b`** - From `a` to `b` (both inclusive)
76/// * **`a..`** - From `a` to the end
77/// * **`..b`** - From the beginning to `b` (exclusive)
78/// * **`..=b`** - From the beginning to `b` (inclusive)
79/// * **`..`** - The full range (all elements)
80///
81/// ## Negative Indices
82///
83/// Negative indices count from the end of the axis:
84/// * **`-1`** refers to the last element
85/// * **`-2`** refers to the second-to-last element
86/// * And so on...
87///
88/// This works in all range forms: `s![-3..-1]`, `s![-2..]`, `s![..-1]`
89///
90/// ## Step Syntax
91///
92/// Steps control the stride between selected elements:
93/// * **`;step`** after a range specifies the step
94/// * **Positive steps** select every nth element going forward
95/// * **Negative steps** select every nth element going backward
96/// * Default step is `1` when not specified
97/// * Step cannot be `0`
98///
99/// ### Negative Step Behavior
100///
101/// With negative steps, the range bounds still specify *which* elements to include,
102/// but the traversal order is reversed:
103///
104/// * `s![0..5;-1]` selects indices `[4, 3, 2, 1, 0]` (not `[0, 1, 2, 3, 4]`)
105/// * `s![2..8;-2]` selects indices `[7, 5, 3]` (starting from 7, going backward by 2)
106/// * `s![..;-1]` reverses the entire axis
107///
108/// This matches the semantics of NumPy and the ndarray crate.
109///
110/// # Examples
111///
112/// ## Basic Slicing
113///
114/// ```rust,ignore
115/// use burn_tensor::{Tensor, s};
116///
117/// # fn example<B: Backend>(tensor: Tensor<B, 3>) {
118/// // Select rows 0-5 (exclusive)
119/// let subset = tensor.slice(s![0..5, .., ..]);
120///
121/// // Select the last row
122/// let last_row = tensor.slice(s![-1, .., ..]);
123///
124/// // Select columns 2, 3, 4
125/// let cols = tensor.slice(s![.., 2..5, ..]);
126///
127/// // Select a single element at position [1, 2, 3]
128/// let element = tensor.slice(s![1, 2, 3]);
129/// # }
130/// ```
131///
132/// ## Slicing with Steps
133///
134/// ```rust,ignore
135/// use burn_tensor::{Tensor, s};
136///
137/// # fn example<B: Backend>(tensor: Tensor<B, 2>) {
138/// // Select every 2nd row
139/// let even_rows = tensor.slice(s![0..10;2, ..]);
140///
141/// // Select every 3rd column
142/// let cols = tensor.slice(s![.., 0..9;3]);
143///
144/// // Select every 2nd element in reverse order
145/// let reversed_even = tensor.slice(s![10..0;-2, ..]);
146/// # }
147/// ```
148///
149/// ## Reversing Dimensions
150///
151/// ```rust,ignore
152/// use burn_tensor::{Tensor, s};
153///
154/// # fn example<B: Backend>(tensor: Tensor<B, 2>) {
155/// // Reverse the first dimension
156/// let reversed = tensor.slice(s![..;-1, ..]);
157///
158/// // Reverse both dimensions
159/// let fully_reversed = tensor.slice(s![..;-1, ..;-1]);
160///
161/// // Reverse a specific range
162/// let range_reversed = tensor.slice(s![2..8;-1, ..]);
163/// # }
164/// ```
165///
166/// ## Complex Multi-dimensional Slicing
167///
168/// ```rust,ignore
169/// use burn_tensor::{Tensor, s};
170///
171/// # fn example<B: Backend>(tensor: Tensor<B, 4>) {
172/// // Mix of different slice types
173/// let complex = tensor.slice(s![
174///     0..10;2,    // Every 2nd element from 0 to 10
175///     ..,         // All elements in dimension 1
176///     5..15;-3,   // Every 3rd element from 14 down to 5
177///     -1          // Last element in dimension 3
178/// ]);
179///
180/// // Using inclusive ranges
181/// let inclusive = tensor.slice(s![2..=5, 1..=3, .., ..]);
182///
183/// // Negative indices with steps
184/// let from_end = tensor.slice(s![-5..-1;2, .., .., ..]);
185/// # }
186/// ```
187///
188/// ## Slice Assignment
189///
190/// ```rust,ignore
191/// use burn_tensor::{Tensor, s};
192///
193/// # fn example<B: Backend>(tensor: Tensor<B, 2>, values: Tensor<B, 2>) {
194/// // Assign to every 2nd row
195/// let tensor = tensor.slice_assign(s![0..10;2, ..], values);
196///
197/// // Assign to a reversed slice
198/// let tensor = tensor.slice_assign(s![..;-1, 0..5], values);
199/// # }
200/// ```
201#[macro_export]
202macro_rules! s {
203    // Empty - should not happen
204    [] => {
205        compile_error!("Empty slice specification")
206    };
207
208    // Single expression with step
209    [$range:expr; $step:expr] => {
210        {
211            #[allow(clippy::reversed_empty_ranges)]
212            {
213                $crate::tensor::Slice::from_range_stepped($range, $step)
214            }
215        }
216    };
217
218    // Single expression without step (no comma after)
219    [$range:expr] => {
220        {
221            #[allow(clippy::reversed_empty_ranges)]
222            {
223                $crate::tensor::Slice::from($range)
224            }
225        }
226    };
227
228    // Two or more expressions with first having step
229    [$range:expr; $step:expr, $($rest:tt)*] => {
230        {
231            #[allow(clippy::reversed_empty_ranges)]
232            {
233                $crate::s!(@internal [$crate::tensor::Slice::from_range_stepped($range, $step)] $($rest)*)
234            }
235        }
236    };
237
238    // Two or more expressions with first not having step
239    [$range:expr, $($rest:tt)*] => {
240        {
241            #[allow(clippy::reversed_empty_ranges)]
242            {
243                $crate::s!(@internal [$crate::tensor::Slice::from($range)] $($rest)*)
244            }
245        }
246    };
247
248    // Internal: finished parsing
249    (@internal [$($acc:expr),*]) => {
250        [$($acc),*]
251    };
252
253    // Internal: parse range with step followed by comma
254    (@internal [$($acc:expr),*] $range:expr; $step:expr, $($rest:tt)*) => {
255        $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)] $($rest)*)
256    };
257
258    // Internal: parse range with step at end
259    (@internal [$($acc:expr),*] $range:expr; $step:expr) => {
260        $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from_range_stepped($range, $step as isize)])
261    };
262
263    // Internal: parse range without step followed by comma
264    (@internal [$($acc:expr),*] $range:expr, $($rest:tt)*) => {
265        $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)] $($rest)*)
266    };
267
268    // Internal: parse range without step at end
269    (@internal [$($acc:expr),*] $range:expr) => {
270        $crate::s!(@internal [$($acc,)* $crate::tensor::Slice::from($range)])
271    };
272}
273
274/// A slice specification for a single tensor dimension.
275///
276/// This struct represents a range with an optional step, used for advanced indexing
277/// operations on tensors. It is typically created using the [`s!`] macro rather than
278/// constructed directly.
279///
280/// # Fields
281///
282/// * `start` - The starting index (inclusive). Negative values count from the end.
283/// * `end` - The ending index (exclusive). `None` means to the end of the dimension.
284/// * `step` - The stride between elements. Must be non-zero.
285///
286/// # Index Interpretation
287///
288/// - **Positive indices**: Count from the beginning (0-based)
289/// - **Negative indices**: Count from the end (-1 is the last element)
290/// - **Bounds checking**: Indices are clamped to valid ranges
291///
292/// # Step Behavior
293///
294/// - **Positive step**: Traverse forward through the range
295/// - **Negative step**: Traverse backward through the range
296/// - **Step size**: Determines how many elements to skip
297///
298/// # Examples
299///
300/// While you typically use the [`s!`] macro, you can also construct slices directly:
301///
302/// ```rust,ignore
303/// use burn_tensor::Slice;
304///
305/// // Equivalent to s![2..8]
306/// let slice1 = Slice::new(2, Some(8), 1);
307///
308/// // Equivalent to s![0..10;2]
309/// let slice2 = Slice::new(0, Some(10), 2);
310///
311/// // Equivalent to s![..;-1] (reverse)
312/// let slice3 = Slice::new(0, None, -1);
313/// ```
314///
315/// See also the [`s!`] macro for the preferred way to create slices.
316#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
317pub struct Slice {
318    /// Slice start index.
319    pub start: isize,
320    /// Slice end index (exclusive).
321    pub end: Option<isize>,
322    /// Step between elements (default: 1).
323    pub step: isize,
324}
325
326impl Default for Slice {
327    fn default() -> Self {
328        Self::full()
329    }
330}
331
332impl Slice {
333    /// Creates a new slice with start, end, and step
334    pub const fn new(start: isize, end: Option<isize>, step: isize) -> Self {
335        assert!(step != 0, "Step cannot be zero");
336        Self { start, end, step }
337    }
338
339    /// Creates a slice that represents the full range.
340    pub const fn full() -> Self {
341        Self::new(0, None, 1)
342    }
343
344    /// Creates a slice that represents a single index
345    pub fn index(idx: isize) -> Self {
346        Self {
347            start: idx,
348            end: handle_signed_inclusive_end(idx),
349            step: 1,
350        }
351    }
352
353    /// Creates a slice with a custom step
354    pub fn with_step(start: isize, end: Option<isize>, step: isize) -> Self {
355        assert!(step != 0, "Step cannot be zero");
356        Self { start, end, step }
357    }
358
359    /// Creates a slice from a range with a specified step
360    pub fn from_range_stepped<R: Into<Slice>>(range: R, step: isize) -> Self {
361        assert!(step != 0, "Step cannot be zero");
362        let mut slice = range.into();
363        slice.step = step;
364        slice
365    }
366
367    /// Returns the step of the slice
368    pub fn step(&self) -> isize {
369        self.step
370    }
371
372    /// Returns the range for this slice given a dimension size
373    pub fn range(&self, size: usize) -> Range<usize> {
374        self.to_range(size)
375    }
376
377    /// Convert this slice to a range for a dimension of the given size.
378    ///
379    /// # Arguments
380    ///
381    /// * `size` - The size of the dimension to slice.
382    ///
383    /// # Returns
384    ///
385    /// A `Range<usize>` representing the slice bounds.
386    pub fn to_range(&self, size: usize) -> Range<usize> {
387        // Always return a valid range with start <= end
388        // The step information will be handled separately
389        let start = convert_signed_index(self.start, size);
390        let end = match self.end {
391            Some(end) => convert_signed_index(end, size),
392            None => size,
393        };
394        start..end
395    }
396
397    /// Converts the slice into a range and step tuple
398    pub fn to_range_and_step(&self, size: usize) -> (Range<usize>, isize) {
399        let range = self.to_range(size);
400        (range, self.step)
401    }
402
403    /// Returns true if the step is negative
404    pub fn is_reversed(&self) -> bool {
405        self.step < 0
406    }
407
408    /// Calculates the output size for this slice operation
409    pub fn output_size(&self, dim_size: usize) -> usize {
410        let range = self.to_range(dim_size);
411        // Handle empty slices (start >= end)
412        if range.start >= range.end {
413            return 0;
414        }
415        let len = range.end - range.start;
416        if self.step.unsigned_abs() == 1 {
417            len
418        } else {
419            len.div_ceil(self.step.unsigned_abs())
420        }
421    }
422}
423
424fn convert_signed_index(index: isize, size: usize) -> usize {
425    if index < 0 {
426        (size as isize + index).max(0) as usize
427    } else {
428        (index as usize).min(size)
429    }
430}
431
432fn handle_signed_inclusive_end(end: isize) -> Option<isize> {
433    match end {
434        -1 => None,
435        end => Some(end + 1),
436    }
437}
438
439impl<I: AsIndex> From<Range<I>> for Slice {
440    fn from(r: Range<I>) -> Self {
441        Self {
442            start: r.start.index(),
443            end: Some(r.end.index()),
444            step: 1,
445        }
446    }
447}
448
449impl<I: AsIndex + Copy> From<RangeInclusive<I>> for Slice {
450    fn from(r: RangeInclusive<I>) -> Self {
451        Self {
452            start: (*r.start()).index(),
453            end: handle_signed_inclusive_end((*r.end()).index()),
454            step: 1,
455        }
456    }
457}
458
459impl<I: AsIndex> From<RangeFrom<I>> for Slice {
460    fn from(r: RangeFrom<I>) -> Self {
461        Self {
462            start: r.start.index(),
463            end: None,
464            step: 1,
465        }
466    }
467}
468
469impl<I: AsIndex> From<RangeTo<I>> for Slice {
470    fn from(r: RangeTo<I>) -> Self {
471        Self {
472            start: 0,
473            end: Some(r.end.index()),
474            step: 1,
475        }
476    }
477}
478
479impl<I: AsIndex> From<RangeToInclusive<I>> for Slice {
480    fn from(r: RangeToInclusive<I>) -> Self {
481        Self {
482            start: 0,
483            end: handle_signed_inclusive_end(r.end.index()),
484            step: 1,
485        }
486    }
487}
488
489impl From<RangeFull> for Slice {
490    fn from(_: RangeFull) -> Self {
491        Self {
492            start: 0,
493            end: None,
494            step: 1,
495        }
496    }
497}
498
499impl From<usize> for Slice {
500    fn from(i: usize) -> Self {
501        Slice::index(i as isize)
502    }
503}
504
505impl From<isize> for Slice {
506    fn from(i: isize) -> Self {
507        Slice::index(i)
508    }
509}
510
511impl From<i32> for Slice {
512    fn from(i: i32) -> Self {
513        Slice::index(i as isize)
514    }
515}
516
517impl Display for Slice {
518    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
519        if self.step == 1
520            && let Some(end) = self.end
521            && self.start == end - 1
522        {
523            f.write_fmt(format_args!("{}", self.start))
524        } else {
525            if self.start != 0 {
526                f.write_fmt(format_args!("{}", self.start))?;
527            }
528            f.write_str("..")?;
529            if let Some(end) = self.end {
530                f.write_fmt(format_args!("{}", end))?;
531            }
532            if self.step != 1 {
533                f.write_fmt(format_args!(";{}", self.step))?;
534            }
535            Ok(())
536        }
537    }
538}
539
540impl FromStr for Slice {
541    type Err = SliceExpressionError;
542
543    fn from_str(source: &str) -> Result<Self, Self::Err> {
544        let mut s = source.trim();
545
546        let parse_int = |v: &str| -> Result<isize, Self::Err> {
547            v.parse::<isize>().map_err(|e| {
548                SliceExpressionError::parse_error(format!("Invalid integer: '{v}': {}", e), source)
549            })
550        };
551
552        let mut start: isize = 0;
553        let mut end: Option<isize> = None;
554        let mut step: isize = 1;
555
556        if let Some((head, tail)) = s.split_once(";") {
557            step = parse_int(tail)?;
558            s = head;
559        }
560
561        if s.is_empty() {
562            return Err(SliceExpressionError::parse_error(
563                "Empty expression",
564                source,
565            ));
566        }
567
568        if let Some((start_s, end_s)) = s.split_once("..") {
569            if !start_s.is_empty() {
570                start = parse_int(start_s)?;
571            }
572            if !end_s.is_empty() {
573                if let Some(end_s) = end_s.strip_prefix('=') {
574                    end = Some(parse_int(end_s)? + 1);
575                } else {
576                    end = Some(parse_int(end_s)?);
577                }
578            }
579        } else {
580            start = parse_int(s)?;
581            end = Some(start + 1);
582        }
583
584        if step == 0 {
585            return Err(SliceExpressionError::invalid_expression(
586                "Step cannot be zero",
587                source,
588            ));
589        }
590
591        Ok(Slice::new(start, end, step))
592    }
593}
594
595#[allow(unused_imports)]
596use alloc::format;
597use alloc::string::String;
598
599/// Common Parse Error.
600#[derive(Debug, Clone, PartialEq, Eq)]
601pub enum SliceExpressionError {
602    /// Parse Error.
603    ParseError {
604        /// The error message.
605        message: String,
606        /// The source expression.
607        source: String,
608    },
609
610    /// Invalid Expression.
611    InvalidExpression {
612        /// The error message.
613        message: String,
614        /// The source expression.
615        source: String,
616    },
617}
618
619impl SliceExpressionError {
620    /// Constructs a new `ParseError`.
621    ///
622    /// This function is a utility for creating instances where a parsing error needs to be represented,
623    /// encapsulating a descriptive error message and the source of the error.
624    ///
625    /// # Parameters
626    ///
627    /// - `message`: A value that can be converted into a `String`, representing a human-readable description
628    ///   of the parsing error.
629    /// - `source`: A value that can be converted into a `String`, typically identifying the origin or
630    ///   input that caused the parsing error.
631    pub fn parse_error(message: impl Into<String>, source: impl Into<String>) -> Self {
632        Self::ParseError {
633            message: message.into(),
634            source: source.into(),
635        }
636    }
637
638    /// Creates a new `InvalidExpression`.
639    ///
640    /// # Parameters
641    /// - `message`: A detailed message describing the nature of the invalid expression.
642    ///   Accepts any type that can be converted into a `String`.
643    /// - `source`: The source or context in which the invalid expression occurred.
644    ///   Accepts any type that can be converted into a `String`.
645    pub fn invalid_expression(message: impl Into<String>, source: impl Into<String>) -> Self {
646        Self::InvalidExpression {
647            message: message.into(),
648            source: source.into(),
649        }
650    }
651}
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656    use crate::alloc::string::ToString;
657
658    #[test]
659    fn test_parse_error() {
660        let err = SliceExpressionError::parse_error("test", "source");
661        assert_eq!(
662            format!("{:?}", err),
663            "ParseError { message: \"test\", source: \"source\" }"
664        );
665    }
666
667    #[test]
668    fn test_invalid_expression() {
669        let err = SliceExpressionError::invalid_expression("test", "source");
670        assert_eq!(
671            format!("{:?}", err),
672            "InvalidExpression { message: \"test\", source: \"source\" }"
673        );
674    }
675
676    #[test]
677    fn test_slice_to_str() {
678        assert_eq!(Slice::new(0, None, 1).to_string(), "..");
679
680        assert_eq!(Slice::new(0, Some(1), 1).to_string(), "0");
681
682        assert_eq!(Slice::new(0, Some(10), 1).to_string(), "..10");
683        assert_eq!(Slice::new(1, Some(10), 1).to_string(), "1..10");
684
685        assert_eq!(Slice::new(-3, Some(10), -2).to_string(), "-3..10;-2");
686    }
687
688    #[test]
689    fn test_slice_from_str() {
690        assert_eq!("1".parse::<Slice>(), Ok(Slice::new(1, Some(2), 1)));
691        assert_eq!("..".parse::<Slice>(), Ok(Slice::new(0, None, 1)));
692        assert_eq!("..3".parse::<Slice>(), Ok(Slice::new(0, Some(3), 1)));
693        assert_eq!("..=3".parse::<Slice>(), Ok(Slice::new(0, Some(4), 1)));
694
695        assert_eq!("-12..3".parse::<Slice>(), Ok(Slice::new(-12, Some(3), 1)));
696        assert_eq!("..;-1".parse::<Slice>(), Ok(Slice::new(0, None, -1)));
697
698        assert_eq!("..=3;-2".parse::<Slice>(), Ok(Slice::new(0, Some(4), -2)));
699
700        assert_eq!(
701            "..;0".parse::<Slice>(),
702            Err(SliceExpressionError::invalid_expression(
703                "Step cannot be zero",
704                "..;0"
705            ))
706        );
707
708        assert_eq!(
709            "".parse::<Slice>(),
710            Err(SliceExpressionError::parse_error("Empty expression", ""))
711        );
712        assert_eq!(
713            "a".parse::<Slice>(),
714            Err(SliceExpressionError::parse_error(
715                "Invalid integer: 'a': invalid digit found in string",
716                "a"
717            ))
718        );
719        assert_eq!(
720            "..a".parse::<Slice>(),
721            Err(SliceExpressionError::parse_error(
722                "Invalid integer: 'a': invalid digit found in string",
723                "..a"
724            ))
725        );
726        assert_eq!(
727            "a:b:c".parse::<Slice>(),
728            Err(SliceExpressionError::parse_error(
729                "Invalid integer: 'a:b:c': invalid digit found in string",
730                "a:b:c"
731            ))
732        );
733    }
734
735    #[test]
736    fn test_slice_output_size() {
737        // Test the output_size method directly
738        assert_eq!(Slice::new(0, Some(10), 1).output_size(10), 10);
739        assert_eq!(Slice::new(0, Some(10), 2).output_size(10), 5);
740        assert_eq!(Slice::new(0, Some(10), 3).output_size(10), 4); // ceil(10/3)
741        assert_eq!(Slice::new(0, Some(10), -1).output_size(10), 10);
742        assert_eq!(Slice::new(0, Some(10), -2).output_size(10), 5);
743        assert_eq!(Slice::new(2, Some(8), -3).output_size(10), 2); // ceil(6/3)
744        assert_eq!(Slice::new(5, Some(5), 1).output_size(10), 0); // empty range
745    }
746}