#![cfg(feature = "runtime")]
use eerie::runtime::{
self,
hal::BufferView,
vm::{List, ToRef, ToValue, Value},
};
use half::f16;
use log::{debug, info};
use test_log::test;
#[test]
fn test_instance() {
let mut driver_registry = runtime::hal::DriverRegistry::new();
debug!("DriverRegistry created");
let options =
runtime::api::InstanceOptions::new(&mut driver_registry).use_all_available_drivers();
debug!("InstanceOptions created");
runtime::api::Instance::new(&options).unwrap();
}
#[test]
fn test_session() {
let mut driver_registry = runtime::hal::DriverRegistry::new();
debug!("DriverRegistry created");
let options =
runtime::api::InstanceOptions::new(&mut driver_registry).use_all_available_drivers();
debug!("InstanceOptions created");
let instance = runtime::api::Instance::new(&options).unwrap();
debug!("Instance created");
let device = instance
.try_create_default_device("local-sync")
.expect("Failed to create device");
debug!("Device created");
let session = runtime::api::Session::create_with_device(
&instance,
&runtime::api::SessionOptions::default(),
&device,
)
.expect("Failed to create session");
session.trim().expect("Failed to trim session");
}
#[test]
fn dynamic_list() {
let instance = runtime::api::Instance::new(
&runtime::api::InstanceOptions::new(&mut runtime::hal::DriverRegistry::new())
.use_all_available_drivers(),
)
.unwrap();
let list = runtime::vm::DynamicList::<Value<i32>>::new(4, &instance).unwrap();
list.push_value(1.to_value()).unwrap();
list.push_value(2.to_value()).unwrap();
list.push_value(3.to_value()).unwrap();
list.push_value(4.to_value()).unwrap();
let val = list.get_value::<i32>(0).unwrap();
drop(list);
assert_eq!(val.from_value(), 1);
}
#[test]
fn ref_list() {
let instance = runtime::api::Instance::new(
&runtime::api::InstanceOptions::new(&mut runtime::hal::DriverRegistry::new())
.use_all_available_drivers(),
)
.unwrap();
let device = instance
.try_create_default_device("local-sync")
.expect("Failed to create device");
let session = runtime::api::Session::create_with_device(
&instance,
&runtime::api::SessionOptions::default(),
&device,
)
.unwrap();
let list =
runtime::vm::DynamicList::<runtime::vm::Ref<BufferView<f32>>>::new(4, &instance).unwrap();
let buffer = BufferView::<f32>::new(
&session,
&[2, 2],
runtime::hal::EncodingType::DenseRowMajor,
&[1.0, 2.0, 3.0, 4.0],
)
.unwrap();
info!("buffer: {:?}", buffer);
let buffer_ref = buffer.to_ref(&instance).unwrap();
list.push_ref(&buffer_ref).unwrap();
list.push_ref(&buffer_ref).unwrap();
let buffer_ref_2: runtime::vm::Ref<BufferView<f32>> = list.get_ref(0).unwrap();
info!("buffer_ref_2: {:?}", buffer_ref_2.to_buffer_view(&session));
let mapping = runtime::hal::BufferMapping::new(buffer_ref_2.to_buffer_view(&session)).unwrap();
info!("mapping: {:?}", mapping.data());
}
#[test]
fn fp16_buffer() {
let instance = runtime::api::Instance::new(
&runtime::api::InstanceOptions::new(&mut runtime::hal::DriverRegistry::new())
.use_all_available_drivers(),
)
.unwrap();
let device = instance
.try_create_default_device("local-sync")
.expect("Failed to create device");
let session = runtime::api::Session::create_with_device(
&instance,
&runtime::api::SessionOptions::default(),
&device,
)
.unwrap();
let buffer = BufferView::<f16>::new(
&session,
&[2, 2],
runtime::hal::EncodingType::DenseRowMajor,
&[
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
f16::from_f32(4.0),
],
)
.unwrap();
let mapping = runtime::hal::BufferMapping::new(buffer).unwrap();
assert_eq!(
mapping.data(),
&[
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
f16::from_f32(4.0)
]
);
}
#[cfg(feature = "compiler")]
mod integration_tests {
use eerie::compiler;
use eerie::runtime;
use eerie::runtime::hal::{BufferMapping, BufferView, EncodingType};
use eerie::runtime::vm::{List, ToRef};
use log::{debug, info};
use std::path::Path;
use std::sync::Mutex;
static COMPILER: Mutex<Option<compiler::Compiler>> = Mutex::new(None);
fn init_compiler() {
let mut global_compiler = COMPILER.lock().unwrap();
if global_compiler.is_none() {
let compiler = compiler::Compiler::new().unwrap();
*global_compiler = Some(compiler);
}
}
#[test]
fn append_module() {
init_compiler();
let compiler = COMPILER.lock().unwrap();
let mut compiler_session = compiler.as_ref().unwrap().create_session();
compiler_session
.set_flags(vec!["--iree-hal-target-backends=llvm-cpu".to_string()])
.unwrap();
let source = compiler_session
.create_source_from_file(Path::new("tests/mul.mlir"))
.unwrap();
let mut invocation = compiler_session.create_invocation();
let mut output = compiler::MemBufferOutput::new(compiler.as_ref().unwrap()).unwrap();
invocation
.parse_source(source)
.unwrap()
.set_verify_ir(true)
.set_compile_to_phase("end")
.unwrap()
.pipeline(compiler::Pipeline::Std)
.unwrap()
.output_vm_byte_code(&mut output)
.unwrap();
let vmfb = output.map_memory().unwrap();
let mut driver_registry = runtime::hal::DriverRegistry::new();
let options =
runtime::api::InstanceOptions::new(&mut driver_registry).use_all_available_drivers();
let instance = runtime::api::Instance::new(&options).unwrap();
let device = instance
.try_create_default_device("local-sync")
.expect("Failed to create device");
let session = runtime::api::Session::create_with_device(
&instance,
&runtime::api::SessionOptions::default(),
&device,
)
.unwrap();
unsafe { session.append_module_from_memory(vmfb) }.unwrap();
let func = session.lookup_function("arithmetic.simple_mul").unwrap();
debug!("Function found!");
let mut call = runtime::api::Call::new(&session, &func).unwrap();
let vec = Vec::from_iter((0..100).map(|i| i as f32));
let input = BufferView::<f32>::new(
&session,
&[100],
EncodingType::DenseRowMajor,
vec.as_slice(),
)
.unwrap();
info!("Input: {:?}", input);
call.inputs_push_back_buffer_view(&input).unwrap();
call.inputs_push_back_buffer_view(&input).unwrap();
call.invoke().unwrap();
let output = call.outputs_pop_front_buffer_view::<f32>().unwrap();
info!("Output: {:?}", output);
let input = BufferView::<f32>::new(
&session,
&[100],
EncodingType::DenseRowMajor,
vec.as_slice(),
)
.unwrap();
let input_list =
runtime::vm::DynamicList::<runtime::vm::Ref<BufferView<f32>>>::new(1, &instance)
.unwrap();
input_list
.push_ref(&input.to_ref(&instance).unwrap())
.unwrap();
input_list
.push_ref(&input.to_ref(&instance).unwrap())
.unwrap();
let output_list =
runtime::vm::DynamicList::<runtime::vm::Ref<BufferView<f32>>>::new(0, &instance)
.unwrap();
let function = session.lookup_function("arithmetic.simple_mul").unwrap();
function.invoke(&input_list, &output_list).unwrap();
let output = output_list.get_ref(0).unwrap();
let output_mapping: BufferMapping<f32> =
runtime::hal::BufferMapping::new(output.to_buffer_view(&session)).unwrap();
info!("Output: {:?}", output_mapping.data());
}
}