use std::{
ffi::CString,
path::{Path, PathBuf},
ptr::NonNull,
sync::Arc,
};
use litert_lm_sys as sys;
use crate::{conversation::Conversation, Error, Result, SamplerParams, Session};
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum Backend {
Cpu,
#[default]
Gpu,
Npu,
}
impl Backend {
fn as_str(self) -> &'static str {
match self {
Self::Cpu => "CPU",
Self::Gpu => "GPU",
Self::Npu => "NPU",
}
}
}
pub struct EngineSettings {
model_path: PathBuf,
backend: Backend,
vision_backend: Option<Backend>,
audio_backend: Option<Backend>,
max_num_tokens: Option<i32>,
cache_dir: Option<PathBuf>,
}
impl EngineSettings {
pub fn new(model_path: impl Into<PathBuf>) -> Self {
Self {
model_path: model_path.into(),
backend: Backend::default(),
vision_backend: None,
audio_backend: None,
max_num_tokens: None,
cache_dir: None,
}
}
#[must_use]
pub fn backend(mut self, backend: Backend) -> Self {
self.backend = backend;
self
}
#[must_use]
pub fn vision_backend(mut self, backend: Backend) -> Self {
self.vision_backend = Some(backend);
self
}
#[must_use]
pub fn audio_backend(mut self, backend: Backend) -> Self {
self.audio_backend = Some(backend);
self
}
#[must_use]
pub fn max_num_tokens(mut self, n: i32) -> Self {
self.max_num_tokens = Some(n);
self
}
#[must_use]
pub fn cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.cache_dir = Some(dir.into());
self
}
}
pub struct Engine {
inner: Arc<EngineInner>,
}
pub(crate) struct EngineInner {
pub(crate) ptr: NonNull<sys::LiteRtLmEngine>,
}
unsafe impl Send for EngineInner {}
unsafe impl Sync for EngineInner {}
impl Engine {
pub fn new(settings: EngineSettings) -> Result<Self> {
unsafe { sys::litert_lm_set_min_log_level(3) }; std::env::set_var("TF_CPP_MIN_LOG_LEVEL", "2");
let model_str = path_to_cstring(&settings.model_path)?;
let backend_str = CString::new(settings.backend.as_str()).unwrap();
let vision_cstr = settings
.vision_backend
.map(|b| CString::new(b.as_str()).unwrap());
let audio_cstr = settings
.audio_backend
.map(|b| CString::new(b.as_str()).unwrap());
let raw_settings = unsafe {
sys::litert_lm_engine_settings_create(
model_str.as_ptr(),
backend_str.as_ptr(),
vision_cstr
.as_ref()
.map_or(std::ptr::null(), |s| s.as_ptr()),
audio_cstr.as_ref().map_or(std::ptr::null(), |s| s.as_ptr()),
)
};
if raw_settings.is_null() {
return Err(Error::NullPointer);
}
if let Some(n) = settings.max_num_tokens {
unsafe { sys::litert_lm_engine_settings_set_max_num_tokens(raw_settings, n) };
}
if let Some(ref dir) = settings.cache_dir {
let dir_str = path_to_cstring(dir)?;
unsafe { sys::litert_lm_engine_settings_set_cache_dir(raw_settings, dir_str.as_ptr()) };
}
let engine_ptr = unsafe { sys::litert_lm_engine_create(raw_settings) };
unsafe { sys::litert_lm_engine_settings_delete(raw_settings) };
let ptr = NonNull::new(engine_ptr).ok_or(Error::EngineCreationFailed)?;
Ok(Self {
inner: Arc::new(EngineInner { ptr }),
})
}
pub fn create_session(&self, params: SamplerParams) -> Result<Session> {
Session::new(self.inner.clone(), params)
}
pub fn create_conversation(&self, params: SamplerParams) -> Result<Conversation> {
Conversation::new(self.inner.clone(), params)
}
#[allow(dead_code)]
pub(crate) fn as_raw(&self) -> *mut sys::LiteRtLmEngine {
self.inner.ptr.as_ptr()
}
}
impl Drop for EngineInner {
fn drop(&mut self) {
unsafe { sys::litert_lm_engine_delete(self.ptr.as_ptr()) }
}
}
impl std::fmt::Debug for Engine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Engine")
.field("ptr", &self.inner.ptr.as_ptr())
.finish()
}
}
fn path_to_cstring(path: &Path) -> Result<CString> {
let s = path
.to_str()
.ok_or_else(|| Error::InvalidPath(path.to_path_buf()))?;
CString::new(s).map_err(|_| Error::InvalidPath(path.to_path_buf()))
}