constensor-core 0.1.1

Experimental ML framework featuring a graph-based JIT compiler.
Documentation
#[cfg(feature = "cuda")]
use constensor_core::Cuda;
use constensor_core::{CompiledGraph, Cpu, Graph, GraphTensor, R2};

macro_rules! test_for_device_fma {
    ($dev:ty, $name:ident) => {
        mod $name {
            use super::*;
            #[test]
            fn float_fma() {
                let mut graph = Graph::empty();
                let a = GraphTensor::<R2<3, 4>, f32, $dev>::fill(&mut graph, 2.0);
                let b = GraphTensor::<R2<3, 4>, f32, $dev>::fill(&mut graph, 3.0);
                let c = GraphTensor::<R2<3, 4>, f32, $dev>::fill(&mut graph, 4.0);
                let _res = a * b + c;
                let compiled: CompiledGraph<R2<3, 4>, f32, $dev> = graph.compile().unwrap();
                let tensor = compiled.run().unwrap();
                assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![10.0; 4]; 3],);
            }

            #[test]
            fn integral_fma() {
                let mut graph = Graph::empty();
                let a = GraphTensor::<R2<3, 4>, i32, $dev>::fill(&mut graph, 2);
                let b = GraphTensor::<R2<3, 4>, i32, $dev>::fill(&mut graph, 3);
                let c = GraphTensor::<R2<3, 4>, i32, $dev>::fill(&mut graph, 4);
                let _res = a * b + c;
                let compiled: CompiledGraph<R2<3, 4>, i32, $dev> = graph.compile().unwrap();
                let tensor = compiled.run().unwrap();
                assert_eq!(tensor.data().unwrap().to_vec(), vec![vec![10; 4]; 3],);
            }
        }
    };
}

#[cfg(feature = "half")]
macro_rules! test_for_device_half_fma {
    ($dev:ty, $name:ident) => {
        mod $name {
            use super::*;
            #[test]
            fn float_fma() {
                use half::f16;

                let mut graph = Graph::empty();
                let a =
                    GraphTensor::<R2<3, 4>, f16, $dev>::fill(&mut graph, f16::from_f64_const(2.0));
                let b =
                    GraphTensor::<R2<3, 4>, f16, $dev>::fill(&mut graph, f16::from_f64_const(3.0));
                let c =
                    GraphTensor::<R2<3, 4>, f16, $dev>::fill(&mut graph, f16::from_f64_const(4.0));
                let _res = a * b + c;
                let compiled: CompiledGraph<R2<3, 4>, f16, $dev> = graph.compile().unwrap();
                let tensor = compiled.run().unwrap();
                assert_eq!(
                    tensor.data().unwrap().to_vec(),
                    vec![vec![f16::from_f64_const(10.0); 4]; 3],
                );
            }
        }
    };
}

#[cfg(feature = "bfloat")]
macro_rules! test_for_device_bfloat_fma {
    ($dev:ty, $name:ident) => {
        mod $name {
            use super::*;
            #[test]
            fn float_fma() {
                use half::bf16;

                let mut graph = Graph::empty();
                let a = GraphTensor::<R2<3, 4>, bf16, $dev>::fill(
                    &mut graph,
                    bf16::from_f64_const(2.0),
                );
                let b = GraphTensor::<R2<3, 4>, bf16, $dev>::fill(
                    &mut graph,
                    bf16::from_f64_const(3.0),
                );
                let c = GraphTensor::<R2<3, 4>, bf16, $dev>::fill(
                    &mut graph,
                    bf16::from_f64_const(4.0),
                );
                let _res = a * b + c;
                let compiled: CompiledGraph<R2<3, 4>, bf16, $dev> = graph.compile().unwrap();
                let tensor = compiled.run().unwrap();
                assert_eq!(
                    tensor.data().unwrap().to_vec(),
                    vec![vec![bf16::from_f64_const(10.0); 4]; 3],
                );
            }
        }
    };
}

test_for_device_fma!(Cpu, cpu_tests_fma);
#[cfg(feature = "cuda")]
test_for_device_fma!(Cuda<0>, cuda_tests_fma);

#[cfg(feature = "half")]
test_for_device_half_fma!(Cpu, cpu_tests_fma_half);
#[cfg(all(feature = "cuda", feature = "half"))]
test_for_device_half_fma!(Cuda<0>, cuda_tests_fma_half);

#[cfg(feature = "bfloat")]
test_for_device_bfloat_fma!(Cpu, cpu_tests_fma_bfloat);
#[cfg(all(feature = "cuda", feature = "half"))]
test_for_device_bfloat_fma!(Cuda<0>, cuda_tests_fma_float);