use std::path::Path;
use std::ptr::NonNull;
use llama_crab_sys as sys;
use super::bitmap::{MtmdBitmap, MtmdInputText};
use super::chunks::{MtmdInputChunk, MtmdInputChunks};
use crate::error::{LlamaError, Result};
use crate::model::LlamaModel;
#[derive(Debug, Clone)]
pub struct MtmdContextParams {
pub use_gpu: bool,
pub print_timings: bool,
pub n_threads: i32,
}
impl Default for MtmdContextParams {
fn default() -> Self {
Self {
use_gpu: true,
print_timings: false,
n_threads: 1,
}
}
}
impl MtmdContextParams {
fn to_c(&self) -> sys::mtmd_context_params {
let mut p = unsafe { sys::mtmd_context_params_default() };
p.use_gpu = self.use_gpu;
p.print_timings = self.print_timings;
p.n_threads = self.n_threads;
p
}
}
#[derive(Debug)]
pub struct MtmdContext {
pub(crate) handle: NonNull<sys::mtmd_context>,
}
impl MtmdContext {
pub fn init_from_file(
mmproj_path: impl AsRef<Path>,
text_model: &LlamaModel,
) -> Result<Self> {
Self::init_from_file_with(mmproj_path, text_model, MtmdContextParams::default())
}
pub fn init_from_file_with(
mmproj_path: impl AsRef<Path>,
text_model: &LlamaModel,
params: MtmdContextParams,
) -> Result<Self> {
let cpath = std::ffi::CString::new(mmproj_path.as_ref().display().to_string())?;
let handle = unsafe {
sys::mtmd_init_from_file(cpath.as_ptr(), text_model.raw(), params.to_c())
};
NonNull::new(handle)
.map(|handle| Self { handle })
.ok_or_else(|| {
LlamaError::ModelLoad(format!(
"mtmd_init_from_file({}) failed",
mmproj_path.as_ref().display()
))
})
}
#[must_use]
pub fn decode_use_non_causal(&self, chunk: &MtmdInputChunk) -> bool {
unsafe { sys::mtmd_decode_use_non_causal(self.handle.as_ptr(), chunk.as_ptr()) }
}
#[must_use]
pub fn decode_use_mrope(&self) -> bool {
unsafe { sys::mtmd_decode_use_mrope(self.handle.as_ptr()) }
}
#[must_use]
pub fn support_vision(&self) -> bool {
unsafe { sys::mtmd_support_vision(self.handle.as_ptr()) }
}
#[must_use]
pub fn support_audio(&self) -> bool {
unsafe { sys::mtmd_support_audio(self.handle.as_ptr()) }
}
#[must_use]
pub fn audio_sample_rate(&self) -> i32 {
unsafe { sys::mtmd_get_audio_sample_rate(self.handle.as_ptr()) }
}
pub fn tokenize(
&self,
text: MtmdInputText<'_>,
bitmaps: &[&MtmdBitmap],
) -> Result<MtmdInputChunks> {
let c_text = text.into_c();
let mut bitmap_ptrs: Vec<*const sys::mtmd_bitmap> =
bitmaps.iter().map(|b| b.as_ptr_const()).collect();
let mut chunks = MtmdInputChunks::new()?;
let rc = unsafe {
sys::mtmd_tokenize(
self.handle.as_ptr(),
chunks.handle.as_ptr(),
&c_text,
bitmap_ptrs.as_mut_ptr(),
bitmap_ptrs.len(),
)
};
if rc != 0 {
return Err(LlamaError::Batch(format!("mtmd_tokenize: {rc}")));
}
Ok(chunks)
}
}
impl Drop for MtmdContext {
fn drop(&mut self) {
unsafe { sys::mtmd_free(self.handle.as_ptr()) };
}
}
unsafe impl Send for MtmdContext {}