1use axonml_core::dtype::{Numeric, Scalar};
17use axonml_core::error::{Error, Result};
18
19use crate::shape::{numel, Shape};
20use crate::tensor::Tensor;
21
22#[derive(Debug, Clone, Copy)]
28pub enum SliceSpec {
29 Index(isize),
31 Range {
33 start: Option<isize>,
35 stop: Option<isize>,
37 step: isize,
39 },
40 All,
42 NewAxis,
44}
45
46impl SliceSpec {
47 #[must_use]
49 pub fn range(start: isize, stop: isize) -> Self {
50 Self::Range {
51 start: Some(start),
52 stop: Some(stop),
53 step: 1,
54 }
55 }
56
57 #[must_use]
59 pub fn range_step(start: isize, stop: isize, step: isize) -> Self {
60 Self::Range {
61 start: Some(start),
62 stop: Some(stop),
63 step,
64 }
65 }
66
67 #[must_use]
69 pub fn from(start: isize) -> Self {
70 Self::Range {
71 start: Some(start),
72 stop: None,
73 step: 1,
74 }
75 }
76
77 #[must_use]
79 pub fn to(stop: isize) -> Self {
80 Self::Range {
81 start: None,
82 stop: Some(stop),
83 step: 1,
84 }
85 }
86}
87
88impl<T: Scalar> Tensor<T> {
93 pub fn slice_dim0(&self, start: usize, end: usize) -> Result<Self> {
99 if self.ndim() == 0 {
100 return Err(Error::invalid_operation("Cannot slice a scalar"));
101 }
102
103 let dim_size = self.shape[0];
104 if start > end || end > dim_size {
105 return Err(Error::IndexOutOfBounds {
106 index: end,
107 size: dim_size,
108 });
109 }
110
111 let mut new_shape = self.shape.clone();
112 new_shape[0] = end - start;
113
114 let new_offset = self.offset + start * self.strides[0] as usize;
115
116 Ok(Self {
117 storage: self.storage.clone(),
118 shape: new_shape,
119 strides: self.strides.clone(),
120 offset: new_offset,
121 })
122 }
123
124 pub fn select(&self, dim: usize, index: usize) -> Result<Self> {
132 if dim >= self.ndim() {
133 return Err(Error::InvalidDimension {
134 index: dim as i64,
135 ndim: self.ndim(),
136 });
137 }
138
139 if index >= self.shape[dim] {
140 return Err(Error::IndexOutOfBounds {
141 index,
142 size: self.shape[dim],
143 });
144 }
145
146 let mut new_shape = self.shape.clone();
147 new_shape.remove(dim);
148
149 let mut new_strides = self.strides.clone();
150 new_strides.remove(dim);
151
152 let new_offset = self.offset + index * self.strides[dim] as usize;
153
154 Ok(Self {
155 storage: self.storage.clone(),
156 shape: new_shape,
157 strides: new_strides,
158 offset: new_offset,
159 })
160 }
161
162 pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self> {
169 if dim >= self.ndim() {
170 return Err(Error::InvalidDimension {
171 index: dim as i64,
172 ndim: self.ndim(),
173 });
174 }
175
176 if start + length > self.shape[dim] {
177 return Err(Error::IndexOutOfBounds {
178 index: start + length,
179 size: self.shape[dim],
180 });
181 }
182
183 let mut new_shape = self.shape.clone();
184 new_shape[dim] = length;
185
186 let new_offset = self.offset + start * self.strides[dim] as usize;
187
188 Ok(Self {
189 storage: self.storage.clone(),
190 shape: new_shape,
191 strides: self.strides.clone(),
192 offset: new_offset,
193 })
194 }
195
196 pub fn chunk(&self, chunks: usize, dim: usize) -> Result<Vec<Self>> {
202 if dim >= self.ndim() {
203 return Err(Error::InvalidDimension {
204 index: dim as i64,
205 ndim: self.ndim(),
206 });
207 }
208
209 let dim_size = self.shape[dim];
210 let chunk_size = dim_size.div_ceil(chunks);
211 let mut result = Vec::with_capacity(chunks);
212
213 let mut start = 0;
214 while start < dim_size {
215 let length = (chunk_size).min(dim_size - start);
216 result.push(self.narrow(dim, start, length)?);
217 start += length;
218 }
219
220 Ok(result)
221 }
222
223 pub fn split(&self, sizes: &[usize], dim: usize) -> Result<Vec<Self>> {
229 if dim >= self.ndim() {
230 return Err(Error::InvalidDimension {
231 index: dim as i64,
232 ndim: self.ndim(),
233 });
234 }
235
236 let total: usize = sizes.iter().sum();
237 if total != self.shape[dim] {
238 return Err(Error::invalid_operation(format!(
239 "Split sizes {} don't sum to dimension size {}",
240 total, self.shape[dim]
241 )));
242 }
243
244 let mut result = Vec::with_capacity(sizes.len());
245 let mut start = 0;
246
247 for &size in sizes {
248 result.push(self.narrow(dim, start, size)?);
249 start += size;
250 }
251
252 Ok(result)
253 }
254}
255
256impl<T: Numeric> Tensor<T> {
261 pub fn gather(&self, dim: usize, indices: &Tensor<i64>) -> Result<Self> {
267 if dim >= self.ndim() {
268 return Err(Error::InvalidDimension {
269 index: dim as i64,
270 ndim: self.ndim(),
271 });
272 }
273
274 let output_shape = indices.shape();
277 let mut output_data = vec![T::zero(); numel(output_shape)];
278
279 let indices_data = indices.to_vec();
280 let self_data = self.to_vec();
281
282 for (out_idx, &index) in indices_data.iter().enumerate() {
283 let index = index as usize;
284 if index >= self.shape[dim] {
285 return Err(Error::IndexOutOfBounds {
286 index,
287 size: self.shape[dim],
288 });
289 }
290 output_data[out_idx] = self_data[index];
292 }
293
294 Tensor::from_vec(output_data, output_shape)
295 }
296
297 pub fn masked_select(&self, mask: &[bool]) -> Result<Self> {
302 if mask.len() != self.numel() {
303 return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
304 }
305
306 let data = self.to_vec();
307 let selected: Vec<T> = data
308 .into_iter()
309 .zip(mask.iter())
310 .filter(|(_, &m)| m)
311 .map(|(v, _)| v)
312 .collect();
313
314 let len = selected.len();
315 Tensor::from_vec(selected, &[len])
316 }
317
318 pub fn masked_fill_(&self, mask: &[bool], value: T) -> Result<()> {
324 if mask.len() != self.numel() {
325 return Err(Error::shape_mismatch(&[mask.len()], &[self.numel()]));
326 }
327
328 if !self.is_contiguous() {
329 return Err(Error::NotContiguous);
330 }
331
332 {
333 let mut guard = self.storage.as_slice_mut();
334 for (idx, &m) in mask.iter().enumerate() {
335 if m {
336 guard[self.offset + idx] = value;
337 }
338 }
339 }
340
341 Ok(())
342 }
343}
344
345pub fn cat<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
355 if tensors.is_empty() {
356 return Err(Error::invalid_operation("Cannot concatenate empty list"));
357 }
358
359 let first = &tensors[0];
360 let ndim = first.ndim();
361
362 if dim >= ndim {
363 return Err(Error::InvalidDimension {
364 index: dim as i64,
365 ndim,
366 });
367 }
368
369 for t in tensors.iter().skip(1) {
371 if t.ndim() != ndim {
372 return Err(Error::invalid_operation(
373 "All tensors must have same number of dimensions",
374 ));
375 }
376 for (d, (&s1, &s2)) in first.shape().iter().zip(t.shape().iter()).enumerate() {
377 if d != dim && s1 != s2 {
378 return Err(Error::shape_mismatch(first.shape(), t.shape()));
379 }
380 }
381 }
382
383 let mut output_shape = Shape::from_slice(first.shape());
385 output_shape[dim] = tensors.iter().map(|t| t.shape()[dim]).sum();
386
387 let total_numel = numel(&output_shape);
389 let mut output_data = vec![T::zeroed(); total_numel];
390
391 let mut offset = 0;
393 for t in tensors {
394 let data = t.to_vec();
395 for val in data {
396 output_data[offset] = val;
397 offset += 1;
398 }
399 }
400
401 Tensor::from_vec(output_data, &output_shape)
402}
403
404pub fn stack<T: Scalar>(tensors: &[Tensor<T>], dim: usize) -> Result<Tensor<T>> {
410 if tensors.is_empty() {
411 return Err(Error::invalid_operation("Cannot stack empty list"));
412 }
413
414 let unsqueezed: Result<Vec<Tensor<T>>> =
416 tensors.iter().map(|t| t.unsqueeze(dim as i64)).collect();
417
418 cat(&unsqueezed?, dim)
419}
420
421#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn test_slice_dim0() {
431 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2]).unwrap();
432
433 let s = t.slice_dim0(1, 3).unwrap();
434 assert_eq!(s.shape(), &[2, 2]);
435 assert_eq!(s.get(&[0, 0]).unwrap(), 3.0);
436 }
437
438 #[test]
439 fn test_select() {
440 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
441
442 let s = t.select(0, 1).unwrap();
443 assert_eq!(s.shape(), &[3]);
444 assert_eq!(s.to_vec(), vec![4.0, 5.0, 6.0]);
445 }
446
447 #[test]
448 fn test_narrow() {
449 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).unwrap();
450
451 let n = t.narrow(0, 1, 3).unwrap();
452 assert_eq!(n.shape(), &[3]);
453 assert_eq!(n.to_vec(), vec![2.0, 3.0, 4.0]);
454 }
455
456 #[test]
457 fn test_chunk() {
458 let t = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]).unwrap();
459
460 let chunks = t.chunk(3, 0).unwrap();
461 assert_eq!(chunks.len(), 3);
462 assert_eq!(chunks[0].to_vec(), vec![1.0, 2.0]);
463 assert_eq!(chunks[1].to_vec(), vec![3.0, 4.0]);
464 assert_eq!(chunks[2].to_vec(), vec![5.0, 6.0]);
465 }
466
467 #[test]
468 fn test_cat() {
469 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
470 let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
471
472 let c = cat(&[a, b], 0).unwrap();
473 assert_eq!(c.shape(), &[4]);
474 assert_eq!(c.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
475 }
476
477 #[test]
478 fn test_stack() {
479 let a = Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).unwrap();
480 let b = Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).unwrap();
481
482 let c = stack(&[a, b], 0).unwrap();
483 assert_eq!(c.shape(), &[2, 2]);
484 }
485}