use std::ffi::{CStr, CString};
use std::fmt;
#[allow(non_upper_case_globals)]
#[allow(non_camel_case_types)]
#[allow(non_snake_case)]
#[allow(dead_code)]
#[allow(clippy::all)]
mod bindings {
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
}
use bindings::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Backend {
Cpu,
Gpu,
}
impl Backend {
fn as_str(&self) -> &'static str {
match self {
Backend::Cpu => "cpu",
Backend::Gpu => "gpu",
}
}
}
#[derive(Debug, Clone)]
pub struct Error {
message: String,
}
impl Error {
fn new(message: impl Into<String>) -> Self {
Error {
message: message.into(),
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "LiteRT-LM Error: {}", self.message)
}
}
impl std::error::Error for Error {}
pub type Result<T> = std::result::Result<T, Error>;
pub struct Engine {
raw: *mut LiteRtLmEngine,
_settings: *mut LiteRtLmEngineSettings, }
unsafe impl Send for Engine {}
unsafe impl Sync for Engine {}
impl Engine {
pub fn new(model_path: &str, backend: Backend) -> Result<Self> {
let model_path_cstr = CString::new(model_path)
.map_err(|e| Error::new(format!("Invalid model path: {}", e)))?;
let backend_cstr = CString::new(backend.as_str())
.map_err(|e| Error::new(format!("Invalid backend string: {}", e)))?;
unsafe {
let settings = litert_lm_engine_settings_create(
model_path_cstr.as_ptr(),
backend_cstr.as_ptr(),
);
if settings.is_null() {
return Err(Error::new("Failed to create engine settings"));
}
let engine = litert_lm_engine_create(settings);
if engine.is_null() {
litert_lm_engine_settings_delete(settings);
return Err(Error::new("Failed to create engine"));
}
Ok(Engine {
raw: engine,
_settings: settings,
})
}
}
pub fn create_session(&self) -> Result<Session> {
unsafe {
let session = litert_lm_engine_create_session(self.raw);
if session.is_null() {
return Err(Error::new("Failed to create session"));
}
Ok(Session { raw: session })
}
}
}
impl Drop for Engine {
fn drop(&mut self) {
unsafe {
litert_lm_engine_delete(self.raw);
litert_lm_engine_settings_delete(self._settings);
}
}
}
pub struct Session {
raw: *mut LiteRtLmSession,
}
unsafe impl Send for Session {}
impl Session {
pub fn generate(&self, prompt: &str) -> Result<String> {
let prompt_cstr = CString::new(prompt)
.map_err(|e| Error::new(format!("Invalid prompt: {}", e)))?;
unsafe {
let input_data = InputData {
type_: InputDataType_kInputText,
data: prompt_cstr.as_ptr() as *const std::ffi::c_void,
size: prompt.len(),
};
let responses = litert_lm_session_generate_content(self.raw, &input_data, 1);
if responses.is_null() {
return Err(Error::new("Failed to generate content"));
}
let text_ptr = litert_lm_responses_get_response_text_at(responses, 0);
let result = if !text_ptr.is_null() {
CStr::from_ptr(text_ptr).to_string_lossy().into_owned()
} else {
litert_lm_responses_delete(responses);
return Err(Error::new("No response generated"));
};
litert_lm_responses_delete(responses);
Ok(result)
}
}
pub fn get_benchmark_info(&self) -> Result<BenchmarkInfo> {
unsafe {
let info = litert_lm_session_get_benchmark_info(self.raw);
if info.is_null() {
return Err(Error::new("Failed to get benchmark info"));
}
let time_to_first_token =
litert_lm_benchmark_info_get_time_to_first_token(info);
let num_prefill_turns = litert_lm_benchmark_info_get_num_prefill_turns(info);
let num_decode_turns = litert_lm_benchmark_info_get_num_decode_turns(info);
let result = BenchmarkInfo {
time_to_first_token,
num_prefill_turns: num_prefill_turns as usize,
num_decode_turns: num_decode_turns as usize,
};
litert_lm_benchmark_info_delete(info);
Ok(result)
}
}
}
impl Drop for Session {
fn drop(&mut self) {
unsafe {
litert_lm_session_delete(self.raw);
}
}
}
#[derive(Debug, Clone)]
pub struct BenchmarkInfo {
pub time_to_first_token: f64,
pub num_prefill_turns: usize,
pub num_decode_turns: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_enum() {
assert_eq!(Backend::Cpu.as_str(), "cpu");
assert_eq!(Backend::Gpu.as_str(), "gpu");
}
#[test]
fn test_error_display() {
let err = Error::new("test error");
assert_eq!(format!("{}", err), "LiteRT-LM Error: test error");
}
}