mod dummy;
use std::sync::Arc;
use crate::dummy::{DummyDevice, DummyElementwiseAddition, client};
use cubecl_runtime::server::CubeCount;
use cubecl_runtime::{ComputeRuntime, server::Bindings};
use dummy::*;
#[allow(unused)]
use serial_test::serial;
type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
#[test]
fn created_resource_is_the_same_when_read() {
let client = client(&DummyDevice);
let resource = Vec::from([0, 1, 2]);
let resource_description = client.create(&resource);
let obtained_resource = client.read_one(resource_description.binding());
assert_eq!(resource, obtained_resource)
}
#[test]
fn empty_allocates_memory() {
let client = client(&DummyDevice);
let size = 4;
let resource_description = client.empty(size);
let empty_resource = client.read_one(resource_description.binding());
assert_eq!(empty_resource.len(), 4);
}
#[test]
fn execute_elementwise_addition() {
let client = client(&DummyDevice);
let lhs = client.create(&[0, 1, 2]);
let rhs = client.create(&[4, 4, 4]);
let out = client.empty(3);
client.execute(
Arc::new(DummyElementwiseAddition),
CubeCount::Static(1, 1, 1),
Bindings::new().with_buffers(vec![lhs.binding(), rhs.binding(), out.clone().binding()]),
);
let obtained_resource = client.read_one(out.binding());
assert_eq!(obtained_resource, Vec::from([4, 5, 6]))
}
#[test]
#[serial]
#[cfg(feature = "std")]
fn autotune_basic_addition_execution() {
TEST_TUNER.clear();
let client = client(&DummyDevice);
let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
let lhs = client.create(&[0, 1, 2]);
let rhs = client.create(&[4, 4, 4]);
let out = client.empty(3);
let handles = vec![lhs.binding(), rhs.binding(), out.clone().binding()];
let test_set = dummy::addition_set(client.clone(), shapes);
autotune_execute(&client, &test_set, handles);
let obtained_resource = client.read_one(out.binding());
assert_eq!(obtained_resource, Vec::from([4, 5, 6]));
}
#[test]
#[serial]
#[cfg(feature = "std")]
fn autotune_basic_multiplication_execution() {
TEST_TUNER.clear();
let client = client(&DummyDevice);
let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
let lhs = client.create(&[0, 1, 2]);
let rhs = client.create(&[4, 4, 4]);
let out = client.empty(3);
let handles = vec![lhs.binding(), rhs.binding(), out.clone().binding()];
let test_set = dummy::multiplication_set(client.clone(), shapes);
autotune_execute(&client, &test_set, handles);
let obtained_resource = client.read_one(out.binding());
assert_eq!(obtained_resource, Vec::from([0, 4, 8]));
}
#[test]
#[serial]
#[cfg(feature = "std")]
fn autotune_cache_same_key_return_a_cache_hit() {
TEST_TUNER.clear();
let runtime = Runtime::new();
let client = runtime.client(&DummyDevice, dummy::init_client);
let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
let lhs_1 = client.create(&[0, 1, 2]);
let rhs_1 = client.create(&[4, 4, 4]);
let out_1 = client.empty(3);
let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()];
let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]];
let lhs_2 = client.create(&[0, 1, 2, 3]);
let rhs_2 = client.create(&[5, 6, 7, 8]);
let out_2 = client.empty(4);
let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()];
let cache_test_autotune_kernel_1 =
dummy::cache_test_set(client.clone(), shapes_1, handles_1, false);
let cache_test_autotune_kernel_2 =
dummy::cache_test_set(client.clone(), shapes_2, handles_2, false);
autotune_execute(&client, &cache_test_autotune_kernel_1, vec![]);
autotune_execute(&client, &cache_test_autotune_kernel_2, vec![]);
let obtained_resource = client.read_one(out_2.binding());
assert_eq!(obtained_resource, Vec::from([0, 1, 2, 3]));
}
#[test]
#[serial]
#[cfg(feature = "std")]
fn autotune_cache_different_keys_return_a_cache_miss() {
TEST_TUNER.clear();
let client = client(&DummyDevice);
let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
let lhs_1 = client.create(&[0, 1, 2]);
let rhs_1 = client.create(&[4, 4, 4]);
let out_1 = client.empty(3);
let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()];
let shapes_2 = vec![vec![1, 5], vec![1, 5], vec![1, 5]];
let lhs_2 = client.create(&[0, 1, 2, 3, 4]);
let rhs_2 = client.create(&[5, 6, 7, 8, 9]);
let out_2 = client.empty(5);
let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()];
let cache_test_autotune_kernel_1 =
dummy::cache_test_set(client.clone(), shapes_1, handles_1, false);
let cache_test_autotune_kernel_2 =
dummy::cache_test_set(client.clone(), shapes_2, handles_2, false);
autotune_execute(&client, &cache_test_autotune_kernel_1, vec![]);
autotune_execute(&client, &cache_test_autotune_kernel_2, vec![]);
let obtained_resource = client.read_one(out_2.binding());
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9]));
}
#[test]
#[serial]
#[cfg(feature = "std")]
fn autotune_cache_different_checksums_return_a_cache_miss() {
TEST_TUNER.clear();
let runtime = Runtime::new();
let client = runtime.client(&DummyDevice, dummy::init_client);
let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]];
let lhs_1 = client.create(&[0, 1, 2]);
let rhs_1 = client.create(&[4, 4, 4]);
let out_1 = client.empty(3);
let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()];
let cache_test_autotune_kernel_1 =
dummy::cache_test_set(client.clone(), shapes_1, handles_1, false);
autotune_execute(&client, &cache_test_autotune_kernel_1, vec![]);
TEST_TUNER.clear();
let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]];
let lhs_2 = client.create(&[0, 1, 2, 3]);
let rhs_2 = client.create(&[5, 6, 7, 8]);
let out_2 = client.empty(4);
let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()];
let cache_test_autotune_kernel_2 =
dummy::cache_test_set(client.clone(), shapes_2, handles_2, true);
autotune_execute(&client, &cache_test_autotune_kernel_2, vec![]);
let obtained_resource = client.read_one(out_2.binding());
assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8]));
}