rkllm_rs/
lib.rs

1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
6
7pub mod prelude {
8
9    use std::ffi::CStr;
10    use std::ptr::null_mut;
11
12    pub use super::RKLLMExtendParam;
13    pub use super::RKLLMLoraParam;
14    pub use super::RKLLMParam;
15    pub use super::RKLLMResultLastHiddenLayer;
16
17    type rkllm_callback =
18        fn(result: Option<RKLLMResult>, userdata: *mut ::std::os::raw::c_void, state: LLMCallState);
19
20    #[derive(Debug, PartialEq, Eq)]
21    pub enum LLMCallState {
22        #[doc = "< The LLM call is in a normal running state."]
23        Normal = 0,
24        #[doc = "< The LLM call is waiting for complete UTF-8 encoded character."]
25        Waiting = 1,
26        #[doc = "< The LLM call has finished execution."]
27        Finish = 2,
28        #[doc = "< An error occurred during the LLM call."]
29        Error = 3,
30        #[doc = "< Retrieve the last hidden layer during inference."]
31        GetLastHiddenLayer = 4,
32    }
33
34    #[doc = " @struct RKLLMInferParam\n @brief Structure for defining parameters during inference."]
35    #[derive(Debug, Clone)]
36    pub struct RKLLMInferParam {
37        #[doc = "< Inference mode (e.g., generate or get last hidden layer)."]
38        pub mode: RKLLMInferMode,
39        #[doc = "< Pointer to Lora adapter parameters."]
40        pub lora_params: Option<*mut RKLLMLoraParam>,
41        #[doc = "< Pointer to prompt cache parameters."]
42        pub prompt_cache_params: Option<RKLLMPromptCacheParam>,
43    }
44
45    #[derive(Debug, Copy, Clone)]
46    pub enum RKLLMInferMode {
47        #[doc = "< The LLM generates text based on input."]
48        InferGenerate = 0,
49        #[doc = "< The LLM retrieves the last hidden layer for further processing."]
50        InferGetLastHiddenLayer = 1,
51    }
52    impl Into<u32> for RKLLMInferMode {
53        fn into(self) -> u32 {
54            self as u32
55        }
56    }
57
58    #[doc = " @struct RKLLMPromptCacheParam\n @brief Structure to define parameters for caching prompts."]
59    #[derive(Debug, Clone)]
60    pub struct RKLLMPromptCacheParam {
61        #[doc = "< Flag to indicate whether to save the prompt cache (0 = don't save, 1 = save)."]
62        pub save_prompt_cache: bool,
63        #[doc = "< Path to the prompt cache file."]
64        pub prompt_cache_path: String,
65    }
66
67    impl Default for super::RKLLMParam {
68        fn default() -> Self {
69            unsafe { super::rkllm_createDefaultParam() }
70        }
71    }
72
73    #[doc = " @struct RKLLMResult\n @brief Structure to represent the result of LLM inference."]
74    #[derive(Debug, Clone)]
75    pub struct RKLLMResult {
76        #[doc = "< Generated text result."]
77        pub text: String,
78        #[doc = "< ID of the generated token."]
79        pub token_id: i32,
80        #[doc = "< Hidden states of the last layer (if requested)."]
81        pub last_hidden_layer: RKLLMResultLastHiddenLayer,
82    }
83
84    #[doc = " @struct LLMHandle\n @brief LLMHandle."]
85    pub struct LLMHandle {
86        handle: super::LLMHandle,
87    }
88
89    impl LLMHandle {
90        #[doc = " @brief Destroys the LLM instance and releases resources.\n @param handle LLM handle.\n @return Status code (0 for success, non-zero for failure)."]
91        pub fn destroy(&self) -> i32 {
92            unsafe { super::rkllm_destroy(self.handle) }
93        }
94
95        #[doc = " @brief Runs an LLM inference task asynchronously.\n @param handle LLM handle.\n @param rkllm_input Input data for the LLM.\n @param rkllm_infer_params Parameters for the inference task.\n @param userdata Pointer to user data for the callback.\n @return Status code (0 for success, non-zero for failure)."]
96        pub fn run(
97            &self,
98            rkllm_input: RKLLMInput,
99            rkllm_infer_params: Option<RKLLMInferParam>,
100            userdata: *mut ::std::os::raw::c_void,
101        ) {
102            let prompt_cstring;
103            let prompt_cstring_ptr;
104            let mut input = match rkllm_input {
105                RKLLMInput::Prompt(prompt) => {
106                    prompt_cstring = std::ffi::CString::new(prompt).unwrap();
107                    prompt_cstring_ptr = prompt_cstring.as_ptr();
108                    super::RKLLMInput {
109                        input_type: super::RKLLMInputType_RKLLM_INPUT_PROMPT,
110                        __bindgen_anon_1: super::RKLLMInput__bindgen_ty_1 {
111                            prompt_input: prompt_cstring_ptr,
112                        },
113                    }
114                }
115                RKLLMInput::Token(_) => todo!(),
116                RKLLMInput::Embed(_) => todo!(),
117                RKLLMInput::Multimodal(_) => todo!(),
118            };
119
120            let prompt_cache_cstring;
121            let prompt_cache_cstring_ptr;
122
123            let new_rkllm_infer_params: *mut super::RKLLMInferParam =
124                if let Some(rkllm_infer_params) = rkllm_infer_params {
125                    &mut super::RKLLMInferParam {
126                        mode: rkllm_infer_params.mode.into(),
127                        lora_params: match rkllm_infer_params.lora_params {
128                            Some(a) => a,
129                            None => null_mut(),
130                        },
131                        prompt_cache_params: if let Some(cache_params) =
132                            rkllm_infer_params.prompt_cache_params
133                        {
134                            prompt_cache_cstring =
135                                std::ffi::CString::new(cache_params.prompt_cache_path).unwrap();
136                            prompt_cache_cstring_ptr = prompt_cache_cstring.as_ptr();
137
138                            &mut super::RKLLMPromptCacheParam {
139                                save_prompt_cache: if cache_params.save_prompt_cache {
140                                    1
141                                } else {
142                                    0
143                                },
144                                prompt_cache_path: prompt_cache_cstring_ptr,
145                            }
146                        } else {
147                            null_mut()
148                        },
149                    }
150                } else {
151                    null_mut()
152                };
153
154            unsafe { super::rkllm_run(self.handle, &mut input, new_rkllm_infer_params, userdata) };
155        }
156
157        #[doc = " @brief Loads a prompt cache from a file.\n @param handle LLM handle.\n @param prompt_cache_path Path to the prompt cache file.\n @return Status code (0 for success, non-zero for failure)."]
158        pub fn load_prompt_cache(&self, cache_path: &str) {
159            let prompt_cache_path = std::ffi::CString::new(cache_path).unwrap();
160            let prompt_cache_path_ptr = prompt_cache_path.as_ptr();
161            unsafe { super::rkllm_load_prompt_cache(self.handle, prompt_cache_path_ptr) };
162        }
163    }
164
165    static mut CALLBACK: Option<rkllm_callback> = None;
166
167    unsafe extern "C" fn callback_passtrough(
168        result: *mut super::RKLLMResult,
169        userdata: *mut ::std::os::raw::c_void,
170        state: super::LLMCallState,
171    ) {
172        let new_state = match state {
173            0 => LLMCallState::Normal,
174            1 => LLMCallState::Waiting,
175            2 => LLMCallState::Finish,
176            3 => LLMCallState::Error,
177            4 => LLMCallState::GetLastHiddenLayer,
178            _ => panic!("Not expect LLMCallState"),
179        };
180
181        let new_result = if result.is_null() {
182            None
183        } else {
184            Some(RKLLMResult {
185                text: if (*result).text.is_null() {
186                    String::new()
187                } else {
188                    (unsafe { CStr::from_ptr((*result).text) })
189                        .to_str()
190                        .expect("Convert cstr failed")
191                        .to_owned()
192                        .clone()
193                },
194                token_id: (*result).token_id,
195                last_hidden_layer: (*result).last_hidden_layer,
196            })
197        };
198
199        if let Some(callback) = CALLBACK {
200            callback(new_result, userdata, new_state);
201        }
202    }
203
204    #[doc = " @brief Initializes the LLM with the given parameters.\n @param handle Pointer to the LLM handle.\n @param param Configuration parameters for the LLM.\n @param callback Callback function to handle LLM results.\n @return Status code (0 for success, non-zero for failure)."]
205    pub fn rkllm_init(
206        param: *mut super::RKLLMParam,
207        callback: rkllm_callback,
208    ) -> Result<LLMHandle, i32> {
209        let mut handle = LLMHandle {
210            handle: std::ptr::null_mut(),
211        };
212        unsafe { CALLBACK = Some(callback) };
213        let callback: Option<
214            unsafe extern "C" fn(
215                *mut super::RKLLMResult,
216                *mut ::std::os::raw::c_void,
217                super::LLMCallState,
218            ),
219        > = Some(callback_passtrough);
220        let ret = unsafe { super::rkllm_init(&mut handle.handle, param, callback) };
221        if ret == 0 {
222            return Ok(handle);
223        } else {
224            return Err(ret);
225        }
226    }
227
228    pub enum RKLLMInput {
229        #[doc = "< Input is a text prompt."]
230        Prompt(String),
231        #[doc = "< Input is a sequence of tokens."]
232        Token(String),
233        #[doc = "< Input is an embedding vector."]
234        Embed(String),
235        #[doc = "< Input is multimodal (e.g., text and image)."]
236        Multimodal(String),
237    }
238}