rkllm_rs/
lib.rs

1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
6
7pub mod prelude {
8    use std::ffi::CStr;
9    use std::ptr::null_mut;
10    use std::sync::Arc;
11    use std::sync::Mutex;
12
13    pub use super::RKLLMExtendParam;
14    pub use super::RKLLMLoraParam;
15    pub use super::RKLLMParam;
16    pub use super::RKLLMResultLastHiddenLayer;
17
18    /// Represents the state of an LLM call.
19    #[derive(Debug, PartialEq, Eq)]
20    pub enum LLMCallState {
21        /// The LLM call is in a normal running state.
22        Normal = 0,
23        /// The LLM call is waiting for a complete UTF-8 encoded character.
24        Waiting = 1,
25        /// The LLM call has finished execution.
26        Finish = 2,
27        /// An error occurred during the LLM call.
28        Error = 3,
29        /// Retrieve the last hidden layer during inference.
30        GetLastHiddenLayer = 4,
31    }
32
33    #[derive(Debug, Clone, Default)]
34    pub enum KeepHistory {
35        #[default]
36        /// Do not keep the history of the conversation.
37        NoKeepHistory = 0,
38        /// Keep the history of the conversation.
39        KeepHistory = 1,
40    }
41
42    /// Structure for defining parameters during inference.
43    #[derive(Debug, Clone, Default)]
44    pub struct RKLLMInferParam {
45        /// Inference mode, such as generating text or getting the last hidden layer.
46        pub mode: RKLLMInferMode,
47        /// Optional Lora adapter parameters.
48        pub lora_params: Option<String>,
49        /// Optional prompt cache parameters.
50        pub prompt_cache_params: Option<RKLLMPromptCacheParam>,
51        pub keep_history: KeepHistory,
52    }
53
54    /// Defines the inference mode for the LLM.
55    #[derive(Debug, Copy, Clone, Default)]
56    pub enum RKLLMInferMode {
57        /// The LLM generates text based on the input. This is the default mode.
58        #[default]
59        InferGenerate = 0,
60        /// The LLM retrieves the last hidden layer for further processing.
61        InferGetLastHiddenLayer = 1,
62    }
63
64    impl Into<u32> for RKLLMInferMode {
65        /// Converts the enum variant to its underlying u32 value.
66        fn into(self) -> u32 {
67            self as u32
68        }
69    }
70
71    /// Structure to define parameters for caching prompts.
72    #[derive(Debug, Clone)]
73    pub struct RKLLMPromptCacheParam {
74        /// Indicates whether to save the prompt cache. If `true`, the cache is saved.
75        pub save_prompt_cache: bool,
76        /// Path to the prompt cache file.
77        pub prompt_cache_path: String,
78    }
79
80    impl Default for super::RKLLMParam {
81        /// Creates a default `RKLLMParam` by calling the underlying C function.
82        fn default() -> Self {
83            unsafe { super::rkllm_createDefaultParam() }
84        }
85    }
86
87    /// Represents the result of an LLM inference.
88    #[derive(Debug, Clone)]
89    pub struct RKLLMResult {
90        /// The generated text from the LLM.
91        pub text: String,
92        /// The ID of the generated token.
93        pub token_id: i32,
94        /// The last hidden layer's states if requested during inference.
95        pub last_hidden_layer: RKLLMResultLastHiddenLayer,
96    }
97
98    #[derive(Debug, Clone)]
99    pub struct RKLLMLoraAdapter {
100        pub lora_adapter_path: String,
101        pub lora_adapter_name: String,
102        pub scale: f32,
103    }
104
105    /// Handle to an LLM instance.
106    #[derive(Clone, Debug, Copy)]
107    pub struct LLMHandle {
108        handle: super::LLMHandle,
109    }
110
111    unsafe impl Send for LLMHandle {} // Asserts that the handle is safe to send across threads.
112    unsafe impl Sync for LLMHandle {} // Asserts that the handle is safe to share across threads.
113
114    /// Trait for handling callbacks from LLM operations.
115    pub trait RkllmCallbackHandler {
116        /// Handles the result and state of an LLM call.
117        fn handle(&mut self, result: Option<RKLLMResult>, state: LLMCallState);
118    }
119
120    /// Internal structure to hold the callback handler.
121    pub struct InstanceData {
122        /// The callback handler wrapped in `Arc` and `Mutex` for thread safety.
123        pub callback_handler: Arc<Mutex<dyn RkllmCallbackHandler + Send + Sync>>,
124    }
125
126    impl LLMHandle {
127        /// Destroys the LLM instance and releases its resources.
128        pub fn destroy(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
129            let ret = unsafe { super::rkllm_destroy(self.handle) };
130
131            if ret == 0 {
132                return Ok(());
133            } else {
134                return Err(Box::new(std::io::Error::new(
135                    std::io::ErrorKind::Other,
136                    format!("rkllm_run returned non-zero: {}", ret),
137                )));
138            }
139        }
140
141        /// Runs an LLM inference task asynchronously.
142        ///
143        /// # Parameters
144        /// - `rkllm_input`: The input data for the LLM.
145        /// - `rkllm_infer_params`: Optional parameters for the inference task.
146        /// - `user_data`: The callback handler to process the results.
147        ///
148        /// # Returns
149        /// This function does not return a value directly. Instead, it starts an asynchronous operation and processes results via the provided callback handler.
150        pub fn run(
151            &self,
152            rkllm_input: RKLLMInput,
153            rkllm_infer_params: Option<RKLLMInferParam>,
154            user_data: impl RkllmCallbackHandler + Send + Sync + 'static,
155        ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
156            let instance_data = Arc::new(InstanceData {
157                callback_handler: Arc::new(Mutex::new(user_data)),
158            });
159
160            let userdata_ptr = Arc::into_raw(instance_data) as *mut std::ffi::c_void;
161            let prompt_cstring;
162            let prompt_cstring_ptr;
163            let role_text;
164            let role_text_ptr;
165            let mut input = match rkllm_input.input_type {
166                RKLLMInputType::Prompt(prompt) => {
167                    prompt_cstring = std::ffi::CString::new(prompt).unwrap();
168                    prompt_cstring_ptr = prompt_cstring.as_ptr() as *const std::os::raw::c_char;
169
170                    role_text = match rkllm_input.role {
171                        RKLLMInputRole::User => "user",
172                        RKLLMInputRole::Tool => "tool",
173                    };
174                    role_text_ptr = role_text.as_ptr() as *const std::os::raw::c_char;
175
176                    super::RKLLMInput {
177                        input_type: super::RKLLMInputType_RKLLM_INPUT_PROMPT,
178                        enable_thinking: rkllm_input.enable_thinking,
179                        role: role_text_ptr,
180                        __bindgen_anon_1: super::RKLLMInput__bindgen_ty_1 {
181                            prompt_input: prompt_cstring_ptr,
182                        },
183                    }
184                }
185                RKLLMInputType::Token(_) => todo!(),
186                RKLLMInputType::Embed(_) => todo!(),
187                RKLLMInputType::Multimodal(_) => todo!(),
188            };
189
190            let prompt_cache_cstring;
191            let prompt_cache_cstring_ptr;
192
193            let lora_adapter_name;
194            let lora_adapter_name_ptr;
195            let mut loraparam;
196
197            let new_rkllm_infer_params: *mut super::RKLLMInferParam =
198                if let Some(rkllm_infer_params) = rkllm_infer_params {
199                    &mut super::RKLLMInferParam {
200                        keep_history: rkllm_infer_params.keep_history as i32,
201                        mode: rkllm_infer_params.mode.into(),
202                        lora_params: match rkllm_infer_params.lora_params {
203                            Some(a) => {
204                                lora_adapter_name = a;
205                                lora_adapter_name_ptr =
206                                    lora_adapter_name.as_ptr() as *const std::os::raw::c_char;
207                                loraparam = RKLLMLoraParam {
208                                    lora_adapter_name: lora_adapter_name_ptr,
209                                };
210                                &mut loraparam
211                            }
212                            None => null_mut(),
213                        },
214                        prompt_cache_params: if let Some(cache_params) =
215                            rkllm_infer_params.prompt_cache_params
216                        {
217                            prompt_cache_cstring =
218                                std::ffi::CString::new(cache_params.prompt_cache_path).unwrap();
219                            prompt_cache_cstring_ptr =
220                                prompt_cache_cstring.as_ptr() as *const std::os::raw::c_char;
221
222                            &mut super::RKLLMPromptCacheParam {
223                                save_prompt_cache: if cache_params.save_prompt_cache {
224                                    1
225                                } else {
226                                    0
227                                },
228                                prompt_cache_path: prompt_cache_cstring_ptr,
229                            }
230                        } else {
231                            null_mut()
232                        },
233                    }
234                } else {
235                    null_mut()
236                };
237
238            let ret = unsafe {
239                super::rkllm_run(
240                    self.handle,
241                    &mut input,
242                    new_rkllm_infer_params,
243                    userdata_ptr,
244                )
245            };
246            if ret == 0 {
247                return Ok(());
248            } else {
249                return Err(Box::new(std::io::Error::new(
250                    std::io::ErrorKind::Other,
251                    format!("rkllm_run returned non-zero: {}", ret),
252                )));
253            }
254        }
255
256        /// Loads a prompt cache from a file.
257        ///
258        /// # Parameters
259        /// - `cache_path`: The path to the prompt cache file.
260        pub fn load_prompt_cache(
261            &self,
262            cache_path: &str,
263        ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
264            let prompt_cache_path = std::ffi::CString::new(cache_path).unwrap();
265            let prompt_cache_path_ptr = prompt_cache_path.as_ptr() as *const std::os::raw::c_char;
266            let ret = unsafe { super::rkllm_load_prompt_cache(self.handle, prompt_cache_path_ptr) };
267            if ret == 0 {
268                return Ok(());
269            } else {
270                return Err(Box::new(std::io::Error::new(
271                    std::io::ErrorKind::Other,
272                    format!("rkllm_load_prompt_cache returned non-zero: {}", ret),
273                )));
274            }
275        }
276
277        /// Release a prompt cache from a file.
278        pub fn release_prompt_cache(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
279            let ret = unsafe { super::rkllm_release_prompt_cache(self.handle) };
280            if ret == 0 {
281                return Ok(());
282            } else {
283                return Err(Box::new(std::io::Error::new(
284                    std::io::ErrorKind::Other,
285                    format!("rkllm_release_prompt_cache returned non-zero: {}", ret),
286                )));
287            }
288        }
289
290        /// Aborts an ongoing LLM task
291        pub fn abort(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
292            let ret = unsafe { super::rkllm_abort(self.handle) };
293            if ret == 0 {
294                return Ok(());
295            } else {
296                return Err(Box::new(std::io::Error::new(
297                    std::io::ErrorKind::Other,
298                    format!("rkllm_abort returned non-zero: {}", ret),
299                )));
300            }
301        }
302
303        /// Checks if an LLM task is currently running.
304        pub fn is_running(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
305            let ret = unsafe { super::rkllm_is_running(self.handle) };
306            if ret == 0 {
307                return Ok(());
308            } else {
309                return Err(Box::new(std::io::Error::new(
310                    std::io::ErrorKind::Other,
311                    format!("rkllm_is_running returned non-zero: {}", ret),
312                )));
313            }
314        }
315
316        /// Loads a Lora adapter into the LLM.
317        pub fn load_lora(
318            &self,
319            lora_cfg: &RKLLMLoraAdapter,
320        ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
321            let lora_adapter_name_cstring =
322                std::ffi::CString::new(lora_cfg.lora_adapter_name.clone()).unwrap();
323            let lora_adapter_name_cstring_ptr =
324                lora_adapter_name_cstring.as_ptr() as *const std::os::raw::c_char;
325            let lora_adapter_path_cstring =
326                std::ffi::CString::new(lora_cfg.lora_adapter_path.clone()).unwrap();
327            let lora_adapter_path_cstring_ptr =
328                lora_adapter_path_cstring.as_ptr() as *const std::os::raw::c_char;
329            let mut param = super::RKLLMLoraAdapter {
330                lora_adapter_path: lora_adapter_path_cstring_ptr,
331                lora_adapter_name: lora_adapter_name_cstring_ptr,
332                scale: lora_cfg.scale,
333            };
334            let ret = unsafe { super::rkllm_load_lora(self.handle, &mut param) };
335            if ret == 0 {
336                return Ok(());
337            } else {
338                return Err(Box::new(std::io::Error::new(
339                    std::io::ErrorKind::Other,
340                    format!("rkllm_load_lora returned non-zero: {}", ret),
341                )));
342            }
343        }
344    }
345
346    /// Internal callback function to handle LLM results from the C library.
347    unsafe extern "C" fn callback_passtrough(
348        result: *mut super::RKLLMResult,
349        userdata: *mut ::std::os::raw::c_void,
350        state: super::LLMCallState,
351    ) -> i32 {
352        Arc::increment_strong_count(userdata); // We don't actually want to free it
353        let instance_data = unsafe { Arc::from_raw(userdata as *const InstanceData) };
354        let new_state = match state {
355            0 => LLMCallState::Normal,
356            1 => LLMCallState::Waiting,
357            2 => LLMCallState::Finish,
358            3 => LLMCallState::Error,
359            4 => LLMCallState::GetLastHiddenLayer,
360            _ => panic!("Unexpected LLMCallState"),
361        };
362
363        let new_result = if result.is_null() {
364            None
365        } else {
366            Some(RKLLMResult {
367                text: if (*result).text.is_null() {
368                    String::new()
369                } else {
370                    (unsafe { CStr::from_ptr((*result).text) })
371                        .to_str()
372                        .expect("Failed to convert C string")
373                        .to_owned()
374                        .clone()
375                },
376                token_id: (*result).token_id,
377                last_hidden_layer: (*result).last_hidden_layer,
378            })
379        };
380
381        instance_data
382            .callback_handler
383            .lock()
384            .unwrap()
385            .handle(new_result, new_state);
386        0
387    }
388
389    /// Initializes the LLM with the given parameters.
390    ///
391    /// # Parameters
392    /// - `param`: A pointer to the LLM configuration parameters.
393    ///
394    /// # Returns
395    /// If successful, returns a `Result` containing the `LLMHandle`; otherwise, returns an error.
396    pub fn rkllm_init(
397        param: *mut super::RKLLMParam,
398    ) -> Result<LLMHandle, Box<dyn std::error::Error + Send + Sync>> {
399        let mut handle = LLMHandle {
400            handle: std::ptr::null_mut(),
401        };
402
403        let callback: Option<
404            unsafe extern "C" fn(
405                *mut super::RKLLMResult,
406                *mut ::std::os::raw::c_void,
407                super::LLMCallState,
408            ) -> i32,
409        > = Some(callback_passtrough);
410        let ret = unsafe { super::rkllm_init(&mut handle.handle, param, callback) };
411        if ret == 0 {
412            return Ok(handle);
413        } else {
414            return Err(Box::new(std::io::Error::new(
415                std::io::ErrorKind::Other,
416                format!("rkllm_init returned non-zero: {}", ret),
417            )));
418        }
419    }
420
421    /// Represents different types of input that can be provided to the LLM.
422    pub struct RKLLMInput {
423        /// The type of input being provided to the LLM.
424        pub input_type: RKLLMInputType,
425        /// Whether to enable thinking during the inference.
426        pub enable_thinking: bool,
427        /// The role of the user providing the input.
428        pub role: RKLLMInputRole,
429    }
430
431    /// The type of input being provided to the LLM.
432    pub enum RKLLMInputType {
433        /// Input is a text prompt.
434        Prompt(String),
435        /// Input is a sequence of tokens.
436        Token(String),
437        /// Input is an embedding vector.
438        Embed(String),
439        /// Input is multimodal, such as text and image.
440        Multimodal(String),
441    }
442
443    /// The role of the user providing the input.
444    pub enum RKLLMInputRole {
445        /// User
446        User,
447        /// Tool
448        Tool,
449    }
450}