ug/
common_tests.rs

1use crate::{Device, LazyBuffer as LB, Result};
2
3pub mod lb {
4    use super::*;
5
6    pub fn add<D: Device>(dev: &D) -> Result<()> {
7        let lhs = LB::cst(40., (5, 2), dev)?;
8        let rhs = LB::cst(2., (5, 2), dev)?;
9        let lb = lhs.binary(crate::lang::BinaryOp::Add, rhs)?;
10        let data = lb.data_vec::<f32>()?;
11        assert_eq!(data, [42., 42., 42., 42., 42., 42., 42., 42., 42., 42.]);
12        Ok(())
13    }
14
15    pub fn mm<D: Device>(dev: &D) -> Result<()> {
16        let lhs = LB::cst(1., (4, 2), dev)?;
17        let rhs = LB::cst(2., (2, 3), dev)?;
18        let lb = lhs.matmul(rhs)?;
19        let data = lb.data_vec::<f32>()?;
20        assert_eq!(data, [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0]);
21        Ok(())
22    }
23
24    pub fn cat<D: Device>(dev: &D) -> Result<()> {
25        let arg0 = LB::copy([1f32, 2.0, 3.0, 4.0].as_slice(), (2, 2), dev)?;
26        let arg1 = LB::copy([1.1f32, 2.1, 3.1, 4.1].as_slice(), (2, 2), dev)?;
27        let arg2 = LB::copy([1.2f32, 2.2, 3.2, 4.2].as_slice(), (2, 2), dev)?;
28        let lb = LB::<D>::cat(&[&arg0, &arg1, &arg2], 0)?;
29        let data = lb.data_vec::<f32>()?;
30        assert_eq!(data, [1.0, 2.0, 3.0, 4.0, 1.1, 2.1, 3.1, 4.1, 1.2, 2.2, 3.2, 4.2]);
31        let lb = LB::<D>::cat(&[&arg0, &arg1, &arg2], 1)?;
32        let data = lb.data_vec::<f32>()?;
33        assert_eq!(data, [1.0, 2.0, 1.1, 2.1, 1.2, 2.2, 3.2, 4.2, 0.0, 0.0, 0.0, 0.0]);
34        Ok(())
35    }
36}