candle_einops/
backend.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
use candle_core::{Shape, Tensor};

use crate::Operation;

pub trait Backend {
    type Output;
    fn shape(self) -> Vec<usize>;
    fn reshape(self, shape: &[usize]) -> Self::Output;
    fn transpose(self, axes: &[usize]) -> Self::Output;
    fn reduce_axes(self, axes_operations: &mut [(usize, Operation)]) -> Self::Output;
    fn add_axes(self, naxes: usize, pos2len: &[(usize, usize)]) -> Self::Output;
}

impl<T: AsRef<Tensor>> Backend for T {
    type Output = Tensor;

    fn shape(self) -> Vec<usize> {
        self.as_ref().dims().to_vec()
    }

    fn reshape(self, shape: &[usize]) -> Self::Output {
        let shape = Shape::from_dims(shape);
        self.as_ref().reshape(shape).unwrap()
    }

    fn transpose(self, axes: &[usize]) -> Self::Output {
        self.as_ref().permute(axes).unwrap()
    }

    fn reduce_axes(self, axes_operations: &mut [(usize, Operation)]) -> Self::Output {
        let mut output = self.as_ref().clone();

        axes_operations.sort_by_key(|(axis, _)| *axis);

        for (axis, operation) in axes_operations.iter().rev() {
            output = match operation {
                Operation::Min => output.min(*axis).unwrap(),
                Operation::Max => output.max(*axis).unwrap(),
                Operation::Sum => output.sum(&[*axis][..]).unwrap(),
                Operation::Mean => output.mean(&[*axis][..]).unwrap(),
                // TODO: implement prod
            };
        }

        output
    }

    fn add_axes(self, naxes: usize, pos2len: &[(usize, usize)]) -> Self::Output {
        let mut output = self.as_ref().clone();

        let mut repeats = vec![1; naxes];

        for &(axis_pos, axis_len) in pos2len {
            output = output.unsqueeze(axis_pos).unwrap();
            repeats[axis_pos] = axis_len;
        }

        let shape = Shape::from_dims(&repeats[..]);
        output.repeat(shape).unwrap()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use candle_core::{Device, Result};

    #[test]
    fn reduce() -> Result<()> {
        let tests = vec![
            (
                Tensor::new(
                    &[
                        0.66984287f32,
                        0.52894678,
                        0.85415958,
                        0.17721198,
                        0.81804799,
                        0.80991797,
                        0.64868822,
                        0.96697902,
                        0.08047191,
                        0.46024353,
                        0.21955009,
                        0.31731976,
                        0.05446258,
                        0.39454557,
                        0.40949016,
                        0.21366165,
                        0.2357463,
                        0.93699481,
                        0.64522596,
                        0.4383618,
                        0.54871827,
                        0.87823442,
                        0.01261184,
                        0.90636503,
                    ],
                    &Device::Cpu,
                )?
                .reshape(&[4, 2, 3]),
                [(0, Operation::Min)],
                Tensor::new(
                    &[
                        [0.05446258f32, 0.39454557, 0.08047191],
                        [0.17721198, 0.01261184, 0.31731976],
                    ],
                    &Device::Cpu,
                )?,
            ),
            (
                Tensor::new(
                    &[
                        0.66984287f32,
                        0.52894678,
                        0.85415958,
                        0.17721198,
                        0.81804799,
                        0.80991797,
                        0.64868822,
                        0.96697902,
                        0.08047191,
                        0.46024353,
                        0.21955009,
                        0.31731976,
                        0.05446258,
                        0.39454557,
                        0.40949016,
                        0.21366165,
                        0.2357463,
                        0.93699481,
                        0.64522596,
                        0.4383618,
                        0.54871827,
                        0.87823442,
                        0.01261184,
                        0.90636503,
                    ],
                    &Device::Cpu,
                )?
                .reshape(&[4, 2, 3]),
                [(0, Operation::Max)],
                Tensor::new(
                    &[
                        [0.6698429f32, 0.966979, 0.8541596],
                        [0.87823445, 0.818048, 0.9369948],
                    ],
                    &Device::Cpu,
                )?,
            ),
        ];

        for (tensor, mut axes_operations, expected) in tests {
            assert_eq!(
                tensor.reduce_axes(&mut axes_operations).to_vec2::<f32>()?,
                expected.to_vec2::<f32>()?
            );
        }

        Ok(())
    }

    #[test]
    fn candle_transpose() -> Result<()> {
        let tests = vec![(
            Tensor::arange(0f32, (2 * 3 * 4) as f32, &Device::Cpu)?.reshape(&[2, 3, 4]),
            &[2, 0, 1],
            Tensor::new(
                &[
                    [[0.0f32, 4.0, 8.0], [12.0, 16.0, 20.0]],
                    [[1.0, 5.0, 9.0], [13.0, 17.0, 21.0]],
                    [[2.0, 6.0, 10.0], [14.0, 18.0, 22.0]],
                    [[3.0, 7.0, 11.0], [15.0, 19.0, 23.0]],
                ],
                &Device::Cpu,
            )?,
        )];

        for (tensor, axes, expected) in tests {
            assert_eq!(
                Backend::transpose(&tensor, axes).to_vec3::<f32>()?,
                expected.to_vec3::<f32>()?
            );
        }

        Ok(())
    }

    #[test]
    fn tch_add_axes() -> Result<()> {
        let tests = vec![(
            Tensor::arange(0u8, 1 * 2 * 3, &Device::Cpu)?.reshape(&[1, 2, 3]),
            5,
            &[(0, 5), (3, 3)],
            Tensor::new(
                vec![
                    0u8, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5, 0, 1, 2, 0, 1, 2, 0, 1,
                    2, 3, 4, 5, 3, 4, 5, 3, 4, 5, 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 3,
                    4, 5, 0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5, 0, 1, 2, 0, 1, 2,
                    0, 1, 2, 3, 4, 5, 3, 4, 5, 3, 4, 5,
                ],
                &Device::Cpu,
            )?
            .reshape(&[5, 1, 2, 3, 3]),
        )];

        for (tensor, naxes, pos2len, expected) in tests {
            assert_eq!(
                tensor
                    .add_axes(naxes, pos2len)
                    .flatten_all()?
                    .to_vec1::<u8>()?,
                expected.flatten_all()?.to_vec1::<u8>()?
            );
        }

        Ok(())
    }
}