litert_lm/lib.rs
1//! # LiteRT-LM Rust Bindings
2//!
3//! Safe, idiomatic Rust wrapper for the LiteRT-LM C API.
4//!
5//! ## Features
6//!
7//! - **Safe API**: Memory-safe wrappers around C FFI
8//! - **Automatic cleanup**: RAII-based resource management
9//! - **Thread-safe**: Proper Send/Sync implementations
10//! - **Error handling**: Result-based error handling
11//!
12//! ## Example
13//!
14//! ```no_run
15//! use litert_lm::{Engine, Backend};
16//!
17//! fn main() -> Result<(), Box<dyn std::error::Error>> {
18//! // Create engine
19//! let engine = Engine::new("model.tflite", Backend::Cpu)?;
20//!
21//! // Create session
22//! let session = engine.create_session()?;
23//!
24//! // Generate text
25//! let response = session.generate("Hello, how are you?")?;
26//! println!("Response: {}", response);
27//!
28//! Ok(())
29//! }
30//! ```
31
32use std::ffi::{CStr, CString};
33use std::fmt;
34
35// Include auto-generated bindings from bindgen
36#[allow(non_upper_case_globals)]
37#[allow(non_camel_case_types)]
38#[allow(non_snake_case)]
39#[allow(dead_code)]
40#[allow(clippy::all)]
41mod bindings {
42 include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
43}
44
45use bindings::*;
46
47// ============================================================================
48// Public Types
49// ============================================================================
50
51/// Backend type for model execution
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum Backend {
54 /// CPU backend
55 Cpu,
56 /// GPU backend (if available)
57 Gpu,
58}
59
60impl Backend {
61 fn as_str(&self) -> &'static str {
62 match self {
63 Backend::Cpu => "cpu",
64 Backend::Gpu => "gpu",
65 }
66 }
67}
68
69/// Error type for LiteRT-LM operations
70#[derive(Debug, Clone)]
71pub struct Error {
72 message: String,
73}
74
75impl Error {
76 fn new(message: impl Into<String>) -> Self {
77 Error {
78 message: message.into(),
79 }
80 }
81}
82
83impl fmt::Display for Error {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 write!(f, "LiteRT-LM Error: {}", self.message)
86 }
87}
88
89impl std::error::Error for Error {}
90
91/// Result type for LiteRT-LM operations
92pub type Result<T> = std::result::Result<T, Error>;
93
94// ============================================================================
95// Engine
96// ============================================================================
97
98/// LiteRT-LM Engine - the main entry point for loading models
99///
100/// The Engine loads a model file and prepares it for inference.
101/// Create sessions from the engine to perform text generation.
102pub struct Engine {
103 raw: *mut LiteRtLmEngine,
104 _settings: *mut LiteRtLmEngineSettings, // Keep settings alive
105}
106
107// Safety: The C API allows engines to be shared between threads
108unsafe impl Send for Engine {}
109unsafe impl Sync for Engine {}
110
111impl Engine {
112 /// Create a new Engine from a model file
113 ///
114 /// # Arguments
115 ///
116 /// * `model_path` - Path to the .tflite model file
117 /// * `backend` - Backend to use (Cpu or Gpu)
118 ///
119 /// # Example
120 ///
121 /// ```no_run
122 /// use litert_lm::{Engine, Backend};
123 ///
124 /// let engine = Engine::new("model.tflite", Backend::Cpu)?;
125 /// # Ok::<(), litert_lm::Error>(())
126 /// ```
127 pub fn new(model_path: &str, backend: Backend) -> Result<Self> {
128 let model_path_cstr = CString::new(model_path)
129 .map_err(|e| Error::new(format!("Invalid model path: {}", e)))?;
130
131 let backend_cstr = CString::new(backend.as_str())
132 .map_err(|e| Error::new(format!("Invalid backend string: {}", e)))?;
133
134 unsafe {
135 // Create engine settings
136 let settings = litert_lm_engine_settings_create(
137 model_path_cstr.as_ptr(),
138 backend_cstr.as_ptr(),
139 );
140
141 if settings.is_null() {
142 return Err(Error::new("Failed to create engine settings"));
143 }
144
145 // Create engine
146 let engine = litert_lm_engine_create(settings);
147
148 if engine.is_null() {
149 litert_lm_engine_settings_delete(settings);
150 return Err(Error::new("Failed to create engine"));
151 }
152
153 Ok(Engine {
154 raw: engine,
155 _settings: settings,
156 })
157 }
158 }
159
160 /// Create a new session for text generation
161 ///
162 /// Sessions maintain conversation history and state.
163 /// You can create multiple sessions from the same engine.
164 ///
165 /// # Example
166 ///
167 /// ```no_run
168 /// use litert_lm::{Engine, Backend};
169 ///
170 /// let engine = Engine::new("model.tflite", Backend::Cpu)?;
171 /// let session = engine.create_session()?;
172 /// # Ok::<(), litert_lm::Error>(())
173 /// ```
174 pub fn create_session(&self) -> Result<Session> {
175 unsafe {
176 let session = litert_lm_engine_create_session(self.raw);
177
178 if session.is_null() {
179 return Err(Error::new("Failed to create session"));
180 }
181
182 Ok(Session { raw: session })
183 }
184 }
185}
186
187impl Drop for Engine {
188 fn drop(&mut self) {
189 unsafe {
190 litert_lm_engine_delete(self.raw);
191 litert_lm_engine_settings_delete(self._settings);
192 }
193 }
194}
195
196// ============================================================================
197// Session
198// ============================================================================
199
200/// LiteRT-LM Session - represents a conversation context
201///
202/// A session maintains the conversation history and can generate
203/// text responses to prompts.
204pub struct Session {
205 raw: *mut LiteRtLmSession,
206}
207
208// Safety: Sessions can be moved between threads but not shared
209unsafe impl Send for Session {}
210
211impl Session {
212 /// Generate text from a prompt
213 ///
214 /// # Arguments
215 ///
216 /// * `prompt` - The input text prompt
217 ///
218 /// # Returns
219 ///
220 /// The generated text response
221 ///
222 /// # Example
223 ///
224 /// ```no_run
225 /// use litert_lm::{Engine, Backend};
226 ///
227 /// let engine = Engine::new("model.tflite", Backend::Cpu)?;
228 /// let session = engine.create_session()?;
229 /// let response = session.generate("What is 2+2?")?;
230 /// println!("Response: {}", response);
231 /// # Ok::<(), litert_lm::Error>(())
232 /// ```
233 pub fn generate(&self, prompt: &str) -> Result<String> {
234 let prompt_cstr = CString::new(prompt)
235 .map_err(|e| Error::new(format!("Invalid prompt: {}", e)))?;
236
237 unsafe {
238 // Create InputData for text
239 let input_data = InputData {
240 type_: InputDataType_kInputText,
241 data: prompt_cstr.as_ptr() as *const std::ffi::c_void,
242 size: prompt.len(),
243 };
244
245 // Generate content
246 let responses = litert_lm_session_generate_content(self.raw, &input_data, 1);
247
248 if responses.is_null() {
249 return Err(Error::new("Failed to generate content"));
250 }
251
252 // Get response text
253 let text_ptr = litert_lm_responses_get_response_text_at(responses, 0);
254
255 let result = if !text_ptr.is_null() {
256 CStr::from_ptr(text_ptr).to_string_lossy().into_owned()
257 } else {
258 litert_lm_responses_delete(responses);
259 return Err(Error::new("No response generated"));
260 };
261
262 // Clean up responses
263 litert_lm_responses_delete(responses);
264
265 Ok(result)
266 }
267 }
268
269 /// Get benchmark information (if benchmarking is enabled)
270 ///
271 /// Returns information about performance metrics like tokens per second.
272 pub fn get_benchmark_info(&self) -> Result<BenchmarkInfo> {
273 unsafe {
274 let info = litert_lm_session_get_benchmark_info(self.raw);
275
276 if info.is_null() {
277 return Err(Error::new("Failed to get benchmark info"));
278 }
279
280 let time_to_first_token =
281 litert_lm_benchmark_info_get_time_to_first_token(info);
282 let num_prefill_turns = litert_lm_benchmark_info_get_num_prefill_turns(info);
283 let num_decode_turns = litert_lm_benchmark_info_get_num_decode_turns(info);
284
285 let result = BenchmarkInfo {
286 time_to_first_token,
287 num_prefill_turns: num_prefill_turns as usize,
288 num_decode_turns: num_decode_turns as usize,
289 };
290
291 litert_lm_benchmark_info_delete(info);
292
293 Ok(result)
294 }
295 }
296}
297
298impl Drop for Session {
299 fn drop(&mut self) {
300 unsafe {
301 litert_lm_session_delete(self.raw);
302 }
303 }
304}
305
306// ============================================================================
307// Benchmark Info
308// ============================================================================
309
310/// Benchmark information for a session
311#[derive(Debug, Clone)]
312pub struct BenchmarkInfo {
313 /// Time to first token in seconds
314 pub time_to_first_token: f64,
315 /// Number of prefill turns
316 pub num_prefill_turns: usize,
317 /// Number of decode turns
318 pub num_decode_turns: usize,
319}
320
321// ============================================================================
322// Tests
323// ============================================================================
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_backend_enum() {
331 assert_eq!(Backend::Cpu.as_str(), "cpu");
332 assert_eq!(Backend::Gpu.as_str(), "gpu");
333 }
334
335 #[test]
336 fn test_error_display() {
337 let err = Error::new("test error");
338 assert_eq!(format!("{}", err), "LiteRT-LM Error: test error");
339 }
340}