burn_tensor/tensor/api/slice.rs
1use alloc::vec::Vec;
2
3use crate::Shape;
4use crate::indexing::AsIndex;
5use core::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
6
7/// Trait for slice arguments that can be converted into an array of slices.
8/// This allows the `slice` method to accept both single slices (from `s![..]`)
9/// and arrays of slices (from `s![.., ..]` or `[0..5, 1..3]`).
10pub trait SliceArg<const D2: usize> {
11 /// Convert to an array of slices with clamping to shape dimensions
12 fn into_slices(self, shape: Shape) -> [Slice; D2];
13}
14
15impl<const D2: usize, T> SliceArg<D2> for [T; D2]
16where
17 T: Into<Slice>,
18{
19 fn into_slices(self, shape: Shape) -> [Slice; D2] {
20 self.into_iter()
21 .enumerate()
22 .map(|(i, s)| {
23 let slice: Slice = s.into();
24 // Apply shape clamping by converting to range and back
25 let clamped_range = slice.to_range(shape[i]);
26 Slice::new(
27 clamped_range.start as isize,
28 Some(clamped_range.end as isize),
29 slice.step(),
30 )
31 })
32 .collect::<Vec<_>>()
33 .try_into()
34 .unwrap()
35 }
36}
37
38impl<T> SliceArg<1> for T
39where
40 T: Into<Slice>,
41{
42 fn into_slices(self, shape: Shape) -> [Slice; 1] {
43 let slice: Slice = self.into();
44 let clamped_range = slice.to_range(shape[0]);
45 [Slice::new(
46 clamped_range.start as isize,
47 Some(clamped_range.end as isize),
48 slice.step(),
49 )]
50 }
51}
52
53/// Slice argument constructor for tensor indexing.
54///
55/// The `s![]` macro is used to create multi-dimensional slice specifications for tensors.
56/// It converts various range syntax forms into a `&[Slice]` that can be used with
57/// `tensor.slice()` and `tensor.slice_assign()` operations.
58///
59/// # Syntax Overview
60///
61/// ## Basic Forms
62///
63/// * **`s![index]`** - Index a single element (produces a subview with that axis removed)
64/// * **`s![range]`** - Slice a range of elements
65/// * **`s![range;step]`** - Slice a range with a custom step
66/// * **`s![dim1, dim2, ...]`** - Multiple dimensions, each can be any of the above forms
67///
68/// ## Range Types
69///
70/// All standard Rust range types are supported:
71/// * **`a..b`** - From `a` (inclusive) to `b` (exclusive)
72/// * **`a..=b`** - From `a` to `b` (both inclusive)
73/// * **`a..`** - From `a` to the end
74/// * **`..b`** - From the beginning to `b` (exclusive)
75/// * **`..=b`** - From the beginning to `b` (inclusive)
76/// * **`..`** - The full range (all elements)
77///
78/// ## Negative Indices
79///
80/// Negative indices count from the end of the axis:
81/// * **`-1`** refers to the last element
82/// * **`-2`** refers to the second-to-last element
83/// * And so on...
84///
85/// This works in all range forms: `s![-3..-1]`, `s![-2..]`, `s![..-1]`
86///
87/// ## Step Syntax
88///
89/// Steps control the stride between selected elements:
90/// * **`;step`** after a range specifies the step
91/// * **Positive steps** select every nth element going forward
92/// * **Negative steps** select every nth element going backward
93/// * Default step is `1` when not specified
94/// * Step cannot be `0`
95///
96/// ### Negative Step Behavior
97///
98/// With negative steps, the range bounds still specify *which* elements to include,
99/// but the traversal order is reversed:
100///
101/// * `s![0..5;-1]` selects indices `[4, 3, 2, 1, 0]` (not `[0, 1, 2, 3, 4]`)
102/// * `s![2..8;-2]` selects indices `[7, 5, 3]` (starting from 7, going backward by 2)
103/// * `s![..;-1]` reverses the entire axis
104///
105/// This matches the semantics of NumPy and the ndarray crate.
106///
107/// # Examples
108///
109/// ## Basic Slicing
110///
111/// ```rust,ignore
112/// use burn_tensor::{Tensor, s};
113///
114/// # fn example<B: Backend>(tensor: Tensor<B, 3>) {
115/// // Select rows 0-5 (exclusive)
116/// let subset = tensor.slice(s![0..5, .., ..]);
117///
118/// // Select the last row
119/// let last_row = tensor.slice(s![-1, .., ..]);
120///
121/// // Select columns 2, 3, 4
122/// let cols = tensor.slice(s![.., 2..5, ..]);
123///
124/// // Select a single element at position [1, 2, 3]
125/// let element = tensor.slice(s![1, 2, 3]);
126/// # }
127/// ```
128///
129/// ## Slicing with Steps
130///
131/// ```rust,ignore
132/// use burn_tensor::{Tensor, s};
133///
134/// # fn example<B: Backend>(tensor: Tensor<B, 2>) {
135/// // Select every 2nd row
136/// let even_rows = tensor.slice(s![0..10;2, ..]);
137///
138/// // Select every 3rd column
139/// let cols = tensor.slice(s![.., 0..9;3]);
140///
141/// // Select every 2nd element in reverse order
142/// let reversed_even = tensor.slice(s![10..0;-2, ..]);
143/// # }
144/// ```
145///
146/// ## Reversing Dimensions
147///
148/// ```rust,ignore
149/// use burn_tensor::{Tensor, s};
150///
151/// # fn example<B: Backend>(tensor: Tensor<B, 2>) {
152/// // Reverse the first dimension
153/// let reversed = tensor.slice(s![..;-1, ..]);
154///
155/// // Reverse both dimensions
156/// let fully_reversed = tensor.slice(s![..;-1, ..;-1]);
157///
158/// // Reverse a specific range
159/// let range_reversed = tensor.slice(s![2..8;-1, ..]);
160/// # }
161/// ```
162///
163/// ## Complex Multi-dimensional Slicing
164///
165/// ```rust,ignore
166/// use burn_tensor::{Tensor, s};
167///
168/// # fn example<B: Backend>(tensor: Tensor<B, 4>) {
169/// // Mix of different slice types
170/// let complex = tensor.slice(s![
171/// 0..10;2, // Every 2nd element from 0 to 10
172/// .., // All elements in dimension 1
173/// 5..15;-3, // Every 3rd element from 14 down to 5
174/// -1 // Last element in dimension 3
175/// ]);
176///
177/// // Using inclusive ranges
178/// let inclusive = tensor.slice(s![2..=5, 1..=3, .., ..]);
179///
180/// // Negative indices with steps
181/// let from_end = tensor.slice(s![-5..-1;2, .., .., ..]);
182/// # }
183/// ```
184///
185/// ## Slice Assignment
186///
187/// ```rust,ignore
188/// use burn_tensor::{Tensor, s};
189///
190/// # fn example<B: Backend>(tensor: Tensor<B, 2>, values: Tensor<B, 2>) {
191/// // Assign to every 2nd row
192/// let tensor = tensor.slice_assign(s![0..10;2, ..], values);
193///
194/// // Assign to a reversed slice
195/// let tensor = tensor.slice_assign(s![..;-1, 0..5], values);
196/// # }
197/// ```
198#[macro_export]
199macro_rules! s {
200 // Empty - should not happen
201 [] => {
202 compile_error!("Empty slice specification")
203 };
204
205 // Single expression with step
206 [$range:expr; $step:expr] => {
207 {
208 #[allow(clippy::reversed_empty_ranges)]
209 {
210 $crate::Slice::from_range_stepped($range, $step)
211 }
212 }
213 };
214
215 // Single expression without step (no comma after)
216 [$range:expr] => {
217 {
218 #[allow(clippy::reversed_empty_ranges)]
219 {
220 $crate::Slice::from($range)
221 }
222 }
223 };
224
225 // Two or more expressions with first having step
226 [$range:expr; $step:expr, $($rest:tt)*] => {
227 {
228 #[allow(clippy::reversed_empty_ranges)]
229 {
230 $crate::s!(@internal [$crate::Slice::from_range_stepped($range, $step)] $($rest)*)
231 }
232 }
233 };
234
235 // Two or more expressions with first not having step
236 [$range:expr, $($rest:tt)*] => {
237 {
238 #[allow(clippy::reversed_empty_ranges)]
239 {
240 $crate::s!(@internal [$crate::Slice::from($range)] $($rest)*)
241 }
242 }
243 };
244
245 // Internal: finished parsing
246 (@internal [$($acc:expr),*]) => {
247 [$($acc),*]
248 };
249
250 // Internal: parse range with step followed by comma
251 (@internal [$($acc:expr),*] $range:expr; $step:expr, $($rest:tt)*) => {
252 $crate::s!(@internal [$($acc,)* $crate::Slice::from_range_stepped($range, $step as isize)] $($rest)*)
253 };
254
255 // Internal: parse range with step at end
256 (@internal [$($acc:expr),*] $range:expr; $step:expr) => {
257 $crate::s!(@internal [$($acc,)* $crate::Slice::from_range_stepped($range, $step as isize)])
258 };
259
260 // Internal: parse range without step followed by comma
261 (@internal [$($acc:expr),*] $range:expr, $($rest:tt)*) => {
262 $crate::s!(@internal [$($acc,)* $crate::Slice::from($range)] $($rest)*)
263 };
264
265 // Internal: parse range without step at end
266 (@internal [$($acc:expr),*] $range:expr) => {
267 $crate::s!(@internal [$($acc,)* $crate::Slice::from($range)])
268 };
269}
270
271/// A slice specification for a single tensor dimension.
272///
273/// This struct represents a range with an optional step, used for advanced indexing
274/// operations on tensors. It is typically created using the [`s!`] macro rather than
275/// constructed directly.
276///
277/// # Fields
278///
279/// * `start` - The starting index (inclusive). Negative values count from the end.
280/// * `end` - The ending index (exclusive). `None` means to the end of the dimension.
281/// * `step` - The stride between elements. Must be non-zero.
282///
283/// # Index Interpretation
284///
285/// - **Positive indices**: Count from the beginning (0-based)
286/// - **Negative indices**: Count from the end (-1 is the last element)
287/// - **Bounds checking**: Indices are clamped to valid ranges
288///
289/// # Step Behavior
290///
291/// - **Positive step**: Traverse forward through the range
292/// - **Negative step**: Traverse backward through the range
293/// - **Step size**: Determines how many elements to skip
294///
295/// # Examples
296///
297/// While you typically use the [`s!`] macro, you can also construct slices directly:
298///
299/// ```rust,ignore
300/// use burn_tensor::Slice;
301///
302/// // Equivalent to s![2..8]
303/// let slice1 = Slice::new(2, Some(8), 1);
304///
305/// // Equivalent to s![0..10;2]
306/// let slice2 = Slice::new(0, Some(10), 2);
307///
308/// // Equivalent to s![..;-1] (reverse)
309/// let slice3 = Slice::new(0, None, -1);
310/// ```
311///
312/// See also the [`s!`] macro for the preferred way to create slices.
313#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
314pub struct Slice {
315 /// Slice start index.
316 pub start: isize,
317 /// Slice end index (exclusive).
318 pub end: Option<isize>,
319 /// Step between elements (default: 1).
320 pub step: isize,
321}
322
323impl Default for Slice {
324 fn default() -> Self {
325 Self::full()
326 }
327}
328
329impl Slice {
330 /// Creates a new slice with start, end, and step
331 pub const fn new(start: isize, end: Option<isize>, step: isize) -> Self {
332 assert!(step != 0, "Step cannot be zero");
333 Self { start, end, step }
334 }
335
336 /// Creates a slice that represents the full range.
337 pub const fn full() -> Self {
338 Self::new(0, None, 1)
339 }
340
341 /// Creates a slice that represents a single index
342 pub fn index(idx: isize) -> Self {
343 Self {
344 start: idx,
345 end: handle_signed_inclusive_end(idx),
346 step: 1,
347 }
348 }
349
350 /// Creates a slice with a custom step
351 pub fn with_step(start: isize, end: Option<isize>, step: isize) -> Self {
352 assert!(step != 0, "Step cannot be zero");
353 Self { start, end, step }
354 }
355
356 /// Creates a slice from a range with a specified step
357 pub fn from_range_stepped<R: Into<Slice>>(range: R, step: isize) -> Self {
358 assert!(step != 0, "Step cannot be zero");
359 let mut slice = range.into();
360 slice.step = step;
361 slice
362 }
363
364 /// Returns the step of the slice
365 pub fn step(&self) -> isize {
366 self.step
367 }
368
369 /// Returns the range for this slice given a dimension size
370 pub fn range(&self, size: usize) -> Range<usize> {
371 self.to_range(size)
372 }
373
374 /// Convert this slice to a range for a dimension of the given size.
375 ///
376 /// # Arguments
377 ///
378 /// * `size` - The size of the dimension to slice.
379 ///
380 /// # Returns
381 ///
382 /// A `Range<usize>` representing the slice bounds.
383 pub fn to_range(&self, size: usize) -> Range<usize> {
384 // Always return a valid range with start <= end
385 // The step information will be handled separately
386 let start = convert_signed_index(self.start, size);
387 let end = match self.end {
388 Some(end) => convert_signed_index(end, size),
389 None => size,
390 };
391 start..end
392 }
393
394 /// Converts the slice into a range and step tuple
395 pub fn to_range_and_step(&self, size: usize) -> (Range<usize>, isize) {
396 let range = self.to_range(size);
397 (range, self.step)
398 }
399
400 /// Returns true if the step is negative
401 pub fn is_reversed(&self) -> bool {
402 self.step < 0
403 }
404
405 /// Calculates the output size for this slice operation
406 pub fn output_size(&self, dim_size: usize) -> usize {
407 let range = self.to_range(dim_size);
408 let len = range.end - range.start;
409 if self.step.unsigned_abs() == 1 {
410 len
411 } else {
412 len.div_ceil(self.step.unsigned_abs())
413 }
414 }
415}
416
417fn convert_signed_index(index: isize, size: usize) -> usize {
418 if index < 0 {
419 (size as isize + index).max(0) as usize
420 } else {
421 (index as usize).min(size)
422 }
423}
424
425fn handle_signed_inclusive_end(end: isize) -> Option<isize> {
426 match end {
427 -1 => None,
428 end => Some(end + 1),
429 }
430}
431
432impl<I: AsIndex> From<Range<I>> for Slice {
433 fn from(r: Range<I>) -> Self {
434 Self {
435 start: r.start.index(),
436 end: Some(r.end.index()),
437 step: 1,
438 }
439 }
440}
441
442impl<I: AsIndex + Copy> From<RangeInclusive<I>> for Slice {
443 fn from(r: RangeInclusive<I>) -> Self {
444 Self {
445 start: (*r.start()).index(),
446 end: handle_signed_inclusive_end((*r.end()).index()),
447 step: 1,
448 }
449 }
450}
451
452impl<I: AsIndex> From<RangeFrom<I>> for Slice {
453 fn from(r: RangeFrom<I>) -> Self {
454 Self {
455 start: r.start.index(),
456 end: None,
457 step: 1,
458 }
459 }
460}
461
462impl<I: AsIndex> From<RangeTo<I>> for Slice {
463 fn from(r: RangeTo<I>) -> Self {
464 Self {
465 start: 0,
466 end: Some(r.end.index()),
467 step: 1,
468 }
469 }
470}
471
472impl<I: AsIndex> From<RangeToInclusive<I>> for Slice {
473 fn from(r: RangeToInclusive<I>) -> Self {
474 Self {
475 start: 0,
476 end: handle_signed_inclusive_end(r.end.index()),
477 step: 1,
478 }
479 }
480}
481
482impl From<RangeFull> for Slice {
483 fn from(_: RangeFull) -> Self {
484 Self {
485 start: 0,
486 end: None,
487 step: 1,
488 }
489 }
490}
491
492impl From<usize> for Slice {
493 fn from(i: usize) -> Self {
494 Slice::index(i as isize)
495 }
496}
497
498impl From<isize> for Slice {
499 fn from(i: isize) -> Self {
500 Slice::index(i)
501 }
502}
503
504impl From<i32> for Slice {
505 fn from(i: i32) -> Self {
506 Slice::index(i as isize)
507 }
508}
509
510#[cfg(test)]
511mod tests {
512 use super::*;
513
514 #[test]
515 fn test_slice_output_size() {
516 // Test the output_size method directly
517 assert_eq!(Slice::new(0, Some(10), 1).output_size(10), 10);
518 assert_eq!(Slice::new(0, Some(10), 2).output_size(10), 5);
519 assert_eq!(Slice::new(0, Some(10), 3).output_size(10), 4); // ceil(10/3)
520 assert_eq!(Slice::new(0, Some(10), -1).output_size(10), 10);
521 assert_eq!(Slice::new(0, Some(10), -2).output_size(10), 5);
522 assert_eq!(Slice::new(2, Some(8), -3).output_size(10), 2); // ceil(6/3)
523 assert_eq!(Slice::new(5, Some(5), 1).output_size(10), 0); // empty range
524 }
525}