1use core::marker::PhantomData;
7
8use crate::{Arena, Device, Instance, Result, Tensor, check};
9use iree_embedded_sys as sys;
10
11#[derive(Clone, Copy)]
14pub struct Function {
15 raw: sys::iree_vm_function_t,
16}
17
18pub struct Context<'i> {
21 raw: *mut sys::iree_vm_context_t,
22 _instance: PhantomData<&'i Instance>,
23}
24
25impl<'i> Context<'i> {
26 pub fn new(
29 instance: &'i Instance,
30 device: &Device,
31 vmfb: &'static [u8],
32 arena: &Arena,
33 ) -> Result<Self> {
34 let alloc = arena.as_iree_allocator();
35 unsafe {
37 let mut group = core::ptr::null_mut();
40 check(sys::iree_hal_device_group_create_from_device(
41 device.raw(),
42 alloc,
43 &mut group,
44 ))?;
45 let mut hal_module = core::ptr::null_mut();
46 let status = sys::iree_hal_module_create(
47 instance.raw(),
48 sys::iree_hal_module_device_policy_default(),
49 group,
50 sys::IREE_HAL_MODULE_FLAG_SYNCHRONOUS as _,
51 sys::iree_hal_module_debug_sink_null(),
52 alloc,
53 &mut hal_module,
54 );
55 sys::iree_hal_device_group_release(group);
56 check(status)?;
57
58 let mut bytecode = core::ptr::null_mut();
60 let bc = sys::iree_vm_bytecode_module_create(
61 instance.raw(),
62 sys::IREE_VM_BYTECODE_MODULE_FLAG_NONE as _,
63 sys::iree_make_const_byte_span(vmfb.as_ptr() as *const _, vmfb.len()),
64 sys::iree_allocator_null(),
65 alloc,
66 &mut bytecode,
67 );
68 if !bc.is_null() {
69 sys::iree_vm_module_release(hal_module);
70 check(bc)?;
71 }
72
73 let mut modules = [hal_module, bytecode];
74 let mut raw = core::ptr::null_mut();
75 let ctx = sys::iree_vm_context_create_with_modules(
76 instance.raw(),
77 sys::IREE_VM_CONTEXT_FLAG_NONE as _,
78 modules.len() as sys::iree_host_size_t,
79 modules.as_mut_ptr(),
80 alloc,
81 &mut raw,
82 );
83 sys::iree_vm_module_release(hal_module);
84 sys::iree_vm_module_release(bytecode);
85 check(ctx)?;
86 Ok(Context {
87 raw,
88 _instance: PhantomData,
89 })
90 }
91 }
92
93 pub fn resolve(&self, name: &str) -> Result<Function> {
96 let mut raw: sys::iree_vm_function_t = unsafe { core::mem::zeroed() };
97 unsafe {
99 check(sys::iree_vm_context_resolve_function(
100 self.raw,
101 sys::iree_string_view_t {
102 data: name.as_ptr() as *const _,
103 size: name.len(),
104 },
105 &mut raw,
106 ))?;
107 }
108 Ok(Function { raw })
109 }
110
111 pub fn invoke(
114 &self,
115 function: Function,
116 inputs: &[&Tensor],
117 arena: &Arena,
118 ) -> Result<heapless::Vec<Tensor, 8>> {
119 let alloc = arena.as_iree_allocator();
120 unsafe {
122 let mut in_list = core::ptr::null_mut();
123 check(sys::iree_vm_list_create(
124 sys::iree_vm_make_undefined_type_def(),
125 inputs.len() as sys::iree_host_size_t,
126 alloc,
127 &mut in_list,
128 ))?;
129 for t in inputs {
130 let mut r = sys::iree_hal_buffer_view_retain_ref(t.raw());
132 let st = sys::iree_vm_list_push_ref_move(in_list, &mut r);
133 if !st.is_null() {
134 sys::iree_vm_ref_release(&mut r);
135 sys::iree_vm_list_release(in_list);
136 check(st)?;
137 }
138 }
139
140 let mut out_list = core::ptr::null_mut();
141 let oc = sys::iree_vm_list_create(
142 sys::iree_vm_make_undefined_type_def(),
143 8,
144 alloc,
145 &mut out_list,
146 );
147 if !oc.is_null() {
148 sys::iree_vm_list_release(in_list);
149 check(oc)?;
150 }
151
152 let status = sys::iree_vm_invoke(
153 self.raw,
154 function.raw,
155 sys::IREE_VM_INVOCATION_FLAG_NONE as _,
156 core::ptr::null(),
157 in_list,
158 out_list,
159 alloc,
160 );
161 sys::iree_vm_list_release(in_list);
162 if !status.is_null() {
163 sys::iree_vm_list_release(out_list);
164 check(status)?;
165 }
166
167 let count = sys::iree_vm_list_size(out_list);
168 let mut results: heapless::Vec<Tensor, 8> = heapless::Vec::new();
169 for i in 0..count {
170 let mut r: sys::iree_vm_ref_t = core::mem::zeroed();
171 if sys::iree_vm_list_get_ref_retain(out_list, i, &mut r).is_null() {
174 let bv = sys::iree_hal_buffer_view_deref(r);
175 let _ = results.push(Tensor::from_raw(bv));
176 }
177 }
178 sys::iree_vm_list_release(out_list);
179 Ok(results)
180 }
181 }
182}
183
184impl<'i> Drop for Context<'i> {
185 fn drop(&mut self) {
186 unsafe { sys::iree_vm_context_release(self.raw) };
188 }
189}