use std::{
ffi::{c_char, c_void, CStr},
ptr::NonNull,
sync::Arc,
};
use litert_lm_sys as sys;
use crate::{engine::EngineInner, input::Input, Error, Result, SamplerParams};
pub struct Session {
ptr: NonNull<sys::LiteRtLmSession>,
_engine: Arc<EngineInner>,
}
unsafe impl Send for Session {}
impl Session {
pub(crate) fn new(engine: Arc<EngineInner>, params: SamplerParams) -> Result<Self> {
let config = unsafe { sys::litert_lm_session_config_create() };
if config.is_null() {
return Err(Error::NullPointer);
}
let raw_params = params.to_raw();
unsafe {
sys::litert_lm_session_config_set_sampler_params(config, &raw_params);
}
let session_ptr =
unsafe { sys::litert_lm_engine_create_session(engine.ptr.as_ptr(), config) };
unsafe { sys::litert_lm_session_config_delete(config) };
let ptr = NonNull::new(session_ptr).ok_or(Error::SessionCreationFailed)?;
Ok(Self {
ptr,
_engine: engine,
})
}
pub fn generate(&mut self, prompt: &str) -> Result<String> {
self.generate_with_inputs(&[Input::Text(prompt)])
}
pub fn generate_with_inputs(&mut self, inputs: &[Input<'_>]) -> Result<String> {
let raw_inputs: Vec<sys::InputData> = inputs.iter().map(Input::to_raw).collect();
let responses = unsafe {
sys::litert_lm_session_generate_content(
self.ptr.as_ptr(),
raw_inputs.as_ptr(),
raw_inputs.len(),
)
};
if responses.is_null() {
return Err(Error::GenerationFailed("returned null".into()));
}
let num = unsafe { sys::litert_lm_responses_get_num_candidates(responses) };
let text = if num > 0 {
let raw = unsafe { sys::litert_lm_responses_get_response_text_at(responses, 0) };
if raw.is_null() {
String::new()
} else {
unsafe { CStr::from_ptr(raw) }
.to_string_lossy()
.into_owned()
}
} else {
String::new()
};
unsafe { sys::litert_lm_responses_delete(responses) };
Ok(text)
}
pub fn generate_stream(
&mut self,
prompt: &str,
mut on_token: impl FnMut(&str) -> bool,
) -> Result<()> {
use std::sync::{Condvar, Mutex};
struct State<'a> {
cb: &'a mut dyn FnMut(&str) -> bool,
error: Option<String>,
done: &'a Mutex<bool>,
cond: &'a Condvar,
}
unsafe extern "C" fn trampoline(
data: *mut c_void,
chunk: *const c_char,
is_final: bool,
error_msg: *const c_char,
) {
let state = &mut *(data as *mut State);
if !error_msg.is_null() {
state.error = Some(CStr::from_ptr(error_msg).to_string_lossy().into_owned());
*state.done.lock().unwrap() = true;
state.cond.notify_one();
return;
}
if !chunk.is_null() {
let s = CStr::from_ptr(chunk).to_string_lossy();
(state.cb)(s.as_ref());
}
if is_final {
*state.done.lock().unwrap() = true;
state.cond.notify_one();
}
}
let input = sys::InputData {
type_: sys::kInputText,
data: prompt.as_ptr().cast(),
size: prompt.len(),
};
let done = Mutex::new(false);
let cond = Condvar::new();
let mut state = State {
cb: &mut on_token,
error: None,
done: &done,
cond: &cond,
};
let ret = unsafe {
sys::litert_lm_session_generate_content_stream(
self.ptr.as_ptr(),
&input,
1,
Some(trampoline),
&mut state as *mut State as *mut c_void,
)
};
if ret != 0 {
return Err(Error::GenerationFailed(format!("stream returned {ret}")));
}
let guard = done.lock().unwrap();
let _guard = cond.wait_while(guard, |d| !*d).unwrap();
if let Some(err) = state.error {
return Err(Error::GenerationFailed(err));
}
Ok(())
}
}
impl Drop for Session {
fn drop(&mut self) {
unsafe { sys::litert_lm_session_delete(self.ptr.as_ptr()) }
}
}
impl std::fmt::Debug for Session {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Session")
.field("ptr", &self.ptr.as_ptr())
.finish()
}
}