use std::ffi::{CStr, CString};
use std::path::Path;
use std::ptr::NonNull;
use std::slice;
use llama_cpp_sys_4 as sys;
use crate::model::LlamaModel;
#[derive(Debug, thiserror::Error)]
pub enum MtmdError {
#[error("failed to create mtmd context (null return from mtmd_init_from_file)")]
ContextCreateFailed,
#[error("failed to create mtmd bitmap")]
BitmapCreateFailed,
#[error("invalid path: {0}")]
InvalidPath(#[from] std::ffi::NulError),
#[error("path is not valid UTF-8")]
PathNotUtf8,
#[error("tokenize error: code {0} (1 = bitmap count mismatch, 2 = preprocessing error)")]
TokenizeError(i32),
#[error("encode error: code {0}")]
EncodeError(i32),
#[error("eval error: code {0}")]
EvalError(i32),
}
pub type Result<T> = std::result::Result<T, MtmdError>;
pub struct MtmdContextParams {
pub(crate) params: sys::mtmd_context_params,
}
impl std::fmt::Debug for MtmdContextParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MtmdContextParams")
.field("use_gpu", &self.params.use_gpu)
.field("print_timings", &self.params.print_timings)
.field("n_threads", &self.params.n_threads)
.field("warmup", &self.params.warmup)
.field("image_min_tokens", &self.params.image_min_tokens)
.field("image_max_tokens", &self.params.image_max_tokens)
.finish()
}
}
impl Default for MtmdContextParams {
fn default() -> Self {
let params = unsafe { sys::mtmd_context_params_default() };
Self { params }
}
}
impl MtmdContextParams {
#[must_use]
pub fn use_gpu(mut self, v: bool) -> Self {
self.params.use_gpu = v;
self
}
#[must_use]
pub fn print_timings(mut self, v: bool) -> Self {
self.params.print_timings = v;
self
}
#[must_use]
pub fn n_threads(mut self, n: i32) -> Self {
self.params.n_threads = n;
self
}
#[must_use]
pub fn warmup(mut self, v: bool) -> Self {
self.params.warmup = v;
self
}
#[must_use]
pub fn image_min_tokens(mut self, n: i32) -> Self {
self.params.image_min_tokens = n;
self
}
#[must_use]
pub fn image_max_tokens(mut self, n: i32) -> Self {
self.params.image_max_tokens = n;
self
}
pub fn media_marker(mut self, marker: Option<&str>) -> std::result::Result<Self, MtmdError> {
match marker {
None => {
self.params.media_marker = std::ptr::null();
Ok(self)
}
Some(s) => {
let cs = CString::new(s)?;
self.params.media_marker = cs.as_ptr();
std::mem::forget(cs);
Ok(self)
}
}
}
}
pub struct MtmdContext {
ptr: NonNull<sys::mtmd_context>,
}
unsafe impl Send for MtmdContext {}
unsafe impl Sync for MtmdContext {}
impl std::fmt::Debug for MtmdContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MtmdContext")
.field("ptr", &self.ptr)
.finish()
}
}
impl Drop for MtmdContext {
fn drop(&mut self) {
unsafe { sys::mtmd_free(self.ptr.as_ptr()) }
}
}
impl MtmdContext {
#[must_use]
pub fn default_marker() -> &'static str {
let ptr = unsafe { sys::mtmd_default_marker() };
unsafe { CStr::from_ptr(ptr) }
.to_str()
.unwrap_or("<__media__>")
}
#[allow(clippy::needless_pass_by_value)]
pub fn init_from_file(
mmproj_path: impl AsRef<Path>,
text_model: &LlamaModel,
params: MtmdContextParams,
) -> Result<Self> {
let path = mmproj_path
.as_ref()
.to_str()
.ok_or(MtmdError::PathNotUtf8)?;
let c_path = CString::new(path)?;
let ptr = unsafe {
sys::mtmd_init_from_file(c_path.as_ptr(), text_model.model.as_ptr(), params.params)
};
let ptr = NonNull::new(ptr).ok_or(MtmdError::ContextCreateFailed)?;
Ok(Self { ptr })
}
pub fn void_logs() {
unsafe extern "C" fn noop(
_level: sys::ggml_log_level,
_text: *const ::std::os::raw::c_char,
_ud: *mut ::std::os::raw::c_void,
) {
}
unsafe { sys::mtmd_log_set(Some(noop), std::ptr::null_mut()) };
}
#[must_use]
pub fn supports_vision(&self) -> bool {
unsafe { sys::mtmd_support_vision(self.ptr.as_ptr()) }
}
#[must_use]
pub fn supports_audio(&self) -> bool {
unsafe { sys::mtmd_support_audio(self.ptr.as_ptr()) }
}
#[must_use]
#[deprecated(note = "use audio_sample_rate() instead")]
pub fn audio_bitrate(&self) -> i32 {
self.audio_sample_rate()
}
#[must_use]
pub fn audio_sample_rate(&self) -> i32 {
unsafe { sys::mtmd_get_audio_sample_rate(self.ptr.as_ptr()) }
}
#[must_use]
pub fn decode_use_non_causal(&self, chunk: &MtmdInputChunk<'_>) -> bool {
unsafe { sys::mtmd_decode_use_non_causal(self.ptr.as_ptr(), chunk.as_ptr()) }
}
#[must_use]
pub fn decode_use_mrope(&self) -> bool {
unsafe { sys::mtmd_decode_use_mrope(self.ptr.as_ptr()) }
}
pub fn tokenize(
&self,
text: &MtmdInputText<'_>,
bitmaps: &[&MtmdBitmap],
output: &mut MtmdInputChunks,
) -> Result<()> {
let mut bitmap_ptrs: Vec<*const sys::mtmd_bitmap> = bitmaps
.iter()
.map(|b| b.ptr.as_ptr().cast_const())
.collect();
let c_text = sys::mtmd_input_text {
text: text.c_text.as_ptr(),
add_special: text.add_special,
parse_special: text.parse_special,
};
let ret = unsafe {
sys::mtmd_tokenize(
self.ptr.as_ptr(),
output.ptr.as_ptr(),
&raw const c_text,
bitmap_ptrs.as_mut_ptr(),
bitmap_ptrs.len(),
)
};
if ret != 0 {
return Err(MtmdError::TokenizeError(ret));
}
Ok(())
}
pub fn encode_chunk(&self, chunk: &MtmdInputChunk<'_>) -> Result<()> {
let ret = unsafe { sys::mtmd_encode_chunk(self.ptr.as_ptr(), chunk.ptr) };
if ret != 0 {
return Err(MtmdError::EncodeError(ret));
}
Ok(())
}
#[must_use]
pub fn output_embd(&self, n_elements: usize) -> &[f32] {
let ptr = unsafe { sys::mtmd_get_output_embd(self.ptr.as_ptr()) };
if ptr.is_null() || n_elements == 0 {
return &[];
}
unsafe { slice::from_raw_parts(ptr, n_elements) }
}
#[allow(clippy::too_many_arguments, clippy::not_unsafe_ptr_arg_deref)]
pub fn eval_chunks(
&self,
lctx: *mut sys::llama_context,
chunks: &MtmdInputChunks,
n_past: i32,
seq_id: i32,
n_batch: i32,
logits_last: bool,
new_n_past: &mut i32,
) -> Result<()> {
let ret = unsafe {
sys::mtmd_helper_eval_chunks(
self.ptr.as_ptr(),
lctx,
chunks.ptr.as_ptr(),
n_past,
seq_id,
n_batch,
logits_last,
new_n_past,
)
};
if ret != 0 {
return Err(MtmdError::EvalError(ret));
}
Ok(())
}
#[allow(clippy::too_many_arguments, clippy::not_unsafe_ptr_arg_deref)]
pub fn eval_chunk_single(
&self,
lctx: *mut sys::llama_context,
chunk: &MtmdInputChunk<'_>,
n_past: i32,
seq_id: i32,
n_batch: i32,
logits_last: bool,
new_n_past: &mut i32,
) -> Result<()> {
let ret = unsafe {
sys::mtmd_helper_eval_chunk_single(
self.ptr.as_ptr(),
lctx,
chunk.ptr,
n_past,
seq_id,
n_batch,
logits_last,
new_n_past,
)
};
if ret != 0 {
return Err(MtmdError::EvalError(ret));
}
Ok(())
}
#[must_use]
pub fn as_ptr(&self) -> *mut sys::mtmd_context {
self.ptr.as_ptr()
}
}
#[derive(Debug)]
pub struct MtmdInputText<'a> {
c_text: CString,
add_special: bool,
parse_special: bool,
_marker: std::marker::PhantomData<&'a ()>,
}
impl<'a> MtmdInputText<'a> {
#[must_use]
pub fn new(text: &'a str, add_special: bool, parse_special: bool) -> Self {
let c_text = CString::new(text).expect("MtmdInputText: text must not contain NUL bytes");
Self {
c_text,
add_special,
parse_special,
_marker: std::marker::PhantomData,
}
}
pub fn try_new(
text: &'a str,
add_special: bool,
parse_special: bool,
) -> std::result::Result<Self, std::ffi::NulError> {
let c_text = CString::new(text)?;
Ok(Self {
c_text,
add_special,
parse_special,
_marker: std::marker::PhantomData,
})
}
}
pub struct MtmdBitmap {
ptr: NonNull<sys::mtmd_bitmap>,
}
unsafe impl Send for MtmdBitmap {}
unsafe impl Sync for MtmdBitmap {}
impl std::fmt::Debug for MtmdBitmap {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MtmdBitmap")
.field("nx", &self.nx())
.field("ny", &self.ny())
.field("n_bytes", &self.n_bytes())
.field("is_audio", &self.is_audio())
.finish()
}
}
impl Drop for MtmdBitmap {
fn drop(&mut self) {
unsafe { sys::mtmd_bitmap_free(self.ptr.as_ptr()) }
}
}
impl MtmdBitmap {
pub fn from_rgb(nx: u32, ny: u32, data: &[u8]) -> Result<Self> {
let ptr = unsafe { sys::mtmd_bitmap_init(nx, ny, data.as_ptr()) };
let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
Ok(Self { ptr })
}
pub fn from_audio(samples: &[f32]) -> Result<Self> {
let ptr = unsafe { sys::mtmd_bitmap_init_from_audio(samples.len(), samples.as_ptr()) };
let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
Ok(Self { ptr })
}
pub fn from_file(ctx: &MtmdContext, path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref().to_str().ok_or(MtmdError::PathNotUtf8)?;
let c_path = CString::new(path)?;
let ptr =
unsafe { sys::mtmd_helper_bitmap_init_from_file(ctx.ptr.as_ptr(), c_path.as_ptr()) };
let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
Ok(Self { ptr })
}
pub fn from_buf(ctx: &MtmdContext, buf: &[u8]) -> Result<Self> {
let ptr = unsafe {
sys::mtmd_helper_bitmap_init_from_buf(ctx.ptr.as_ptr(), buf.as_ptr(), buf.len())
};
let ptr = NonNull::new(ptr).ok_or(MtmdError::BitmapCreateFailed)?;
Ok(Self { ptr })
}
#[must_use]
pub fn nx(&self) -> u32 {
unsafe { sys::mtmd_bitmap_get_nx(self.ptr.as_ptr()) }
}
#[must_use]
pub fn ny(&self) -> u32 {
unsafe { sys::mtmd_bitmap_get_ny(self.ptr.as_ptr()) }
}
#[must_use]
pub fn n_bytes(&self) -> usize {
unsafe { sys::mtmd_bitmap_get_n_bytes(self.ptr.as_ptr()) }
}
#[must_use]
pub fn is_audio(&self) -> bool {
unsafe { sys::mtmd_bitmap_is_audio(self.ptr.as_ptr()) }
}
#[must_use]
pub fn data(&self) -> &[u8] {
let n = self.n_bytes();
if n == 0 {
return &[];
}
let ptr = unsafe { sys::mtmd_bitmap_get_data(self.ptr.as_ptr()) };
unsafe { slice::from_raw_parts(ptr, n) }
}
#[must_use]
pub fn id(&self) -> Option<&str> {
let ptr = unsafe { sys::mtmd_bitmap_get_id(self.ptr.as_ptr()) };
if ptr.is_null() {
return None;
}
unsafe { CStr::from_ptr(ptr) }.to_str().ok()
}
pub fn set_id(&mut self, id: &str) -> std::result::Result<(), std::ffi::NulError> {
let cs = CString::new(id)?;
unsafe { sys::mtmd_bitmap_set_id(self.ptr.as_ptr(), cs.as_ptr()) };
Ok(())
}
}
pub struct MtmdInputChunks {
ptr: NonNull<sys::mtmd_input_chunks>,
}
impl std::fmt::Debug for MtmdInputChunks {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MtmdInputChunks")
.field("len", &self.len())
.finish()
}
}
impl Drop for MtmdInputChunks {
fn drop(&mut self) {
unsafe { sys::mtmd_input_chunks_free(self.ptr.as_ptr()) }
}
}
impl MtmdInputChunks {
#[must_use]
pub fn new() -> Self {
let ptr = unsafe { sys::mtmd_input_chunks_init() };
let ptr = NonNull::new(ptr).expect("mtmd_input_chunks_init returned null");
Self { ptr }
}
#[must_use]
pub fn len(&self) -> usize {
unsafe { sys::mtmd_input_chunks_size(self.ptr.as_ptr()) }
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn get(&self, idx: usize) -> Option<MtmdInputChunk<'_>> {
if idx >= self.len() {
return None;
}
let ptr = unsafe { sys::mtmd_input_chunks_get(self.ptr.as_ptr(), idx) };
if ptr.is_null() {
return None;
}
Some(MtmdInputChunk {
ptr,
_marker: std::marker::PhantomData,
})
}
pub fn iter(&self) -> impl Iterator<Item = MtmdInputChunk<'_>> {
(0..self.len()).filter_map(|i| self.get(i))
}
#[must_use]
pub fn n_tokens(&self) -> usize {
unsafe { sys::mtmd_helper_get_n_tokens(self.ptr.as_ptr()) }
}
#[must_use]
pub fn n_pos(&self) -> i32 {
unsafe { sys::mtmd_helper_get_n_pos(self.ptr.as_ptr()) }
}
}
impl Default for MtmdInputChunks {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MtmdInputChunkType {
Text,
Image,
Audio,
}
impl From<sys::mtmd_input_chunk_type> for MtmdInputChunkType {
fn from(v: sys::mtmd_input_chunk_type) -> Self {
if v == sys::MTMD_INPUT_CHUNK_TYPE_IMAGE {
Self::Image
} else if v == sys::MTMD_INPUT_CHUNK_TYPE_AUDIO {
Self::Audio
} else {
Self::Text
}
}
}
#[derive(Debug)]
pub struct MtmdInputChunk<'chunks> {
ptr: *const sys::mtmd_input_chunk,
_marker: std::marker::PhantomData<&'chunks MtmdInputChunks>,
}
impl<'chunks> MtmdInputChunk<'chunks> {
#[must_use]
pub fn chunk_type(&self) -> MtmdInputChunkType {
let t = unsafe { sys::mtmd_input_chunk_get_type(self.ptr) };
MtmdInputChunkType::from(t)
}
#[must_use]
pub fn n_tokens(&self) -> usize {
unsafe { sys::mtmd_input_chunk_get_n_tokens(self.ptr) }
}
#[must_use]
pub fn n_pos(&self) -> i32 {
unsafe { sys::mtmd_input_chunk_get_n_pos(self.ptr) }
}
#[must_use]
pub fn text_tokens(&self) -> Option<&[i32]> {
if self.chunk_type() != MtmdInputChunkType::Text {
return None;
}
let mut n: usize = 0;
let ptr = unsafe { sys::mtmd_input_chunk_get_tokens_text(self.ptr, &raw mut n) };
if ptr.is_null() || n == 0 {
return Some(&[]);
}
Some(unsafe { slice::from_raw_parts(ptr, n) })
}
#[must_use]
pub fn image_tokens(&self) -> Option<MtmdImageTokens<'chunks>> {
match self.chunk_type() {
MtmdInputChunkType::Image | MtmdInputChunkType::Audio => {}
MtmdInputChunkType::Text => return None,
}
let ptr = unsafe { sys::mtmd_input_chunk_get_tokens_image(self.ptr) };
if ptr.is_null() {
return None;
}
Some(MtmdImageTokens {
ptr,
_marker: std::marker::PhantomData,
})
}
#[must_use]
pub fn id(&self) -> Option<&str> {
let ptr = unsafe { sys::mtmd_input_chunk_get_id(self.ptr) };
if ptr.is_null() {
return None;
}
unsafe { CStr::from_ptr(ptr) }.to_str().ok()
}
#[must_use]
pub fn as_ptr(&self) -> *const sys::mtmd_input_chunk {
self.ptr
}
}
#[derive(Debug)]
pub struct MtmdImageTokens<'chunks> {
ptr: *const sys::mtmd_image_tokens,
_marker: std::marker::PhantomData<&'chunks MtmdInputChunks>,
}
impl MtmdImageTokens<'_> {
#[must_use]
pub fn n_tokens(&self) -> usize {
unsafe { sys::mtmd_image_tokens_get_n_tokens(self.ptr) }
}
#[must_use]
pub fn nx(&self) -> usize {
unsafe { sys::mtmd_image_tokens_get_nx(self.ptr) }
}
#[must_use]
pub fn ny(&self) -> usize {
unsafe { sys::mtmd_image_tokens_get_ny(self.ptr) }
}
#[must_use]
pub fn n_pos(&self) -> i32 {
unsafe { sys::mtmd_image_tokens_get_n_pos(self.ptr) }
}
#[must_use]
pub fn id(&self) -> Option<&str> {
let ptr = unsafe { sys::mtmd_image_tokens_get_id(self.ptr) };
if ptr.is_null() {
return None;
}
unsafe { CStr::from_ptr(ptr) }.to_str().ok()
}
}
use crate::context::LlamaContext;
impl LlamaContext<'_> {
#[must_use]
pub fn as_ptr(&self) -> *mut sys::llama_context {
self.context.as_ptr()
}
}