1use core::ffi::c_void;
4
5use crate::{Device, Result, check};
6use iree_embedded_sys as sys;
7
8pub struct Tensor {
11 raw: *mut sys::iree_hal_buffer_view_t,
12}
13
14impl Tensor {
15 pub fn from_f32(device: &Device, shape: &[usize], data: &[f32]) -> Result<Self> {
17 Self::from_bytes(
18 device,
19 shape,
20 sys::IREE_HAL_ELEMENT_TYPE_FLOAT_32,
21 data.as_ptr() as *const c_void,
22 core::mem::size_of_val(data),
23 )
24 }
25
26 pub fn from_u8(device: &Device, shape: &[usize], data: &[u8]) -> Result<Self> {
29 Self::from_bytes(
30 device,
31 shape,
32 sys::IREE_HAL_ELEMENT_TYPE_UINT_8,
33 data.as_ptr() as *const c_void,
34 data.len(),
35 )
36 }
37
38 pub fn read_into_f32(&self, device: &Device, out: &mut [f32]) -> Result<()> {
40 self.read_bytes(
41 device,
42 out.as_mut_ptr() as *mut c_void,
43 core::mem::size_of_val(out),
44 )
45 }
46
47 pub fn read_into_u8(&self, device: &Device, out: &mut [u8]) -> Result<()> {
49 self.read_bytes(device, out.as_mut_ptr() as *mut c_void, out.len())
50 }
51
52 fn from_bytes(
53 device: &Device,
54 shape: &[usize],
55 element_type: u32,
56 data: *const c_void,
57 len: usize,
58 ) -> Result<Self> {
59 let dims: heapless::Vec<sys::iree_hal_dim_t, 8> =
60 shape.iter().map(|&d| d as sys::iree_hal_dim_t).collect();
61 let params = sys::iree_hal_buffer_params_t {
62 usage: sys::IREE_HAL_BUFFER_USAGE_DEFAULT as _,
63 type_: sys::IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL as _,
64 ..unsafe { core::mem::zeroed() }
65 };
66 let mut raw = core::ptr::null_mut();
67 unsafe {
69 check(sys::iree_hal_buffer_view_allocate_buffer_copy(
70 device.raw(),
71 sys::iree_hal_device_allocator(device.raw()),
72 dims.len() as sys::iree_host_size_t,
73 dims.as_ptr(),
74 element_type as _,
75 sys::IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR as _,
76 params,
77 sys::iree_make_const_byte_span(data, len),
78 &mut raw,
79 ))?;
80 }
81 Ok(Tensor { raw })
82 }
83
84 fn read_bytes(&self, device: &Device, out: *mut c_void, len: usize) -> Result<()> {
85 unsafe {
87 check(sys::iree_hal_device_transfer_d2h(
88 device.raw(),
89 sys::iree_hal_buffer_view_buffer(self.raw),
90 0,
91 out,
92 len as sys::iree_device_size_t,
93 sys::IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT as _,
94 sys::iree_infinite_timeout(),
95 ))?;
96 }
97 Ok(())
98 }
99
100 pub(crate) fn raw(&self) -> *mut sys::iree_hal_buffer_view_t {
101 self.raw
102 }
103
104 pub(crate) fn from_raw(raw: *mut sys::iree_hal_buffer_view_t) -> Self {
106 Tensor { raw }
107 }
108}
109
110impl Drop for Tensor {
111 fn drop(&mut self) {
112 unsafe { sys::iree_hal_buffer_view_release(self.raw) };
114 }
115}