1use std::{fmt::Display, ops::Deref};
2use crate::{AutogradMetaT, Dim, Error, IntTensor, NumDType, Result, WithDType};
3use super::Tensor;
4
5impl<T: WithDType> Tensor<T> {
6 pub fn indexes(&self, indexers: &[Indexer]) -> Result<Self> {
7 let mut x = self.clone();
8 let mut current_dim = 0;
9 for indexer in indexers.iter() {
10 x = match indexer {
11 Indexer::Select(n) => x.narrow(current_dim, *n, 1)?.squeeze(current_dim)?,
12 Indexer::Slice(range) => {
13 let out = x.slice(current_dim, range)?;
14 current_dim += 1;
15 out
16 }
17 Indexer::Boolean(index) => {
18 if index.dims1()? != x.dims()[current_dim] {
19 return Err(Error::BooleanIndexShouldLikeVector(index.shape().clone()))
20 }
21
22 let index = index.to_vec()?
25 .into_iter()
26 .enumerate()
27 .filter(|(_, v)| *v)
28 .map(|(i, _)| i as u32)
29 .collect::<Vec<_>>();
30 let index = Tensor::new(index)?;
31 let out = x.index_select(index, current_dim)?;
32 current_dim += 1;
33 out
34 }
35 };
36 }
37 Ok(x)
38 }
39
40 pub fn get(&self, i: usize) -> Result<Self> {
51 let dims = self.dims();
52 if dims.is_empty() {
53 Ok(self.clone())
54 } else {
55 self.narrow(0, i, 1)?.reshape(&dims[1..])
56 }
57 }
58
59 pub fn index_select<D: Dim>(&self, indexes: impl Into<IntTensor>, dim: D) -> Result<Self> {
60 let indexes: IntTensor = indexes.into();
61 let dim = dim.to_index(self.shape(), "index-select")?;
62 let indexes_len = indexes.shape().dims1()?;
63 let mut dims = self.dims().to_vec();
64 dims[dim] = indexes_len;
65 let meta = T::AutogradMeta::on_index_select_op(self, &indexes, dim);
66 let storage = match indexes {
67 IntTensor::I32(indexes) => self.storage_read()?.index_select(
68 self.layout(),
69 indexes.storage_read()?.deref(),
70 indexes.layout(),
71 dim,
72 )?,
73 IntTensor::U32(indexes) => self.storage_read()?.index_select(
74 self.layout(),
75 indexes.storage_read()?.deref(),
76 indexes.layout(),
77 dim,
78 )?,
79 IntTensor::U8(indexes) => self.storage_read()?.index_select(
80 self.layout(),
81 indexes.storage_read()?.deref(),
82 indexes.layout(),
83 dim,
84 )?,
85 };
86
87 Ok(Self::from_storage(storage, dims, meta))
88 }
89
90 pub fn gather<D: Dim>(&self, indexes: impl Into<IntTensor>, dim: D) -> Result<Self> {
102 let indexes = indexes.into();
103 let dim = dim.to_index(self.shape(), "gather")?;
104 let self_dims = self.dims();
105 let indexes_dims = indexes.dims();
106 let mismatch = if indexes_dims.len() != self_dims.len() {
107 true
108 } else {
109 let mut mismatch = false;
110 for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
111 if i != dim && d1 < d2 {
112 mismatch = true;
113 break;
114 }
115 }
116 mismatch
117 };
118 if mismatch {
119 Err(Error::ShapeMismatchBinaryOp {
120 op: "gather",
121 lhs: self.shape().clone(),
122 rhs: indexes.shape().clone(),
123 })?
124 }
125
126 let storage = match &indexes {
127 IntTensor::I32(idx) => self.storage_read()?.gather(self.layout(), idx.storage_read()?.deref(), idx.layout(), dim)?,
128 IntTensor::U32(idx) => self.storage_read()?.gather(self.layout(), idx.storage_read()?.deref(), idx.layout(), dim)?,
129 IntTensor::U8(idx) => self.storage_read()?.gather(self.layout(), idx.storage_read()?.deref(), idx.layout(), dim)?,
130 };
131
132 let meta = T::AutogradMeta::on_gather_op(self, &indexes, dim);
133 Ok(Self::from_storage(storage, indexes.shape(), meta))
134 }
135}
136
137impl<T: NumDType> Tensor<T> {
138 pub fn index_add<D: Dim>(&self, indexes: impl Into<IntTensor>, source: &Tensor<T>, dim: D) -> Result<Self> {
139 let indexes: IntTensor = indexes.into();
140 let dim = dim.to_index(self.shape(), "index-add")?;
141
142 let source_dims = source.dims();
143 let self_dims = self.dims();
144 if source_dims.len() != self_dims.len() {
145 return Err(Error::ShapeMismatchBinaryOp {
146 op: "index-add",
147 lhs: self.shape().clone(),
148 rhs: source.shape().clone()
149 }.into());
150 }
151
152 let indexes_len = indexes.shape().dims1()?;
153 for (i, (&d_self, &d_src)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
154 if i == dim {
155 if d_src != indexes_len {
156 return Err(Error::ShapeMismatchBinaryOp { op: "index-add (dim mismatch)", lhs: self.shape().clone(), rhs: source.shape().clone() }.into());
157 }
158 } else if d_self != d_src {
159 return Err(Error::ShapeMismatchBinaryOp { op: "index-add", lhs: self.shape().clone(), rhs: source.shape().clone() }.into());
160 }
161 }
162
163 let storage = match &indexes {
164 IntTensor::I32(idx) => self.storage_read()?.index_add(
165 self.layout(),
166 idx.storage_read()?.deref(),
167 idx.layout(),
168 source.storage_read()?.deref(),
169 source.layout(),
170 dim,
171 )?,
172 IntTensor::U32(idx) => self.storage_read()?.index_add(
173 self.layout(),
174 idx.storage_read()?.deref(),
175 idx.layout(),
176 source.storage_read()?.deref(),
177 source.layout(),
178 dim,
179 )?,
180 IntTensor::U8(idx) => self.storage_read()?.index_add(
181 self.layout(),
182 idx.storage_read()?.deref(),
183 idx.layout(),
184 source.storage_read()?.deref(),
185 source.layout(),
186 dim,
187 )?,
188 };
189
190 let meta = T::AutogradMeta::on_index_add_op(self, &indexes, source, dim);
191 Ok(Self::from_storage(storage, self_dims.to_vec(), meta))
192 }
193
194 pub fn scatter_add<D: Dim>(&self, indexes: impl Into<IntTensor>, source: &Self, dim: D) -> Result<Self> {
195 let indexes = indexes.into();
196 let dim = dim.to_index(self.shape(), "scatter-add")?;
197 self.scatter_checks(&indexes, source, dim)?;
198
199 let storage = match &indexes {
200 IntTensor::I32(idx) => self.storage_read()?.scatter_add(
201 self.layout(),
202 idx.storage_read()?.deref(),
203 idx.layout(),
204 source.storage_read()?.deref(),
205 source.layout(),
206 dim,
207 )?,
208 IntTensor::U32(idx) => self.storage_read()?.scatter_add(
209 self.layout(),
210 idx.storage_read()?.deref(),
211 idx.layout(),
212 source.storage_read()?.deref(),
213 source.layout(),
214 dim,
215 )?,
216 IntTensor::U8(idx) => self.storage_read()?.scatter_add(
217 self.layout(),
218 idx.storage_read()?.deref(),
219 idx.layout(),
220 source.storage_read()?.deref(),
221 source.layout(),
222 dim,
223 )?,
224 };
225
226 let meta = T::AutogradMeta::on_scatter_add_op(self, &indexes, source, dim);
227 Ok(Self::from_storage(storage, self.shape(), meta))
228 }
229
230 fn scatter_checks(&self, indexes: &IntTensor, source: &Self, dim: usize) -> Result<()> {
231 let source_dims = source.dims();
232 let self_dims = self.dims();
233 let mismatch = if source_dims.len() != self_dims.len() {
234 true
235 } else {
236 let mut mismatch = false;
237 for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
238 if i != dim && d1 != d2 {
239 mismatch = true;
240 break;
241 }
242 }
243 mismatch
244 };
245 if mismatch {
246 Err(Error::ShapeMismatchBinaryOp {
247 op: "scatter (self, src)",
248 lhs: self.shape().clone(),
249 rhs: source.shape().clone(),
250 })?
251 }
252 if indexes.dims() != source.dims() {
253 Err(Error::ShapeMismatchBinaryOp {
254 op: "scatter (indexes, src)",
255 lhs: indexes.shape().clone(),
256 rhs: source.shape().clone(),
257 })?
258 }
259 Ok(())
260 }
261}
262
263impl<T: WithDType> Tensor<T> {
264 pub fn matrix_get(&self, row: usize, col: usize) -> Result<T> {
265 self.index((row, col))?.to_scalar()
266 }
267
268 pub fn matrix_set(&self, row: usize, col: usize, val: T) -> Result<()> {
269 self.index((row, col))?.set_scalar(val)
270 }
271
272 pub fn vector_get(&self, n: usize) -> Result<T> {
273 self.index(n)?.to_scalar()
274 }
275}
276
277#[derive(Debug, Clone, PartialEq, Eq)]
278pub enum Indexer {
279 Select(usize),
280 Slice(Slice),
281 Boolean(Tensor<bool>),
282}
283
284#[derive(Debug, Clone, PartialEq, Eq)]
285pub struct Slice {
286 pub start: usize,
287 pub end: Option<isize>,
288 pub step: usize
289}
290
291impl Slice {
292 pub fn new(start: usize, end: Option<isize>, step: usize) -> Self {
293 Self { start, end, step }
294 }
295
296 pub fn len(&self) -> usize {
297 self.clone().count()
298 }
299}
300
301impl Iterator for Slice {
302 type Item = usize;
303 fn next(&mut self) -> Option<Self::Item> {
304 match self.end {
305 Some(end) if end < 0 => {
306 let value = self.start;
307 self.start += self.step;
308 Some(value)
309 }
310 Some(end) => {
311 if self.start < end as usize {
312 let value = self.start;
313 self.start += self.step;
314 Some(value)
315 } else {
316 None
317 }
318 }
319 None => {
320 let value = self.start;
321 self.start += self.step;
322 Some(value)
323 }
324 }
325 }
326}
327
328impl Display for Slice {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 let step_part = match self.step {
331 1 => format!(""),
332 _ => format!(":{}", self.step),
333 };
334 match self.end {
335 Some(end) => write!(f, "{}:{}{}", self.start, end, step_part),
336 None => write!(f, "{}:{}", self.start, step_part),
337 }
338 }
339}
340
341impl From<usize> for Indexer {
342 fn from(index: usize) -> Self {
343 Indexer::Select(index)
344 }
345}
346
347impl From<Slice> for Indexer {
348 fn from(value: Slice) -> Self {
349 Indexer::Slice(value)
350 }
351}
352
353impl From<&Tensor<bool>> for Indexer {
354 fn from(value: &Tensor<bool>) -> Self {
355 Indexer::Boolean(value.clone())
356 }
357}
358
359impl From<Tensor<bool>> for Indexer {
360 fn from(value: Tensor<bool>) -> Self {
361 Indexer::Boolean(value)
362 }
363}
364
365impl From<std::ops::Range<usize>> for Indexer {
366 fn from(value: std::ops::Range<usize>) -> Self {
367 let range = Slice::new(value.start, Some(value.end as isize), 1);
368 range.into()
369 }
370}
371
372impl From<std::ops::RangeFrom<usize>> for Indexer {
373 fn from(value: std::ops::RangeFrom<usize>) -> Self {
374 let range = Slice::new(value.start, None, 1);
375 range.into()
376 }
377}
378
379impl From<std::ops::RangeFull> for Indexer {
380 fn from(_: std::ops::RangeFull) -> Self {
381 let range = Slice::new(0, None, 1);
382 range.into()
383 }
384}
385
386pub trait IndexOp<T, D: WithDType> {
387 fn index(&self, index: T) -> Result<Tensor<D>>;
388}
389
390impl<I: Into<Indexer>, D: WithDType> IndexOp<I, D> for Tensor<D> {
391 fn index(&self, index: I) -> Result<Tensor<D>> {
392 self.indexes(&[index.into()])
393 }
394}
395
396impl<I: Into<Indexer>, D: WithDType> IndexOp<(I,), D> for Tensor<D> {
397 fn index(&self, (index,): (I,)) -> Result<Tensor<D>> {
398 self.indexes(&[index.into()])
399 }
400}
401
402macro_rules! index_op_tuple {
403 ($($t:ident),+) => {
404 #[allow(non_snake_case)]
405 impl<$($t),*, D: WithDType> IndexOp<($($t,)*), D> for Tensor<D>
406 where
407 $($t: Into<Indexer>,)*
408 {
409 fn index(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor<D>> {
410 self.indexes(&[$($t.into(),)*])
411 }
412 }
413 };
414}
415
416index_op_tuple!(I1, I2);
417index_op_tuple!(I1, I2, I3);
418index_op_tuple!(I1, I2, I3, I4);
419index_op_tuple!(I1, I2, I3, I4, I5);
420
421impl<I: Into<Indexer>, D: WithDType> IndexOp<Vec<I>, D> for Tensor<D> {
422 fn index(&self, index: Vec<I>) -> Result<Tensor<D>> {
423 let indexs = index.into_iter().map(|i| i.into()).collect::<Vec<Indexer>>();
424 self.indexes(&indexs)
425 }
426}
427
428#[macro_export]
429macro_rules! s {
430 ($start:tt : $end:expr) => {
432 Slice::new($start as usize, Some($end as isize), 1)
433 };
434 ($start:tt : $end:tt : $step:expr) => {
436 Slice::new($start as usize, Some($end as isize), $step as usize)
437 };
438 ($start:tt :) => {
440 Slice::new($start as usize, None, 1)
441 };
442 ($start:tt :: $step:expr) => {
444 Slice::new($start as usize, None, $step as usize)
445 };
446 (: $end:tt) => {
448 Slice::new(0, Some($end as isize), 1)
449 };
450 (: $end:tt : $step:expr) => {
452 Slice::new(0, Some($end as isize), $step as usize)
453 };
454 (:: $step:expr) => {
456 Slice::new(0, None, $step as usize)
457 };
458 (:) => {
460 Slice::new(0, None, 1)
461 };
462}
463
464#[cfg(test)]
465#[allow(unused)]
466mod test {
467 use crate::DType;
468 use super::*;
469
470 #[test]
471 fn test_index_select_basic() {
472 let arr = Tensor::arange(0, 12).unwrap().reshape((3, 4)).unwrap();
476
477 let indices = Tensor::new(&[0, 2]).unwrap();
478 let selected = arr.index_select(indices, 0).unwrap();
479
480 assert_eq!(selected.shape().dims(), &[2, 4]);
481 let data = selected.to_vec().unwrap();
482 assert_eq!(data, vec![0, 1, 2, 3, 8, 9, 10, 11]);
483
484 let indices_col = Tensor::new(&[1]).unwrap();
485 let selected_col = arr.index_select(indices_col, 1).unwrap();
486
487 assert_eq!(selected_col.shape().dims(), &[3, 1]);
488 let data_col = selected_col.to_vec().unwrap();
489 assert_eq!(data_col, vec![1, 5, 9]);
490 }
491
492 #[test]
493 fn test_boolean_index() {
494 let arr = Tensor::arange(0, 12).unwrap().reshape((3, 4)).unwrap();
498 let index = Tensor::new([true, false, true]).unwrap();
499
500 let selected = arr.index(index).unwrap();
501 println!("{}", selected);
502 }
503
504 #[test]
505 fn test_index_select_1d() {
506 let scores = Tensor::<f64>::arange(0.0, 100.0).unwrap();
507 let indices = Tensor::new(&[25, 34, 12, 90]).unwrap();
508
509 let selected = scores.index_select(indices, 0).unwrap();
510 println!("{}", selected);
511 }
512
513 #[test]
514 fn test_index_select_duplicates_and_reorder() {
515 let arr = Tensor::arange(0, 5).unwrap(); let indices = Tensor::new(&[4, 0, 0, 1]).unwrap();
518 let selected = arr.index_select(indices, 0).unwrap();
519
520 assert_eq!(selected.shape().dims(), &[4]);
521 let data = selected.to_vec().unwrap();
522 assert_eq!(data, vec![4, 0, 0, 1]);
523 }
524
525 #[test]
526 fn test_index_select_out_of_bounds() {
527 let arr = Tensor::arange(0, 10).unwrap();
528 let indices = Tensor::new(&[0, 10]).unwrap();
529
530 let result = arr.index_select(indices, 0);
531 assert!(result.is_err());
532 }
533
534 #[test]
535 fn test_index_add_basic() {
536 let dst = Tensor::<i32>::zeros((3, 3)).unwrap();
537 let src = Tensor::<i32>::ones((2, 3)).unwrap();
538 let indices = Tensor::new(&[0, 2]).unwrap();
539
540 let result = dst.index_add(indices, &src, 0).unwrap();
541
542 let data = result.to_vec().unwrap();
546 assert_eq!(data, vec![
547 1, 1, 1,
548 0, 0, 0,
549 1, 1, 1
550 ]);
551 }
552
553 #[test]
554 fn test_index_add_accumulate() {
555 let dst = Tensor::<i32>::zeros((5,)).unwrap(); let src = Tensor::new(&[10, 20, 30]).unwrap();
557 let indices = Tensor::new(&[1, 1, 3]).unwrap();
558
559 let result = dst.index_add(indices, &src, 0).unwrap();
560
561 let data = result.to_vec().unwrap();
567 assert_eq!(data, vec![0, 30, 0, 30, 0]);
568 }
569
570 #[test]
571 fn test_index_add_dim_mismatch() {
572 let dst = Tensor::<i32>::zeros((3, 3)).unwrap();
573 let src = Tensor::<i32>::ones((2, 3)).unwrap();
574 let indices = Tensor::new(&[0, 1, 2]).unwrap();
575
576 let result = dst.index_add(indices, &src, 0);
577 assert!(result.is_err());
578 }
579
580 #[test]
581 fn test_index_add_inner_dim() {
582 let dst = Tensor::<i32>::zeros((2, 3)).unwrap();
586 let src = Tensor::new(&[
588 [5],
589 [5]
590 ]).unwrap(); let indices = Tensor::new(&[1]).unwrap();
592 let result = dst.index_add(indices, &src, 1).unwrap();
593
594 let data = result.to_vec().unwrap();
598 assert_eq!(data, vec![0, 5, 0, 0, 5, 0]);
599 }
600
601
602 #[test]
603 fn test_gather_dim_1() {
604 let src = Tensor::new(&[
608 [1, 2],
609 [3, 4]
610 ]).unwrap();
611
612 let indices = Tensor::new(&[
617 [0, 0],
618 [1, 0]
619 ]).unwrap();
620
621 let result = src.gather(&indices, 1).unwrap();
622
623 let data = result.to_vec().unwrap();
627 assert_eq!(data, vec![1, 1, 4, 3]);
628 }
629
630 #[test]
631 fn test_gather_dim_0() {
632 let src = Tensor::new(&[
637 [10, 20],
638 [30, 40],
639 [50, 60]
640 ]).unwrap();
641
642 let indices = Tensor::new(&[
646 [1, 2],
647 [0, 1]
648 ]).unwrap();
649
650 let result = src.gather(&indices, 0).unwrap();
651
652 let data = result.to_vec().unwrap();
656 assert_eq!(data, vec![30, 60, 10, 40]);
657 }
658
659 #[test]
660 fn test_gather_3d() {
661 let src = Tensor::new(&[0, 1, 2, 3, 4, 5, 6, 7]).unwrap().reshape((2, 2, 2)).unwrap();
665
666 let indices = Tensor::<u32>::zeros((2, 1, 2)).unwrap(); let result = src.gather(&indices, 1).unwrap();
678
679 assert_eq!(result.dims(), &[2, 1, 2]);
680 let data = result.to_vec().unwrap();
681 assert_eq!(data, vec![0, 1, 4, 5]);
682 }
683
684
685 #[test]
686 fn test_scatter_add_1d_accumulate() {
687 let dst = Tensor::<i32>::zeros((5,)).unwrap();
689
690 let src = Tensor::new(&[1, 1, 1, 1]).unwrap();
692 let indices = Tensor::new(&[0, 2, 0, 4]).unwrap();
693
694 let result = dst.scatter_add(indices, &src, 0).unwrap();
700
701 let data = result.to_vec().unwrap();
703 assert_eq!(data, vec![2, 0, 1, 0, 1]);
704 }
705
706 #[test]
707 fn test_scatter_add_2d_dim1() {
708 let dst = Tensor::<i32>::zeros((2, 3)).unwrap();
710
711 let src = Tensor::new(&[
715 [10, 20],
716 [30, 40]
717 ]).unwrap();
718
719 let indices = Tensor::new(&[
723 [0, 2],
724 [1, 1]
725 ]).unwrap();
726
727 let result = dst.scatter_add(indices, &src, 1).unwrap();
728
729 let data = result.to_vec().unwrap();
733 assert_eq!(data, vec![10, 0, 20, 0, 70, 0]);
734 }
735
736 #[test]
737 fn test_scatter_add_3d() {
738 let dst = Tensor::<i32>::zeros((2, 2, 2)).unwrap();
740
741 let src = Tensor::ones((2, 1, 2)).unwrap(); let indices = Tensor::new(&[1, 0, 0, 0]).unwrap().reshape((2, 1, 2)).unwrap();
748
749 let result = dst.scatter_add(indices, &src, 1).unwrap();
750
751 let res_vec = result.to_vec().unwrap();
758 assert_eq!(res_vec, vec![0, 1, 1, 0, 1, 1, 0, 0]);
767 }
768
769 #[test]
770 fn test_scatter_add_shape_mismatch() {
771 let dst = Tensor::<i32>::zeros((2, 2)).unwrap();
772 let src = Tensor::<i32>::ones((2, 2)).unwrap();
773 let indices = Tensor::new(&[0]).unwrap();
775
776 let result = dst.scatter_add(indices, &src, 0);
777 assert!(result.is_err());
778 }
779
780 #[test]
781 fn test_index_scalar_dim_reduction() {
782 let arr = Tensor::arange(0, 125).unwrap().reshape((5, 5, 5)).unwrap();
783 let sub = arr.index(1).unwrap();
784 assert_eq!(sub.shape().dims(), &[5, 5]);
785
786 let sub = arr.index((2, 3)).unwrap();
787 assert_eq!(sub.shape().dims(), &[5]);
788 }
789
790 #[test]
791 fn test_index_range_basic() {
792 let arr = Tensor::arange(0, 125).unwrap().reshape((5, 5, 5)).unwrap();
793
794 let sub = arr.index(s!(1:3)).unwrap();
795 assert_eq!(sub.shape().dims(), &[2, 5, 5]);
796
797 let sub = arr.index((s!(1:3), s!(3:4), 1)).unwrap();
798 assert_eq!(sub.shape().dims(), &[2, 1]);
799 }
800
801 #[test]
802 fn test_index_full_and_mixed() {
803 let arr = Tensor::<i32>::zeros((5, 5, 5)).unwrap();
804
805 let sub = arr.index((s!(1:3), .., 1..2)).unwrap();
806 assert_eq!(sub.shape().dims(), &[2, 5, 1]);
807
808 let sub = arr.index((2, .., s!(0:2))).unwrap();
809 assert_eq!(sub.shape().dims(), &[5, 2]);
810
811 let sub = arr.index((s!(0:2), s!(2:5), s!(1:3))).unwrap();
812 assert_eq!(sub.shape().dims(), &[2, 3, 2]);
813 }
814
815 #[test]
816 fn test_index_out_of_bounds() {
817 let arr = Tensor::<i32>::zeros((5, 5, 5)).unwrap();
818 let result = arr.index(10);
819 assert!(result.is_err());
820
821 let result = arr.index(s!(3:10));
822 assert!(result.is_err());
823 }
824
825 #[test]
826 fn test_index_scalar_and_values() {
827 let arr = Tensor::arange(0, 125).unwrap().reshape((5, 5, 5)).unwrap();
828
829 let sub = arr.index(1).unwrap();
830 let expected = Tensor::arange(25, 50).unwrap().reshape((5, 5)).unwrap();
831 assert!(sub.allclose(&expected, 0.0, 0.0).unwrap());
832 }
833
834 #[test]
835 fn test_index_range_values() {
836 let arr = Tensor::arange(0, 125).unwrap().reshape((5, 5, 5)).unwrap();
837
838 let sub = arr.index(s!(1:3)).unwrap();
839 let expected = Tensor::arange(25, 75).unwrap().reshape((2, 5, 5)).unwrap();
840 assert!(sub.allclose(&expected, 0.0, 0.0).unwrap());
841 }
842
843 #[test]
844 fn test_index_mixed_values() {
845 let arr = Tensor::arange(0, 125).unwrap().reshape((5, 5, 5)).unwrap();
846 let sub = arr.index((2, 3)).unwrap();
847 let expected = Tensor::arange(65, 70).unwrap();
848 assert!(sub.allclose(&expected, 0.0, 0.0).unwrap());
849
850 let sub = arr.index((s!(1:3), s!(3:5), 2)).unwrap();
851 let mut vals = Vec::new();
852 for i in 1..3 {
853 for j in 3..5 {
854 vals.push(i * 25 + j * 5 + 2);
855 }
856 }
857 let expected = Tensor::from_vec(vals, (2, 2)).unwrap();
858 assert!(sub.allclose(&expected, 0.0, 0.0).unwrap());
859 }
860
861 #[test]
862 fn test_index_with_full_dim() {
863 let arr = Tensor::arange(0, 125).unwrap().reshape((5, 5, 5)).unwrap();
864 let sub = arr.index((s!(1:3), .., 1..2)).unwrap();
865
866 let expected = arr.index((s!(1:3), s!(0:5), s!(1:2))).unwrap();
867 assert!(sub.allclose(&expected, 0.0, 0.0).unwrap());
868 }
869
870 #[test]
871 fn test_macro() {
872 let t = (0..12usize);
873 let t = (2usize..);
874 assert_eq!(s!(1:10), Slice {start:1, end: Some(10), step:1});
875
876 assert!(
877 s!(1:20).zip((1..20))
878 .all(|(a, b)| a == b)
879 );
880
881 assert!(
882 s!(1:13:3).zip((1..13).step_by(3))
883 .all(|(a, b)| a == b)
884 );
885
886 assert!(
887 s!(1:).zip((1..).take(100))
888 .all(|(a, b)| a == b)
889 );
890
891 assert!(
892 s!(1::2).zip((1..).step_by(2).take(100))
893 .all(|(a, b)| a == b)
894 );
895
896 assert!(
897 s!(:20).zip((0..20usize))
898 .all(|(a, b)| a == b)
899 );
900
901 assert!(
902 s!(:20:5).zip((0..20usize).step_by(5))
903 .all(|(a, b)| a == b)
904 );
905
906 assert!(
907 s!(::2).zip((0..).step_by(2).take(100))
908 .all(|(a, b)| a == b)
909 );
910
911 assert!(
912 s!(:).zip((0..).take(100))
913 .all(|(a, b)| a == b)
914 );
915 }
916}
917