use std::{ffi::CString, path::Path, ptr::NonNull, sync::Arc};
use litert_sys as sys;
use crate::{check, Error, Result};
#[derive(Clone)]
pub struct Model {
inner: Arc<ModelInner>,
}
impl std::fmt::Debug for Model {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Model")
.field("ptr", &self.inner.ptr.as_ptr())
.field("strong_count", &Arc::strong_count(&self.inner))
.finish()
}
}
struct ModelInner {
ptr: NonNull<sys::LiteRtModelT>,
_owned_bytes: Option<Box<[u8]>>,
}
unsafe impl Send for ModelInner {}
unsafe impl Sync for ModelInner {}
impl Model {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let path_str = path
.to_str()
.ok_or_else(|| Error::InvalidPath(path.to_path_buf()))?;
let cstr = CString::new(path_str).map_err(|_| Error::InvalidPath(path.to_path_buf()))?;
let mut raw: sys::LiteRtModel = std::ptr::null_mut();
check(unsafe { sys::LiteRtCreateModelFromFile(cstr.as_ptr(), &mut raw) })?;
let ptr = NonNull::new(raw).ok_or(Error::NullPointer)?;
Ok(Self {
inner: Arc::new(ModelInner {
ptr,
_owned_bytes: None,
}),
})
}
pub fn from_bytes(bytes: impl Into<Box<[u8]>>) -> Result<Self> {
let bytes: Box<[u8]> = bytes.into();
let mut raw: sys::LiteRtModel = std::ptr::null_mut();
check(unsafe {
sys::LiteRtCreateModelFromBuffer(bytes.as_ptr().cast(), bytes.len(), &mut raw)
})?;
let ptr = NonNull::new(raw).ok_or(Error::NullPointer)?;
Ok(Self {
inner: Arc::new(ModelInner {
ptr,
_owned_bytes: Some(bytes),
}),
})
}
pub fn signature_count(&self) -> Result<usize> {
let mut count: sys::LiteRtParamIndex = 0;
check(unsafe { sys::LiteRtGetNumModelSignatures(self.as_raw(), &mut count) })?;
Ok(count)
}
pub fn signature(&self, index: usize) -> Result<crate::Signature> {
crate::Signature::new(self.clone(), index)
}
pub(crate) fn as_raw(&self) -> sys::LiteRtModel {
self.inner.ptr.as_ptr()
}
}
impl Drop for ModelInner {
fn drop(&mut self) {
unsafe { sys::LiteRtDestroyModel(self.ptr.as_ptr()) }
}
}