candle_core/
test_utils.rs1use crate::{Result, Tensor};
2
3#[macro_export]
4macro_rules! test_device {
5 ($fn_name: ident, $test_cpu: ident, $test_cuda: ident, $test_metal: ident) => {
8 #[test]
9 fn $test_cpu() -> Result<()> {
10 $fn_name(&Device::Cpu)
11 }
12
13 #[cfg(feature = "cuda")]
14 #[test]
15 fn $test_cuda() -> Result<()> {
16 $fn_name(&Device::new_cuda(0)?)
17 }
18
19 #[cfg(feature = "metal")]
20 #[test]
21 fn $test_metal() -> Result<()> {
22 $fn_name(&Device::new_metal(0)?)
23 }
24 };
25}
26
27pub fn assert_tensor_eq(t1: &Tensor, t2: &Tensor) -> Result<()> {
28 assert_eq!(t1.shape(), t2.shape());
29 let eq_tensor = t1.eq(t2)?.to_dtype(crate::DType::U32)?;
31 let all_equal = eq_tensor.sum_all()?;
32 assert_eq!(all_equal.to_scalar::<u32>()?, eq_tensor.elem_count() as u32);
33 Ok(())
34}
35
36pub fn to_vec0_round(t: &Tensor, digits: i32) -> Result<f32> {
37 let b = 10f32.powi(digits);
38 let t = t.to_vec0::<f32>()?;
39 Ok(f32::round(t * b) / b)
40}
41
42pub fn to_vec1_round(t: &Tensor, digits: i32) -> Result<Vec<f32>> {
43 let b = 10f32.powi(digits);
44 let t = t.to_vec1::<f32>()?;
45 let t = t.iter().map(|t| f32::round(t * b) / b).collect();
46 Ok(t)
47}
48
49pub fn to_vec2_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<f32>>> {
50 let b = 10f32.powi(digits);
51 let t = t.to_vec2::<f32>()?;
52 let t = t
53 .iter()
54 .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
55 .collect();
56 Ok(t)
57}
58
59pub fn to_vec3_round(t: &Tensor, digits: i32) -> Result<Vec<Vec<Vec<f32>>>> {
60 let b = 10f32.powi(digits);
61 let t = t.to_vec3::<f32>()?;
62 let t = t
63 .iter()
64 .map(|t| {
65 t.iter()
66 .map(|t| t.iter().map(|t| f32::round(t * b) / b).collect())
67 .collect()
68 })
69 .collect();
70 Ok(t)
71}