candle_core/
test_utils.rs

1use crate::{Result, Tensor};
2
3#[macro_export]
4macro_rules! test_device {
5    // TODO: Switch to generating the two last arguments automatically once concat_idents is
6    // stable. https://github.com/rust-lang/rust/issues/29599
7    ($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
8        #[test]
9        fn $test_cpu() -> Result<()> {
10            $fn_name(&Device::Cpu)
11        }
12
13        #[cfg(feature = "cuda")]
14        #[test]
15        fn $test_cuda() -> Result<()> {
16            $fn_name(&Device::new_cuda(0)?)
17        }
18
19        #[cfg(feature = "metal")]
20        #[test]
21        fn $test_metal() -> Result<()> {
22            $fn_name(&Device::new_metal(0)?)
23        }
24    };
25}
26
27pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
28    let b = 10f32.powi(digits);
29    let t = t.to_vec0::<f32>()?;
30    Ok(f32::round(t * b) / b)
31}
32
33pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
34    let b = 10f32.powi(digits);
35    let t = t.to_vec1::<f32>()?;
36    let t = t.iter().map(|t| f32::round(t * b) / b).collect();
37    Ok(t)
38}
39
40pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
41    let b = 10f32.powi(digits);
42    let t = t.to_vec2::<f32>()?;
43    let t = t
44        .iter()
45        .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
46        .collect();
47    Ok(t)
48}
49
50pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
51    let b = 10f32.powi(digits);
52    let t = t.to_vec3::<f32>()?;
53    let t = t
54        .iter()
55        .map(|t| {
56            t.iter()
57                .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
58                .collect()
59        })
60        .collect();
61    Ok(t)
62}