use std::ffi::{CString, c_char, c_void};
use std::sync::Arc;
use async_trait::async_trait;
use blazen_uniffi::errors::{BlazenError as InnerError, BlazenResult};
use blazen_uniffi::llm::{CompletionModel as InnerCompletionModel, TokenUsage as InnerTokenUsage};
use blazen_uniffi::streaming::{
CompletionStreamSink, StreamChunk as InnerStreamChunk, complete_streaming,
complete_streaming_blocking,
};
use crate::error::BlazenError;
use crate::future::BlazenFuture;
use crate::llm::BlazenCompletionModel;
use crate::llm_records::{BlazenCompletionRequest, BlazenTokenUsage};
use crate::streaming_records::BlazenStreamChunk;
unsafe fn write_error(out_err: *mut *mut BlazenError, e: InnerError) -> i32 {
if !out_err.is_null() {
unsafe {
*out_err = BlazenError::from(e).into_ptr();
}
}
-1
}
unsafe fn write_internal_error(out_err: *mut *mut BlazenError, msg: &str) -> i32 {
unsafe {
write_error(
out_err,
InnerError::Internal {
message: msg.into(),
},
)
}
}
#[repr(C)]
pub struct BlazenCompletionStreamSinkVTable {
pub user_data: *mut c_void,
pub drop_user_data: extern "C" fn(user_data: *mut c_void),
pub on_chunk: extern "C" fn(
user_data: *mut c_void,
chunk: *mut BlazenStreamChunk,
out_err: *mut *mut BlazenError,
) -> i32,
pub on_done: extern "C" fn(
user_data: *mut c_void,
finish_reason: *mut c_char,
usage: *mut BlazenTokenUsage,
out_err: *mut *mut BlazenError,
) -> i32,
pub on_error: extern "C" fn(
user_data: *mut c_void,
err: *mut BlazenError,
out_err: *mut *mut BlazenError,
) -> i32,
}
unsafe impl Send for BlazenCompletionStreamSinkVTable {}
unsafe impl Sync for BlazenCompletionStreamSinkVTable {}
pub(crate) struct CStreamSink {
vtable: BlazenCompletionStreamSinkVTable,
}
impl Drop for CStreamSink {
fn drop(&mut self) {
(self.vtable.drop_user_data)(self.vtable.user_data);
}
}
#[async_trait]
impl CompletionStreamSink for CStreamSink {
#[allow(clippy::result_large_err)]
async fn on_chunk(&self, chunk: InnerStreamChunk) -> BlazenResult<()> {
let chunk_ptr = BlazenStreamChunk::from(chunk).into_ptr();
let user_data_addr = self.vtable.user_data as usize;
let on_chunk_fn = self.vtable.on_chunk;
let chunk_addr = chunk_ptr as usize;
let join = tokio::task::spawn_blocking(move || -> Result<(), InnerError> {
let user_data = user_data_addr as *mut c_void;
let chunk_ptr = chunk_addr as *mut BlazenStreamChunk;
let mut out_err: *mut BlazenError = std::ptr::null_mut();
let status = on_chunk_fn(user_data, chunk_ptr, &raw mut out_err);
if status == 0 {
Ok(())
} else {
if out_err.is_null() {
return Err(InnerError::Internal {
message: format!(
"stream sink on_chunk returned non-zero status ({status}) without setting out_err"
),
});
}
let be = unsafe { Box::from_raw(out_err) };
Err(be.inner)
}
})
.await;
match join {
Ok(Ok(())) => Ok(()),
Ok(Err(e)) => Err(e),
Err(join_err) => Err(InnerError::Internal {
message: format!("stream sink on_chunk task panicked: {join_err}"),
}),
}
}
#[allow(clippy::result_large_err)]
async fn on_done(&self, finish_reason: String, usage: InnerTokenUsage) -> BlazenResult<()> {
let Ok(finish_cstring) = CString::new(finish_reason) else {
return Err(InnerError::Internal {
message: "stream sink on_done: finish_reason contains interior NUL byte".into(),
});
};
let finish_raw = finish_cstring.into_raw();
let usage_ptr = BlazenTokenUsage::from(usage).into_ptr();
let user_data_addr = self.vtable.user_data as usize;
let on_done_fn = self.vtable.on_done;
let finish_addr = finish_raw as usize;
let usage_addr = usage_ptr as usize;
let join = tokio::task::spawn_blocking(move || -> Result<(), InnerError> {
let user_data = user_data_addr as *mut c_void;
let finish_ptr = finish_addr as *mut c_char;
let usage_ptr = usage_addr as *mut BlazenTokenUsage;
let mut out_err: *mut BlazenError = std::ptr::null_mut();
let status = on_done_fn(user_data, finish_ptr, usage_ptr, &raw mut out_err);
if status == 0 {
Ok(())
} else {
if out_err.is_null() {
return Err(InnerError::Internal {
message: format!(
"stream sink on_done returned non-zero status ({status}) without setting out_err"
),
});
}
let be = unsafe { Box::from_raw(out_err) };
Err(be.inner)
}
})
.await;
match join {
Ok(Ok(())) => Ok(()),
Ok(Err(e)) => Err(e),
Err(join_err) => Err(InnerError::Internal {
message: format!("stream sink on_done task panicked: {join_err}"),
}),
}
}
#[allow(clippy::result_large_err)]
async fn on_error(&self, err: InnerError) -> BlazenResult<()> {
let err_ptr = BlazenError::from(err).into_ptr();
let user_data_addr = self.vtable.user_data as usize;
let on_error_fn = self.vtable.on_error;
let err_addr = err_ptr as usize;
let join = tokio::task::spawn_blocking(move || -> Result<(), InnerError> {
let user_data = user_data_addr as *mut c_void;
let err_ptr = err_addr as *mut BlazenError;
let mut out_err: *mut BlazenError = std::ptr::null_mut();
let status = on_error_fn(user_data, err_ptr, &raw mut out_err);
if status == 0 {
Ok(())
} else {
if out_err.is_null() {
return Err(InnerError::Internal {
message: format!(
"stream sink on_error returned non-zero status ({status}) without setting out_err"
),
});
}
let be = unsafe { Box::from_raw(out_err) };
Err(be.inner)
}
})
.await;
match join {
Ok(Ok(())) => Ok(()),
Ok(Err(e)) => Err(e),
Err(join_err) => Err(InnerError::Internal {
message: format!("stream sink on_error task panicked: {join_err}"),
}),
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn blazen_complete_streaming_blocking(
model: *const BlazenCompletionModel,
request: *mut BlazenCompletionRequest,
sink: BlazenCompletionStreamSinkVTable,
out_err: *mut *mut BlazenError,
) -> i32 {
if model.is_null() {
(sink.drop_user_data)(sink.user_data);
if !request.is_null() {
drop(unsafe { Box::from_raw(request) });
}
return unsafe { write_internal_error(out_err, "blazen_complete_streaming: null model") };
}
if request.is_null() {
(sink.drop_user_data)(sink.user_data);
return unsafe { write_internal_error(out_err, "blazen_complete_streaming: null request") };
}
let model_handle = unsafe { &*model };
let model_arc: Arc<InnerCompletionModel> = Arc::clone(&model_handle.0);
let request_box = unsafe { Box::from_raw(request) };
let inner_request = request_box.0;
let sink_arc: Arc<dyn CompletionStreamSink> = Arc::new(CStreamSink { vtable: sink });
let result = complete_streaming_blocking(model_arc, inner_request, sink_arc);
match result {
Ok(()) => 0,
Err(e) => unsafe { write_error(out_err, e) },
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn blazen_complete_streaming(
model: *const BlazenCompletionModel,
request: *mut BlazenCompletionRequest,
sink: BlazenCompletionStreamSinkVTable,
) -> *mut BlazenFuture {
if model.is_null() {
(sink.drop_user_data)(sink.user_data);
if !request.is_null() {
drop(unsafe { Box::from_raw(request) });
}
return std::ptr::null_mut();
}
if request.is_null() {
(sink.drop_user_data)(sink.user_data);
return std::ptr::null_mut();
}
let model_handle = unsafe { &*model };
let model_arc: Arc<InnerCompletionModel> = Arc::clone(&model_handle.0);
let request_box = unsafe { Box::from_raw(request) };
let inner_request = request_box.0;
let sink_arc: Arc<dyn CompletionStreamSink> = Arc::new(CStreamSink { vtable: sink });
BlazenFuture::spawn::<(), _>(async move {
complete_streaming(model_arc, inner_request, sink_arc).await
})
}