rkllm_rs/
lib.rs

1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
6
7pub mod prelude {
8    use std::ffi::CStr;
9    use std::ptr::null_mut;
10    use std::sync::Arc;
11    use std::sync::Mutex;
12
13    pub use super::RKLLMExtendParam;
14    pub use super::RKLLMLoraParam;
15    pub use super::RKLLMParam;
16    pub use super::RKLLMResultLastHiddenLayer;
17
18    #[derive(Debug, PartialEq, Eq)]
19    pub enum LLMCallState {
20        #[doc = "< The LLM call is in a normal running state."]
21        Normal = 0,
22        #[doc = "< The LLM call is waiting for complete UTF-8 encoded character."]
23        Waiting = 1,
24        #[doc = "< The LLM call has finished execution."]
25        Finish = 2,
26        #[doc = "< An error occurred during the LLM call."]
27        Error = 3,
28        #[doc = "< Retrieve the last hidden layer during inference."]
29        GetLastHiddenLayer = 4,
30    }
31
32    #[doc = " @struct RKLLMInferParam\n @brief Structure for defining parameters during inference."]
33    #[derive(Debug, Clone, Default)]
34    pub struct RKLLMInferParam {
35        #[doc = "< Inference mode (e.g., generate or get last hidden layer)."]
36        pub mode: RKLLMInferMode,
37        #[doc = "< Pointer to Lora adapter parameters."]
38        pub lora_params: Option<String>,
39        #[doc = "< Pointer to prompt cache parameters."]
40        pub prompt_cache_params: Option<RKLLMPromptCacheParam>,
41    }
42
43    #[derive(Debug, Copy, Clone, Default)]
44    pub enum RKLLMInferMode {
45        #[doc = "< The LLM generates text based on input."]
46        #[default]
47        InferGenerate = 0,
48        #[doc = "< The LLM retrieves the last hidden layer for further processing."]
49        InferGetLastHiddenLayer = 1,
50    }
51    impl Into<u32> for RKLLMInferMode {
52        fn into(self) -> u32 {
53            self as u32
54        }
55    }
56
57    #[doc = " @struct RKLLMPromptCacheParam\n @brief Structure to define parameters for caching prompts."]
58    #[derive(Debug, Clone)]
59    pub struct RKLLMPromptCacheParam {
60        #[doc = "< Flag to indicate whether to save the prompt cache (0 = don't save, 1 = save)."]
61        pub save_prompt_cache: bool,
62        #[doc = "< Path to the prompt cache file."]
63        pub prompt_cache_path: String,
64    }
65
66    impl Default for super::RKLLMParam {
67        fn default() -> Self {
68            unsafe { super::rkllm_createDefaultParam() }
69        }
70    }
71
72    #[doc = " @struct RKLLMResult\n @brief Structure to represent the result of LLM inference."]
73    #[derive(Debug, Clone)]
74    pub struct RKLLMResult {
75        #[doc = "< Generated text result."]
76        pub text: String,
77        #[doc = "< ID of the generated token."]
78        pub token_id: i32,
79        #[doc = "< Hidden states of the last layer (if requested)."]
80        pub last_hidden_layer: RKLLMResultLastHiddenLayer,
81    }
82
83    #[doc = " @struct LLMHandle\n @brief LLMHandle."]
84    #[derive(Clone, Debug, Copy)]
85    pub struct LLMHandle {
86        handle: super::LLMHandle,
87    }
88
89    unsafe impl Send for LLMHandle {} // Asserts the pointer is safe to send
90    unsafe impl Sync for LLMHandle {} // Asserts the pointer is safe to share
91
92    pub trait RkllmCallbackHandler {
93        fn handle(&mut self, result: Option<RKLLMResult>, state: LLMCallState);
94    }
95
96    pub struct InstanceData {
97        pub callback_handler: Arc<Mutex<dyn RkllmCallbackHandler + Send + Sync>>,
98    }
99
100    impl LLMHandle {
101        #[doc = " @brief Destroys the LLM instance and releases resources.\n @param handle LLM handle.\n @return Status code (0 for success, non-zero for failure)."]
102        pub fn destroy(&self) -> i32 {
103            unsafe { super::rkllm_destroy(self.handle) }
104        }
105
106        #[doc = " @brief Runs an LLM inference task asynchronously.\n @param handle LLM handle.\n @param rkllm_input Input data for the LLM.\n @param rkllm_infer_params Parameters for the inference task.\n @param userdata Pointer to user data for the callback.\n @return Status code (0 for success, non-zero for failure)."]
107        pub fn run(
108            &self,
109            rkllm_input: RKLLMInput,
110            rkllm_infer_params: Option<RKLLMInferParam>,
111            user_data: impl RkllmCallbackHandler + Send + Sync + 'static,
112        ) {
113            let instance_data = Arc::new(InstanceData {
114                callback_handler: Arc::new(Mutex::new(user_data)),
115            });
116
117            let userdata_ptr = Arc::into_raw(instance_data) as *mut std::ffi::c_void;
118            let prompt_cstring;
119            let prompt_cstring_ptr;
120            let mut input = match rkllm_input {
121                RKLLMInput::Prompt(prompt) => {
122                    prompt_cstring = std::ffi::CString::new(prompt).unwrap();
123                    prompt_cstring_ptr = prompt_cstring.as_ptr() as *const std::os::raw::c_char;
124                    super::RKLLMInput {
125                        input_type: super::RKLLMInputType_RKLLM_INPUT_PROMPT,
126                        __bindgen_anon_1: super::RKLLMInput__bindgen_ty_1 {
127                            prompt_input: prompt_cstring_ptr,
128                        },
129                    }
130                }
131                RKLLMInput::Token(_) => todo!(),
132                RKLLMInput::Embed(_) => todo!(),
133                RKLLMInput::Multimodal(_) => todo!(),
134            };
135
136            let prompt_cache_cstring;
137            let prompt_cache_cstring_ptr;
138
139            let lora_adapter_name;
140            let lora_adapter_name_ptr;
141            let mut loraparam;
142
143            let new_rkllm_infer_params: *mut super::RKLLMInferParam =
144                if let Some(rkllm_infer_params) = rkllm_infer_params {
145                    &mut super::RKLLMInferParam {
146                        mode: rkllm_infer_params.mode.into(),
147                        lora_params: match rkllm_infer_params.lora_params {
148                            Some(a) => {
149                                lora_adapter_name = a;
150                                lora_adapter_name_ptr = lora_adapter_name.as_ptr() as *const std::os::raw::c_char;
151                                loraparam = RKLLMLoraParam{
152                                    lora_adapter_name: lora_adapter_name_ptr
153                                };
154                                &mut loraparam
155                            }
156                            None => null_mut(),
157                        },
158                        prompt_cache_params: if let Some(cache_params) =
159                            rkllm_infer_params.prompt_cache_params
160                        {
161                            prompt_cache_cstring =
162                                std::ffi::CString::new(cache_params.prompt_cache_path).unwrap();
163                            prompt_cache_cstring_ptr = prompt_cache_cstring.as_ptr() as *const std::os::raw::c_char;
164
165                            &mut super::RKLLMPromptCacheParam {
166                                save_prompt_cache: if cache_params.save_prompt_cache {
167                                    1
168                                } else {
169                                    0
170                                },
171                                prompt_cache_path: prompt_cache_cstring_ptr,
172                            }
173                        } else {
174                            null_mut()
175                        },
176                    }
177                } else {
178                    null_mut()
179                };
180
181            unsafe {
182                super::rkllm_run(
183                    self.handle,
184                    &mut input,
185                    new_rkllm_infer_params,
186                    userdata_ptr,
187                )
188            };
189        }
190
191        #[doc = " @brief Loads a prompt cache from a file.\n @param handle LLM handle.\n @param prompt_cache_path Path to the prompt cache file.\n @return Status code (0 for success, non-zero for failure)."]
192        pub fn load_prompt_cache(&self, cache_path: &str) {
193            let prompt_cache_path = std::ffi::CString::new(cache_path).unwrap();
194            let prompt_cache_path_ptr = prompt_cache_path.as_ptr() as *const std::os::raw::c_char;
195            unsafe { super::rkllm_load_prompt_cache(self.handle, prompt_cache_path_ptr) };
196        }
197    }
198
199    unsafe extern "C" fn callback_passtrough(
200        result: *mut super::RKLLMResult,
201        userdata: *mut ::std::os::raw::c_void,
202        state: super::LLMCallState,
203    ) {
204        Arc::increment_strong_count(userdata); // 我們沒有真的要free掉它
205        let instance_data = unsafe { Arc::from_raw(userdata as *const InstanceData) };
206        let new_state = match state {
207            0 => LLMCallState::Normal,
208            1 => LLMCallState::Waiting,
209            2 => LLMCallState::Finish,
210            3 => LLMCallState::Error,
211            4 => LLMCallState::GetLastHiddenLayer,
212            _ => panic!("Not expect LLMCallState"),
213        };
214
215        let new_result = if result.is_null() {
216            None
217        } else {
218            Some(RKLLMResult {
219                text: if (*result).text.is_null() {
220                    String::new()
221                } else {
222                    (unsafe { CStr::from_ptr((*result).text) })
223                        .to_str()
224                        .expect("Convert cstr failed")
225                        .to_owned()
226                        .clone()
227                },
228                token_id: (*result).token_id,
229                last_hidden_layer: (*result).last_hidden_layer,
230            })
231        };
232
233        instance_data.callback_handler.lock().unwrap().handle(new_result, new_state);
234    }
235
236    #[doc = " @brief Initializes the LLM with the given parameters.\n @param handle Pointer to the LLM handle.\n @param param Configuration parameters for the LLM.\n @param callback Callback function to handle LLM results.\n @return Status code (0 for success, non-zero for failure)."]
237    pub fn rkllm_init(
238        param: *mut super::RKLLMParam,
239    ) -> Result<LLMHandle, Box<dyn std::error::Error + Send + Sync>> {
240        let mut handle = LLMHandle {
241            handle: std::ptr::null_mut(),
242        };
243
244        let callback: Option<
245            unsafe extern "C" fn(
246                *mut super::RKLLMResult,
247                *mut ::std::os::raw::c_void,
248                super::LLMCallState,
249            ),
250        > = Some(callback_passtrough);
251        let ret = unsafe { super::rkllm_init(&mut handle.handle, param, callback) };
252        if ret == 0 {
253            return Ok(handle);
254        } else {
255            return Err(Box::new(std::io::Error::new(
256                std::io::ErrorKind::Other,
257                format!("rkllm_init ret non zero: {}", ret),
258            )));
259        }
260    }
261
262    pub enum RKLLMInput {
263        #[doc = "< Input is a text prompt."]
264        Prompt(String),
265        #[doc = "< Input is a sequence of tokens."]
266        Token(String),
267        #[doc = "< Input is an embedding vector."]
268        Embed(String),
269        #[doc = "< Input is multimodal (e.g., text and image)."]
270        Multimodal(String),
271    }
272}