Skip to main content

lumen_core/tensor/
indexer.rs

1use std::{fmt::Display, ops::Deref};
2use crate::{AutogradMetaT, Dim, Error, IntTensor, NumDType, Result, WithDType};
3use super::Tensor;
4
5impl<T: WithDType> Tensor<T> {
6    pub fn indexes(&self, indexers: &[Indexer]) -> Result<Self> {
7        let mut x = self.clone();
8        let mut current_dim = 0;
9        for indexer in indexers.iter() {
10            x = match indexer {
11                Indexer::Select(n) => x.narrow(current_dim, *n, 1)?.squeeze(current_dim)?,
12                Indexer::Slice(range) => {
13                    let out = x.slice(current_dim, range)?;
14                    current_dim += 1;
15                    out
16                }
17                Indexer::Boolean(index) => {
18                    if index.dims1()? != x.dims()[current_dim] {
19                        return Err(Error::BooleanIndexShouldLikeVector(index.shape().clone()))
20                    }
21
22                    // index is  a 1'dim, len == x.dims(current_dim)
23                    // TODO: better
24                    let index = index.to_vec()?
25                        .into_iter()
26                        .enumerate()
27                        .filter(|(_, v)| *v)
28                        .map(|(i, _)| i as u32)
29                        .collect::<Vec<_>>();
30                    let index = Tensor::new(index)?;
31                    let out = x.index_select(index, current_dim)?;
32                    current_dim += 1;
33                    out
34                }
35            };
36        }
37        Ok(x)
38    }
39
40    /// Returns the sub-tensor fixing the index at `i` on the first dimension.
41    ///
42    /// ```rust
43    /// use lumen_core::Tensor;
44    /// let tensor = Tensor::<f32>::new(&[[0f32, 1.], [2., 3.], [4., 5.]]).unwrap();
45    /// let t = tensor.get(0).unwrap();
46    /// assert_eq!(t.to_vec().unwrap(), &[0., 1.]);
47    /// let t = tensor.get(1).unwrap();
48    /// assert_eq!(t.to_vec().unwrap(), &[2., 3.]);
49    /// ```
50    pub fn get(&self, i: usize) -> Result<Self> {
51        let dims = self.dims();
52        if dims.is_empty() {
53            Ok(self.clone())
54        } else {
55            self.narrow(0, i, 1)?.reshape(&dims[1..])
56        }
57    }
58
59    pub fn index_select<D: Dim>(&self, indexes: impl Into<IntTensor>, dim: D) -> Result<Self> {   
60        let indexes: IntTensor = indexes.into();
61        let dim = dim.to_index(self.shape(), "index-select")?;
62        let indexes_len = indexes.shape().dims1()?;
63        let mut dims = self.dims().to_vec();
64        dims[dim] = indexes_len;
65        let meta = T::AutogradMeta::on_index_select_op(self, &indexes, dim);
66        let storage = match indexes {
67            IntTensor::I32(indexes) => self.storage_read()?.index_select(
68                self.layout(),
69                indexes.storage_read()?.deref(),
70                indexes.layout(),
71                dim,
72            )?,
73            IntTensor::U32(indexes) => self.storage_read()?.index_select(
74                self.layout(),
75                indexes.storage_read()?.deref(),
76                indexes.layout(),
77                dim,
78            )?,
79            IntTensor::U8(indexes) => self.storage_read()?.index_select(
80                self.layout(),
81                indexes.storage_read()?.deref(),
82                indexes.layout(),
83                dim,
84            )?,
85        };
86
87        Ok(Self::from_storage(storage, dims, meta))
88    }
89
90    /// Gather values across the target dimension.
91    ///
92    /// # Arguments
93    ///
94    /// * `self` - The input tensor.
95    /// * `indexes` - The indices of elements to gather, this should have same number of dimensions as `self`
96    ///   and indexes.dims()[d] <= self.dims()[d] for all dimensions d != dim
97    /// * `dim` - the target dimension.
98    ///
99    /// The resulting tensor has the same shape as `indexes` and use values from `self` indexed on
100    /// dimension `dim` by the values in `indexes`.
101    pub fn gather<D: Dim>(&self, indexes: impl Into<IntTensor>, dim: D) -> Result<Self> {
102        let indexes = indexes.into();
103        let dim = dim.to_index(self.shape(), "gather")?;
104        let self_dims = self.dims();
105        let indexes_dims = indexes.dims();
106        let mismatch = if indexes_dims.len() != self_dims.len() {
107            true
108        } else {
109            let mut mismatch = false;
110            for (i, (&d1, &d2)) in self_dims.iter().zip(indexes_dims.iter()).enumerate() {
111                if i != dim && d1 < d2 {
112                    mismatch = true;
113                    break;
114                }
115            }
116            mismatch
117        };
118        if mismatch {
119            Err(Error::ShapeMismatchBinaryOp {
120                op: "gather",
121                lhs: self.shape().clone(),
122                rhs: indexes.shape().clone(),
123            })?
124        }
125
126        let storage = match &indexes {
127            IntTensor::I32(idx) => self.storage_read()?.gather(self.layout(), idx.storage_read()?.deref(), idx.layout(), dim)?,
128            IntTensor::U32(idx) => self.storage_read()?.gather(self.layout(), idx.storage_read()?.deref(), idx.layout(), dim)?,
129            IntTensor::U8(idx) => self.storage_read()?.gather(self.layout(), idx.storage_read()?.deref(), idx.layout(), dim)?,
130        };
131
132        let meta = T::AutogradMeta::on_gather_op(self, &indexes, dim);
133        Ok(Self::from_storage(storage, indexes.shape(), meta))
134    }
135}
136
137impl<T: NumDType> Tensor<T> {
138    pub fn index_add<D: Dim>(&self, indexes: impl Into<IntTensor>, source: &Tensor<T>, dim: D) -> Result<Self> {
139        let indexes: IntTensor = indexes.into();
140        let dim = dim.to_index(self.shape(), "index-add")?;
141        
142        let source_dims = source.dims();
143        let self_dims = self.dims();
144        if source_dims.len() != self_dims.len() {
145             return Err(Error::ShapeMismatchBinaryOp { 
146                 op: "index-add", 
147                 lhs: self.shape().clone(), 
148                 rhs: source.shape().clone() 
149             }.into());
150        }
151
152        let indexes_len = indexes.shape().dims1()?;
153        for (i, (&d_self, &d_src)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
154            if i == dim {
155                if d_src != indexes_len {
156                    return Err(Error::ShapeMismatchBinaryOp { op: "index-add (dim mismatch)", lhs: self.shape().clone(), rhs: source.shape().clone() }.into());
157                }
158            } else if d_self != d_src {
159                return Err(Error::ShapeMismatchBinaryOp { op: "index-add", lhs: self.shape().clone(), rhs: source.shape().clone() }.into());
160            }
161        }
162
163        let storage = match &indexes {
164            IntTensor::I32(idx) => self.storage_read()?.index_add(
165                self.layout(),
166                idx.storage_read()?.deref(),
167                idx.layout(),
168                source.storage_read()?.deref(),
169                source.layout(),
170                dim,
171            )?,
172            IntTensor::U32(idx) => self.storage_read()?.index_add(
173                self.layout(),
174                idx.storage_read()?.deref(),
175                idx.layout(),
176                source.storage_read()?.deref(),
177                source.layout(),
178                dim,
179            )?,
180            IntTensor::U8(idx) => self.storage_read()?.index_add(
181                self.layout(),
182                idx.storage_read()?.deref(),
183                idx.layout(),
184                source.storage_read()?.deref(),
185                source.layout(),
186                dim,
187            )?,
188        };
189
190        let meta = T::AutogradMeta::on_index_add_op(self, &indexes, source, dim);
191        Ok(Self::from_storage(storage, self_dims.to_vec(), meta))
192    } 
193
194    pub fn scatter_add<D: Dim>(&self, indexes: impl Into<IntTensor>, source: &Self, dim: D) -> Result<Self> {
195        let indexes = indexes.into();
196        let dim = dim.to_index(self.shape(), "scatter-add")?;
197        self.scatter_checks(&indexes, source, dim)?;
198
199        let storage = match &indexes {
200            IntTensor::I32(idx) => self.storage_read()?.scatter_add(
201                self.layout(),
202                idx.storage_read()?.deref(),
203                idx.layout(),
204                source.storage_read()?.deref(),
205                source.layout(),
206                dim,
207            )?,
208            IntTensor::U32(idx) => self.storage_read()?.scatter_add(
209                self.layout(),
210                idx.storage_read()?.deref(),
211                idx.layout(),
212                source.storage_read()?.deref(),
213                source.layout(),
214                dim,
215            )?,
216            IntTensor::U8(idx) => self.storage_read()?.scatter_add(
217                self.layout(),
218                idx.storage_read()?.deref(),
219                idx.layout(),
220                source.storage_read()?.deref(),
221                source.layout(),
222                dim,
223            )?,
224        };
225
226        let meta = T::AutogradMeta::on_scatter_add_op(self, &indexes, source, dim);
227        Ok(Self::from_storage(storage, self.shape(), meta))
228    }
229
230    fn scatter_checks(&self, indexes: &IntTensor, source: &Self, dim: usize) -> Result<()> {
231        let source_dims = source.dims();
232        let self_dims = self.dims();
233        let mismatch = if source_dims.len() != self_dims.len() {
234            true
235        } else {
236            let mut mismatch = false;
237            for (i, (&d1, &d2)) in self_dims.iter().zip(source_dims.iter()).enumerate() {
238                if i != dim && d1 != d2 {
239                    mismatch = true;
240                    break;
241                }
242            }
243            mismatch
244        };
245        if mismatch {
246            Err(Error::ShapeMismatchBinaryOp {
247                op: "scatter (self, src)",
248                lhs: self.shape().clone(),
249                rhs: source.shape().clone(),
250            })?
251        }
252        if indexes.dims() != source.dims() {
253            Err(Error::ShapeMismatchBinaryOp {
254                op: "scatter (indexes, src)",
255                lhs: indexes.shape().clone(),
256                rhs: source.shape().clone(),
257            })?
258        }
259        Ok(())
260    }
261}
262
263impl<T: WithDType> Tensor<T> {
264    pub fn matrix_get(&self, row: usize, col: usize) -> Result<T> {
265        self.index((row, col))?.to_scalar()
266    }
267
268    pub fn matrix_set(&self, row: usize, col: usize, val: T) -> Result<()> {
269        self.index((row, col))?.set_scalar(val)
270    }
271
272    pub fn vector_get(&self, n: usize) -> Result<T> {
273        self.index(n)?.to_scalar()
274    }
275}
276
277#[derive(Debug, Clone, PartialEq, Eq)]
278pub enum Indexer {
279    Select(usize),
280    Slice(Slice),
281    Boolean(Tensor<bool>),
282}
283
284#[derive(Debug, Clone, PartialEq, Eq)]
285pub struct Slice {
286    pub start: usize, 
287    pub end: Option<isize>, 
288    pub step: usize
289}
290
291impl Slice {
292    pub fn new(start: usize, end: Option<isize>, step: usize) -> Self {
293        Self { start, end, step }
294    }
295
296    pub fn len(&self) -> usize {
297        self.clone().count()
298    }
299}
300
301impl Iterator for Slice {
302    type Item = usize;
303    fn next(&mut self) -> Option<Self::Item> {
304        match self.end {
305            Some(end) if end < 0 => {
306                let value = self.start;
307                self.start += self.step;
308                Some(value)
309            }
310            Some(end) => {
311                if self.start < end as usize {
312                    let value = self.start;
313                    self.start += self.step;
314                    Some(value)
315                } else {
316                    None
317                }
318            }
319            None => {
320                let value = self.start;
321                self.start += self.step;
322                Some(value)
323            }
324        }
325    }
326}
327
328impl Display for Slice {
329    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330        let step_part = match self.step {
331            1 => format!(""),
332            _ => format!(":{}", self.step),
333        };
334        match self.end {
335            Some(end) => write!(f, "{}:{}{}", self.start, end, step_part),
336            None => write!(f, "{}:{}", self.start, step_part),
337        }
338    }
339}
340
341impl From<usize> for Indexer {
342    fn from(index: usize) -> Self {
343        Indexer::Select(index)
344    }
345}
346
347impl From<Slice> for Indexer {
348    fn from(value: Slice) -> Self {
349        Indexer::Slice(value)
350    }
351}
352
353impl From<&Tensor<bool>> for Indexer {
354    fn from(value: &Tensor<bool>) -> Self {
355        Indexer::Boolean(value.clone())
356    }
357}
358
359impl From<Tensor<bool>> for Indexer {
360    fn from(value: Tensor<bool>) -> Self {
361        Indexer::Boolean(value)
362    }
363}
364
365impl From<std::ops::Range<usize>> for Indexer {
366    fn from(value: std::ops::Range<usize>) -> Self {
367        let range = Slice::new(value.start, Some(value.end as isize), 1);
368        range.into()
369    }
370}
371
372impl From<std::ops::RangeFrom<usize>> for Indexer {
373    fn from(value: std::ops::RangeFrom<usize>) -> Self {
374        let range = Slice::new(value.start, None, 1);
375        range.into()
376    }
377}
378
379impl From<std::ops::RangeFull> for Indexer {
380    fn from(_: std::ops::RangeFull) -> Self {
381        let range = Slice::new(0, None, 1);
382        range.into()
383    }
384}
385
386pub trait IndexOp<T, D: WithDType> {
387    fn index(&self, index: T) -> Result<Tensor<D>>;
388}
389
390impl<I: Into<Indexer>, D: WithDType> IndexOp<I, D> for Tensor<D> {
391    fn index(&self, index: I) -> Result<Tensor<D>> {
392        self.indexes(&[index.into()])
393    }
394}
395
396impl<I: Into<Indexer>, D: WithDType> IndexOp<(I,), D> for Tensor<D> {
397    fn index(&self, (index,): (I,)) -> Result<Tensor<D>> {
398        self.indexes(&[index.into()])
399    }
400}
401
402macro_rules! index_op_tuple {
403    ($($t:ident),+) => {
404        #[allow(non_snake_case)]
405        impl<$($t),*, D: WithDType> IndexOp<($($t,)*), D> for Tensor<D>
406        where
407            $($t: Into<Indexer>,)*
408        {
409            fn index(&self, ($($t,)*): ($($t,)*)) -> Result<Tensor<D>> {
410                self.indexes(&[$($t.into(),)*])
411            }
412        }
413    };
414}
415
416index_op_tuple!(I1, I2);
417index_op_tuple!(I1, I2, I3);
418index_op_tuple!(I1, I2, I3, I4);
419index_op_tuple!(I1, I2, I3, I4, I5);
420
421impl<I: Into<Indexer>, D: WithDType> IndexOp<Vec<I>, D> for Tensor<D> {
422    fn index(&self, index: Vec<I>) -> Result<Tensor<D>> {
423        let indexs = index.into_iter().map(|i| i.into()).collect::<Vec<Indexer>>();
424        self.indexes(&indexs)
425    }
426}
427
428#[macro_export]
429macro_rules! s {
430    // s!(start:end)
431    ($start:tt : $end:expr) => {
432        Slice::new($start as usize, Some($end as isize), 1)
433    };
434    // s!(start:end:step)
435    ($start:tt : $end:tt : $step:expr) => {
436        Slice::new($start as usize, Some($end as isize), $step as usize)
437    };
438    // s!(start:)
439    ($start:tt :) => {
440        Slice::new($start as usize, None, 1)
441    };
442    // s!(start::step)
443    ($start:tt :: $step:expr) => {
444        Slice::new($start as usize, None, $step as usize)
445    };
446    // s!(:$end)
447    (: $end:tt) => {
448        Slice::new(0, Some($end as isize), 1)
449    };
450    // s!(:$end:$step)
451    (: $end:tt : $step:expr) => {
452        Slice::new(0, Some($end as isize), $step as usize)
453    };
454    // s!(::$step)
455    (:: $step:expr) => {
456        Slice::new(0, None, $step as usize)
457    };
458    // s!(:)
459    (:) => {
460        Slice::new(0, None, 1)
461    };
462}
463
464#[cfg(test)]
465#[allow(unused)]
466mod test {
467    use crate::DType;
468    use super::*;
469
470    #[test]
471    fn test_index_select_basic() {
472        // [[ 0,  1,  2,  3],
473        //  [ 4,  5,  6,  7],
474        //  [ 8,  9, 10, 11]]
475        let arr = Tensor::arange(0, 12).unwrap().reshape((3, 4)).unwrap();
476
477        let indices = Tensor::new(&[0, 2]).unwrap();
478        let selected = arr.index_select(indices, 0).unwrap();
479        
480        assert_eq!(selected.shape().dims(), &[2, 4]);
481        let data = selected.to_vec().unwrap();
482        assert_eq!(data, vec![0, 1, 2, 3, 8, 9, 10, 11]);
483
484        let indices_col = Tensor::new(&[1]).unwrap();
485        let selected_col = arr.index_select(indices_col, 1).unwrap();
486
487        assert_eq!(selected_col.shape().dims(), &[3, 1]);
488        let data_col = selected_col.to_vec().unwrap();
489        assert_eq!(data_col, vec![1, 5, 9]);
490    }
491
492    #[test]
493    fn test_boolean_index() {
494        // [[ 0,  1,  2,  3],
495        //  [ 4,  5,  6,  7],
496        //  [ 8,  9, 10, 11]]
497        let arr = Tensor::arange(0, 12).unwrap().reshape((3, 4)).unwrap();
498        let index = Tensor::new([true, false, true]).unwrap();
499
500        let selected = arr.index(index).unwrap();
501        println!("{}", selected);
502    }
503
504    #[test]
505    fn test_index_select_1d() {
506        let scores = Tensor::<f64>::arange(0.0, 100.0).unwrap();
507        let indices = Tensor::new(&[25, 34, 12, 90]).unwrap();
508
509        let selected = scores.index_select(indices, 0).unwrap();
510        println!("{}", selected);
511    }
512
513    #[test]
514    fn test_index_select_duplicates_and_reorder() {
515        let arr = Tensor::arange(0, 5).unwrap(); // [0, 1, 2, 3, 4]
516
517        let indices = Tensor::new(&[4, 0, 0, 1]).unwrap();
518        let selected = arr.index_select(indices, 0).unwrap();
519
520        assert_eq!(selected.shape().dims(), &[4]);
521        let data = selected.to_vec().unwrap();
522        assert_eq!(data, vec![4, 0, 0, 1]);
523    }
524
525    #[test]
526    fn test_index_select_out_of_bounds() {
527        let arr = Tensor::arange(0, 10).unwrap();
528        let indices = Tensor::new(&[0, 10]).unwrap(); 
529        
530        let result = arr.index_select(indices, 0);
531        assert!(result.is_err());
532    }
533
534    #[test]
535    fn test_index_add_basic() {
536        let dst = Tensor::<i32>::zeros((3, 3)).unwrap();
537        let src = Tensor::<i32>::ones((2, 3)).unwrap();
538        let indices = Tensor::new(&[0, 2]).unwrap();
539        
540        let result = dst.index_add(indices, &src, 0).unwrap();
541        
542        // [[1, 1, 1],
543        //  [0, 0, 0],
544        //  [1, 1, 1]]
545        let data = result.to_vec().unwrap();
546        assert_eq!(data, vec![
547            1, 1, 1, 
548            0, 0, 0, 
549            1, 1, 1
550        ]);
551    }
552
553    #[test]
554    fn test_index_add_accumulate() {
555        let dst = Tensor::<i32>::zeros((5,)).unwrap(); // [0, 0, 0, 0, 0]        
556        let src = Tensor::new(&[10, 20, 30]).unwrap();
557        let indices = Tensor::new(&[1, 1, 3]).unwrap();
558        
559        let result = dst.index_add(indices, &src, 0).unwrap();
560        
561        // dst[0] = 0
562        // dst[1] = 0 + 10 + 20 = 30
563        // dst[2] = 0
564        // dst[3] = 0 + 30 = 30
565        // dst[4] = 0
566        let data = result.to_vec().unwrap();
567        assert_eq!(data, vec![0, 30, 0, 30, 0]);
568    }
569
570    #[test]
571    fn test_index_add_dim_mismatch() {
572        let dst = Tensor::<i32>::zeros((3, 3)).unwrap();        
573        let src = Tensor::<i32>::ones((2, 3)).unwrap();
574        let indices = Tensor::new(&[0, 1, 2]).unwrap();
575        
576        let result = dst.index_add(indices, &src, 0);
577        assert!(result.is_err());
578    }
579
580    #[test]
581    fn test_index_add_inner_dim() {
582        // Shape: 2x3
583        // [[0, 0, 0],
584        //  [0, 0, 0]]
585        let dst = Tensor::<i32>::zeros((2, 3)).unwrap();
586        // Source: 2x1
587        let src = Tensor::new(&[
588            [5],
589            [5]
590        ]).unwrap(); // Shape 2x1
591        let indices = Tensor::new(&[1]).unwrap();        
592        let result = dst.index_add(indices, &src, 1).unwrap();
593        
594        // 预期:
595        // [[0, 5, 0],
596        //  [0, 5, 0]]
597        let data = result.to_vec().unwrap();
598        assert_eq!(data, vec![0, 5, 0, 0, 5, 0]);
599    }
600
601
602    #[test]
603    fn test_gather_dim_1() {
604        // Source: 2x2
605        // [[1, 2],
606        //  [3, 4]]
607        let src = Tensor::new(&[
608            [1, 2],
609            [3, 4]
610        ]).unwrap();
611        
612        // 我们想从第 1 维 (列) 取值。
613        // indexes 形状必须与输出一致。
614        // [[0, 0],  -> 取 src[0,0], src[0,0]
615        //  [1, 0]]  -> 取 src[1,1], src[1,0]
616        let indices = Tensor::new(&[
617            [0, 0],
618            [1, 0]
619        ]).unwrap();
620
621        let result = src.gather(&indices, 1).unwrap();
622
623        // 预期结果:
624        // [[1, 1],
625        //  [4, 3]]
626        let data = result.to_vec().unwrap();
627        assert_eq!(data, vec![1, 1, 4, 3]);
628    }
629
630    #[test]
631    fn test_gather_dim_0() {
632        // Source: 3x2
633        // [[10, 20],
634        //  [30, 40],
635        //  [50, 60]]
636        let src = Tensor::new(&[
637            [10, 20],
638            [30, 40],
639            [50, 60]
640        ]).unwrap();
641
642        // 沿维度 0 (行) gather
643        // [[1, 2], -> src[1,0]=30, src[2,1]=60
644        //  [0, 1]] -> src[0,0]=10, src[1,1]=40
645        let indices = Tensor::new(&[
646            [1, 2],
647            [0, 1]
648        ]).unwrap();
649
650        let result = src.gather(&indices, 0).unwrap();
651
652        // 预期结果:
653        // [[30, 60],
654        //  [10, 40]]
655        let data = result.to_vec().unwrap();
656        assert_eq!(data, vec![30, 60, 10, 40]);
657    }
658
659    #[test]
660    fn test_gather_3d() {
661        // Shape: 2x2x2
662        // Block 0: [[0, 1], [2, 3]]
663        // Block 1: [[4, 5], [6, 7]]
664        let src = Tensor::new(&[0, 1, 2, 3, 4, 5, 6, 7]).unwrap().reshape((2, 2, 2)).unwrap();
665
666        // Gather dim 1 (中间维度)
667        // 我们保持 dim 0 和 dim 2 不变,只改变 dim 1 的索引
668        // index shape: 2x1x2 (我们在 dim 1 上做降维提取)
669        let indices = Tensor::<u32>::zeros((2, 1, 2)).unwrap(); // 全是 0
670
671        // 逻辑:
672        // out[0, 0, 0] = src[0, idx[0,0,0], 0] = src[0,0,0] = 0
673        // out[0, 0, 1] = src[0, idx[0,0,1], 1] = src[0,0,1] = 1
674        // out[1, 0, 0] = src[1, idx[1,0,0], 0] = src[1,0,0] = 4
675        // out[1, 0, 1] = src[1, idx[1,0,1], 1] = src[1,0,1] = 5
676        
677        let result = src.gather(&indices, 1).unwrap();
678        
679        assert_eq!(result.dims(), &[2, 1, 2]);
680        let data = result.to_vec().unwrap();
681        assert_eq!(data, vec![0, 1, 4, 5]);
682    }
683
684
685    #[test]
686    fn test_scatter_add_1d_accumulate() {
687        // 类似于直方图统计
688        let dst = Tensor::<i32>::zeros((5,)).unwrap();
689        
690        // Source 和 Indices 形状一致
691        let src = Tensor::new(&[1, 1, 1, 1]).unwrap();
692        let indices = Tensor::new(&[0, 2, 0, 4]).unwrap();
693
694        // scatter_add(dim=0)
695        // dst[0] += 1
696        // dst[2] += 1
697        // dst[0] += 1 (累加)
698        // dst[4] += 1
699        let result = dst.scatter_add(indices, &src, 0).unwrap();
700
701        // 预期: [2, 0, 1, 0, 1]
702        let data = result.to_vec().unwrap();
703        assert_eq!(data, vec![2, 0, 1, 0, 1]);
704    }
705
706    #[test]
707    fn test_scatter_add_2d_dim1() {
708        // Dst: 2x3 zeros
709        let dst = Tensor::<i32>::zeros((2, 3)).unwrap();
710        
711        // Src: 2x2
712        // [[10, 20],
713        //  [30, 40]]
714        let src = Tensor::new(&[
715            [10, 20],
716            [30, 40]
717        ]).unwrap();
718
719        // Indices: 2x2
720        // [[0, 2], -> 将 10 加到 dst[0,0], 20 加到 dst[0,2]
721        //  [1, 1]] -> 将 30 加到 dst[1,1], 40 加到 dst[1,1] (累加)
722        let indices = Tensor::new(&[
723            [0, 2],
724            [1, 1]
725        ]).unwrap();
726
727        let result = dst.scatter_add(indices, &src, 1).unwrap();
728
729        // 预期:
730        // Row 0: [10, 0, 20]
731        // Row 1: [0, 70, 0] (30+40=70)
732        let data = result.to_vec().unwrap();
733        assert_eq!(data, vec![10, 0, 20, 0, 70, 0]);
734    }
735
736    #[test]
737    fn test_scatter_add_3d() {
738        // Dst: 2x2x2 zeros
739        let dst = Tensor::<i32>::zeros((2, 2, 2)).unwrap();
740        
741        // Src: 2x1x2 (dim 1 is smaller)
742        let src = Tensor::ones((2, 1, 2)).unwrap(); // All ones
743        
744        // Indices: 2x1x2
745        // Block 0: [[1, 0]] -> dst[0, 1, 0] += 1, dst[0, 0, 1] += 1
746        // Block 1: [[0, 0]] -> dst[1, 0, 0] += 1, dst[1, 0, 1] += 1
747        let indices = Tensor::new(&[1, 0, 0, 0]).unwrap().reshape((2, 1, 2)).unwrap();
748
749        let result = dst.scatter_add(indices, &src, 1).unwrap();
750
751        // Check specifics:
752        // dst[0, 1, 0] should be 1
753        // dst[0, 0, 1] should be 1
754        // dst[1, 0, 0] should be 1
755        // dst[1, 0, 1] should be 1
756        // Others 0
757        let res_vec = result.to_vec().unwrap();
758        // Flattened index check:
759        // 2x2x2 -> stride [4, 2, 1]
760        // [0,1,0] -> 2 -> val 1
761        // [0,0,1] -> 1 -> val 1
762        // [1,0,0] -> 4 -> val 1
763        // [1,0,1] -> 5 -> val 1
764        
765        // Vec: [0, 1, 1, 0, 1, 1, 0, 0]
766        assert_eq!(res_vec, vec![0, 1, 1, 0, 1, 1, 0, 0]);
767    }
768
769    #[test]
770    fn test_scatter_add_shape_mismatch() {
771        let dst = Tensor::<i32>::zeros((2, 2)).unwrap();
772        let src = Tensor::<i32>::ones((2, 2)).unwrap();
773        // Indices shape mismatch with Src
774        let indices = Tensor::new(&[0]).unwrap(); 
775
776        let result = dst.scatter_add(indices, &src, 0);
777        assert!(result.is_err());
778    }
779
780    #[test]
781    fn test_index_scalar_dim_reduction() {
782        let arr = Tensor::arange(0, 125).unwrap().reshape((5, 5, 5)).unwrap();
783        let sub = arr.index(1).unwrap();
784        assert_eq!(sub.shape().dims(), &[5, 5]);
785
786        let sub = arr.index((2, 3)).unwrap();
787        assert_eq!(sub.shape().dims(), &[5]); 
788    }
789
790    #[test]
791    fn test_index_range_basic() {
792        let arr = Tensor::arange(0, 125).unwrap().reshape((5, 5, 5)).unwrap();
793
794        let sub = arr.index(s!(1:3)).unwrap();
795        assert_eq!(sub.shape().dims(), &[2, 5, 5]);
796
797        let sub = arr.index((s!(1:3), s!(3:4), 1)).unwrap();
798        assert_eq!(sub.shape().dims(), &[2, 1]);
799    }
800
801    #[test]
802    fn test_index_full_and_mixed() {
803        let arr = Tensor::<i32>::zeros((5, 5, 5)).unwrap();
804
805        let sub = arr.index((s!(1:3), .., 1..2)).unwrap();
806        assert_eq!(sub.shape().dims(), &[2, 5, 1]);
807
808        let sub = arr.index((2, .., s!(0:2))).unwrap();
809        assert_eq!(sub.shape().dims(), &[5, 2]);
810
811        let sub = arr.index((s!(0:2), s!(2:5), s!(1:3))).unwrap();
812        assert_eq!(sub.shape().dims(), &[2, 3, 2]);
813    }
814
815    #[test]
816    fn test_index_out_of_bounds() {
817        let arr = Tensor::<i32>::zeros((5, 5, 5)).unwrap();
818        let result = arr.index(10);
819        assert!(result.is_err());
820
821        let result = arr.index(s!(3:10));
822        assert!(result.is_err());
823    }
824    
825    #[test]
826    fn test_index_scalar_and_values() {
827        let arr = Tensor::arange(0, 125).unwrap().reshape((5, 5, 5)).unwrap();
828
829        let sub = arr.index(1).unwrap();
830        let expected = Tensor::arange(25, 50).unwrap().reshape((5, 5)).unwrap();
831        assert!(sub.allclose(&expected, 0.0, 0.0).unwrap());
832    }
833
834    #[test]
835    fn test_index_range_values() {
836        let arr = Tensor::arange(0, 125).unwrap().reshape((5, 5, 5)).unwrap();
837
838        let sub = arr.index(s!(1:3)).unwrap();
839        let expected = Tensor::arange(25, 75).unwrap().reshape((2, 5, 5)).unwrap();
840        assert!(sub.allclose(&expected, 0.0, 0.0).unwrap());
841    }
842
843    #[test]
844    fn test_index_mixed_values() {
845        let arr = Tensor::arange(0, 125).unwrap().reshape((5, 5, 5)).unwrap();
846        let sub = arr.index((2, 3)).unwrap();
847        let expected = Tensor::arange(65, 70).unwrap();
848        assert!(sub.allclose(&expected, 0.0, 0.0).unwrap());
849
850        let sub = arr.index((s!(1:3), s!(3:5), 2)).unwrap();
851        let mut vals = Vec::new();
852        for i in 1..3 {
853            for j in 3..5 {
854                vals.push(i * 25 + j * 5 + 2);
855            }
856        }
857        let expected = Tensor::from_vec(vals, (2, 2)).unwrap();
858        assert!(sub.allclose(&expected, 0.0, 0.0).unwrap());
859    }
860
861    #[test]
862    fn test_index_with_full_dim() {
863        let arr = Tensor::arange(0, 125).unwrap().reshape((5, 5, 5)).unwrap();
864        let sub = arr.index((s!(1:3), .., 1..2)).unwrap();
865
866        let expected = arr.index((s!(1:3), s!(0:5), s!(1:2))).unwrap();
867        assert!(sub.allclose(&expected, 0.0, 0.0).unwrap());
868    }
869
870    #[test]
871    fn test_macro() {
872        let t = (0..12usize);
873        let t = (2usize..);
874        assert_eq!(s!(1:10), Slice {start:1, end: Some(10), step:1});
875
876        assert!(
877            s!(1:20).zip((1..20))
878                .all(|(a, b)| a == b)
879        );
880    
881        assert!(
882            s!(1:13:3).zip((1..13).step_by(3))
883                .all(|(a, b)| a == b)
884        );
885    
886        assert!(
887            s!(1:).zip((1..).take(100))
888                .all(|(a, b)| a == b)
889        );
890
891        assert!(
892            s!(1::2).zip((1..).step_by(2).take(100))
893                .all(|(a, b)| a == b)
894        );
895
896        assert!(
897            s!(:20).zip((0..20usize))
898                .all(|(a, b)| a == b)
899        );
900
901        assert!(
902            s!(:20:5).zip((0..20usize).step_by(5))
903                .all(|(a, b)| a == b)
904        );
905
906        assert!(
907            s!(::2).zip((0..).step_by(2).take(100))
908                .all(|(a, b)| a == b)
909        );
910
911        assert!(
912            s!(:).zip((0..).take(100))
913                .all(|(a, b)| a == b)
914        );
915    }
916}
917