1use axonml_core::dtype::{Numeric, Scalar};
17use axonml_core::error::{Error, Result};
18
19use crate::shape::{numel, Shape};
20use crate::tensor::Tensor;
21
22#[derive(Debug, Clone, Copy)]
28pub enum SliceSpec {
29 Index(isize),
31 Range {
33 start: Option<isize>,
35 stop: Option<isize>,
37 step: isize,
39 },
40 All,
42 NewAxis,
44}
45
46impl SliceSpec {
47 #[must_use] pub fn range(start: isize, stop: isize) -> Self {
49 Self::Range {
50 start: Some(start),
51 stop: Some(stop),
52 step: 1,
53 }
54 }
55
56 #[must_use] pub fn range_step(start: isize, stop: isize, step: isize) -> Self {
58 Self::Range {
59 start: Some(start),
60 stop: Some(stop),
61 step,
62 }
63 }
64
65 #[must_use] pub fn from(start: isize) -> Self {
67 Self::Range {
68 start: Some(start),
69 stop: None,
70 step: 1,
71 }
72 }
73
74 #[must_use] pub fn to(stop: isize) -> Self {
76 Self::Range {
77 start: None,
78 stop: Some(stop),
79 step: 1,
80 }
81 }
82}
83
84impl<T: Scalar> Tensor<T> {
89 pub fn slice_dim0(&self, start: usize, end: usize) -> Result<Self> {
95 if self.ndim() == 0 {
96 return Err(Error::invalid_operation("Cannot slice a scalar"));
97 }
98
99 let dim_size = self.shape[0];
100 if start > end || end > dim_size {
101 return Err(Error::IndexOutOfBounds {
102 index: end,
103 size: dim_size,
104 });
105 }
106
107 let mut new_shape = self.shape.clone();
108 new_shape[0] = end - start;
109
110 let new_offset = self.offset + start * self.strides[0] as usize;
111
112 Ok(Self {
113 storage: self.storage.clone(),
114 shape: new_shape,
115 strides: self.strides.clone(),
116 offset: new_offset,
117 })
118 }
119
120 pub fn select(&self, dim: usize, index: usize) -> Result<Self> {
128 if dim >= self.ndim() {
129 return Err(Error::InvalidDimension {
130 index: dim as i64,
131 ndim: self.ndim(),
132 });
133 }
134
135 if index >= self.shape[dim] {
136 return Err(Error::IndexOutOfBounds {
137 index,
138 size: self.shape[dim],
139 });
140 }
141
142 let mut new_shape = self.shape.clone();
143 new_shape.remove(dim);
144
145 let mut new_strides = self.strides.clone();
146 new_strides.remove(dim);
147
148 let new_offset = self.offset + index * self.strides[dim] as usize;
149
150 Ok(Self {
151 storage: self.storage.clone(),
152 shape: new_shape,
153 strides: new_strides,
154 offset: new_offset,
155 })
156 }
157
158 pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
165 if dim >= self.ndim() {
166 return Err(Error::InvalidDimension {
167 index: dim as i64,
168 ndim: self.ndim(),
169 });
170 }
171
172 if start + length > self.shape[dim] {
173 return Err(Error::IndexOutOfBounds {
174 index: start + length,
175 size: self.shape[dim],
176 });
177 }
178
179 let mut new_shape = self.shape.clone();
180 new_shape[dim] = length;
181
182 let new_offset = self.offset + start * self.strides[dim] as usize;
183
184 Ok(Self {
185 storage: self.storage.clone(),
186 shape: new_shape,
187 strides: self.strides.clone(),
188 offset: new_offset,
189 })
190 }
191
192 pub fn chunk(&self, chunks: usize, dim: usize) -> Result<Vec<Self>> {
198 if dim >= self.ndim() {
199 return Err(Error::InvalidDimension {
200 index: dim as i64,
201 ndim: self.ndim(),
202 });
203 }
204
205 let dim_size = self.shape[dim];
206 let chunk_size = dim_size.div_ceil(chunks);
207 let mut result = Vec::with_capacity(chunks);
208
209 let mut start = 0;
210 while start < dim_size {
211 let length = (chunk_size).min(dim_size - start);
212 result.push(self.narrow(dim, start, length)?);
213 start += length;
214 }
215
216 Ok(result)
217 }
218
219 pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Self>> {
225 if dim >= self.ndim() {
226 return Err(Error::InvalidDimension {
227 index: dim as i64,
228 ndim: self.ndim(),
229 });
230 }
231
232 let total: usize = sizes.iter().sum();
233 if total != self.shape[dim] {
234 return Err(Error::invalid_operation(format!(
235 "Split sizes {} don't sum to dimension size {}",
236 total, self.shape[dim]
237 )));
238 }
239
240 let mut result = Vec::with_capacity(sizes.len());
241 let mut start = 0;
242
243 for &size in sizes {
244 result.push(self.narrow(dim, start, size)?);
245 start += size;
246 }
247
248 Ok(result)
249 }
250}
251
252impl<T: Numeric> Tensor<T> {
257 pub fn gather(&self, dim: usize, indices: &Tensor<i64>) -> Result<Self> {
263 if dim >= self.ndim() {
264 return Err(Error::InvalidDimension {
265 index: dim as i64,
266 ndim: self.ndim(),
267 });
268 }
269
270 let output_shape = indices.shape();
273 let mut output_data = vec![T::zero(); numel(output_shape)];
274
275 let indices_data = indices.to_vec();
276 let self_data = self.to_vec();
277
278 for (out_idx, &index) in indices_data.iter().enumerate() {
279 let index = index as usize;
280 if index >= self.shape[dim] {
281 return Err(Error::IndexOutOfBounds {
282 index,
283 size: self.shape[dim],
284 });
285 }
286 output_data[out_idx] = self_data[index];
288 }
289
290 Tensor::from_vec(output_data, output_shape)
291 }
292
293 pub fn masked_select(&self, mask: &[bool]) -> Result<Self> {
298 if mask.len() != self.numel() {
299 return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
300 }
301
302 let data = self.to_vec();
303 let selected: Vec<T> = data
304 .into_iter()
305 .zip(mask.iter())
306 .filter(|(_, &m)| m)
307 .map(|(v, _)| v)
308 .collect();
309
310 let len = selected.len();
311 Tensor::from_vec(selected, &[len])
312 }
313
314 pub fn masked_fill_(&self, mask: &[bool], value: T) -> Result<()> {
320 if mask.len() != self.numel() {
321 return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
322 }
323
324 if !self.is_contiguous() {
325 return Err(Error::NotContiguous);
326 }
327
328 {
329 let mut guard = self.storage.as_slice_mut();
330 for (idx, &m) in mask.iter().enumerate() {
331 if m {
332 guard[self.offset + idx] = value;
333 }
334 }
335 }
336
337 Ok(())
338 }
339}
340
341pub fn cat<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
351 if tensors.is_empty() {
352 return Err(Error::invalid_operation("Cannot concatenate empty list"));
353 }
354
355 let first = &tensors[0];
356 let ndim = first.ndim();
357
358 if dim >= ndim {
359 return Err(Error::InvalidDimension {
360 index: dim as i64,
361 ndim,
362 });
363 }
364
365 for t in tensors.iter().skip(1) {
367 if t.ndim() != ndim {
368 return Err(Error::invalid_operation(
369 "All tensors must have same number of dimensions",
370 ));
371 }
372 for (d, (&s1, &s2)) in first.shape().iter().zip(t.shape().iter()).enumerate() {
373 if d != dim && s1 != s2 {
374 return Err(Error::shape_mismatch(first.shape(), t.shape()));
375 }
376 }
377 }
378
379 let mut output_shape = Shape::from_slice(first.shape());
381 output_shape[dim] = tensors.iter().map(|t| t.shape()[dim]).sum();
382
383 let total_numel = numel(&output_shape);
385 let mut output_data = vec![T::zeroed(); total_numel];
386
387 let mut offset = 0;
389 for t in tensors {
390 let data = t.to_vec();
391 for val in data {
392 output_data[offset] = val;
393 offset += 1;
394 }
395 }
396
397 Tensor::from_vec(output_data, &output_shape)
398}
399
400pub fn stack<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
406 if tensors.is_empty() {
407 return Err(Error::invalid_operation("Cannot stack empty list"));
408 }
409
410 let unsqueezed: Result<Vec<Tensor<T>>> =
412 tensors.iter().map(|t| t.unsqueeze(dim as i64)).collect();
413
414 cat(&unsqueezed?, dim)
415}
416
417#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn test_slice_dim0() {
427 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
428
429 let s = t.slice_dim0(1, 3).unwrap();
430 assert_eq!(s.shape(), &[2, 2]);
431 assert_eq!(s.get(&[0, 0]).unwrap(), 3.0);
432 }
433
434 #[test]
435 fn test_select() {
436 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
437
438 let s = t.select(0, 1).unwrap();
439 assert_eq!(s.shape(), &[3]);
440 assert_eq!(s.to_vec(), vec![4.0, 5.0, 6.0]);
441 }
442
443 #[test]
444 fn test_narrow() {
445 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
446
447 let n = t.narrow(0, 1, 3).unwrap();
448 assert_eq!(n.shape(), &[3]);
449 assert_eq!(n.to_vec(), vec![2.0, 3.0, 4.0]);
450 }
451
452 #[test]
453 fn test_chunk() {
454 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
455
456 let chunks = t.chunk(3, 0).unwrap();
457 assert_eq!(chunks.len(), 3);
458 assert_eq!(chunks[0].to_vec(), vec![1.0, 2.0]);
459 assert_eq!(chunks[1].to_vec(), vec![3.0, 4.0]);
460 assert_eq!(chunks[2].to_vec(), vec![5.0, 6.0]);
461 }
462
463 #[test]
464 fn test_cat() {
465 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
466 let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
467
468 let c = cat(&[a, b], 0).unwrap();
469 assert_eq!(c.shape(), &[4]);
470 assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
471 }
472
473 #[test]
474 fn test_stack() {
475 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
476 let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
477
478 let c = stack(&[a, b], 0).unwrap();
479 assert_eq!(c.shape(), &[2, 2]);
480 }
481}