use crate::tensor_desc::TensorDesc;
use crate::{drop_using_function, try_unsafe, util::Result, InferenceError};
use crate::{Layout, Precision};
use openvino_sys::{
self, dimensions_t, ie_blob_buffer__bindgen_ty_1, ie_blob_buffer_t, ie_blob_byte_size,
ie_blob_free, ie_blob_get_buffer, ie_blob_get_dims, ie_blob_get_layout, ie_blob_get_precision,
ie_blob_make_memory, ie_blob_size, ie_blob_t,
};
use std::convert::TryFrom;
pub struct Blob {
pub(crate) instance: *mut ie_blob_t,
}
drop_using_function!(Blob, ie_blob_free);
impl Blob {
pub fn new(description: &TensorDesc, data: &[u8]) -> Result<Self> {
let mut blob = Self::allocate(description)?;
let blob_len = blob.byte_len()?;
assert_eq!(
blob_len,
data.len(),
"The data to initialize ({} bytes) must be the same as the blob size ({} bytes).",
data.len(),
blob_len
);
let buffer = blob.buffer_mut()?;
buffer.copy_from_slice(data);
Ok(blob)
}
pub fn allocate(description: &TensorDesc) -> Result<Self> {
let mut instance = std::ptr::null_mut();
try_unsafe!(ie_blob_make_memory(
std::ptr::addr_of!(description.instance),
std::ptr::addr_of_mut!(instance)
))?;
Ok(Self { instance })
}
pub fn tensor_desc(&self) -> Result<TensorDesc> {
let blob = self.instance as *const ie_blob_t;
let mut layout = Layout::ANY;
try_unsafe!(ie_blob_get_layout(blob, std::ptr::addr_of_mut!(layout)))?;
let mut dimensions = dimensions_t {
ranks: 0,
dims: [0; 8usize],
};
try_unsafe!(ie_blob_get_dims(blob, std::ptr::addr_of_mut!(dimensions)))?;
let mut precision = Precision::UNSPECIFIED;
try_unsafe!(ie_blob_get_precision(
blob,
std::ptr::addr_of_mut!(precision)
))?;
Ok(TensorDesc::new(layout, &dimensions.dims, precision))
}
pub fn len(&mut self) -> Result<usize> {
let mut size = 0;
try_unsafe!(ie_blob_size(self.instance, std::ptr::addr_of_mut!(size)))?;
Ok(usize::try_from(size).unwrap())
}
pub fn byte_len(&mut self) -> Result<usize> {
let mut size = 0;
try_unsafe!(ie_blob_byte_size(
self.instance,
std::ptr::addr_of_mut!(size)
))?;
Ok(usize::try_from(size).unwrap())
}
pub fn buffer(&mut self) -> Result<&[u8]> {
let mut buffer = Blob::empty_buffer();
try_unsafe!(ie_blob_get_buffer(
self.instance,
std::ptr::addr_of_mut!(buffer)
))?;
let size = self.byte_len()?;
let slice = unsafe {
std::slice::from_raw_parts(buffer.__bindgen_anon_1.buffer as *const u8, size)
};
Ok(slice)
}
pub fn buffer_mut(&mut self) -> Result<&mut [u8]> {
let mut buffer = Blob::empty_buffer();
try_unsafe!(ie_blob_get_buffer(
self.instance,
std::ptr::addr_of_mut!(buffer)
))?;
let size = self.byte_len()?;
let slice = unsafe {
std::slice::from_raw_parts_mut(buffer.__bindgen_anon_1.buffer.cast::<u8>(), size)
};
Ok(slice)
}
pub unsafe fn buffer_mut_as_type<T>(&mut self) -> Result<&mut [T]> {
let mut buffer = Blob::empty_buffer();
InferenceError::from(ie_blob_get_buffer(
self.instance,
std::ptr::addr_of_mut!(buffer),
))?;
let size = self.byte_len()? / std::mem::size_of::<T>();
let slice =
std::slice::from_raw_parts_mut(buffer.__bindgen_anon_1.buffer.cast::<T>(), size);
Ok(slice)
}
pub(crate) unsafe fn from_raw_pointer(instance: *mut ie_blob_t) -> Self {
Self { instance }
}
fn empty_buffer() -> ie_blob_buffer_t {
ie_blob_buffer_t {
__bindgen_anon_1: ie_blob_buffer__bindgen_ty_1 {
buffer: std::ptr::null_mut(),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic]
fn invalid_blob_size() {
let desc = TensorDesc::new(Layout::NHWC, &[1, 2, 2, 2], Precision::U8);
let _ = Blob::new(&desc, &[0; 7]).unwrap();
}
#[test]
fn buffer_conversion() {
openvino_sys::library::load().expect("unable to find an OpenVINO shared library");
const LEN: usize = 200 * 100;
let desc = TensorDesc::new(Layout::HW, &[200, 100], Precision::U16);
let mut blob = Blob::new(&desc, &[0; LEN * 2]).unwrap();
assert_eq!(blob.len().unwrap(), LEN);
assert_eq!(
blob.byte_len().unwrap(),
LEN * 2,
"we should have twice as many bytes (u16 = u8 * 2)"
);
assert_eq!(
blob.buffer().unwrap().len(),
LEN * 2,
"we should have twice as many items (u16 = u8 * 2)"
);
assert_eq!(
unsafe { blob.buffer_mut_as_type::<f32>() }.unwrap().len(),
LEN / 2,
"we should have half as many items (u16 = f32 / 2)"
);
}
}