candle_core/
test_utils.rs1use crate::{Result, Tensor};
2
3#[macro_export]
4macro_rules! test_device {
5 ($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
8 #[test]
9 fn $test_cpu() -> Result<()> {
10 $fn_name(&Device::Cpu)
11 }
12
13 #[cfg(feature = "cuda")]
14 #[test]
15 fn $test_cuda() -> Result<()> {
16 $fn_name(&Device::new_cuda(0)?)
17 }
18
19 #[cfg(feature = "metal")]
20 #[test]
21 fn $test_metal() -> Result<()> {
22 $fn_name(&Device::new_metal(0)?)
23 }
24 };
25}
26
27pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
28 let b = 10f32.powi(digits);
29 let t = t.to_vec0::<f32>()?;
30 Ok(f32::round(t * b) / b)
31}
32
33pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
34 let b = 10f32.powi(digits);
35 let t = t.to_vec1::<f32>()?;
36 let t = t.iter().map(|t| f32::round(t * b) / b).collect();
37 Ok(t)
38}
39
40pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
41 let b = 10f32.powi(digits);
42 let t = t.to_vec2::<f32>()?;
43 let t = t
44 .iter()
45 .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
46 .collect();
47 Ok(t)
48}
49
50pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
51 let b = 10f32.powi(digits);
52 let t = t.to_vec3::<f32>()?;
53 let t = t
54 .iter()
55 .map(|t| {
56 t.iter()
57 .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
58 .collect()
59 })
60 .collect();
61 Ok(t)
62}