use std::ffi::{CStr, CString};
use std::ptr::NonNull;
use std::slice;
use crate::context::LlamaContext;
use crate::model::LlamaModel;
use crate::token::LlamaToken;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum MtmdInputChunkType {
Text = fellhorn_llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT as _,
Image = fellhorn_llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE as _,
Audio = fellhorn_llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO as _,
}
impl From<fellhorn_llama_cpp_sys_2::mtmd_input_chunk_type> for MtmdInputChunkType {
fn from(chunk_type: fellhorn_llama_cpp_sys_2::mtmd_input_chunk_type) -> Self {
match chunk_type {
fellhorn_llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_TEXT => MtmdInputChunkType::Text,
fellhorn_llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_IMAGE => MtmdInputChunkType::Image,
fellhorn_llama_cpp_sys_2::MTMD_INPUT_CHUNK_TYPE_AUDIO => MtmdInputChunkType::Audio,
_ => panic!("Unknown MTMD input chunk type: {chunk_type}"),
}
}
}
#[derive(Debug, Clone)]
pub struct MtmdContextParams {
pub use_gpu: bool,
pub print_timings: bool,
pub n_threads: i32,
pub media_marker: CString,
}
impl Default for MtmdContextParams {
fn default() -> Self {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_context_params_default() }.into()
}
}
impl From<&MtmdContextParams> for fellhorn_llama_cpp_sys_2::mtmd_context_params {
fn from(params: &MtmdContextParams) -> Self {
let mut context = unsafe { fellhorn_llama_cpp_sys_2::mtmd_context_params_default() };
let MtmdContextParams {
use_gpu,
print_timings,
n_threads,
media_marker,
} = params;
context.use_gpu = *use_gpu;
context.print_timings = *print_timings;
context.n_threads = *n_threads;
context.media_marker = media_marker.as_ptr();
context
}
}
impl From<fellhorn_llama_cpp_sys_2::mtmd_context_params> for MtmdContextParams {
fn from(params: fellhorn_llama_cpp_sys_2::mtmd_context_params) -> Self {
Self {
use_gpu: params.use_gpu,
print_timings: params.print_timings,
n_threads: params.n_threads,
media_marker: unsafe { CStr::from_ptr(params.media_marker) }.to_owned(),
}
}
}
#[derive(Debug, Clone)]
pub struct MtmdInputText {
pub text: String,
pub add_special: bool,
pub parse_special: bool,
}
#[derive(Debug)]
pub struct MtmdContext {
pub(crate) context: NonNull<fellhorn_llama_cpp_sys_2::mtmd_context>,
}
impl MtmdContext {
pub fn init_from_file(
mmproj_path: &str,
text_model: &LlamaModel,
params: &MtmdContextParams,
) -> Result<Self, MtmdInitError> {
let path_cstr = CString::new(mmproj_path)?;
let ctx_params = fellhorn_llama_cpp_sys_2::mtmd_context_params::from(params);
let context = unsafe {
fellhorn_llama_cpp_sys_2::mtmd_init_from_file(
path_cstr.as_ptr(),
text_model.model.as_ptr(),
ctx_params,
)
};
let context = NonNull::new(context).ok_or(MtmdInitError::NullResult)?;
Ok(Self { context })
}
#[must_use]
pub fn decode_use_non_causal(&self) -> bool {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_decode_use_non_causal(self.context.as_ptr()) }
}
#[must_use]
pub fn decode_use_mrope(&self) -> bool {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_decode_use_mrope(self.context.as_ptr()) }
}
#[must_use]
pub fn support_vision(&self) -> bool {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_support_vision(self.context.as_ptr()) }
}
#[must_use]
pub fn support_audio(&self) -> bool {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_support_audio(self.context.as_ptr()) }
}
#[must_use]
pub fn get_audio_bitrate(&self) -> Option<u32> {
let rate = unsafe { fellhorn_llama_cpp_sys_2::mtmd_get_audio_bitrate(self.context.as_ptr()) };
(rate > 0).then_some(rate.unsigned_abs())
}
pub fn tokenize(
&self,
text: MtmdInputText,
bitmaps: &[&MtmdBitmap],
) -> Result<MtmdInputChunks, MtmdTokenizeError> {
let chunks = MtmdInputChunks::new();
let text_cstring = CString::new(text.text)?;
let input_text = fellhorn_llama_cpp_sys_2::mtmd_input_text {
text: text_cstring.as_ptr(),
add_special: text.add_special,
parse_special: text.parse_special,
};
let bitmap_ptrs: Vec<*const fellhorn_llama_cpp_sys_2::mtmd_bitmap> = bitmaps
.iter()
.map(|b| b.bitmap.as_ptr().cast_const())
.collect();
let result = unsafe {
fellhorn_llama_cpp_sys_2::mtmd_tokenize(
self.context.as_ptr(),
chunks.chunks.as_ptr(),
&raw const input_text,
bitmap_ptrs.as_ptr().cast_mut(),
bitmaps.len(),
)
};
match result {
0 => Ok(chunks),
1 => Err(MtmdTokenizeError::BitmapCountMismatch),
2 => Err(MtmdTokenizeError::ImagePreprocessingError),
_ => Err(MtmdTokenizeError::UnknownError(result)),
}
}
pub fn encode_chunk(&self, chunk: &MtmdInputChunk) -> Result<(), MtmdEncodeError> {
let result = unsafe {
fellhorn_llama_cpp_sys_2::mtmd_encode_chunk(self.context.as_ptr(), chunk.chunk.as_ptr())
};
if result == 0 {
Ok(())
} else {
Err(MtmdEncodeError::EncodeFailure(result))
}
}
}
impl Drop for MtmdContext {
fn drop(&mut self) {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_free(self.context.as_ptr()) }
}
}
#[derive(Debug, Clone)]
pub struct MtmdBitmap {
pub(crate) bitmap: NonNull<fellhorn_llama_cpp_sys_2::mtmd_bitmap>,
}
impl MtmdBitmap {
pub fn from_image_data(nx: u32, ny: u32, data: &[u8]) -> Result<Self, MtmdBitmapError> {
if data.len() != (nx * ny * 3) as usize {
return Err(MtmdBitmapError::InvalidDataSize);
}
let bitmap = unsafe { fellhorn_llama_cpp_sys_2::mtmd_bitmap_init(nx, ny, data.as_ptr()) };
let bitmap = NonNull::new(bitmap).ok_or(MtmdBitmapError::NullResult)?;
Ok(Self { bitmap })
}
pub fn from_audio_data(data: &[f32]) -> Result<Self, MtmdBitmapError> {
let bitmap =
unsafe { fellhorn_llama_cpp_sys_2::mtmd_bitmap_init_from_audio(data.len(), data.as_ptr()) };
let bitmap = NonNull::new(bitmap).ok_or(MtmdBitmapError::NullResult)?;
Ok(Self { bitmap })
}
pub fn from_file(ctx: &MtmdContext, path: &str) -> Result<Self, MtmdBitmapError> {
let path_cstr = CString::new(path)?;
let bitmap = unsafe {
fellhorn_llama_cpp_sys_2::mtmd_helper_bitmap_init_from_file(
ctx.context.as_ptr(),
path_cstr.as_ptr(),
)
};
let bitmap = NonNull::new(bitmap).ok_or(MtmdBitmapError::NullResult)?;
Ok(Self { bitmap })
}
pub fn from_buffer(ctx: &MtmdContext, data: &[u8]) -> Result<Self, MtmdBitmapError> {
let bitmap = unsafe {
fellhorn_llama_cpp_sys_2::mtmd_helper_bitmap_init_from_buf(
ctx.context.as_ptr(),
data.as_ptr(),
data.len(),
)
};
let bitmap = NonNull::new(bitmap).ok_or(MtmdBitmapError::NullResult)?;
Ok(Self { bitmap })
}
#[must_use]
pub fn nx(&self) -> u32 {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_bitmap_get_nx(self.bitmap.as_ptr()) }
}
#[must_use]
pub fn ny(&self) -> u32 {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_bitmap_get_ny(self.bitmap.as_ptr()) }
}
#[must_use]
pub fn data(&self) -> &[u8] {
let ptr = unsafe { fellhorn_llama_cpp_sys_2::mtmd_bitmap_get_data(self.bitmap.as_ptr()) };
let len = unsafe { fellhorn_llama_cpp_sys_2::mtmd_bitmap_get_n_bytes(self.bitmap.as_ptr()) };
unsafe { slice::from_raw_parts(ptr, len) }
}
#[must_use]
pub fn is_audio(&self) -> bool {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_bitmap_is_audio(self.bitmap.as_ptr()) }
}
#[must_use]
pub fn id(&self) -> Option<String> {
let ptr = unsafe { fellhorn_llama_cpp_sys_2::mtmd_bitmap_get_id(self.bitmap.as_ptr()) };
if ptr.is_null() {
None
} else {
let id = unsafe { CStr::from_ptr(ptr) }
.to_string_lossy()
.into_owned();
Some(id)
}
}
pub fn set_id(&self, id: &str) -> Result<(), std::ffi::NulError> {
let id_cstr = CString::new(id)?;
unsafe {
fellhorn_llama_cpp_sys_2::mtmd_bitmap_set_id(self.bitmap.as_ptr(), id_cstr.as_ptr());
}
Ok(())
}
}
impl Drop for MtmdBitmap {
fn drop(&mut self) {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_bitmap_free(self.bitmap.as_ptr()) }
}
}
#[derive(Debug)]
pub struct MtmdInputChunks {
pub(crate) chunks: NonNull<fellhorn_llama_cpp_sys_2::mtmd_input_chunks>,
}
impl Default for MtmdInputChunks {
fn default() -> Self {
Self::new()
}
}
impl MtmdInputChunks {
#[must_use]
pub fn new() -> Self {
let chunks = unsafe { fellhorn_llama_cpp_sys_2::mtmd_input_chunks_init() };
let chunks = NonNull::new(chunks).unwrap();
Self { chunks }
}
#[must_use]
pub fn len(&self) -> usize {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_input_chunks_size(self.chunks.as_ptr()) }
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn get(&self, index: usize) -> Option<MtmdInputChunk> {
if index >= self.len() {
return None;
}
let chunk_ptr =
unsafe { fellhorn_llama_cpp_sys_2::mtmd_input_chunks_get(self.chunks.as_ptr(), index) };
NonNull::new(chunk_ptr.cast_mut()).map(|ptr| MtmdInputChunk {
chunk: ptr,
owned: false,
})
}
#[must_use]
pub fn total_tokens(&self) -> usize {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_helper_get_n_tokens(self.chunks.as_ptr()) }
}
#[must_use]
pub fn total_positions(&self) -> i32 {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_helper_get_n_pos(self.chunks.as_ptr()) }
}
pub fn eval_chunks(
&self,
mtmd_ctx: &MtmdContext,
llama_ctx: &LlamaContext,
n_past: fellhorn_llama_cpp_sys_2::llama_pos,
seq_id: fellhorn_llama_cpp_sys_2::llama_seq_id,
n_batch: i32,
logits_last: bool,
) -> Result<fellhorn_llama_cpp_sys_2::llama_pos, MtmdEvalError> {
let mut new_n_past: fellhorn_llama_cpp_sys_2::llama_pos = 0;
let result = unsafe {
fellhorn_llama_cpp_sys_2::mtmd_helper_eval_chunks(
mtmd_ctx.context.as_ptr(),
llama_ctx.context.as_ptr(),
self.chunks.as_ptr(),
n_past,
seq_id,
n_batch,
logits_last,
&raw mut new_n_past,
)
};
if result == 0 {
Ok(new_n_past)
} else {
Err(MtmdEvalError::EvalFailure(result))
}
}
}
impl Drop for MtmdInputChunks {
fn drop(&mut self) {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_input_chunks_free(self.chunks.as_ptr()) }
}
}
#[derive(Debug)]
pub struct MtmdInputChunk {
pub(crate) chunk: NonNull<fellhorn_llama_cpp_sys_2::mtmd_input_chunk>,
owned: bool,
}
impl MtmdInputChunk {
#[must_use]
pub fn chunk_type(&self) -> MtmdInputChunkType {
let chunk_type = unsafe { fellhorn_llama_cpp_sys_2::mtmd_input_chunk_get_type(self.chunk.as_ptr()) };
MtmdInputChunkType::from(chunk_type)
}
#[must_use]
pub fn text_tokens(&self) -> Option<&[LlamaToken]> {
if self.chunk_type() != MtmdInputChunkType::Text {
return None;
}
let mut n_tokens = 0usize;
let tokens_ptr = unsafe {
fellhorn_llama_cpp_sys_2::mtmd_input_chunk_get_tokens_text(
self.chunk.as_ptr(),
&raw mut n_tokens,
)
};
if tokens_ptr.is_null() || n_tokens == 0 {
None
} else {
unsafe {
Some(slice::from_raw_parts(
tokens_ptr.cast::<LlamaToken>(),
n_tokens,
))
}
}
}
#[must_use]
pub fn n_tokens(&self) -> usize {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_input_chunk_get_n_tokens(self.chunk.as_ptr()) }
}
#[must_use]
pub fn n_positions(&self) -> i32 {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_input_chunk_get_n_pos(self.chunk.as_ptr()) }
}
#[must_use]
pub fn id(&self) -> Option<String> {
let ptr = unsafe { fellhorn_llama_cpp_sys_2::mtmd_input_chunk_get_id(self.chunk.as_ptr()) };
if ptr.is_null() {
None
} else {
unsafe { CStr::from_ptr(ptr) }
.to_string_lossy()
.into_owned()
.into()
}
}
pub fn copy(&self) -> Result<Self, MtmdInputChunkError> {
let chunk = unsafe { fellhorn_llama_cpp_sys_2::mtmd_input_chunk_copy(self.chunk.as_ptr()) };
let chunk = NonNull::new(chunk).ok_or(MtmdInputChunkError::NullResult)?;
Ok(Self { chunk, owned: true })
}
}
impl Drop for MtmdInputChunk {
fn drop(&mut self) {
if self.owned {
unsafe { fellhorn_llama_cpp_sys_2::mtmd_input_chunk_free(self.chunk.as_ptr()) }
}
}
}
#[must_use]
pub fn mtmd_default_marker() -> &'static str {
unsafe {
let c_str = fellhorn_llama_cpp_sys_2::mtmd_default_marker();
CStr::from_ptr(c_str).to_str().unwrap_or("<__media__>")
}
}
#[derive(thiserror::Error, Debug)]
pub enum MtmdInitError {
#[error("Failed to create CString: {0}")]
CStringError(#[from] std::ffi::NulError),
#[error("MTMD context initialization returned null")]
NullResult,
}
#[derive(thiserror::Error, Debug)]
pub enum MtmdBitmapError {
#[error("Failed to create CString: {0}")]
CStringError(#[from] std::ffi::NulError),
#[error("Invalid data size for bitmap")]
InvalidDataSize,
#[error("Bitmap creation returned null")]
NullResult,
}
#[derive(thiserror::Error, Debug)]
pub enum MtmdInputChunksError {
#[error("Input chunks creation returned null")]
NullResult,
}
#[derive(thiserror::Error, Debug)]
pub enum MtmdInputChunkError {
#[error("Input chunk operation returned null")]
NullResult,
}
#[derive(thiserror::Error, Debug)]
pub enum MtmdTokenizeError {
#[error("Number of bitmaps does not match number of markers")]
BitmapCountMismatch,
#[error("Image preprocessing error")]
ImagePreprocessingError,
#[error("Failed to create CString from text: {0}")]
CStringError(#[from] std::ffi::NulError),
#[error("Unknown error: {0}")]
UnknownError(i32),
}
#[derive(thiserror::Error, Debug)]
pub enum MtmdEncodeError {
#[error("Encode failed with code: {0}")]
EncodeFailure(i32),
}
#[derive(thiserror::Error, Debug)]
pub enum MtmdEvalError {
#[error("Eval failed with code: {0}")]
EvalFailure(i32),
}