#![cfg(feature = "runtime")]
use eerie::runtime::{
self,
error::RuntimeError,
hal::{Buffer, BufferMapping, BufferParams, BufferView, ElementType, Encoding, Tensor},
vm::{FunctionLinkage, List, ToRef, ToValue, Undefined, Value},
};
use half::f16;
use log::info;
use test_log::test;
fn local_sync_device() -> (
runtime::hal::DriverRegistry,
runtime::hal::Driver,
runtime::hal::Device,
) {
let registry = runtime::hal::DriverRegistry::with_available_drivers().unwrap();
let driver = registry.create_driver("local-sync").unwrap();
let device = driver.create_default_device().unwrap();
(registry, driver, device)
}
#[test]
fn test_instance() {
runtime::vm::Instance::new().unwrap();
}
#[test]
fn test_context_with_hal_module() {
let instance = runtime::vm::Instance::new().unwrap();
let (_registry, _driver, device) = local_sync_device();
let hal_module = runtime::vm::Module::hal(&instance, &device).unwrap();
runtime::vm::Context::with_modules(&instance, &[&hal_module]).unwrap();
}
#[test]
fn device_metadata() {
let (_registry, driver, device) = local_sync_device();
let devices = driver.available_devices().unwrap();
assert!(!devices.is_empty());
assert_eq!(devices[0].ordinal, 0);
assert!(!devices[0].name.is_empty() || !devices[0].path.is_empty());
assert!(!device.id().is_empty());
device.capabilities().unwrap();
assert!(device
.query_i64("eerie.missing.query.category", "missing-key")
.is_err());
device.trim().unwrap();
}
#[test]
fn dynamic_list() {
let instance = runtime::vm::Instance::new().unwrap();
let mut list = List::<Value<i32>>::new(4, &instance).unwrap();
list.push_value(1.to_value()).unwrap();
list.push_value(2.to_value()).unwrap();
list.push_value(3.to_value()).unwrap();
list.push_value(4.to_value()).unwrap();
let val = list.get_value::<i32>(0).unwrap();
drop(list);
assert_eq!(val.get(), 1);
}
#[test]
fn ref_list() {
let instance = runtime::vm::Instance::new().unwrap();
let (_registry, _driver, device) = local_sync_device();
let mut list = List::<Undefined>::new(4, &instance).unwrap();
let buffer = BufferView::<f32>::from_host(
&device,
&[2, 2],
Encoding::DenseRowMajor,
&[1.0, 2.0, 3.0, 4.0],
)
.unwrap();
info!("buffer: {:?}", buffer);
let buffer_ref = buffer.to_ref(&instance).unwrap();
list.push_ref(&buffer_ref).unwrap();
list.push_ref(&buffer_ref).unwrap();
let buffer_ref_2 = list.get_ref::<BufferView<f32>>(0).unwrap();
let buffer_2 = buffer_ref_2.to_buffer_view().unwrap();
info!("buffer_ref_2: {:?}", buffer_2);
let mapping = BufferMapping::map_read(&buffer_2).unwrap();
info!("mapping: {:?}", mapping.data());
}
#[test]
fn fp16_buffer() {
let (_registry, _driver, device) = local_sync_device();
let buffer = BufferView::<f16>::from_host(
&device,
&[2, 2],
Encoding::DenseRowMajor,
&[
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
f16::from_f32(4.0),
],
)
.unwrap();
let mapping = BufferMapping::map_read(&buffer).unwrap();
assert_eq!(
mapping.data(),
&[
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
f16::from_f32(4.0)
]
);
}
#[test]
fn buffer_metadata() {
let (_registry, _driver, device) = local_sync_device();
let buffer = BufferView::<f32>::from_host(
&device,
&[2, 2],
Encoding::DenseRowMajor,
&[1.0, 2.0, 3.0, 4.0],
)
.unwrap();
assert_eq!(buffer.rank(), 2);
assert_eq!(buffer.shape(), vec![2, 2]);
assert_eq!(buffer.dim(0), 2);
assert_eq!(buffer.dim(1), 2);
assert_eq!(buffer.element_count(), 4);
assert_eq!(buffer.element_size(), core::mem::size_of::<f32>());
assert_eq!(buffer.element_type(), ElementType::Float32);
assert_eq!(buffer.encoding(), Encoding::DenseRowMajor);
}
#[test]
fn raw_buffer_allocation_and_view() {
let (_registry, _driver, device) = local_sync_device();
let buffer = Buffer::allocate(
&device,
4 * core::mem::size_of::<f32>(),
BufferParams::default(),
)
.unwrap();
assert_eq!(buffer.byte_offset(), 0);
assert_eq!(buffer.byte_length(), 4 * core::mem::size_of::<f32>());
assert!(buffer.allocation_size() >= buffer.byte_length());
assert_ne!(buffer.memory_type(), 0);
assert_ne!(buffer.allowed_access(), 0);
assert_ne!(buffer.allowed_usage(), 0);
let view = BufferView::<f32>::from_buffer(&buffer, &[4], Encoding::DenseRowMajor).unwrap();
let raw_buffer = view.raw_buffer();
assert_eq!(raw_buffer.byte_length(), buffer.byte_length());
assert_eq!(raw_buffer.memory_type(), buffer.memory_type());
assert_eq!(raw_buffer.allowed_access(), buffer.allowed_access());
assert_eq!(raw_buffer.allowed_usage(), buffer.allowed_usage());
view.write_from_slice(&device, &[1.0, 2.0, 3.0, 4.0])
.unwrap();
assert_eq!(view.read_to_vec(&device).unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
let subspan = buffer.subspan(0, 2 * core::mem::size_of::<f32>()).unwrap();
assert_eq!(subspan.byte_length(), 2 * core::mem::size_of::<f32>());
}
#[test]
fn buffer_shape_mismatch_is_rejected() {
let (_registry, _driver, device) = local_sync_device();
let err = BufferView::<f32>::from_host(
&device,
&[2, 3],
Encoding::DenseRowMajor,
&[1.0, 2.0, 3.0, 4.0],
)
.unwrap_err();
assert!(matches!(err, RuntimeError::InvalidArgument(_)));
}
#[test]
fn tensor_read_write_and_copy() {
let (_registry, _driver, device) = local_sync_device();
let tensor = Tensor::<f32>::from_slice(&device, &[2, 2], &[1.0, 2.0, 3.0, 4.0]).unwrap();
assert_eq!(tensor.shape(), vec![2, 2]);
assert_eq!(
tensor.read_to_vec(&device).unwrap(),
vec![1.0, 2.0, 3.0, 4.0]
);
tensor
.write_from_slice(&device, &[5.0, 6.0, 7.0, 8.0])
.unwrap();
assert_eq!(
tensor.read_to_vec(&device).unwrap(),
vec![5.0, 6.0, 7.0, 8.0]
);
let target = Tensor::<f32>::from_slice(&device, &[2, 2], &[0.0, 0.0, 0.0, 0.0]).unwrap();
tensor
.as_buffer_view()
.copy_to(&device, target.as_buffer_view())
.unwrap();
assert_eq!(
target.read_to_vec(&device).unwrap(),
vec![5.0, 6.0, 7.0, 8.0]
);
}
#[test]
fn append_module_from_vmvx_fixture() {
let vmfb = include_bytes!("mul_vmvx.vmfb");
let output = run_mul(vmfb, "local-sync");
assert_eq!(output[0], 0.0);
assert_eq!(output[7], 49.0);
assert_eq!(output[99], 9801.0);
}
#[test]
fn function_metadata_and_tensor_invoke() {
let vmfb = include_bytes!("mul_vmvx.vmfb");
let instance = runtime::vm::Instance::new().unwrap();
let registry = runtime::hal::DriverRegistry::with_available_drivers().unwrap();
let driver = registry.create_driver("local-sync").unwrap();
let device = driver.create_default_device().unwrap();
let hal_module = runtime::vm::Module::hal(&instance, &device).unwrap();
let bytecode_module = runtime::vm::Module::bytecode(&instance, vmfb).unwrap();
let context =
runtime::vm::Context::with_modules(&instance, &[&hal_module, &bytecode_module]).unwrap();
let function = context.resolve_function("arithmetic.simple_mul").unwrap();
assert_eq!(bytecode_module.name(), "arithmetic");
let module_signature = bytecode_module.signature();
assert!(module_signature.export_function_count >= 1);
assert!(bytecode_module.lookup_attr("missing.module.attr").is_none());
assert!(bytecode_module
.attr(module_signature.attr_count)
.unwrap()
.is_none());
let function_ref = bytecode_module
.lookup_export_function("simple_mul")
.unwrap();
assert_eq!(function_ref.name(), "simple_mul");
assert_eq!(function_ref.signature().argument_count, 2);
assert!(function_ref.lookup_attr("missing.function.attr").is_none());
let function_ref = bytecode_module
.lookup_function("simple_mul", FunctionLinkage::Export)
.unwrap();
assert_eq!(function_ref.name(), "simple_mul");
assert_eq!(function.name(), "simple_mul");
let signature = function.signature();
assert_eq!(signature.argument_count, 2);
assert_eq!(signature.result_count, 1);
assert!(function.lookup_attr("missing.function.attr").is_none());
let input_data = Vec::from_iter((0..100).map(|i| i as f32));
let input = Tensor::<f32>::from_slice(&device, &[100], input_data.as_slice()).unwrap();
let outputs = function.invoke_tensors(&[&input, &input], 1).unwrap();
let output = outputs[0].read_to_vec(&device).unwrap();
assert_eq!(output[0], 0.0);
assert_eq!(output[7], 49.0);
assert_eq!(output[99], 9801.0);
}
fn run_mul(vmfb: &[u8], driver_name: &str) -> Vec<f32> {
let instance = runtime::vm::Instance::new().unwrap();
let registry = runtime::hal::DriverRegistry::with_available_drivers().unwrap();
let driver = registry.create_driver(driver_name).unwrap();
let device = driver.create_default_device().unwrap();
let hal_module = runtime::vm::Module::hal(&instance, &device).unwrap();
let bytecode_module = runtime::vm::Module::bytecode(&instance, vmfb).unwrap();
let context =
runtime::vm::Context::with_modules(&instance, &[&hal_module, &bytecode_module]).unwrap();
let function = context.resolve_function("arithmetic.simple_mul").unwrap();
let input_data = Vec::from_iter((0..100).map(|i| i as f32));
let input = BufferView::<f32>::from_host(
&device,
&[100],
Encoding::DenseRowMajor,
input_data.as_slice(),
)
.unwrap();
let mut input_list = List::<Undefined>::new(2, &instance).unwrap();
input_list
.push_ref(&input.to_ref(&instance).unwrap())
.unwrap();
input_list
.push_ref(&input.to_ref(&instance).unwrap())
.unwrap();
let mut output_list = List::<Undefined>::new(1, &instance).unwrap();
function.invoke(&input_list, &mut output_list).unwrap();
let output_ref = output_list.get_ref::<BufferView<f32>>(0).unwrap();
output_ref
.to_buffer_view()
.unwrap()
.read_to_vec(&device)
.unwrap()
}
#[cfg(feature = "compiler")]
mod integration_tests {
use eerie::compiler;
use log::info;
use std::path::Path;
use std::sync::Mutex;
use super::run_mul;
static COMPILER: Mutex<Option<compiler::Compiler>> = Mutex::new(None);
fn init_compiler() {
let mut global_compiler = COMPILER.lock().unwrap();
if global_compiler.is_none() {
let compiler = compiler::Compiler::new().unwrap();
*global_compiler = Some(compiler);
}
}
fn compile_mul(target_backend: &str) -> Vec<u8> {
init_compiler();
let compiler = COMPILER.lock().unwrap();
let mut compiler_session = compiler.as_ref().unwrap().create_session();
let mut flags = vec![format!("--iree-hal-target-backends={target_backend}")];
if target_backend == "metal-spirv" {
flags.push("--iree-metal-compile-to-metallib=false".to_string());
}
compiler_session.set_flags(flags).unwrap();
let source = compiler_session
.create_source_from_file(Path::new("tests/mul.mlir"))
.unwrap();
let mut invocation = compiler_session.create_invocation();
let mut output = compiler::MemBufferOutput::new(compiler.as_ref().unwrap()).unwrap();
invocation
.parse_source(source)
.unwrap()
.set_verify_ir(true)
.set_compile_to_phase("end")
.unwrap()
.pipeline(compiler::Pipeline::Std)
.unwrap()
.output_vm_byte_code(&mut output)
.unwrap();
output.map_memory().unwrap().to_vec()
}
#[test]
fn append_module() {
let vmfb = compile_mul("llvm-cpu");
let output = run_mul(&vmfb, "local-sync");
info!("Output: {:?}", output);
assert_eq!(output[0], 0.0);
assert_eq!(output[7], 49.0);
assert_eq!(output[99], 9801.0);
}
#[test]
fn metal_smoke() {
if std::env::var("EERIE_TEST_METAL").as_deref() != Ok("1") {
return;
}
let vmfb = compile_mul("metal-spirv");
let output = run_mul(&vmfb, "metal");
assert_eq!(output[0], 0.0);
assert_eq!(output[7], 49.0);
assert_eq!(output[99], 9801.0);
}
}