use crate::session::{KgliteSession, SessionState};
use crate::status::KgliteStatusCode;
use kglite::api::Embedder;
use std::sync::Arc;
#[cfg(feature = "fastembed")]
use crate::strings::alloc_c_string;
#[cfg(feature = "fastembed")]
use std::ffi::{c_char, CStr};
#[repr(C)]
pub struct KgliteEmbedder {
_opaque: [u8; 0],
_marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>,
}
pub(crate) struct EmbedderState {
pub(crate) inner: Arc<dyn Embedder>,
}
impl EmbedderState {
#[allow(dead_code)]
pub(crate) fn into_handle(inner: Arc<dyn Embedder>) -> *mut KgliteEmbedder {
let boxed = Box::new(EmbedderState { inner });
Box::into_raw(boxed).cast::<KgliteEmbedder>()
}
pub(crate) unsafe fn from_handle<'a>(handle: *const KgliteEmbedder) -> &'a EmbedderState {
unsafe { &*handle.cast::<EmbedderState>() }
}
unsafe fn free_handle(handle: *mut KgliteEmbedder) {
if handle.is_null() {
return;
}
let _ = unsafe { Box::from_raw(handle.cast::<EmbedderState>()) };
}
}
#[no_mangle]
pub unsafe extern "C" fn kglite_embedder_free(embedder: *mut KgliteEmbedder) {
unsafe { EmbedderState::free_handle(embedder) };
}
#[no_mangle]
pub unsafe extern "C" fn kglite_session_set_embedder(
session: *mut KgliteSession,
embedder: *const KgliteEmbedder,
) -> KgliteStatusCode {
if session.is_null() || embedder.is_null() {
return KgliteStatusCode::NullPointer;
}
let session_state = unsafe { SessionState::from_handle_mut(session) };
let embedder_state = unsafe { EmbedderState::from_handle(embedder) };
session_state.embedder = Some(Arc::clone(&embedder_state.inner));
KgliteStatusCode::Ok
}
#[cfg(feature = "fastembed")]
#[no_mangle]
pub unsafe extern "C" fn kglite_embedder_fastembed_new(
model_name: *const c_char,
out_embedder: *mut *mut KgliteEmbedder,
out_error_msg: *mut *const c_char,
) -> KgliteStatusCode {
if model_name.is_null() || out_embedder.is_null() {
return KgliteStatusCode::NullPointer;
}
let model_str = match unsafe { CStr::from_ptr(model_name) }.to_str() {
Ok(s) => s,
Err(_) => return KgliteStatusCode::InvalidUtf8,
};
match kglite::api::FastEmbedAdapter::new(model_str) {
Ok(adapter) => {
let arc: Arc<dyn Embedder> = Arc::new(adapter);
unsafe {
*out_embedder = EmbedderState::into_handle(arc);
}
if !out_error_msg.is_null() {
unsafe {
*out_error_msg = std::ptr::null();
}
}
KgliteStatusCode::Ok
}
Err(msg) => {
unsafe {
*out_embedder = std::ptr::null_mut();
}
if !out_error_msg.is_null() {
unsafe {
*out_error_msg = alloc_c_string(&msg);
}
}
KgliteStatusCode::InvalidArgument
}
}
}