use std::ptr::NonNull;
use litert_sys as sys;
use crate::{check, CompilationOptions, Environment, Error, Model, Result, TensorBuffer};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(transparent)]
pub struct SignatureIndex(pub sys::LiteRtParamIndex);
impl SignatureIndex {
pub const DEFAULT: Self = Self(0);
}
impl Default for SignatureIndex {
fn default() -> Self {
Self::DEFAULT
}
}
pub struct CompiledModel {
ptr: NonNull<sys::LiteRtCompiledModelT>,
_env: Environment,
_model: Model,
}
impl std::fmt::Debug for CompiledModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompiledModel")
.field("ptr", &self.ptr.as_ptr())
.finish()
}
}
impl CompiledModel {
pub fn new(env: Environment, model: Model, options: &CompilationOptions) -> Result<Self> {
let mut raw: sys::LiteRtCompiledModel = std::ptr::null_mut();
check(unsafe {
sys::LiteRtCreateCompiledModel(env.as_raw(), model.as_raw(), options.as_raw(), &mut raw)
})?;
let ptr = NonNull::new(raw).ok_or(Error::NullPointer)?;
Ok(Self {
ptr,
_env: env,
_model: model,
})
}
pub fn is_fully_accelerated(&self) -> Result<bool> {
let mut out: bool = false;
check(unsafe { sys::LiteRtCompiledModelIsFullyAccelerated(self.ptr.as_ptr(), &mut out) })?;
Ok(out)
}
pub fn run(&self, inputs: &mut [TensorBuffer], outputs: &mut [TensorBuffer]) -> Result<()> {
self.run_signature(SignatureIndex::DEFAULT, inputs, outputs)
}
pub fn run_signature(
&self,
signature: SignatureIndex,
inputs: &mut [TensorBuffer],
outputs: &mut [TensorBuffer],
) -> Result<()> {
let mut in_raw: Vec<sys::LiteRtTensorBuffer> =
inputs.iter().map(TensorBuffer::as_raw).collect();
let mut out_raw: Vec<sys::LiteRtTensorBuffer> =
outputs.iter().map(TensorBuffer::as_raw).collect();
check(unsafe {
sys::LiteRtRunCompiledModel(
self.ptr.as_ptr(),
signature.0,
in_raw.len(),
in_raw.as_mut_ptr(),
out_raw.len(),
out_raw.as_mut_ptr(),
)
})
}
}
impl Drop for CompiledModel {
fn drop(&mut self) {
unsafe { sys::LiteRtDestroyCompiledModel(self.ptr.as_ptr()) }
}
}
unsafe impl Send for CompiledModel {}