pub mod object;
pub mod string;
pub use object::*;
pub use string::*;
use std::{
ffi::{CStr, CString},
str,
};
pub use crate::{
context::{Context, DeviceType},
errors::*,
function::Function,
module::Module,
ndarray::NDArray,
};
pub use function::{ArgValue, RetValue};
pub use tvm_sys::byte_array::ByteArray;
pub use tvm_sys::datatype::DataType;
use tvm_sys::ffi;
pub use tvm_macros::external;
#[macro_export]
macro_rules! tvm_call {
($e:expr) => {{
if unsafe { $e } != 0 {
Err($crate::get_last_error().into())
} else {
Ok(())
}
}};
}
#[macro_export]
macro_rules! check_call {
($e:expr) => {{
if unsafe { $e } != 0 {
panic!("{}", $crate::get_last_error());
}
}};
}
pub fn get_last_error() -> &'static str {
unsafe {
match CStr::from_ptr(ffi::TVMGetLastError()).to_str() {
Ok(s) => s,
Err(_) => "Invalid UTF-8 message",
}
}
}
pub(crate) fn set_last_error<E: std::error::Error>(err: &E) {
let c_string = CString::new(err.to_string()).unwrap();
unsafe {
ffi::TVMAPISetLastError(c_string.as_ptr());
}
}
pub mod array;
pub mod context;
pub mod errors;
pub mod function;
pub mod map;
pub mod module;
pub mod ndarray;
mod to_function;
pub fn version() -> &'static str {
match str::from_utf8(ffi::TVM_VERSION) {
Ok(s) => s,
Err(_) => "Invalid UTF-8 string",
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{ByteArray, Context, DataType};
use std::{convert::TryInto, str::FromStr};
#[test]
fn print_version() {
println!("TVM version: {}", version());
}
#[test]
fn set_error() {
let err = errors::NDArrayError::EmptyArray;
set_last_error(&err);
assert_eq!(
get_last_error().trim(),
errors::NDArrayError::EmptyArray.to_string()
);
}
#[test]
fn bytearray() {
let w = vec![1u8, 2, 3, 4, 5];
let v = ByteArray::from(w.as_slice());
let tvm: ByteArray = RetValue::from(v).try_into().unwrap();
assert_eq!(
tvm.data(),
w.iter().copied().collect::<Vec<u8>>().as_slice()
);
}
#[test]
fn ty() {
let t = DataType::from_str("int32").unwrap();
let tvm: DataType = RetValue::from(t).try_into().unwrap();
assert_eq!(tvm, t);
}
#[test]
fn ctx() {
let c = Context::from_str("gpu").unwrap();
let tvm: Context = RetValue::from(c).try_into().unwrap();
assert_eq!(tvm, c);
}
}