1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#[burn_tensor_testgen::testgen(full)]
mod tests {
use super::*;
use burn_tensor::{backend::Backend, Bool, Int, Shape, Tensor, TensorData};
#[test]
fn test_data_full() {
let tensor = TensorData::full([2, 3], 2.0);
tensor.assert_eq(&TensorData::from([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]), false);
}
#[test]
fn test_tensor_full() {
let device = Default::default();
// Test full with f32
let tensor = TestTensor::<2>::full([2, 3], 2.1, &device);
tensor
.into_data()
.assert_eq(&TensorData::from([[2.1, 2.1, 2.1], [2.1, 2.1, 2.1]]), false);
// Test full with Int
let int_tensor = TestTensorInt::<2>::full([2, 2], 2, &device);
int_tensor
.into_data()
.assert_eq(&TensorData::from([[2, 2], [2, 2]]), false);
// TODO enable after adding support for bool
// // Test full with bool
// let bool_tensor = TestTensorBool::<2>::full([2, 2], true, &device);
// let data_expected = TensorData::from([[true, true], [true, true]]);
// assert_eq!(data_expected, bool_tensor.into_data());
// let bool_tensor = TestTensorBool::<2>::full([2, 2], false, &device);
// let data_expected = TensorData::from([[false, false], [false, false]]);
// assert_eq!(data_expected, bool_tensor.into_data());
}
}