burn_tensor/tensor/api/
slice.rs

1use core::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
2
3/// Creates a slice specification for tensor indexing operations.
4///
5/// This macro simplifies the creation of tensor slices by allowing various range types
6/// to be used together in a concise way. It supports all standard Rust range types
7/// as well as negative indexing for accessing elements from the end of a dimension.
8///
9/// # Examples
10///
11/// ```rust,ignore
12/// // Basic slicing
13/// let slice = tensor.slice(s![0..5, .., 3]);
14///
15/// // Using negative indices (counting from the end)
16/// let last_row = tensor.slice(s![-1, ..]);
17///
18/// // Mixed range types
19/// let complex_slice = tensor.slice(s![2..5, .., 0..=3, -2..]);
20/// ```
21#[macro_export]
22macro_rules! s {
23    [$range:expr] => {
24        $crate::Slice::from($range)
25    };
26
27    [$($range:expr),+] => {
28        [$($crate::Slice::from($range)),+]
29    };
30}
31
32/// A slice (range).
33///
34/// - `end` is an exclusive index.
35/// - Negative `start` or `end` indices are counted from the back of the axis.
36/// - If `end` is `None`, the slice extends to the end of the axis.
37///
38/// See also the [`s![]`](s!) macro.
39#[derive(new, Clone, Debug)]
40pub struct Slice {
41    /// Slice start index.
42    start: isize,
43    /// Slice end index (exclusive).
44    end: Option<isize>,
45}
46
47impl Slice {
48    /// Creates a slice that represents a single index
49    pub fn index(idx: isize) -> Self {
50        Self {
51            start: idx,
52            end: handle_signed_inclusive_end(idx),
53        }
54    }
55
56    pub(crate) fn into_range(self, size: usize) -> Range<usize> {
57        let start = convert_signed_index(self.start, size);
58
59        let end = match self.end {
60            Some(end) => convert_signed_index(end, size),
61            None => size,
62        };
63
64        start..end
65    }
66}
67
68fn convert_signed_index(index: isize, size: usize) -> usize {
69    if index < 0 {
70        (size as isize + index).max(0) as usize
71    } else {
72        (index as usize).min(size)
73    }
74}
75
76fn handle_signed_inclusive_end(end: isize) -> Option<isize> {
77    match end {
78        -1 => None,
79        end => Some(end + 1),
80    }
81}
82
83/// A helper trait to convert difference indices type to a slice index.
84pub trait IndexConversion {
85    /// Converts into a slice index.
86    fn index(self) -> isize;
87}
88
89impl IndexConversion for usize {
90    fn index(self) -> isize {
91        self as isize
92    }
93}
94
95impl IndexConversion for isize {
96    fn index(self) -> isize {
97        self
98    }
99}
100
101// Default integer type
102impl IndexConversion for i32 {
103    fn index(self) -> isize {
104        self as isize
105    }
106}
107
108impl<I: IndexConversion> From<Range<I>> for Slice {
109    fn from(r: Range<I>) -> Self {
110        Self {
111            start: r.start.index(),
112            end: Some(r.end.index()),
113        }
114    }
115}
116
117impl<I: IndexConversion + Copy> From<RangeInclusive<I>> for Slice {
118    fn from(r: RangeInclusive<I>) -> Self {
119        Self {
120            start: (*r.start()).index(),
121            end: handle_signed_inclusive_end((*r.end()).index()),
122        }
123    }
124}
125
126impl<I: IndexConversion> From<RangeFrom<I>> for Slice {
127    fn from(r: RangeFrom<I>) -> Self {
128        Self {
129            start: r.start.index(),
130            end: None,
131        }
132    }
133}
134
135impl<I: IndexConversion> From<RangeTo<I>> for Slice {
136    fn from(r: RangeTo<I>) -> Self {
137        Self {
138            start: 0,
139            end: Some(r.end.index()),
140        }
141    }
142}
143
144impl<I: IndexConversion> From<RangeToInclusive<I>> for Slice {
145    fn from(r: RangeToInclusive<I>) -> Self {
146        Self {
147            start: 0,
148            end: handle_signed_inclusive_end(r.end.index()),
149        }
150    }
151}
152
153impl From<RangeFull> for Slice {
154    fn from(_: RangeFull) -> Self {
155        Self {
156            start: 0,
157            end: None,
158        }
159    }
160}
161
162impl From<usize> for Slice {
163    fn from(i: usize) -> Self {
164        Slice::index(i as isize)
165    }
166}
167
168impl From<isize> for Slice {
169    fn from(i: isize) -> Self {
170        Slice::index(i)
171    }
172}
173
174impl From<i32> for Slice {
175    fn from(i: i32) -> Self {
176        Slice::index(i as isize)
177    }
178}