candle_core/
test_utils.rs

1use crate::{Result, Tensor};
2
3#[macro_export]
4macro_rules! test_device {
5    // TODO: Switch to generating the two last arguments automatically once concat_idents is
6    // stable. https://github.com/rust-lang/rust/issues/29599
7    ($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
8        #[test]
9        fn $test_cpu() -> Result<()> {
10            $fn_name(&Device::Cpu)
11        }
12
13        #[cfg(feature = "cuda")]
14        #[test]
15        fn $test_cuda() -> Result<()> {
16            $fn_name(&Device::new_cuda(0)?)
17        }
18
19        #[cfg(feature = "metal")]
20        #[test]
21        fn $test_metal() -> Result<()> {
22            $fn_name(&Device::new_metal(0)?)
23        }
24    };
25}
26
27pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
28    assert_eq!(t1.shape(), t2.shape());
29    // Default U8 may not be large enough to hold the sum (`t.sum_all` defaults to the dtype of `t`)
30    let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
31    let all_equal = eq_tensor.sum_all()?;
32    assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
33    Ok(())
34}
35
36pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
37    let b = 10f32.powi(digits);
38    let t = t.to_vec0::<f32>()?;
39    Ok(f32::round(t * b) / b)
40}
41
42pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
43    let b = 10f32.powi(digits);
44    let t = t.to_vec1::<f32>()?;
45    let t = t.iter().map(|t| f32::round(t * b) / b).collect();
46    Ok(t)
47}
48
49pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
50    let b = 10f32.powi(digits);
51    let t = t.to_vec2::<f32>()?;
52    let t = t
53        .iter()
54        .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
55        .collect();
56    Ok(t)
57}
58
59pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
60    let b = 10f32.powi(digits);
61    let t = t.to_vec3::<f32>()?;
62    let t = t
63        .iter()
64        .map(|t| {
65            t.iter()
66                .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
67                .collect()
68        })
69        .collect();
70    Ok(t)
71}