Skip to main content

iree_embedded/
context.rs

1//! Loading a compiled module and invoking its functions.
2//!
3//! A `Context` binds the HAL module (wrapping the device) and the bytecode
4//! module (the `.vmfb`) so functions can be resolved and invoked.
5
6use core::marker::PhantomData;
7
8use crate::{Arena, Device, Instance, Result, Tensor, check};
9use iree_embedded_sys as sys;
10
11/// A resolved entry-point function. `iree_vm_function_t` is a plain value
12/// handle (not refcounted), so this is `Copy`.
13#[derive(Clone, Copy)]
14pub struct Function {
15    raw: sys::iree_vm_function_t,
16}
17
18/// A loaded model: the HAL module plus the bytecode module from a `.vmfb`,
19/// ready to resolve and invoke functions. Borrows its [`Instance`].
20pub struct Context<'i> {
21    raw: *mut sys::iree_vm_context_t,
22    _instance: PhantomData<&'i Instance>,
23}
24
25impl<'i> Context<'i> {
26    /// Load the model `vmfb` onto `device`, allocating from `arena`. The bytes
27    /// must outlive the context (use [`include_vmfb!`](crate::include_vmfb)).
28    pub fn new(
29        instance: &'i Instance,
30        device: &Device,
31        vmfb: &'static [u8],
32        arena: &Arena,
33    ) -> Result<Self> {
34        let alloc = arena.as_iree_allocator();
35        // SAFETY: all handles are created/owned here; out-pointers are valid.
36        unsafe {
37            // The HAL module is built over a device group; the group is only
38            // needed during module creation and released immediately after.
39            let mut group = core::ptr::null_mut();
40            check(sys::iree_hal_device_group_create_from_device(
41                device.raw(),
42                alloc,
43                &mut group,
44            ))?;
45            let mut hal_module = core::ptr::null_mut();
46            let status = sys::iree_hal_module_create(
47                instance.raw(),
48                sys::iree_hal_module_device_policy_default(),
49                group,
50                sys::IREE_HAL_MODULE_FLAG_SYNCHRONOUS as _,
51                sys::iree_hal_module_debug_sink_null(),
52                alloc,
53                &mut hal_module,
54            );
55            sys::iree_hal_device_group_release(group);
56            check(status)?;
57
58            // Bytecode module from the embedded .vmfb bytes (not copied).
59            let mut bytecode = core::ptr::null_mut();
60            let bc = sys::iree_vm_bytecode_module_create(
61                instance.raw(),
62                sys::IREE_VM_BYTECODE_MODULE_FLAG_NONE as _,
63                sys::iree_make_const_byte_span(vmfb.as_ptr() as *const _, vmfb.len()),
64                sys::iree_allocator_null(),
65                alloc,
66                &mut bytecode,
67            );
68            if !bc.is_null() {
69                sys::iree_vm_module_release(hal_module);
70                check(bc)?;
71            }
72
73            let mut modules = [hal_module, bytecode];
74            let mut raw = core::ptr::null_mut();
75            let ctx = sys::iree_vm_context_create_with_modules(
76                instance.raw(),
77                sys::IREE_VM_CONTEXT_FLAG_NONE as _,
78                modules.len() as sys::iree_host_size_t,
79                modules.as_mut_ptr(),
80                alloc,
81                &mut raw,
82            );
83            sys::iree_vm_module_release(hal_module);
84            sys::iree_vm_module_release(bytecode);
85            check(ctx)?;
86            Ok(Context {
87                raw,
88                _instance: PhantomData,
89            })
90        }
91    }
92
93    /// Look up an exported function by its fully qualified `name`
94    /// (for example `"module.main"`).
95    pub fn resolve(&self, name: &str) -> Result<Function> {
96        let mut raw: sys::iree_vm_function_t = unsafe { core::mem::zeroed() };
97        // SAFETY: name is a valid UTF-8 slice; out-pointer is valid.
98        unsafe {
99            check(sys::iree_vm_context_resolve_function(
100                self.raw,
101                sys::iree_string_view_t {
102                    data: name.as_ptr() as *const _,
103                    size: name.len(),
104                },
105                &mut raw,
106            ))?;
107        }
108        Ok(Function { raw })
109    }
110
111    /// Synchronously invoke `function` with the given tensor inputs, returning
112    /// the output tensors.
113    pub fn invoke(
114        &self,
115        function: Function,
116        inputs: &[&Tensor],
117        arena: &Arena,
118    ) -> Result<heapless::Vec<Tensor, 8>> {
119        let alloc = arena.as_iree_allocator();
120        // SAFETY: lists and refs are created/owned here and released below.
121        unsafe {
122            let mut in_list = core::ptr::null_mut();
123            check(sys::iree_vm_list_create(
124                sys::iree_vm_make_undefined_type_def(),
125                inputs.len() as sys::iree_host_size_t,
126                alloc,
127                &mut in_list,
128            ))?;
129            for t in inputs {
130                // retain_ref takes its own reference; the Tensor keeps its own.
131                let mut r = sys::iree_hal_buffer_view_retain_ref(t.raw());
132                let st = sys::iree_vm_list_push_ref_move(in_list, &mut r);
133                if !st.is_null() {
134                    sys::iree_vm_ref_release(&mut r);
135                    sys::iree_vm_list_release(in_list);
136                    check(st)?;
137                }
138            }
139
140            let mut out_list = core::ptr::null_mut();
141            let oc = sys::iree_vm_list_create(
142                sys::iree_vm_make_undefined_type_def(),
143                8,
144                alloc,
145                &mut out_list,
146            );
147            if !oc.is_null() {
148                sys::iree_vm_list_release(in_list);
149                check(oc)?;
150            }
151
152            let status = sys::iree_vm_invoke(
153                self.raw,
154                function.raw,
155                sys::IREE_VM_INVOCATION_FLAG_NONE as _,
156                core::ptr::null(),
157                in_list,
158                out_list,
159                alloc,
160            );
161            sys::iree_vm_list_release(in_list);
162            if !status.is_null() {
163                sys::iree_vm_list_release(out_list);
164                check(status)?;
165            }
166
167            let count = sys::iree_vm_list_size(out_list);
168            let mut results: heapless::Vec<Tensor, 8> = heapless::Vec::new();
169            for i in 0..count {
170                let mut r: sys::iree_vm_ref_t = core::mem::zeroed();
171                // get_ref_retain hands us a +1 reference; deref reads the
172                // pointer, and the Tensor takes ownership of that reference.
173                if sys::iree_vm_list_get_ref_retain(out_list, i, &mut r).is_null() {
174                    let bv = sys::iree_hal_buffer_view_deref(r);
175                    let _ = results.push(Tensor::from_raw(bv));
176                }
177            }
178            sys::iree_vm_list_release(out_list);
179            Ok(results)
180        }
181    }
182}
183
184impl<'i> Drop for Context<'i> {
185    fn drop(&mut self) {
186        // SAFETY: raw was created by iree_vm_context_create_with_modules.
187        unsafe { sys::iree_vm_context_release(self.raw) };
188    }
189}