1use axonml_core::dtype::{Numeric, Scalar};
27use axonml_core::error::{Error, Result};
28
29use crate::shape::{Shape, numel};
30use crate::tensor::Tensor;
31
32#[derive(Debug, Clone, Copy)]
38pub enum SliceSpec {
39 Index(isize),
41 Range {
43 start: Option<isize>,
45 stop: Option<isize>,
47 step: isize,
49 },
50 All,
52 NewAxis,
54}
55
56impl SliceSpec {
57 #[must_use]
59 pub fn range(start: isize, stop: isize) -> Self {
60 Self::Range {
61 start: Some(start),
62 stop: Some(stop),
63 step: 1,
64 }
65 }
66
67 #[must_use]
69 pub fn range_step(start: isize, stop: isize, step: isize) -> Self {
70 Self::Range {
71 start: Some(start),
72 stop: Some(stop),
73 step,
74 }
75 }
76
77 #[must_use]
79 pub fn from(start: isize) -> Self {
80 Self::Range {
81 start: Some(start),
82 stop: None,
83 step: 1,
84 }
85 }
86
87 #[must_use]
89 pub fn to(stop: isize) -> Self {
90 Self::Range {
91 start: None,
92 stop: Some(stop),
93 step: 1,
94 }
95 }
96}
97
98impl<T: Scalar> Tensor<T> {
103 pub fn slice_dim0(&self, start: usize, end: usize) -> Result<Self> {
109 if self.ndim() == 0 {
110 return Err(Error::invalid_operation("Cannot slice a scalar"));
111 }
112
113 let dim_size = self.shape[0];
114 if start > end || end > dim_size {
115 return Err(Error::IndexOutOfBounds {
116 index: end,
117 size: dim_size,
118 });
119 }
120
121 let mut new_shape = self.shape.clone();
122 new_shape[0] = end - start;
123
124 let new_offset = self.offset + start * self.strides[0] as usize;
125
126 Ok(Self {
127 storage: self.storage.clone(),
128 shape: new_shape,
129 strides: self.strides.clone(),
130 offset: new_offset,
131 })
132 }
133
134 pub fn select(&self, dim: usize, index: usize) -> Result<Self> {
142 if dim >= self.ndim() {
143 return Err(Error::InvalidDimension {
144 index: dim as i64,
145 ndim: self.ndim(),
146 });
147 }
148
149 if index >= self.shape[dim] {
150 return Err(Error::IndexOutOfBounds {
151 index,
152 size: self.shape[dim],
153 });
154 }
155
156 let mut new_shape = self.shape.clone();
157 new_shape.remove(dim);
158
159 let mut new_strides = self.strides.clone();
160 new_strides.remove(dim);
161
162 let new_offset = self.offset + index * self.strides[dim] as usize;
163
164 Ok(Self {
165 storage: self.storage.clone(),
166 shape: new_shape,
167 strides: new_strides,
168 offset: new_offset,
169 })
170 }
171
172 pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
179 if dim >= self.ndim() {
180 return Err(Error::InvalidDimension {
181 index: dim as i64,
182 ndim: self.ndim(),
183 });
184 }
185
186 if start + length > self.shape[dim] {
187 return Err(Error::IndexOutOfBounds {
188 index: start + length,
189 size: self.shape[dim],
190 });
191 }
192
193 let mut new_shape = self.shape.clone();
194 new_shape[dim] = length;
195
196 let new_offset = self.offset + start * self.strides[dim] as usize;
197
198 Ok(Self {
199 storage: self.storage.clone(),
200 shape: new_shape,
201 strides: self.strides.clone(),
202 offset: new_offset,
203 })
204 }
205
206 pub fn chunk(&self, chunks: usize, dim: usize) -> Result<Vec<Self>> {
212 if dim >= self.ndim() {
213 return Err(Error::InvalidDimension {
214 index: dim as i64,
215 ndim: self.ndim(),
216 });
217 }
218
219 let dim_size = self.shape[dim];
220 let chunk_size = dim_size.div_ceil(chunks);
221 let mut result = Vec::with_capacity(chunks);
222
223 let mut start = 0;
224 while start < dim_size {
225 let length = (chunk_size).min(dim_size - start);
226 result.push(self.narrow(dim, start, length)?);
227 start += length;
228 }
229
230 Ok(result)
231 }
232
233 pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Self>> {
239 if dim >= self.ndim() {
240 return Err(Error::InvalidDimension {
241 index: dim as i64,
242 ndim: self.ndim(),
243 });
244 }
245
246 let total: usize = sizes.iter().sum();
247 if total != self.shape[dim] {
248 return Err(Error::invalid_operation(format!(
249 "Split sizes {} don't sum to dimension size {}",
250 total, self.shape[dim]
251 )));
252 }
253
254 let mut result = Vec::with_capacity(sizes.len());
255 let mut start = 0;
256
257 for &size in sizes {
258 result.push(self.narrow(dim, start, size)?);
259 start += size;
260 }
261
262 Ok(result)
263 }
264}
265
266impl<T: Numeric> Tensor<T> {
271 pub fn gather(&self, dim: usize, indices: &Tensor<i64>) -> Result<Self> {
277 if dim >= self.ndim() {
278 return Err(Error::InvalidDimension {
279 index: dim as i64,
280 ndim: self.ndim(),
281 });
282 }
283
284 let output_shape = indices.shape();
287 let mut output_data = vec![T::zero(); numel(output_shape)];
288
289 let indices_data = indices.to_vec();
290 let self_data = self.to_vec();
291
292 for (out_idx, &index) in indices_data.iter().enumerate() {
293 let index = index as usize;
294 if index >= self.shape[dim] {
295 return Err(Error::IndexOutOfBounds {
296 index,
297 size: self.shape[dim],
298 });
299 }
300 output_data[out_idx] = self_data[index];
302 }
303
304 Tensor::from_vec(output_data, output_shape)
305 }
306
307 pub fn masked_select(&self, mask: &[bool]) -> Result<Self> {
312 if mask.len() != self.numel() {
313 return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
314 }
315
316 let data = self.to_vec();
317 let selected: Vec<T> = data
318 .into_iter()
319 .zip(mask.iter())
320 .filter(|(_, m)| **m)
321 .map(|(v, _)| v)
322 .collect();
323
324 let len = selected.len();
325 Tensor::from_vec(selected, &[len])
326 }
327
328 pub fn masked_fill_(&self, mask: &[bool], value: T) -> Result<()> {
334 if mask.len() != self.numel() {
335 return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
336 }
337
338 if !self.is_contiguous() {
339 return Err(Error::NotContiguous);
340 }
341
342 {
343 let mut guard = self.storage.as_slice_mut();
344 for (idx, &m) in mask.iter().enumerate() {
345 if m {
346 guard[self.offset + idx] = value;
347 }
348 }
349 }
350
351 Ok(())
352 }
353}
354
355pub fn cat<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
365 if tensors.is_empty() {
366 return Err(Error::invalid_operation("Cannot concatenate empty list"));
367 }
368
369 let first = &tensors[0];
370 let ndim = first.ndim();
371
372 if dim >= ndim {
373 return Err(Error::InvalidDimension {
374 index: dim as i64,
375 ndim,
376 });
377 }
378
379 for t in tensors.iter().skip(1) {
381 if t.ndim() != ndim {
382 return Err(Error::invalid_operation(
383 "All tensors must have same number of dimensions",
384 ));
385 }
386 for (d, (&s1, &s2)) in first.shape().iter().zip(t.shape().iter()).enumerate() {
387 if d != dim && s1 != s2 {
388 return Err(Error::shape_mismatch(first.shape(), t.shape()));
389 }
390 }
391 }
392
393 let mut output_shape = Shape::from_slice(first.shape());
395 output_shape[dim] = tensors.iter().map(|t| t.shape()[dim]).sum();
396
397 let total_numel = numel(&output_shape);
399 let mut output_data = vec![T::zeroed(); total_numel];
400
401 let mut offset = 0;
403 for t in tensors {
404 let data = t.to_vec();
405 for val in data {
406 output_data[offset] = val;
407 offset += 1;
408 }
409 }
410
411 Tensor::from_vec(output_data, &output_shape)
412}
413
414pub fn stack<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
420 if tensors.is_empty() {
421 return Err(Error::invalid_operation("Cannot stack empty list"));
422 }
423
424 let unsqueezed: Result<Vec<Tensor<T>>> =
426 tensors.iter().map(|t| t.unsqueeze(dim as i64)).collect();
427
428 cat(&unsqueezed?, dim)
429}
430
431#[cfg(test)]
436mod tests {
437 use super::*;
438
439 #[test]
440 fn test_slice_dim0() {
441 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
442
443 let s = t.slice_dim0(1, 3).unwrap();
444 assert_eq!(s.shape(), &[2, 2]);
445 assert_eq!(s.get(&[0, 0]).unwrap(), 3.0);
446 }
447
448 #[test]
449 fn test_select() {
450 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
451
452 let s = t.select(0, 1).unwrap();
453 assert_eq!(s.shape(), &[3]);
454 assert_eq!(s.to_vec(), vec![4.0, 5.0, 6.0]);
455 }
456
457 #[test]
458 fn test_narrow() {
459 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
460
461 let n = t.narrow(0, 1, 3).unwrap();
462 assert_eq!(n.shape(), &[3]);
463 assert_eq!(n.to_vec(), vec![2.0, 3.0, 4.0]);
464 }
465
466 #[test]
467 fn test_chunk() {
468 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
469
470 let chunks = t.chunk(3, 0).unwrap();
471 assert_eq!(chunks.len(), 3);
472 assert_eq!(chunks[0].to_vec(), vec![1.0, 2.0]);
473 assert_eq!(chunks[1].to_vec(), vec![3.0, 4.0]);
474 assert_eq!(chunks[2].to_vec(), vec![5.0, 6.0]);
475 }
476
477 #[test]
478 fn test_cat() {
479 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
480 let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
481
482 let c = cat(&[a, b], 0).unwrap();
483 assert_eq!(c.shape(), &[4]);
484 assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
485 }
486
487 #[test]
488 fn test_stack() {
489 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
490 let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
491
492 let c = stack(&[a, b], 0).unwrap();
493 assert_eq!(c.shape(), &[2, 2]);
494 }
495}