1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
6
7pub mod prelude {
8    use std::ffi::CStr;
9    use std::ptr::null_mut;
10    use std::sync::Arc;
11    use std::sync::Mutex;
12
13    pub use super::RKLLMExtendParam;
14    pub use super::RKLLMLoraParam;
15    pub use super::RKLLMParam;
16    pub use super::RKLLMResultLastHiddenLayer;
17
18    #[derive(Debug, PartialEq, Eq)]
20    pub enum LLMCallState {
21        Normal = 0,
23        Waiting = 1,
25        Finish = 2,
27        Error = 3,
29        GetLastHiddenLayer = 4,
31    }
32
33    #[derive(Debug, Clone, Default)]
34    pub enum KeepHistory {
35        #[default]
36        NoKeepHistory = 0,
38        KeepHistory = 1,
40    }
41
42    #[derive(Debug, Clone, Default)]
44    pub struct RKLLMInferParam {
45        pub mode: RKLLMInferMode,
47        pub lora_params: Option<String>,
49        pub prompt_cache_params: Option<RKLLMPromptCacheParam>,
51        pub keep_history: KeepHistory,
52    }
53
54    #[derive(Debug, Copy, Clone, Default)]
56    pub enum RKLLMInferMode {
57        #[default]
59        InferGenerate = 0,
60        InferGetLastHiddenLayer = 1,
62    }
63
64    impl Into<u32> for RKLLMInferMode {
65        fn into(self) -> u32 {
67            self as u32
68        }
69    }
70
71    #[derive(Debug, Clone)]
73    pub struct RKLLMPromptCacheParam {
74        pub save_prompt_cache: bool,
76        pub prompt_cache_path: String,
78    }
79
80    impl Default for super::RKLLMParam {
81        fn default() -> Self {
83            unsafe { super::rkllm_createDefaultParam() }
84        }
85    }
86
87    #[derive(Debug, Clone)]
89    pub struct RKLLMResult {
90        pub text: String,
92        pub token_id: i32,
94        pub last_hidden_layer: RKLLMResultLastHiddenLayer,
96    }
97
98    #[derive(Debug, Clone)]
99    pub struct RKLLMLoraAdapter {
100        pub lora_adapter_path: String,
101        pub lora_adapter_name: String,
102        pub scale: f32,
103    }
104
105    #[derive(Clone, Debug, Copy)]
107    pub struct LLMHandle {
108        handle: super::LLMHandle,
109    }
110
111    unsafe impl Send for LLMHandle {} unsafe impl Sync for LLMHandle {} pub trait RkllmCallbackHandler {
116        fn handle(&mut self, result: Option<RKLLMResult>, state: LLMCallState);
118    }
119
120    pub struct InstanceData {
122        pub callback_handler: Arc<Mutex<dyn RkllmCallbackHandler + Send + Sync>>,
124    }
125
126    impl LLMHandle {
127        pub fn destroy(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
129            let ret = unsafe { super::rkllm_destroy(self.handle) };
130
131            if ret == 0 {
132                return Ok(());
133            } else {
134                return Err(Box::new(std::io::Error::new(
135                    std::io::ErrorKind::Other,
136                    format!("rkllm_run returned non-zero: {}", ret),
137                )));
138            }
139        }
140
141        pub fn run(
151            &self,
152            rkllm_input: RKLLMInput,
153            rkllm_infer_params: Option<RKLLMInferParam>,
154            user_data: impl RkllmCallbackHandler + Send + Sync + 'static,
155        ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
156            let instance_data = Arc::new(InstanceData {
157                callback_handler: Arc::new(Mutex::new(user_data)),
158            });
159
160            let userdata_ptr = Arc::into_raw(instance_data) as *mut std::ffi::c_void;
161            let prompt_cstring;
162            let prompt_cstring_ptr;
163            let mut input = match rkllm_input {
164                RKLLMInput::Prompt(prompt) => {
165                    prompt_cstring = std::ffi::CString::new(prompt).unwrap();
166                    prompt_cstring_ptr = prompt_cstring.as_ptr() as *const std::os::raw::c_char;
167                    super::RKLLMInput {
168                        input_type: super::RKLLMInputType_RKLLM_INPUT_PROMPT,
169                        __bindgen_anon_1: super::RKLLMInput__bindgen_ty_1 {
170                            prompt_input: prompt_cstring_ptr,
171                        },
172                    }
173                }
174                RKLLMInput::Token(_) => todo!(),
175                RKLLMInput::Embed(_) => todo!(),
176                RKLLMInput::Multimodal(_) => todo!(),
177            };
178
179            let prompt_cache_cstring;
180            let prompt_cache_cstring_ptr;
181
182            let lora_adapter_name;
183            let lora_adapter_name_ptr;
184            let mut loraparam;
185
186            let new_rkllm_infer_params: *mut super::RKLLMInferParam =
187                if let Some(rkllm_infer_params) = rkllm_infer_params {
188                    &mut super::RKLLMInferParam {
189                        keep_history: rkllm_infer_params.keep_history as i32,
190                        mode: rkllm_infer_params.mode.into(),
191                        lora_params: match rkllm_infer_params.lora_params {
192                            Some(a) => {
193                                lora_adapter_name = a;
194                                lora_adapter_name_ptr =
195                                    lora_adapter_name.as_ptr() as *const std::os::raw::c_char;
196                                loraparam = RKLLMLoraParam {
197                                    lora_adapter_name: lora_adapter_name_ptr,
198                                };
199                                &mut loraparam
200                            }
201                            None => null_mut(),
202                        },
203                        prompt_cache_params: if let Some(cache_params) =
204                            rkllm_infer_params.prompt_cache_params
205                        {
206                            prompt_cache_cstring =
207                                std::ffi::CString::new(cache_params.prompt_cache_path).unwrap();
208                            prompt_cache_cstring_ptr =
209                                prompt_cache_cstring.as_ptr() as *const std::os::raw::c_char;
210
211                            &mut super::RKLLMPromptCacheParam {
212                                save_prompt_cache: if cache_params.save_prompt_cache {
213                                    1
214                                } else {
215                                    0
216                                },
217                                prompt_cache_path: prompt_cache_cstring_ptr,
218                            }
219                        } else {
220                            null_mut()
221                        },
222                    }
223                } else {
224                    null_mut()
225                };
226
227            let ret = unsafe {
228                super::rkllm_run(
229                    self.handle,
230                    &mut input,
231                    new_rkllm_infer_params,
232                    userdata_ptr,
233                )
234            };
235            if ret == 0 {
236                return Ok(());
237            } else {
238                return Err(Box::new(std::io::Error::new(
239                    std::io::ErrorKind::Other,
240                    format!("rkllm_run returned non-zero: {}", ret),
241                )));
242            }
243        }
244
245        pub fn load_prompt_cache(
250            &self,
251            cache_path: &str,
252        ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
253            let prompt_cache_path = std::ffi::CString::new(cache_path).unwrap();
254            let prompt_cache_path_ptr = prompt_cache_path.as_ptr() as *const std::os::raw::c_char;
255            let ret = unsafe { super::rkllm_load_prompt_cache(self.handle, prompt_cache_path_ptr) };
256            if ret == 0 {
257                return Ok(());
258            } else {
259                return Err(Box::new(std::io::Error::new(
260                    std::io::ErrorKind::Other,
261                    format!("rkllm_load_prompt_cache returned non-zero: {}", ret),
262                )));
263            }
264        }
265
266        pub fn release_prompt_cache(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
268            let ret = unsafe { super::rkllm_release_prompt_cache(self.handle) };
269            if ret == 0 {
270                return Ok(());
271            } else {
272                return Err(Box::new(std::io::Error::new(
273                    std::io::ErrorKind::Other,
274                    format!("rkllm_release_prompt_cache returned non-zero: {}", ret),
275                )));
276            }
277        }
278
279        pub fn abort(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
281            let ret = unsafe { super::rkllm_abort(self.handle) };
282            if ret == 0 {
283                return Ok(());
284            } else {
285                return Err(Box::new(std::io::Error::new(
286                    std::io::ErrorKind::Other,
287                    format!("rkllm_abort returned non-zero: {}", ret),
288                )));
289            }
290        }
291
292        pub fn is_running(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
294            let ret = unsafe { super::rkllm_is_running(self.handle) };
295            if ret == 0 {
296                return Ok(());
297            } else {
298                return Err(Box::new(std::io::Error::new(
299                    std::io::ErrorKind::Other,
300                    format!("rkllm_is_running returned non-zero: {}", ret),
301                )));
302            }
303        }
304
305        pub fn load_lora(
307            &self,
308            lora_cfg: &RKLLMLoraAdapter,
309        ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
310            let lora_adapter_name_cstring =
311                std::ffi::CString::new(lora_cfg.lora_adapter_name.clone()).unwrap();
312            let lora_adapter_name_cstring_ptr =
313                lora_adapter_name_cstring.as_ptr() as *const std::os::raw::c_char;
314            let lora_adapter_path_cstring =
315                std::ffi::CString::new(lora_cfg.lora_adapter_path.clone()).unwrap();
316            let lora_adapter_path_cstring_ptr =
317                lora_adapter_path_cstring.as_ptr() as *const std::os::raw::c_char;
318            let mut param = super::RKLLMLoraAdapter {
319                lora_adapter_path: lora_adapter_path_cstring_ptr,
320                lora_adapter_name: lora_adapter_name_cstring_ptr,
321                scale: lora_cfg.scale,
322            };
323            let ret = unsafe { super::rkllm_load_lora(self.handle, &mut param) };
324            if ret == 0 {
325                return Ok(());
326            } else {
327                return Err(Box::new(std::io::Error::new(
328                    std::io::ErrorKind::Other,
329                    format!("rkllm_load_lora returned non-zero: {}", ret),
330                )));
331            }
332        }
333    }
334
335    unsafe extern "C" fn callback_passtrough(
337        result: *mut super::RKLLMResult,
338        userdata: *mut ::std::os::raw::c_void,
339        state: super::LLMCallState,
340    ) {
341        Arc::increment_strong_count(userdata); let instance_data = unsafe { Arc::from_raw(userdata as *const InstanceData) };
343        let new_state = match state {
344            0 => LLMCallState::Normal,
345            1 => LLMCallState::Waiting,
346            2 => LLMCallState::Finish,
347            3 => LLMCallState::Error,
348            4 => LLMCallState::GetLastHiddenLayer,
349            _ => panic!("Unexpected LLMCallState"),
350        };
351
352        let new_result = if result.is_null() {
353            None
354        } else {
355            Some(RKLLMResult {
356                text: if (*result).text.is_null() {
357                    String::new()
358                } else {
359                    (unsafe { CStr::from_ptr((*result).text) })
360                        .to_str()
361                        .expect("Failed to convert C string")
362                        .to_owned()
363                        .clone()
364                },
365                token_id: (*result).token_id,
366                last_hidden_layer: (*result).last_hidden_layer,
367            })
368        };
369
370        instance_data
371            .callback_handler
372            .lock()
373            .unwrap()
374            .handle(new_result, new_state);
375    }
376
377    pub fn rkllm_init(
385        param: *mut super::RKLLMParam,
386    ) -> Result<LLMHandle, Box<dyn std::error::Error + Send + Sync>> {
387        let mut handle = LLMHandle {
388            handle: std::ptr::null_mut(),
389        };
390
391        let callback: Option<
392            unsafe extern "C" fn(
393                *mut super::RKLLMResult,
394                *mut ::std::os::raw::c_void,
395                super::LLMCallState,
396            ),
397        > = Some(callback_passtrough);
398        let ret = unsafe { super::rkllm_init(&mut handle.handle, param, callback) };
399        if ret == 0 {
400            return Ok(handle);
401        } else {
402            return Err(Box::new(std::io::Error::new(
403                std::io::ErrorKind::Other,
404                format!("rkllm_init returned non-zero: {}", ret),
405            )));
406        }
407    }
408
409    pub enum RKLLMInput {
411        Prompt(String),
413        Token(String),
415        Embed(String),
417        Multimodal(String),
419    }
420}