1use axonml_core::dtype::{Numeric, Scalar};
18use axonml_core::error::{Error, Result};
19
20use crate::shape::{Shape, numel};
21use crate::tensor::Tensor;
22
23#[derive(Debug, Clone, Copy)]
29pub enum SliceSpec {
30 Index(isize),
32 Range {
34 start: Option<isize>,
36 stop: Option<isize>,
38 step: isize,
40 },
41 All,
43 NewAxis,
45}
46
47impl SliceSpec {
48 #[must_use]
50 pub fn range(start: isize, stop: isize) -> Self {
51 Self::Range {
52 start: Some(start),
53 stop: Some(stop),
54 step: 1,
55 }
56 }
57
58 #[must_use]
60 pub fn range_step(start: isize, stop: isize, step: isize) -> Self {
61 Self::Range {
62 start: Some(start),
63 stop: Some(stop),
64 step,
65 }
66 }
67
68 #[must_use]
70 pub fn from(start: isize) -> Self {
71 Self::Range {
72 start: Some(start),
73 stop: None,
74 step: 1,
75 }
76 }
77
78 #[must_use]
80 pub fn to(stop: isize) -> Self {
81 Self::Range {
82 start: None,
83 stop: Some(stop),
84 step: 1,
85 }
86 }
87}
88
89impl<T: Scalar> Tensor<T> {
94 pub fn slice_dim0(&self, start: usize, end: usize) -> Result<Self> {
100 if self.ndim() == 0 {
101 return Err(Error::invalid_operation("Cannot slice a scalar"));
102 }
103
104 let dim_size = self.shape[0];
105 if start > end || end > dim_size {
106 return Err(Error::IndexOutOfBounds {
107 index: end,
108 size: dim_size,
109 });
110 }
111
112 let mut new_shape = self.shape.clone();
113 new_shape[0] = end - start;
114
115 let new_offset = self.offset + start * self.strides[0] as usize;
116
117 Ok(Self {
118 storage: self.storage.clone(),
119 shape: new_shape,
120 strides: self.strides.clone(),
121 offset: new_offset,
122 })
123 }
124
125 pub fn select(&self, dim: usize, index: usize) -> Result<Self> {
133 if dim >= self.ndim() {
134 return Err(Error::InvalidDimension {
135 index: dim as i64,
136 ndim: self.ndim(),
137 });
138 }
139
140 if index >= self.shape[dim] {
141 return Err(Error::IndexOutOfBounds {
142 index,
143 size: self.shape[dim],
144 });
145 }
146
147 let mut new_shape = self.shape.clone();
148 new_shape.remove(dim);
149
150 let mut new_strides = self.strides.clone();
151 new_strides.remove(dim);
152
153 let new_offset = self.offset + index * self.strides[dim] as usize;
154
155 Ok(Self {
156 storage: self.storage.clone(),
157 shape: new_shape,
158 strides: new_strides,
159 offset: new_offset,
160 })
161 }
162
163 pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
170 if dim >= self.ndim() {
171 return Err(Error::InvalidDimension {
172 index: dim as i64,
173 ndim: self.ndim(),
174 });
175 }
176
177 if start + length > self.shape[dim] {
178 return Err(Error::IndexOutOfBounds {
179 index: start + length,
180 size: self.shape[dim],
181 });
182 }
183
184 let mut new_shape = self.shape.clone();
185 new_shape[dim] = length;
186
187 let new_offset = self.offset + start * self.strides[dim] as usize;
188
189 Ok(Self {
190 storage: self.storage.clone(),
191 shape: new_shape,
192 strides: self.strides.clone(),
193 offset: new_offset,
194 })
195 }
196
197 pub fn chunk(&self, chunks: usize, dim: usize) -> Result<Vec<Self>> {
203 if dim >= self.ndim() {
204 return Err(Error::InvalidDimension {
205 index: dim as i64,
206 ndim: self.ndim(),
207 });
208 }
209
210 let dim_size = self.shape[dim];
211 let chunk_size = dim_size.div_ceil(chunks);
212 let mut result = Vec::with_capacity(chunks);
213
214 let mut start = 0;
215 while start < dim_size {
216 let length = (chunk_size).min(dim_size - start);
217 result.push(self.narrow(dim, start, length)?);
218 start += length;
219 }
220
221 Ok(result)
222 }
223
224 pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Self>> {
230 if dim >= self.ndim() {
231 return Err(Error::InvalidDimension {
232 index: dim as i64,
233 ndim: self.ndim(),
234 });
235 }
236
237 let total: usize = sizes.iter().sum();
238 if total != self.shape[dim] {
239 return Err(Error::invalid_operation(format!(
240 "Split sizes {} don't sum to dimension size {}",
241 total, self.shape[dim]
242 )));
243 }
244
245 let mut result = Vec::with_capacity(sizes.len());
246 let mut start = 0;
247
248 for &size in sizes {
249 result.push(self.narrow(dim, start, size)?);
250 start += size;
251 }
252
253 Ok(result)
254 }
255}
256
257impl<T: Numeric> Tensor<T> {
262 pub fn gather(&self, dim: usize, indices: &Tensor<i64>) -> Result<Self> {
268 if dim >= self.ndim() {
269 return Err(Error::InvalidDimension {
270 index: dim as i64,
271 ndim: self.ndim(),
272 });
273 }
274
275 let output_shape = indices.shape();
278 let mut output_data = vec![T::zero(); numel(output_shape)];
279
280 let indices_data = indices.to_vec();
281 let self_data = self.to_vec();
282
283 for (out_idx, &index) in indices_data.iter().enumerate() {
284 let index = index as usize;
285 if index >= self.shape[dim] {
286 return Err(Error::IndexOutOfBounds {
287 index,
288 size: self.shape[dim],
289 });
290 }
291 output_data[out_idx] = self_data[index];
293 }
294
295 Tensor::from_vec(output_data, output_shape)
296 }
297
298 pub fn masked_select(&self, mask: &[bool]) -> Result<Self> {
303 if mask.len() != self.numel() {
304 return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
305 }
306
307 let data = self.to_vec();
308 let selected: Vec<T> = data
309 .into_iter()
310 .zip(mask.iter())
311 .filter(|(_, m)| **m)
312 .map(|(v, _)| v)
313 .collect();
314
315 let len = selected.len();
316 Tensor::from_vec(selected, &[len])
317 }
318
319 pub fn masked_fill_(&self, mask: &[bool], value: T) -> Result<()> {
325 if mask.len() != self.numel() {
326 return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
327 }
328
329 if !self.is_contiguous() {
330 return Err(Error::NotContiguous);
331 }
332
333 {
334 let mut guard = self.storage.as_slice_mut();
335 for (idx, &m) in mask.iter().enumerate() {
336 if m {
337 guard[self.offset + idx] = value;
338 }
339 }
340 }
341
342 Ok(())
343 }
344}
345
346pub fn cat<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
356 if tensors.is_empty() {
357 return Err(Error::invalid_operation("Cannot concatenate empty list"));
358 }
359
360 let first = &tensors[0];
361 let ndim = first.ndim();
362
363 if dim >= ndim {
364 return Err(Error::InvalidDimension {
365 index: dim as i64,
366 ndim,
367 });
368 }
369
370 for t in tensors.iter().skip(1) {
372 if t.ndim() != ndim {
373 return Err(Error::invalid_operation(
374 "All tensors must have same number of dimensions",
375 ));
376 }
377 for (d, (&s1, &s2)) in first.shape().iter().zip(t.shape().iter()).enumerate() {
378 if d != dim && s1 != s2 {
379 return Err(Error::shape_mismatch(first.shape(), t.shape()));
380 }
381 }
382 }
383
384 let mut output_shape = Shape::from_slice(first.shape());
386 output_shape[dim] = tensors.iter().map(|t| t.shape()[dim]).sum();
387
388 let total_numel = numel(&output_shape);
390 let mut output_data = vec![T::zeroed(); total_numel];
391
392 let mut offset = 0;
394 for t in tensors {
395 let data = t.to_vec();
396 for val in data {
397 output_data[offset] = val;
398 offset += 1;
399 }
400 }
401
402 Tensor::from_vec(output_data, &output_shape)
403}
404
405pub fn stack<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
411 if tensors.is_empty() {
412 return Err(Error::invalid_operation("Cannot stack empty list"));
413 }
414
415 let unsqueezed: Result<Vec<Tensor<T>>> =
417 tensors.iter().map(|t| t.unsqueeze(dim as i64)).collect();
418
419 cat(&unsqueezed?, dim)
420}
421
422#[cfg(test)]
427mod tests {
428 use super::*;
429
430 #[test]
431 fn test_slice_dim0() {
432 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
433
434 let s = t.slice_dim0(1, 3).unwrap();
435 assert_eq!(s.shape(), &[2, 2]);
436 assert_eq!(s.get(&[0, 0]).unwrap(), 3.0);
437 }
438
439 #[test]
440 fn test_select() {
441 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
442
443 let s = t.select(0, 1).unwrap();
444 assert_eq!(s.shape(), &[3]);
445 assert_eq!(s.to_vec(), vec![4.0, 5.0, 6.0]);
446 }
447
448 #[test]
449 fn test_narrow() {
450 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
451
452 let n = t.narrow(0, 1, 3).unwrap();
453 assert_eq!(n.shape(), &[3]);
454 assert_eq!(n.to_vec(), vec![2.0, 3.0, 4.0]);
455 }
456
457 #[test]
458 fn test_chunk() {
459 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
460
461 let chunks = t.chunk(3, 0).unwrap();
462 assert_eq!(chunks.len(), 3);
463 assert_eq!(chunks[0].to_vec(), vec![1.0, 2.0]);
464 assert_eq!(chunks[1].to_vec(), vec![3.0, 4.0]);
465 assert_eq!(chunks[2].to_vec(), vec![5.0, 6.0]);
466 }
467
468 #[test]
469 fn test_cat() {
470 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
471 let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
472
473 let c = cat(&[a, b], 0).unwrap();
474 assert_eq!(c.shape(), &[4]);
475 assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
476 }
477
478 #[test]
479 fn test_stack() {
480 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
481 let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
482
483 let c = stack(&[a, b], 0).unwrap();
484 assert_eq!(c.shape(), &[2, 2]);
485 }
486}