Skip to main content

litert_lm/
lib.rs

1//! # LiteRT-LM Rust Bindings
2//!
3//! Safe, idiomatic Rust wrapper for the LiteRT-LM C API.
4//!
5//! ## Features
6//!
7//! - **Safe API**: Memory-safe wrappers around C FFI
8//! - **Automatic cleanup**: RAII-based resource management
9//! - **Thread-safe**: Proper Send/Sync implementations
10//! - **Error handling**: Result-based error handling
11//!
12//! ## Example
13//!
14//! ```no_run
15//! use litert_lm::{Engine, Backend};
16//!
17//! fn main() -> Result<(), Box<dyn std::error::Error>> {
18//!     // Create engine
19//!     let engine = Engine::new("model.tflite", Backend::Cpu)?;
20//!
21//!     // Create session
22//!     let session = engine.create_session()?;
23//!
24//!     // Generate text
25//!     let response = session.generate("Hello, how are you?")?;
26//!     println!("Response: {}", response);
27//!
28//!     Ok(())
29//! }
30//! ```
31
32use std::ffi::{CStr, CString};
33use std::fmt;
34
35// Include auto-generated bindings from bindgen
36#[allow(non_upper_case_globals)]
37#[allow(non_camel_case_types)]
38#[allow(non_snake_case)]
39#[allow(dead_code)]
40#[allow(clippy::all)]
41mod bindings {
42    include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
43}
44
45use bindings::*;
46
47// ============================================================================
48// Public Types
49// ============================================================================
50
51/// Backend type for model execution
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum Backend {
54    /// CPU backend
55    Cpu,
56    /// GPU backend (if available)
57    Gpu,
58}
59
60impl Backend {
61    fn as_str(&self) -> &'static str {
62        match self {
63            Backend::Cpu => "cpu",
64            Backend::Gpu => "gpu",
65        }
66    }
67}
68
69/// Error type for LiteRT-LM operations
70#[derive(Debug, Clone)]
71pub struct Error {
72    message: String,
73}
74
75impl Error {
76    fn new(message: impl Into<String>) -> Self {
77        Error {
78            message: message.into(),
79        }
80    }
81}
82
83impl fmt::Display for Error {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        write!(f, "LiteRT-LM Error: {}", self.message)
86    }
87}
88
89impl std::error::Error for Error {}
90
91/// Result type for LiteRT-LM operations
92pub type Result<T> = std::result::Result<T, Error>;
93
94// ============================================================================
95// Engine
96// ============================================================================
97
98/// LiteRT-LM Engine - the main entry point for loading models
99///
100/// The Engine loads a model file and prepares it for inference.
101/// Create sessions from the engine to perform text generation.
102pub struct Engine {
103    raw: *mut LiteRtLmEngine,
104    _settings: *mut LiteRtLmEngineSettings, // Keep settings alive
105}
106
107// Safety: The C API allows engines to be shared between threads
108unsafe impl Send for Engine {}
109unsafe impl Sync for Engine {}
110
111impl Engine {
112    /// Create a new Engine from a model file
113    ///
114    /// # Arguments
115    ///
116    /// * `model_path` - Path to the .tflite model file
117    /// * `backend` - Backend to use (Cpu or Gpu)
118    ///
119    /// # Example
120    ///
121    /// ```no_run
122    /// use litert_lm::{Engine, Backend};
123    ///
124    /// let engine = Engine::new("model.tflite", Backend::Cpu)?;
125    /// # Ok::<(), litert_lm::Error>(())
126    /// ```
127    pub fn new(model_path: &str, backend: Backend) -> Result<Self> {
128        let model_path_cstr = CString::new(model_path)
129            .map_err(|e| Error::new(format!("Invalid model path: {}", e)))?;
130
131        let backend_cstr = CString::new(backend.as_str())
132            .map_err(|e| Error::new(format!("Invalid backend string: {}", e)))?;
133
134        unsafe {
135            // Create engine settings
136            let settings = litert_lm_engine_settings_create(
137                model_path_cstr.as_ptr(),
138                backend_cstr.as_ptr(),
139            );
140
141            if settings.is_null() {
142                return Err(Error::new("Failed to create engine settings"));
143            }
144
145            // Create engine
146            let engine = litert_lm_engine_create(settings);
147
148            if engine.is_null() {
149                litert_lm_engine_settings_delete(settings);
150                return Err(Error::new("Failed to create engine"));
151            }
152
153            Ok(Engine {
154                raw: engine,
155                _settings: settings,
156            })
157        }
158    }
159
160    /// Create a new session for text generation
161    ///
162    /// Sessions maintain conversation history and state.
163    /// You can create multiple sessions from the same engine.
164    ///
165    /// # Example
166    ///
167    /// ```no_run
168    /// use litert_lm::{Engine, Backend};
169    ///
170    /// let engine = Engine::new("model.tflite", Backend::Cpu)?;
171    /// let session = engine.create_session()?;
172    /// # Ok::<(), litert_lm::Error>(())
173    /// ```
174    pub fn create_session(&self) -> Result<Session> {
175        unsafe {
176            let session = litert_lm_engine_create_session(self.raw);
177
178            if session.is_null() {
179                return Err(Error::new("Failed to create session"));
180            }
181
182            Ok(Session { raw: session })
183        }
184    }
185}
186
187impl Drop for Engine {
188    fn drop(&mut self) {
189        unsafe {
190            litert_lm_engine_delete(self.raw);
191            litert_lm_engine_settings_delete(self._settings);
192        }
193    }
194}
195
196// ============================================================================
197// Session
198// ============================================================================
199
200/// LiteRT-LM Session - represents a conversation context
201///
202/// A session maintains the conversation history and can generate
203/// text responses to prompts.
204pub struct Session {
205    raw: *mut LiteRtLmSession,
206}
207
208// Safety: Sessions can be moved between threads but not shared
209unsafe impl Send for Session {}
210
211impl Session {
212    /// Generate text from a prompt
213    ///
214    /// # Arguments
215    ///
216    /// * `prompt` - The input text prompt
217    ///
218    /// # Returns
219    ///
220    /// The generated text response
221    ///
222    /// # Example
223    ///
224    /// ```no_run
225    /// use litert_lm::{Engine, Backend};
226    ///
227    /// let engine = Engine::new("model.tflite", Backend::Cpu)?;
228    /// let session = engine.create_session()?;
229    /// let response = session.generate("What is 2+2?")?;
230    /// println!("Response: {}", response);
231    /// # Ok::<(), litert_lm::Error>(())
232    /// ```
233    pub fn generate(&self, prompt: &str) -> Result<String> {
234        let prompt_cstr = CString::new(prompt)
235            .map_err(|e| Error::new(format!("Invalid prompt: {}", e)))?;
236
237        unsafe {
238            // Create InputData for text
239            let input_data = InputData {
240                type_: InputDataType_kInputText,
241                data: prompt_cstr.as_ptr() as *const std::ffi::c_void,
242                size: prompt.len(),
243            };
244
245            // Generate content
246            let responses = litert_lm_session_generate_content(self.raw, &input_data, 1);
247
248            if responses.is_null() {
249                return Err(Error::new("Failed to generate content"));
250            }
251
252            // Get response text
253            let text_ptr = litert_lm_responses_get_response_text_at(responses, 0);
254
255            let result = if !text_ptr.is_null() {
256                CStr::from_ptr(text_ptr).to_string_lossy().into_owned()
257            } else {
258                litert_lm_responses_delete(responses);
259                return Err(Error::new("No response generated"));
260            };
261
262            // Clean up responses
263            litert_lm_responses_delete(responses);
264
265            Ok(result)
266        }
267    }
268
269    /// Get benchmark information (if benchmarking is enabled)
270    ///
271    /// Returns information about performance metrics like tokens per second.
272    pub fn get_benchmark_info(&self) -> Result<BenchmarkInfo> {
273        unsafe {
274            let info = litert_lm_session_get_benchmark_info(self.raw);
275
276            if info.is_null() {
277                return Err(Error::new("Failed to get benchmark info"));
278            }
279
280            let time_to_first_token =
281                litert_lm_benchmark_info_get_time_to_first_token(info);
282            let num_prefill_turns = litert_lm_benchmark_info_get_num_prefill_turns(info);
283            let num_decode_turns = litert_lm_benchmark_info_get_num_decode_turns(info);
284
285            let result = BenchmarkInfo {
286                time_to_first_token,
287                num_prefill_turns: num_prefill_turns as usize,
288                num_decode_turns: num_decode_turns as usize,
289            };
290
291            litert_lm_benchmark_info_delete(info);
292
293            Ok(result)
294        }
295    }
296}
297
298impl Drop for Session {
299    fn drop(&mut self) {
300        unsafe {
301            litert_lm_session_delete(self.raw);
302        }
303    }
304}
305
306// ============================================================================
307// Benchmark Info
308// ============================================================================
309
310/// Benchmark information for a session
311#[derive(Debug, Clone)]
312pub struct BenchmarkInfo {
313    /// Time to first token in seconds
314    pub time_to_first_token: f64,
315    /// Number of prefill turns
316    pub num_prefill_turns: usize,
317    /// Number of decode turns
318    pub num_decode_turns: usize,
319}
320
321// ============================================================================
322// Tests
323// ============================================================================
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn test_backend_enum() {
331        assert_eq!(Backend::Cpu.as_str(), "cpu");
332        assert_eq!(Backend::Gpu.as_str(), "gpu");
333    }
334
335    #[test]
336    fn test_error_display() {
337        let err = Error::new("test error");
338        assert_eq!(format!("{}", err), "LiteRT-LM Error: test error");
339    }
340}