rkllm_rs/
lib.rs

1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
6
7pub mod prelude {
8    use std::ffi::CStr;
9    use std::ptr::null_mut;
10    use std::sync::Arc;
11    use std::sync::Mutex;
12
13    pub use super::RKLLMExtendParam;
14    pub use super::RKLLMLoraParam;
15    pub use super::RKLLMParam;
16    pub use super::RKLLMResultLastHiddenLayer;
17
18    /// Represents the state of an LLM call.
19    #[derive(Debug, PartialEq, Eq)]
20    pub enum LLMCallState {
21        /// The LLM call is in a normal running state.
22        Normal = 0,
23        /// The LLM call is waiting for a complete UTF-8 encoded character.
24        Waiting = 1,
25        /// The LLM call has finished execution.
26        Finish = 2,
27        /// An error occurred during the LLM call.
28        Error = 3,
29        /// Retrieve the last hidden layer during inference.
30        GetLastHiddenLayer = 4,
31    }
32
33    /// Structure for defining parameters during inference.
34    #[derive(Debug, Clone, Default)]
35    pub struct RKLLMInferParam {
36        /// Inference mode, such as generating text or getting the last hidden layer.
37        pub mode: RKLLMInferMode,
38        /// Optional Lora adapter parameters.
39        pub lora_params: Option<String>,
40        /// Optional prompt cache parameters.
41        pub prompt_cache_params: Option<RKLLMPromptCacheParam>,
42    }
43
44    /// Defines the inference mode for the LLM.
45    #[derive(Debug, Copy, Clone, Default)]
46    pub enum RKLLMInferMode {
47        /// The LLM generates text based on the input. This is the default mode.
48        #[default]
49        InferGenerate = 0,
50        /// The LLM retrieves the last hidden layer for further processing.
51        InferGetLastHiddenLayer = 1,
52    }
53
54    impl Into<u32> for RKLLMInferMode {
55        /// Converts the enum variant to its underlying u32 value.
56        fn into(self) -> u32 {
57            self as u32
58        }
59    }
60
61    /// Structure to define parameters for caching prompts.
62    #[derive(Debug, Clone)]
63    pub struct RKLLMPromptCacheParam {
64        /// Indicates whether to save the prompt cache. If `true`, the cache is saved.
65        pub save_prompt_cache: bool,
66        /// Path to the prompt cache file.
67        pub prompt_cache_path: String,
68    }
69
70    impl Default for super::RKLLMParam {
71        /// Creates a default `RKLLMParam` by calling the underlying C function.
72        fn default() -> Self {
73            unsafe { super::rkllm_createDefaultParam() }
74        }
75    }
76
77    /// Represents the result of an LLM inference.
78    #[derive(Debug, Clone)]
79    pub struct RKLLMResult {
80        /// The generated text from the LLM.
81        pub text: String,
82        /// The ID of the generated token.
83        pub token_id: i32,
84        /// The last hidden layer's states if requested during inference.
85        pub last_hidden_layer: RKLLMResultLastHiddenLayer,
86    }
87
88    #[derive(Debug, Clone)]
89    pub struct RKLLMLoraAdapter {
90        pub lora_adapter_path: String,
91        pub lora_adapter_name: String,
92        pub scale: f32,
93    }
94
95    /// Handle to an LLM instance.
96    #[derive(Clone, Debug, Copy)]
97    pub struct LLMHandle {
98        handle: super::LLMHandle,
99    }
100
101    unsafe impl Send for LLMHandle {} // Asserts that the handle is safe to send across threads.
102    unsafe impl Sync for LLMHandle {} // Asserts that the handle is safe to share across threads.
103
104    /// Trait for handling callbacks from LLM operations.
105    pub trait RkllmCallbackHandler {
106        /// Handles the result and state of an LLM call.
107        fn handle(&mut self, result: Option<RKLLMResult>, state: LLMCallState);
108    }
109
110    /// Internal structure to hold the callback handler.
111    pub struct InstanceData {
112        /// The callback handler wrapped in `Arc` and `Mutex` for thread safety.
113        pub callback_handler: Arc<Mutex<dyn RkllmCallbackHandler + Send + Sync>>,
114    }
115
116    impl LLMHandle {
117        /// Destroys the LLM instance and releases its resources.
118        pub fn destroy(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
119            let ret = unsafe { super::rkllm_destroy(self.handle) };
120
121            if ret == 0 {
122                return Ok(());
123            } else {
124                return Err(Box::new(std::io::Error::new(
125                    std::io::ErrorKind::Other,
126                    format!("rkllm_run returned non-zero: {}", ret),
127                )));
128            }
129        }
130
131        /// Runs an LLM inference task asynchronously.
132        ///
133        /// # Parameters
134        /// - `rkllm_input`: The input data for the LLM.
135        /// - `rkllm_infer_params`: Optional parameters for the inference task.
136        /// - `user_data`: The callback handler to process the results.
137        ///
138        /// # Returns
139        /// This function does not return a value directly. Instead, it starts an asynchronous operation and processes results via the provided callback handler.
140        pub fn run(
141            &self,
142            rkllm_input: RKLLMInput,
143            rkllm_infer_params: Option<RKLLMInferParam>,
144            user_data: impl RkllmCallbackHandler + Send + Sync + 'static,
145        ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
146            let instance_data = Arc::new(InstanceData {
147                callback_handler: Arc::new(Mutex::new(user_data)),
148            });
149
150            let userdata_ptr = Arc::into_raw(instance_data) as *mut std::ffi::c_void;
151            let prompt_cstring;
152            let prompt_cstring_ptr;
153            let mut input = match rkllm_input {
154                RKLLMInput::Prompt(prompt) => {
155                    prompt_cstring = std::ffi::CString::new(prompt).unwrap();
156                    prompt_cstring_ptr = prompt_cstring.as_ptr() as *const std::os::raw::c_char;
157                    super::RKLLMInput {
158                        input_type: super::RKLLMInputType_RKLLM_INPUT_PROMPT,
159                        __bindgen_anon_1: super::RKLLMInput__bindgen_ty_1 {
160                            prompt_input: prompt_cstring_ptr,
161                        },
162                    }
163                }
164                RKLLMInput::Token(_) => todo!(),
165                RKLLMInput::Embed(_) => todo!(),
166                RKLLMInput::Multimodal(_) => todo!(),
167            };
168
169            let prompt_cache_cstring;
170            let prompt_cache_cstring_ptr;
171
172            let lora_adapter_name;
173            let lora_adapter_name_ptr;
174            let mut loraparam;
175
176            let new_rkllm_infer_params: *mut super::RKLLMInferParam =
177                if let Some(rkllm_infer_params) = rkllm_infer_params {
178                    &mut super::RKLLMInferParam {
179                        mode: rkllm_infer_params.mode.into(),
180                        lora_params: match rkllm_infer_params.lora_params {
181                            Some(a) => {
182                                lora_adapter_name = a;
183                                lora_adapter_name_ptr =
184                                    lora_adapter_name.as_ptr() as *const std::os::raw::c_char;
185                                loraparam = RKLLMLoraParam {
186                                    lora_adapter_name: lora_adapter_name_ptr,
187                                };
188                                &mut loraparam
189                            }
190                            None => null_mut(),
191                        },
192                        prompt_cache_params: if let Some(cache_params) =
193                            rkllm_infer_params.prompt_cache_params
194                        {
195                            prompt_cache_cstring =
196                                std::ffi::CString::new(cache_params.prompt_cache_path).unwrap();
197                            prompt_cache_cstring_ptr =
198                                prompt_cache_cstring.as_ptr() as *const std::os::raw::c_char;
199
200                            &mut super::RKLLMPromptCacheParam {
201                                save_prompt_cache: if cache_params.save_prompt_cache {
202                                    1
203                                } else {
204                                    0
205                                },
206                                prompt_cache_path: prompt_cache_cstring_ptr,
207                            }
208                        } else {
209                            null_mut()
210                        },
211                    }
212                } else {
213                    null_mut()
214                };
215
216            let ret = unsafe {
217                super::rkllm_run(
218                    self.handle,
219                    &mut input,
220                    new_rkllm_infer_params,
221                    userdata_ptr,
222                )
223            };
224            if ret == 0 {
225                return Ok(());
226            } else {
227                return Err(Box::new(std::io::Error::new(
228                    std::io::ErrorKind::Other,
229                    format!("rkllm_run returned non-zero: {}", ret),
230                )));
231            }
232        }
233
234        /// Loads a prompt cache from a file.
235        ///
236        /// # Parameters
237        /// - `cache_path`: The path to the prompt cache file.
238        pub fn load_prompt_cache(
239            &self,
240            cache_path: &str,
241        ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
242            let prompt_cache_path = std::ffi::CString::new(cache_path).unwrap();
243            let prompt_cache_path_ptr = prompt_cache_path.as_ptr() as *const std::os::raw::c_char;
244            let ret = unsafe { super::rkllm_load_prompt_cache(self.handle, prompt_cache_path_ptr) };
245            if ret == 0 {
246                return Ok(());
247            } else {
248                return Err(Box::new(std::io::Error::new(
249                    std::io::ErrorKind::Other,
250                    format!("rkllm_load_prompt_cache returned non-zero: {}", ret),
251                )));
252            }
253        }
254
255        /// Release a prompt cache from a file.
256        pub fn release_prompt_cache(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
257            let ret = unsafe { super::rkllm_release_prompt_cache(self.handle) };
258            if ret == 0 {
259                return Ok(());
260            } else {
261                return Err(Box::new(std::io::Error::new(
262                    std::io::ErrorKind::Other,
263                    format!("rkllm_release_prompt_cache returned non-zero: {}", ret),
264                )));
265            }
266        }
267
268        /// Aborts an ongoing LLM task
269        pub fn abort(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
270            let ret = unsafe { super::rkllm_abort(self.handle) };
271            if ret == 0 {
272                return Ok(());
273            } else {
274                return Err(Box::new(std::io::Error::new(
275                    std::io::ErrorKind::Other,
276                    format!("rkllm_abort returned non-zero: {}", ret),
277                )));
278            }
279        }
280
281        /// Checks if an LLM task is currently running.
282        pub fn is_running(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
283            let ret = unsafe { super::rkllm_is_running(self.handle) };
284            if ret == 0 {
285                return Ok(());
286            } else {
287                return Err(Box::new(std::io::Error::new(
288                    std::io::ErrorKind::Other,
289                    format!("rkllm_is_running returned non-zero: {}", ret),
290                )));
291            }
292        }
293
294        /// Loads a Lora adapter into the LLM.
295        pub fn load_lora(
296            &self,
297            lora_cfg: &RKLLMLoraAdapter,
298        ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
299            let lora_adapter_name_cstring =
300                std::ffi::CString::new(lora_cfg.lora_adapter_name.clone()).unwrap();
301            let lora_adapter_name_cstring_ptr =
302                lora_adapter_name_cstring.as_ptr() as *const std::os::raw::c_char;
303            let lora_adapter_path_cstring =
304                std::ffi::CString::new(lora_cfg.lora_adapter_path.clone()).unwrap();
305            let lora_adapter_path_cstring_ptr =
306                lora_adapter_path_cstring.as_ptr() as *const std::os::raw::c_char;
307            let mut param = super::RKLLMLoraAdapter {
308                lora_adapter_path: lora_adapter_path_cstring_ptr,
309                lora_adapter_name: lora_adapter_name_cstring_ptr,
310                scale: lora_cfg.scale,
311            };
312            let ret = unsafe { super::rkllm_load_lora(self.handle, &mut param) };
313            if ret == 0 {
314                return Ok(());
315            } else {
316                return Err(Box::new(std::io::Error::new(
317                    std::io::ErrorKind::Other,
318                    format!("rkllm_load_lora returned non-zero: {}", ret),
319                )));
320            }
321        }
322    }
323
324    /// Internal callback function to handle LLM results from the C library.
325    unsafe extern "C" fn callback_passtrough(
326        result: *mut super::RKLLMResult,
327        userdata: *mut ::std::os::raw::c_void,
328        state: super::LLMCallState,
329    ) {
330        Arc::increment_strong_count(userdata); // We don't actually want to free it
331        let instance_data = unsafe { Arc::from_raw(userdata as *const InstanceData) };
332        let new_state = match state {
333            0 => LLMCallState::Normal,
334            1 => LLMCallState::Waiting,
335            2 => LLMCallState::Finish,
336            3 => LLMCallState::Error,
337            4 => LLMCallState::GetLastHiddenLayer,
338            _ => panic!("Unexpected LLMCallState"),
339        };
340
341        let new_result = if result.is_null() {
342            None
343        } else {
344            Some(RKLLMResult {
345                text: if (*result).text.is_null() {
346                    String::new()
347                } else {
348                    (unsafe { CStr::from_ptr((*result).text) })
349                        .to_str()
350                        .expect("Failed to convert C string")
351                        .to_owned()
352                        .clone()
353                },
354                token_id: (*result).token_id,
355                last_hidden_layer: (*result).last_hidden_layer,
356            })
357        };
358
359        instance_data
360            .callback_handler
361            .lock()
362            .unwrap()
363            .handle(new_result, new_state);
364    }
365
366    /// Initializes the LLM with the given parameters.
367    ///
368    /// # Parameters
369    /// - `param`: A pointer to the LLM configuration parameters.
370    ///
371    /// # Returns
372    /// If successful, returns a `Result` containing the `LLMHandle`; otherwise, returns an error.
373    pub fn rkllm_init(
374        param: *mut super::RKLLMParam,
375    ) -> Result<LLMHandle, Box<dyn std::error::Error + Send + Sync>> {
376        let mut handle = LLMHandle {
377            handle: std::ptr::null_mut(),
378        };
379
380        let callback: Option<
381            unsafe extern "C" fn(
382                *mut super::RKLLMResult,
383                *mut ::std::os::raw::c_void,
384                super::LLMCallState,
385            ),
386        > = Some(callback_passtrough);
387        let ret = unsafe { super::rkllm_init(&mut handle.handle, param, callback) };
388        if ret == 0 {
389            return Ok(handle);
390        } else {
391            return Err(Box::new(std::io::Error::new(
392                std::io::ErrorKind::Other,
393                format!("rkllm_init returned non-zero: {}", ret),
394            )));
395        }
396    }
397
398    /// Represents different types of input that can be provided to the LLM.
399    pub enum RKLLMInput {
400        /// Input is a text prompt.
401        Prompt(String),
402        /// Input is a sequence of tokens.
403        Token(String),
404        /// Input is an embedding vector.
405        Embed(String),
406        /// Input is multimodal, such as text and image.
407        Multimodal(String),
408    }
409}