Skip to main content

iree_embedded/
tensor.rs

1//! Device buffers with shape and dtype, wrapping `iree_hal_buffer_view_t`.
2
3use core::ffi::c_void;
4
5use crate::{Device, Result, check};
6use iree_embedded_sys as sys;
7
8/// A device buffer with shape and element type: an input to or output from
9/// [`Context::invoke`](crate::Context::invoke).
10pub struct Tensor {
11    raw: *mut sys::iree_hal_buffer_view_t,
12}
13
14impl Tensor {
15    /// Allocate a device-local f32 buffer view, copying `data` in.
16    pub fn from_f32(device: &Device, shape: &[usize], data: &[f32]) -> Result<Self> {
17        Self::from_bytes(
18            device,
19            shape,
20            sys::IREE_HAL_ELEMENT_TYPE_FLOAT_32,
21            data.as_ptr() as *const c_void,
22            core::mem::size_of_val(data),
23        )
24    }
25
26    /// Allocate a device-local u8 buffer view, copying `data` in (e.g. an
27    /// int8/uint8 quantized model input).
28    pub fn from_u8(device: &Device, shape: &[usize], data: &[u8]) -> Result<Self> {
29        Self::from_bytes(
30            device,
31            shape,
32            sys::IREE_HAL_ELEMENT_TYPE_UINT_8,
33            data.as_ptr() as *const c_void,
34            data.len(),
35        )
36    }
37
38    /// Copy the buffer contents back to the host as f32.
39    pub fn read_into_f32(&self, device: &Device, out: &mut [f32]) -> Result<()> {
40        self.read_bytes(
41            device,
42            out.as_mut_ptr() as *mut c_void,
43            core::mem::size_of_val(out),
44        )
45    }
46
47    /// Copy the buffer contents back to the host as u8.
48    pub fn read_into_u8(&self, device: &Device, out: &mut [u8]) -> Result<()> {
49        self.read_bytes(device, out.as_mut_ptr() as *mut c_void, out.len())
50    }
51
52    fn from_bytes(
53        device: &Device,
54        shape: &[usize],
55        element_type: u32,
56        data: *const c_void,
57        len: usize,
58    ) -> Result<Self> {
59        let dims: heapless::Vec<sys::iree_hal_dim_t, 8> =
60            shape.iter().map(|&d| d as sys::iree_hal_dim_t).collect();
61        let params = sys::iree_hal_buffer_params_t {
62            usage: sys::IREE_HAL_BUFFER_USAGE_DEFAULT as _,
63            type_: sys::IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL as _,
64            ..unsafe { core::mem::zeroed() }
65        };
66        let mut raw = core::ptr::null_mut();
67        // SAFETY: data/len describe a valid initial-contents span; dims is valid.
68        unsafe {
69            check(sys::iree_hal_buffer_view_allocate_buffer_copy(
70                device.raw(),
71                sys::iree_hal_device_allocator(device.raw()),
72                dims.len() as sys::iree_host_size_t,
73                dims.as_ptr(),
74                element_type as _,
75                sys::IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR as _,
76                params,
77                sys::iree_make_const_byte_span(data, len),
78                &mut raw,
79            ))?;
80        }
81        Ok(Tensor { raw })
82    }
83
84    fn read_bytes(&self, device: &Device, out: *mut c_void, len: usize) -> Result<()> {
85        // SAFETY: out/len describe a valid mutable span; the buffer outlives the call.
86        unsafe {
87            check(sys::iree_hal_device_transfer_d2h(
88                device.raw(),
89                sys::iree_hal_buffer_view_buffer(self.raw),
90                0,
91                out,
92                len as sys::iree_device_size_t,
93                sys::IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT as _,
94                sys::iree_infinite_timeout(),
95            ))?;
96        }
97        Ok(())
98    }
99
100    pub(crate) fn raw(&self) -> *mut sys::iree_hal_buffer_view_t {
101        self.raw
102    }
103
104    /// Wrap a buffer view whose reference this `Tensor` now owns.
105    pub(crate) fn from_raw(raw: *mut sys::iree_hal_buffer_view_t) -> Self {
106        Tensor { raw }
107    }
108}
109
110impl Drop for Tensor {
111    fn drop(&mut self) {
112        // SAFETY: raw is an owned buffer-view reference.
113        unsafe { sys::iree_hal_buffer_view_release(self.raw) };
114    }
115}