#![allow(unsafe_code)]
use crate::mtmd_ffi as ffi;
use std::ffi::{CStr, CString};
use std::path::Path;
use std::ptr::NonNull;
#[derive(Debug, thiserror::Error)]
pub enum MtmdError {
#[error("mtmd_init_from_file returned null (mmproj load failed)")]
InitFromFile,
#[error("path contains NUL byte: {0}")]
PathNul(std::path::PathBuf),
#[error("mtmd_bitmap_init returned null")]
BitmapAlloc,
#[error("mtmd_tokenize failed: code={0}")]
Tokenize(i32),
#[error("mtmd_encode_chunk failed: code={0}")]
EncodeChunk(i32),
#[error("mtmd internal: {0}")]
Internal(&'static str),
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MmprojCaps {
pub vision: bool,
pub audio: bool,
}
pub fn probe_mmproj_caps(mmproj: &Path) -> Result<MmprojCaps, MtmdError> {
let cpath = path_to_cstring(mmproj)?;
let raw = unsafe { ffi::mtmd_get_cap_from_file(cpath.as_ptr()) };
Ok(MmprojCaps {
vision: raw.inp_vision,
audio: raw.inp_audio,
})
}
pub struct Mtmd {
ctx: NonNull<ffi::mtmd_context>,
}
unsafe impl Send for Mtmd {}
unsafe impl Sync for Mtmd {}
#[derive(Debug, Clone)]
pub struct MtmdConfig {
pub use_gpu: bool,
pub print_timings: bool,
pub n_threads: Option<i32>,
pub warmup: bool,
}
impl Default for MtmdConfig {
fn default() -> Self {
Self {
use_gpu: true,
print_timings: false,
n_threads: None,
warmup: true,
}
}
}
impl Mtmd {
pub unsafe fn new(
mmproj: &Path,
text_model: *const crate::ffi::llama_model,
config: MtmdConfig,
) -> Result<Self, MtmdError> {
let cpath = path_to_cstring(mmproj)?;
let mut params = unsafe { ffi::mtmd_context_params_default() };
params.use_gpu = config.use_gpu;
params.print_timings = config.print_timings;
if let Some(n) = config.n_threads {
params.n_threads = n;
}
params.warmup = config.warmup;
let raw = unsafe { ffi::mtmd_init_from_file(cpath.as_ptr(), text_model.cast(), params) };
let ctx = NonNull::new(raw).ok_or(MtmdError::InitFromFile)?;
Ok(Self { ctx })
}
pub fn supports_vision(&self) -> bool {
unsafe { ffi::mtmd_support_vision(self.ctx.as_ptr()) }
}
pub fn supports_audio(&self) -> bool {
unsafe { ffi::mtmd_support_audio(self.ctx.as_ptr()) }
}
pub fn audio_sample_rate(&self) -> Option<u32> {
let n = unsafe { ffi::mtmd_get_audio_sample_rate(self.ctx.as_ptr()) };
if n <= 0 { None } else { Some(n as u32) }
}
pub unsafe fn eval_chunks(
&self,
lctx: *mut crate::ffi::llama_context,
chunks: &InputChunks,
n_past: i32,
seq_id: i32,
n_batch: i32,
logits_last: bool,
) -> Result<i32, MtmdError> {
let mut new_n_past: i32 = 0;
let rc = unsafe {
ffi::mtmd_helper_eval_chunks(
self.ctx.as_ptr(),
lctx.cast(),
chunks.raw(),
n_past,
seq_id,
n_batch,
logits_last,
&mut new_n_past,
)
};
if rc != 0 {
return Err(MtmdError::EncodeChunk(rc));
}
Ok(new_n_past)
}
pub fn tokenize(&self, text: &str, bitmaps: &[&Bitmap]) -> Result<InputChunks, MtmdError> {
let c_text = CString::new(text).map_err(|_| MtmdError::Internal("text contains NUL"))?;
let in_text = ffi::mtmd_input_text {
text: c_text.as_ptr(),
add_special: true,
parse_special: true,
};
let mut bitmap_ptrs: Vec<*const ffi::mtmd_bitmap> =
bitmaps.iter().map(|b| b.raw() as *const _).collect();
let chunks = InputChunks::new()?;
let rc = unsafe {
ffi::mtmd_tokenize(
self.ctx.as_ptr(),
chunks.raw(),
&in_text,
bitmap_ptrs.as_mut_ptr(),
bitmap_ptrs.len(),
)
};
if rc != 0 {
return Err(MtmdError::Tokenize(rc));
}
Ok(chunks)
}
}
impl Drop for Mtmd {
fn drop(&mut self) {
unsafe { ffi::mtmd_free(self.ctx.as_ptr()) };
}
}
pub struct Bitmap {
ptr: NonNull<ffi::mtmd_bitmap>,
}
unsafe impl Send for Bitmap {}
unsafe impl Sync for Bitmap {}
impl Bitmap {
pub fn from_image_rgb(width: u32, height: u32, rgb: &[u8]) -> Result<Self, MtmdError> {
let expected = (width as usize) * (height as usize) * 3;
if rgb.len() != expected {
return Err(MtmdError::Internal("rgb slice length != width*height*3"));
}
let raw = unsafe { ffi::mtmd_bitmap_init(width, height, rgb.as_ptr()) };
let ptr = NonNull::new(raw).ok_or(MtmdError::BitmapAlloc)?;
Ok(Self { ptr })
}
pub fn from_audio_f32(samples: &[f32]) -> Result<Self, MtmdError> {
let raw = unsafe { ffi::mtmd_bitmap_init_from_audio(samples.len(), samples.as_ptr()) };
let ptr = NonNull::new(raw).ok_or(MtmdError::BitmapAlloc)?;
Ok(Self { ptr })
}
pub fn is_audio(&self) -> bool {
unsafe { ffi::mtmd_bitmap_is_audio(self.ptr.as_ptr()) }
}
pub fn set_id(&mut self, id: &str) -> Result<(), MtmdError> {
let cid = CString::new(id).map_err(|_| MtmdError::Internal("id contains NUL"))?;
unsafe { ffi::mtmd_bitmap_set_id(self.ptr.as_ptr(), cid.as_ptr()) };
Ok(())
}
pub(crate) fn raw(&self) -> *mut ffi::mtmd_bitmap {
self.ptr.as_ptr()
}
}
impl Drop for Bitmap {
fn drop(&mut self) {
unsafe { ffi::mtmd_bitmap_free(self.ptr.as_ptr()) };
}
}
pub struct InputChunks {
ptr: NonNull<ffi::mtmd_input_chunks>,
}
unsafe impl Send for InputChunks {}
unsafe impl Sync for InputChunks {}
impl InputChunks {
fn new() -> Result<Self, MtmdError> {
let raw = unsafe { ffi::mtmd_input_chunks_init() };
let ptr = NonNull::new(raw).ok_or(MtmdError::Internal("mtmd_input_chunks_init"))?;
Ok(Self { ptr })
}
pub fn len(&self) -> usize {
unsafe { ffi::mtmd_input_chunks_size(self.ptr.as_ptr()) }
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn get(&self, i: usize) -> Option<InputChunk<'_>> {
if i >= self.len() {
return None;
}
let raw = unsafe { ffi::mtmd_input_chunks_get(self.ptr.as_ptr(), i) };
NonNull::new(raw as *mut _).map(|ptr| InputChunk {
ptr,
_marker: std::marker::PhantomData,
})
}
pub(crate) fn raw(&self) -> *mut ffi::mtmd_input_chunks {
self.ptr.as_ptr()
}
}
impl Drop for InputChunks {
fn drop(&mut self) {
unsafe { ffi::mtmd_input_chunks_free(self.ptr.as_ptr()) };
}
}
pub struct InputChunk<'a> {
ptr: NonNull<ffi::mtmd_input_chunk>,
_marker: std::marker::PhantomData<&'a InputChunks>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InputChunkKind {
Text,
Image,
Audio,
}
impl InputChunk<'_> {
pub fn kind(&self) -> InputChunkKind {
let raw = unsafe { ffi::mtmd_input_chunk_get_type(self.ptr.as_ptr()) };
match raw {
ffi::MTMD_INPUT_CHUNK_TYPE_TEXT => InputChunkKind::Text,
ffi::MTMD_INPUT_CHUNK_TYPE_IMAGE => InputChunkKind::Image,
ffi::MTMD_INPUT_CHUNK_TYPE_AUDIO => InputChunkKind::Audio,
_ => InputChunkKind::Text,
}
}
pub fn n_tokens(&self) -> usize {
unsafe { ffi::mtmd_input_chunk_get_n_tokens(self.ptr.as_ptr()) }
}
pub fn n_pos(&self) -> i32 {
unsafe { ffi::mtmd_input_chunk_get_n_pos(self.ptr.as_ptr()) }
}
#[allow(dead_code)]
pub(crate) fn raw(&self) -> *const ffi::mtmd_input_chunk {
self.ptr.as_ptr()
}
pub fn id(&self) -> Option<&str> {
let raw = unsafe { ffi::mtmd_input_chunk_get_id(self.ptr.as_ptr()) };
if raw.is_null() {
return None;
}
unsafe { CStr::from_ptr(raw).to_str().ok() }
}
}
fn path_to_cstring(p: &Path) -> Result<CString, MtmdError> {
CString::new(p.as_os_str().to_string_lossy().as_bytes())
.map_err(|_| MtmdError::PathNul(p.to_path_buf()))
}
pub fn default_media_marker() -> &'static str {
"<__media__>"
}