use crate::{api, check, sys, Result, RunInput};
use std::ffi::{c_char, c_void, CString};
use std::ptr;
pub fn model_editor_api() -> Option<&'static sys::ModelEditorApi> {
let p = unsafe { api().get_model_editor_api()() };
(!p.is_null()).then(|| unsafe { &*p })
}
pub fn compile_api() -> Option<&'static sys::CompileApi> {
let p = unsafe { api().get_compile_api()() };
(!p.is_null()).then(|| unsafe { &*p })
}
pub fn ep_api() -> Option<&'static sys::EpApi> {
let p = unsafe { api().get_ep_api()() };
(!p.is_null()).then(|| unsafe { &*p })
}
pub fn interop_api() -> Option<&'static sys::InteropApi> {
let p = unsafe { api().get_interop_api()() };
(!p.is_null()).then(|| unsafe { &*p })
}
fn me() -> Result<&'static sys::ModelEditorApi> {
model_editor_api().ok_or_else(|| crate::Error::new(-1, "ModelEditorApi unavailable"))
}
fn ca() -> Result<&'static sys::CompileApi> {
compile_api().ok_or_else(|| crate::Error::new(-1, "CompileApi unavailable"))
}
pub(crate) fn require_sub_api_fn<T: Copy>(
f: Option<T>, api_name: &str, function_name: &str,
) -> Result<T> {
f.ok_or_else(|| crate::Error::new(-1, format!("{api_name}.{function_name} unavailable")))
}
pub struct TypeInfo {
ptr: *mut sys::TypeInfoHandle,
}
impl TypeInfo {
pub fn tensor(info: &crate::TensorTypeAndShapeInfo) -> Result<Self> {
let create = require_sub_api_fn(
me()?.CreateTensorTypeInfo,
"ModelEditorApi",
"CreateTensorTypeInfo",
)?;
let mut p: *mut sys::TypeInfoHandle = ptr::null_mut();
check(unsafe { create(info.as_ptr(), &mut p) })?;
let p = crate::ensure_non_null(p, "model-editor type info")?;
Ok(Self { ptr: p })
}
pub fn sparse_tensor(info: &crate::TensorTypeAndShapeInfo) -> Result<Self> {
let create = require_sub_api_fn(
me()?.CreateSparseTensorTypeInfo,
"ModelEditorApi",
"CreateSparseTensorTypeInfo",
)?;
let mut p: *mut sys::TypeInfoHandle = ptr::null_mut();
check(unsafe { create(info.as_ptr(), &mut p) })?;
let p = crate::ensure_non_null(p, "model-editor type info")?;
Ok(Self { ptr: p })
}
pub fn map(key_type: sys::ElementType, value_type: &TypeInfo) -> Result<Self> {
let create = require_sub_api_fn(
me()?.CreateMapTypeInfo,
"ModelEditorApi",
"CreateMapTypeInfo",
)?;
let mut p: *mut sys::TypeInfoHandle = ptr::null_mut();
check(unsafe { create(key_type, value_type.ptr, &mut p) })?;
let p = crate::ensure_non_null(p, "model-editor type info")?;
Ok(Self { ptr: p })
}
pub fn sequence(element: &TypeInfo) -> Result<Self> {
let create = require_sub_api_fn(
me()?.CreateSequenceTypeInfo,
"ModelEditorApi",
"CreateSequenceTypeInfo",
)?;
let mut p: *mut sys::TypeInfoHandle = ptr::null_mut();
check(unsafe { create(element.ptr, &mut p) })?;
let p = crate::ensure_non_null(p, "model-editor type info")?;
Ok(Self { ptr: p })
}
pub fn optional(contained: &TypeInfo) -> Result<Self> {
let create = require_sub_api_fn(
me()?.CreateOptionalTypeInfo,
"ModelEditorApi",
"CreateOptionalTypeInfo",
)?;
let mut p: *mut sys::TypeInfoHandle = ptr::null_mut();
check(unsafe { create(contained.ptr, &mut p) })?;
let p = crate::ensure_non_null(p, "model-editor type info")?;
Ok(Self { ptr: p })
}
pub(crate) fn as_ptr(&self) -> *const sys::TypeInfoHandle {
self.ptr
}
}
impl Drop for TypeInfo {
fn drop(&mut self) {
unsafe { api().release_type_info()(self.ptr) }
}
}
pub struct ValueInfo {
ptr: *mut sys::ValueInfoHandle,
}
impl ValueInfo {
pub fn new(name: &str, ty: &TypeInfo) -> Result<Self> {
let create =
require_sub_api_fn(me()?.CreateValueInfo, "ModelEditorApi", "CreateValueInfo")?;
let cname = CString::new(name)?;
let mut p: *mut sys::ValueInfoHandle = ptr::null_mut();
check(unsafe { create(cname.as_ptr(), ty.as_ptr(), &mut p) })?;
let p = crate::ensure_non_null(p, "model-editor value info")?;
Ok(Self { ptr: p })
}
}
impl Drop for ValueInfo {
fn drop(&mut self) {
unsafe { api().release_value_info()(self.ptr) }
}
}
pub struct Node {
ptr: *mut sys::NodeHandle,
}
pub struct NodeAttr {
ptr: *mut sys::OpAttrHandle,
}
impl NodeAttr {
fn new(name: &str, data: &[u8], len: usize, ty: sys::OpAttrType) -> Result<Self> {
let name = CString::new(name)?;
let len = i32::try_from(len)
.map_err(|_| crate::Error::new(-1, "model-editor attribute length overflows i32"))?;
let mut out: *mut sys::OpAttrHandle = ptr::null_mut();
check(unsafe {
api().create_op_attr()(
name.as_ptr(),
data.as_ptr() as *const c_void,
len,
ty,
&mut out,
)
})?;
let out = crate::ensure_non_null(out, "model-editor op attribute")?;
Ok(Self { ptr: out })
}
pub fn int(name: &str, value: i64) -> Result<Self> {
Self::new(
name,
value.to_ne_bytes().as_slice(),
1,
sys::OpAttrType::Int,
)
}
pub fn float(name: &str, value: f32) -> Result<Self> {
Self::new(
name,
value.to_ne_bytes().as_slice(),
1,
sys::OpAttrType::Float,
)
}
pub fn ints(name: &str, values: &[i64]) -> Result<Self> {
Self::new(
name,
unsafe {
std::slice::from_raw_parts(
values.as_ptr() as *const u8,
std::mem::size_of_val(values),
)
},
values.len(),
sys::OpAttrType::Ints,
)
}
pub(crate) fn as_ptr(&self) -> *mut sys::OpAttrHandle {
self.ptr
}
}
impl Drop for NodeAttr {
fn drop(&mut self) {
unsafe { api().release_op_attr()(self.ptr) }
}
}
impl Node {
pub fn new(
op: &str, domain: &str, name: &str, inputs: &[&str], outputs: &[&str],
) -> Result<Self> {
Self::with_attributes(op, domain, name, inputs, outputs, Vec::new())
}
pub fn with_attributes(
op: &str, domain: &str, name: &str, inputs: &[&str], outputs: &[&str],
attributes: Vec<NodeAttr>,
) -> Result<Self> {
let cop = CString::new(op)?;
let cdom = CString::new(domain)?;
let cname = CString::new(name)?;
let in_c: Vec<CString> = inputs
.iter()
.map(|s| CString::new(*s))
.collect::<std::result::Result<_, _>>()?;
let out_c: Vec<CString> = outputs
.iter()
.map(|s| CString::new(*s))
.collect::<std::result::Result<_, _>>()?;
let in_p: Vec<*const c_char> = in_c.iter().map(|c| c.as_ptr()).collect();
let out_p: Vec<*const c_char> = out_c.iter().map(|c| c.as_ptr()).collect();
let mut attr_p: Vec<*mut sys::OpAttrHandle> =
attributes.iter().map(|attr| attr.as_ptr()).collect();
let attr_ptr = if attr_p.is_empty() {
ptr::null_mut()
} else {
attr_p.as_mut_ptr()
};
let create = require_sub_api_fn(me()?.CreateNode, "ModelEditorApi", "CreateNode")?;
let mut p: *mut sys::NodeHandle = ptr::null_mut();
check(unsafe {
create(
cop.as_ptr(),
cdom.as_ptr(),
cname.as_ptr(),
in_p.as_ptr(),
in_p.len(),
out_p.as_ptr(),
out_p.len(),
attr_ptr,
attr_p.len(),
&mut p,
)
})?;
let p = crate::ensure_non_null(p, "model-editor node")?;
for attr in attributes {
std::mem::forget(attr); }
Ok(Self { ptr: p })
}
}
impl Drop for Node {
fn drop(&mut self) {
unsafe { api().release_node()(self.ptr) }
}
}
pub struct Graph {
ptr: *mut sys::GraphHandle,
}
impl Graph {
pub fn new() -> Result<Self> {
let create = require_sub_api_fn(me()?.CreateGraph, "ModelEditorApi", "CreateGraph")?;
let mut p: *mut sys::GraphHandle = ptr::null_mut();
check(unsafe { create(&mut p) })?;
let p = crate::ensure_non_null(p, "model-editor graph")?;
Ok(Self { ptr: p })
}
pub fn set_inputs(&self, inputs: Vec<ValueInfo>) -> Result<()> {
let set_inputs =
require_sub_api_fn(me()?.SetGraphInputs, "ModelEditorApi", "SetGraphInputs")?;
let mut ptrs: Vec<*mut sys::ValueInfoHandle> = inputs.iter().map(|v| v.ptr).collect();
check(unsafe { set_inputs(self.ptr, ptrs.as_mut_ptr(), ptrs.len()) })?;
for v in inputs {
std::mem::forget(v); }
Ok(())
}
pub fn set_outputs(&self, outputs: Vec<ValueInfo>) -> Result<()> {
let set_outputs =
require_sub_api_fn(me()?.SetGraphOutputs, "ModelEditorApi", "SetGraphOutputs")?;
let mut ptrs: Vec<*mut sys::ValueInfoHandle> = outputs.iter().map(|v| v.ptr).collect();
check(unsafe { set_outputs(self.ptr, ptrs.as_mut_ptr(), ptrs.len()) })?;
for v in outputs {
std::mem::forget(v);
}
Ok(())
}
pub fn add_node(&self, node: Node) -> Result<()> {
let add = require_sub_api_fn(me()?.AddNodeToGraph, "ModelEditorApi", "AddNodeToGraph")?;
check(unsafe { add(self.ptr, node.ptr) })?;
std::mem::forget(node);
Ok(())
}
pub fn add_initializer(
&self, name: &str, tensor: crate::Tensor<'_>, data_is_external: bool,
) -> Result<()> {
let add = require_sub_api_fn(
me()?.AddInitializerToGraph,
"ModelEditorApi",
"AddInitializerToGraph",
)?;
let cname = CString::new(name)?;
let value = tensor.as_value_ptr() as *mut sys::ValueHandle;
check(unsafe { add(self.ptr, cname.as_ptr(), value, data_is_external) })?;
std::mem::forget(tensor); Ok(())
}
}
impl Drop for Graph {
fn drop(&mut self) {
unsafe { api().release_graph()(self.ptr) }
}
}
pub struct Model {
ptr: *mut sys::ModelHandle,
}
impl Model {
pub fn new(opsets: &[(&str, i32)]) -> Result<Self> {
let doms: Vec<CString> = opsets
.iter()
.map(|(d, _)| CString::new(*d))
.collect::<std::result::Result<_, _>>()?;
let vers: Vec<i32> = opsets.iter().map(|(_, v)| *v).collect();
let dom_p: Vec<*const c_char> = doms.iter().map(|c| c.as_ptr()).collect();
let create = require_sub_api_fn(me()?.CreateModel, "ModelEditorApi", "CreateModel")?;
let mut p: *mut sys::ModelHandle = ptr::null_mut();
check(unsafe { create(dom_p.as_ptr(), vers.as_ptr(), opsets.len(), &mut p) })?;
let p = crate::ensure_non_null(p, "model-editor model")?;
Ok(Self { ptr: p })
}
pub fn add_graph(&self, graph: Graph) -> Result<()> {
let add = require_sub_api_fn(me()?.AddGraphToModel, "ModelEditorApi", "AddGraphToModel")?;
check(unsafe { add(self.ptr, graph.ptr) })?;
std::mem::forget(graph);
Ok(())
}
pub fn to_bytes(
&self, env: &crate::Environment, opts: &crate::SessionOptions,
) -> Result<Vec<u8>> {
let ca = compile_api().ok_or_else(|| crate::Error::new(-1, "CompileApi unavailable"))?;
let create_options = require_sub_api_fn(
ca.CreateModelCompilationOptionsFromSessionOptions,
"CompileApi",
"CreateModelCompilationOptionsFromSessionOptions",
)?;
let opts_handle = opts.build_handle()?;
let mut copts: *mut sys::ModelCompilationOptionsHandle = ptr::null_mut();
let build = check(unsafe {
create_options(
env.as_ptr(),
opts_handle as *const sys::SessionOptionsHandle,
&mut copts,
)
});
unsafe { api().release_session_options()(opts_handle) };
build?;
let copts = crate::ensure_non_null(copts, "model compilation options")?;
let outcome: Result<Vec<u8>> = (|| {
let set_input = require_sub_api_fn(
ca.ModelCompilationOptions_SetInputModel,
"CompileApi",
"ModelCompilationOptions_SetInputModel",
)?;
check(unsafe { set_input(copts, self.as_ptr()) })?;
let alloc = crate::allocator::Allocator::get_default()?;
let mut buf_ptr: *mut c_void = ptr::null_mut();
let mut buf_size: usize = 0;
let set_output_buffer = require_sub_api_fn(
ca.ModelCompilationOptions_SetOutputModelBuffer,
"CompileApi",
"ModelCompilationOptions_SetOutputModelBuffer",
)?;
check(unsafe { set_output_buffer(copts, alloc.alloc, &mut buf_ptr, &mut buf_size) })?;
let compile = require_sub_api_fn(ca.CompileModel, "CompileApi", "CompileModel")?;
let compile = check(unsafe { compile(env.as_ptr(), copts) });
let bytes = if buf_ptr.is_null() || buf_size == 0 {
Vec::new()
} else {
unsafe { std::slice::from_raw_parts(buf_ptr as *const u8, buf_size).to_vec() }
};
let free = if buf_ptr.is_null() {
Ok(())
} else {
unsafe { alloc.free(buf_ptr) }
};
match (compile, free) {
(Ok(()), Ok(())) => Ok(bytes),
(Err(err), _) => Err(err),
(Ok(()), Err(err)) => Err(err),
}
})();
if let Some(release) = ca.ReleaseModelCompilationOptions {
unsafe { release(copts) };
}
outcome
}
pub fn to_file(
&self, env: &crate::Environment, opts: &crate::SessionOptions, path: &str,
) -> Result<()> {
let copts = ModelCompilationOptions::new(env, opts)?;
copts.set_input_model(self)?;
copts.set_output_model_path(path)?;
copts.compile(env)
}
pub(crate) fn as_ptr(&self) -> *const sys::ModelHandle {
self.ptr
}
}
impl Drop for Model {
fn drop(&mut self) {
unsafe { api().release_model()(self.ptr) }
}
}
pub struct ModelCompilationOptions {
ptr: *mut sys::ModelCompilationOptionsHandle,
}
impl ModelCompilationOptions {
pub fn new(env: &crate::Environment, opts: &crate::SessionOptions) -> Result<Self> {
let ca = ca()?;
let create_options = require_sub_api_fn(
ca.CreateModelCompilationOptionsFromSessionOptions,
"CompileApi",
"CreateModelCompilationOptionsFromSessionOptions",
)?;
let opts_handle = opts.build_handle()?;
let mut p: *mut sys::ModelCompilationOptionsHandle = ptr::null_mut();
let r = check(unsafe {
create_options(
env.as_ptr(),
opts_handle as *const sys::SessionOptionsHandle,
&mut p,
)
});
unsafe { api().release_session_options()(opts_handle) };
r?;
let p = crate::ensure_non_null(p, "model compilation options")?;
Ok(Self { ptr: p })
}
pub fn set_input_model(&self, model: &Model) -> Result<()> {
let set_input = require_sub_api_fn(
ca()?.ModelCompilationOptions_SetInputModel,
"CompileApi",
"ModelCompilationOptions_SetInputModel",
)?;
check(unsafe { set_input(self.ptr, model.as_ptr()) })
}
pub fn set_input_model_from_buffer(&self, bytes: &[u8]) -> Result<()> {
let set_input = require_sub_api_fn(
ca()?.ModelCompilationOptions_SetInputModelFromBuffer,
"CompileApi",
"ModelCompilationOptions_SetInputModelFromBuffer",
)?;
check(unsafe { set_input(self.ptr, bytes.as_ptr() as *const c_void, bytes.len()) })
}
pub fn set_input_model_path(&self, path: &str) -> Result<()> {
let set_input = require_sub_api_fn(
ca()?.ModelCompilationOptions_SetInputModelPath,
"CompileApi",
"ModelCompilationOptions_SetInputModelPath",
)?;
let c = CString::new(path)?;
check(unsafe { set_input(self.ptr, c.as_ptr()) })
}
pub fn set_output_model_path(&self, path: &str) -> Result<()> {
let set_output = require_sub_api_fn(
ca()?.ModelCompilationOptions_SetOutputModelPath,
"CompileApi",
"ModelCompilationOptions_SetOutputModelPath",
)?;
let c = CString::new(path)?;
check(unsafe { set_output(self.ptr, c.as_ptr()) })
}
pub fn set_output_model_external_initializers_file(
&self, path: &str, threshold: usize,
) -> Result<()> {
let set_output = require_sub_api_fn(
ca()?.ModelCompilationOptions_SetOutputModelExternalInitializersFile,
"CompileApi",
"ModelCompilationOptions_SetOutputModelExternalInitializersFile",
)?;
let c = CString::new(path)?;
check(unsafe { set_output(self.ptr, c.as_ptr(), threshold) })
}
pub fn set_ep_context_embed_mode(&self, embed: bool) -> Result<()> {
let set_mode = require_sub_api_fn(
ca()?.ModelCompilationOptions_SetEpContextEmbedMode,
"CompileApi",
"ModelCompilationOptions_SetEpContextEmbedMode",
)?;
check(unsafe { set_mode(self.ptr, embed) })
}
pub fn set_flags(&self, flags: u32) -> Result<()> {
let set_flags = require_sub_api_fn(
ca()?.ModelCompilationOptions_SetFlags,
"CompileApi",
"ModelCompilationOptions_SetFlags",
)?;
check(unsafe { set_flags(self.ptr, flags) })
}
pub fn set_ep_context_binary_information(
&self, directory: &str, model_name: &str,
) -> Result<()> {
let set_info = require_sub_api_fn(
ca()?.ModelCompilationOptions_SetEpContextBinaryInformation,
"CompileApi",
"ModelCompilationOptions_SetEpContextBinaryInformation",
)?;
let d = CString::new(directory)?;
let m = CString::new(model_name)?;
check(unsafe { set_info(self.ptr, d.as_ptr(), m.as_ptr()) })
}
pub fn set_graph_optimization_level(&self, level: sys::GraphOptimizationLevel) -> Result<()> {
let set_level = require_sub_api_fn(
ca()?.ModelCompilationOptions_SetGraphOptimizationLevel,
"CompileApi",
"ModelCompilationOptions_SetGraphOptimizationLevel",
)?;
check(unsafe { set_level(self.ptr, level) })
}
pub fn compile(&self, env: &crate::Environment) -> Result<()> {
let compile = require_sub_api_fn(ca()?.CompileModel, "CompileApi", "CompileModel")?;
check(unsafe { compile(env.as_ptr(), self.ptr) })
}
}
impl Drop for ModelCompilationOptions {
fn drop(&mut self) {
if let Some(ca) = compile_api() {
if let Some(release) = ca.ReleaseModelCompilationOptions {
unsafe { release(self.ptr) };
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sub_apis_available() {
assert!(model_editor_api().is_some(), "ModelEditorApi missing");
assert!(compile_api().is_some(), "CompileApi missing");
assert!(ep_api().is_some(), "EpApi missing");
assert!(interop_api().is_some(), "InteropApi missing");
eprintln!("all four sub-APIs available via the safe accessors");
}
}