burn_tensor/tensor/api/
slice.rs1use core::ops::{Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
2
3#[macro_export]
22macro_rules! s {
23 [$range:expr] => {
24 $crate::Slice::from($range)
25 };
26
27 [$($range:expr),+] => {
28 [$($crate::Slice::from($range)),+]
29 };
30}
31
32#[derive(new, Clone, Debug)]
40pub struct Slice {
41 start: isize,
43 end: Option<isize>,
45}
46
47impl Slice {
48 pub fn index(idx: isize) -> Self {
50 Self {
51 start: idx,
52 end: handle_signed_inclusive_end(idx),
53 }
54 }
55
56 pub(crate) fn into_range(self, size: usize) -> Range<usize> {
57 let start = convert_signed_index(self.start, size);
58
59 let end = match self.end {
60 Some(end) => convert_signed_index(end, size),
61 None => size,
62 };
63
64 start..end
65 }
66}
67
68fn convert_signed_index(index: isize, size: usize) -> usize {
69 if index < 0 {
70 (size as isize + index).max(0) as usize
71 } else {
72 (index as usize).min(size)
73 }
74}
75
76fn handle_signed_inclusive_end(end: isize) -> Option<isize> {
77 match end {
78 -1 => None,
79 end => Some(end + 1),
80 }
81}
82
83pub trait IndexConversion {
85 fn index(self) -> isize;
87}
88
89impl IndexConversion for usize {
90 fn index(self) -> isize {
91 self as isize
92 }
93}
94
95impl IndexConversion for isize {
96 fn index(self) -> isize {
97 self
98 }
99}
100
101impl IndexConversion for i32 {
103 fn index(self) -> isize {
104 self as isize
105 }
106}
107
108impl<I: IndexConversion> From<Range<I>> for Slice {
109 fn from(r: Range<I>) -> Self {
110 Self {
111 start: r.start.index(),
112 end: Some(r.end.index()),
113 }
114 }
115}
116
117impl<I: IndexConversion + Copy> From<RangeInclusive<I>> for Slice {
118 fn from(r: RangeInclusive<I>) -> Self {
119 Self {
120 start: (*r.start()).index(),
121 end: handle_signed_inclusive_end((*r.end()).index()),
122 }
123 }
124}
125
126impl<I: IndexConversion> From<RangeFrom<I>> for Slice {
127 fn from(r: RangeFrom<I>) -> Self {
128 Self {
129 start: r.start.index(),
130 end: None,
131 }
132 }
133}
134
135impl<I: IndexConversion> From<RangeTo<I>> for Slice {
136 fn from(r: RangeTo<I>) -> Self {
137 Self {
138 start: 0,
139 end: Some(r.end.index()),
140 }
141 }
142}
143
144impl<I: IndexConversion> From<RangeToInclusive<I>> for Slice {
145 fn from(r: RangeToInclusive<I>) -> Self {
146 Self {
147 start: 0,
148 end: handle_signed_inclusive_end(r.end.index()),
149 }
150 }
151}
152
153impl From<RangeFull> for Slice {
154 fn from(_: RangeFull) -> Self {
155 Self {
156 start: 0,
157 end: None,
158 }
159 }
160}
161
162impl From<usize> for Slice {
163 fn from(i: usize) -> Self {
164 Slice::index(i as isize)
165 }
166}
167
168impl From<isize> for Slice {
169 fn from(i: isize) -> Self {
170 Slice::index(i)
171 }
172}
173
174impl From<i32> for Slice {
175 fn from(i: i32) -> Self {
176 Slice::index(i as isize)
177 }
178}