1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use std::ops::{Index, IndexMut};
use tensor::Tensor;
use traits::TensorTrait;

// Vector indexing
impl<'b, T: TensorTrait> Index<&'b [usize]> for Tensor<T> {
    type Output = T;
    fn index<'a>(&'a self, ii: &'b [usize]) -> &'a T {
        assert!(self.canonical);
        let index = self.ravel_index(ii);
        &self.data[index]
    }
}

impl<'b, T: Copy> IndexMut<&'b [usize]> for Tensor<T> {
    fn index_mut<'a>(&'a mut self, ii: &'b [usize]) -> &'a mut T {
        assert!(self.canonical);
        let index = self.ravel_index(ii);
        &mut self.mem_slice_mut()[index]
    }
}

impl<'b, T: Copy> Index<&'b Vec<usize>> for Tensor<T> {
    type Output = T;
    fn index<'a>(&'a self, ii: &'b Vec<usize>) -> &'a T {
        assert!(self.canonical);
        let index = self.ravel_index(&ii[..]);
        &self.data[index]
    }
}

impl<'b, T: TensorTrait> IndexMut<&'b Vec<usize>> for Tensor<T> {
    fn index_mut<'a>(&'a mut self, ii: &'b Vec<usize>) -> &'a mut T {
        assert!(self.canonical);
        //self.canonize_inplace();
        let index = self.ravel_index(&ii[..]);
        &mut self.mem_slice_mut()[index]
    }
}


/*
// Flattened indexing (this will index it as one-dimensional)
impl<T: TensorTrait> Index<usize> for Tensor<T> {
    type Output = T;
    fn index<'a>(&'a self, _index: usize) -> &'a T {
        &self.data[self.mem_offset + _index]
    }
}

impl<T: TensorTrait> IndexMut<usize> for Tensor<T> {
    fn index_mut<'a>(&'a mut self, _index: usize) -> &'a mut T {
        //self.harden_inplace();
        let offset = self.mem_offset;
        &mut self.slice_mut()[offset + _index]
    }
}
*/

// 1-D indexing
impl<T: TensorTrait> Index<(usize,)> for Tensor<T> {
    type Output = T;
    fn index<'a>(&'a self, _index: (usize,)) -> &'a T {
        assert!(self.ndim() == 1);
        &self.data[(self.mem_offset as isize + _index.0 as isize * self.strides[0]) as usize]
    }
}

impl<T: TensorTrait> IndexMut<(usize,)> for Tensor<T> {
    fn index_mut<'a>(&'a mut self, _index: (usize,)) -> &'a mut T {
        assert!(self.ndim() == 1);
        let offset = self.mem_offset as isize;
        let s0 = self.strides[0];
        &mut self.mem_slice_mut()[(offset + _index.0 as isize * s0) as usize]
    }
}

// 2-D indexing
impl<T: TensorTrait> Index<(usize, usize)> for Tensor<T> {
    type Output = T;
    fn index<'a>(&'a self, _index: (usize, usize)) -> &'a T {
        assert!(self.ndim() == 2);
        &self.data[self.mem_offset + (_index.0 as isize * self.strides[0] +
                                      _index.1 as isize * self.strides[1]) as usize]
    }
}
impl<T: TensorTrait> IndexMut<(usize, usize)> for Tensor<T> {
    fn index_mut<'a>(&'a mut self, _index: (usize, usize)) -> &'a mut T {
        assert!(self.ndim() == 2);
        let i = self.mem_offset + (_index.0 as isize * self.strides[0] +
                                   _index.1 as isize * self.strides[1]) as usize;
        &mut self.mem_slice_mut()[i]
    }
}

// 3-D indexing
impl<T: TensorTrait> Index<(usize, usize, usize)> for Tensor<T> {
    type Output = T;
    fn index<'a>(&'a self, _index: (usize, usize, usize)) -> &'a T {
        assert!(self.ndim() == 3);
        &self.data[(self.mem_offset as isize +
                    _index.0 as isize * self.strides[0] +
                    _index.1 as isize * self.strides[1] +
                    _index.2 as isize * self.strides[2]) as usize]
    }
}

impl<T: TensorTrait> IndexMut<(usize, usize, usize)> for Tensor<T> {
    fn index_mut<'a>(&'a mut self, _index: (usize, usize, usize)) -> &'a mut T {
        assert!(self.ndim() == 3);
        let offset = self.mem_offset as isize;
        let (s0, s1, s2) = (self.strides[1], self.strides[2], self.strides[3]);
        &mut self.mem_slice_mut()[(offset +
                               _index.0 as isize * s0 +
                               _index.1 as isize * s1 +
                               _index.2 as isize * s2) as usize]
    }
}