1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
6
7pub mod prelude {
8 use std::ffi::CStr;
9 use std::ptr::null_mut;
10 use std::sync::Arc;
11 use std::sync::Mutex;
12
13 pub use super::RKLLMExtendParam;
14 pub use super::RKLLMLoraParam;
15 pub use super::RKLLMParam;
16 pub use super::RKLLMResultLastHiddenLayer;
17
18 #[derive(Debug, PartialEq, Eq)]
19 pub enum LLMCallState {
20 #[doc = "< The LLM call is in a normal running state."]
21 Normal = 0,
22 #[doc = "< The LLM call is waiting for complete UTF-8 encoded character."]
23 Waiting = 1,
24 #[doc = "< The LLM call has finished execution."]
25 Finish = 2,
26 #[doc = "< An error occurred during the LLM call."]
27 Error = 3,
28 #[doc = "< Retrieve the last hidden layer during inference."]
29 GetLastHiddenLayer = 4,
30 }
31
32 #[doc = " @struct RKLLMInferParam\n @brief Structure for defining parameters during inference."]
33 #[derive(Debug, Clone, Default)]
34 pub struct RKLLMInferParam {
35 #[doc = "< Inference mode (e.g., generate or get last hidden layer)."]
36 pub mode: RKLLMInferMode,
37 #[doc = "< Pointer to Lora adapter parameters."]
38 pub lora_params: Option<String>,
39 #[doc = "< Pointer to prompt cache parameters."]
40 pub prompt_cache_params: Option<RKLLMPromptCacheParam>,
41 }
42
43 #[derive(Debug, Copy, Clone, Default)]
44 pub enum RKLLMInferMode {
45 #[doc = "< The LLM generates text based on input."]
46 #[default]
47 InferGenerate = 0,
48 #[doc = "< The LLM retrieves the last hidden layer for further processing."]
49 InferGetLastHiddenLayer = 1,
50 }
51 impl Into<u32> for RKLLMInferMode {
52 fn into(self) -> u32 {
53 self as u32
54 }
55 }
56
57 #[doc = " @struct RKLLMPromptCacheParam\n @brief Structure to define parameters for caching prompts."]
58 #[derive(Debug, Clone)]
59 pub struct RKLLMPromptCacheParam {
60 #[doc = "< Flag to indicate whether to save the prompt cache (0 = don't save, 1 = save)."]
61 pub save_prompt_cache: bool,
62 #[doc = "< Path to the prompt cache file."]
63 pub prompt_cache_path: String,
64 }
65
66 impl Default for super::RKLLMParam {
67 fn default() -> Self {
68 unsafe { super::rkllm_createDefaultParam() }
69 }
70 }
71
72 #[doc = " @struct RKLLMResult\n @brief Structure to represent the result of LLM inference."]
73 #[derive(Debug, Clone)]
74 pub struct RKLLMResult {
75 #[doc = "< Generated text result."]
76 pub text: String,
77 #[doc = "< ID of the generated token."]
78 pub token_id: i32,
79 #[doc = "< Hidden states of the last layer (if requested)."]
80 pub last_hidden_layer: RKLLMResultLastHiddenLayer,
81 }
82
83 #[doc = " @struct LLMHandle\n @brief LLMHandle."]
84 #[derive(Clone, Debug, Copy)]
85 pub struct LLMHandle {
86 handle: super::LLMHandle,
87 }
88
89 unsafe impl Send for LLMHandle {} unsafe impl Sync for LLMHandle {} pub trait RkllmCallbackHandler {
93 fn handle(&mut self, result: Option<RKLLMResult>, state: LLMCallState);
94 }
95
96 pub struct InstanceData {
97 pub callback_handler: Arc<Mutex<dyn RkllmCallbackHandler + Send + Sync>>,
98 }
99
100 impl LLMHandle {
101 #[doc = " @brief Destroys the LLM instance and releases resources.\n @param handle LLM handle.\n @return Status code (0 for success, non-zero for failure)."]
102 pub fn destroy(&self) -> i32 {
103 unsafe { super::rkllm_destroy(self.handle) }
104 }
105
106 #[doc = " @brief Runs an LLM inference task asynchronously.\n @param handle LLM handle.\n @param rkllm_input Input data for the LLM.\n @param rkllm_infer_params Parameters for the inference task.\n @param userdata Pointer to user data for the callback.\n @return Status code (0 for success, non-zero for failure)."]
107 pub fn run(
108 &self,
109 rkllm_input: RKLLMInput,
110 rkllm_infer_params: Option<RKLLMInferParam>,
111 user_data: impl RkllmCallbackHandler + Send + Sync + 'static,
112 ) {
113 let instance_data = Arc::new(InstanceData {
114 callback_handler: Arc::new(Mutex::new(user_data)),
115 });
116
117 let userdata_ptr = Arc::into_raw(instance_data) as *mut std::ffi::c_void;
118 let prompt_cstring;
119 let prompt_cstring_ptr;
120 let mut input = match rkllm_input {
121 RKLLMInput::Prompt(prompt) => {
122 prompt_cstring = std::ffi::CString::new(prompt).unwrap();
123 prompt_cstring_ptr = prompt_cstring.as_ptr() as *const std::os::raw::c_char;
124 super::RKLLMInput {
125 input_type: super::RKLLMInputType_RKLLM_INPUT_PROMPT,
126 __bindgen_anon_1: super::RKLLMInput__bindgen_ty_1 {
127 prompt_input: prompt_cstring_ptr,
128 },
129 }
130 }
131 RKLLMInput::Token(_) => todo!(),
132 RKLLMInput::Embed(_) => todo!(),
133 RKLLMInput::Multimodal(_) => todo!(),
134 };
135
136 let prompt_cache_cstring;
137 let prompt_cache_cstring_ptr;
138
139 let lora_adapter_name;
140 let lora_adapter_name_ptr;
141 let mut loraparam;
142
143 let new_rkllm_infer_params: *mut super::RKLLMInferParam =
144 if let Some(rkllm_infer_params) = rkllm_infer_params {
145 &mut super::RKLLMInferParam {
146 mode: rkllm_infer_params.mode.into(),
147 lora_params: match rkllm_infer_params.lora_params {
148 Some(a) => {
149 lora_adapter_name = a;
150 lora_adapter_name_ptr = lora_adapter_name.as_ptr() as *const std::os::raw::c_char;
151 loraparam = RKLLMLoraParam{
152 lora_adapter_name: lora_adapter_name_ptr
153 };
154 &mut loraparam
155 }
156 None => null_mut(),
157 },
158 prompt_cache_params: if let Some(cache_params) =
159 rkllm_infer_params.prompt_cache_params
160 {
161 prompt_cache_cstring =
162 std::ffi::CString::new(cache_params.prompt_cache_path).unwrap();
163 prompt_cache_cstring_ptr = prompt_cache_cstring.as_ptr() as *const std::os::raw::c_char;
164
165 &mut super::RKLLMPromptCacheParam {
166 save_prompt_cache: if cache_params.save_prompt_cache {
167 1
168 } else {
169 0
170 },
171 prompt_cache_path: prompt_cache_cstring_ptr,
172 }
173 } else {
174 null_mut()
175 },
176 }
177 } else {
178 null_mut()
179 };
180
181 unsafe {
182 super::rkllm_run(
183 self.handle,
184 &mut input,
185 new_rkllm_infer_params,
186 userdata_ptr,
187 )
188 };
189 }
190
191 #[doc = " @brief Loads a prompt cache from a file.\n @param handle LLM handle.\n @param prompt_cache_path Path to the prompt cache file.\n @return Status code (0 for success, non-zero for failure)."]
192 pub fn load_prompt_cache(&self, cache_path: &str) {
193 let prompt_cache_path = std::ffi::CString::new(cache_path).unwrap();
194 let prompt_cache_path_ptr = prompt_cache_path.as_ptr() as *const std::os::raw::c_char;
195 unsafe { super::rkllm_load_prompt_cache(self.handle, prompt_cache_path_ptr) };
196 }
197 }
198
199 unsafe extern "C" fn callback_passtrough(
200 result: *mut super::RKLLMResult,
201 userdata: *mut ::std::os::raw::c_void,
202 state: super::LLMCallState,
203 ) {
204 Arc::increment_strong_count(userdata); let instance_data = unsafe { Arc::from_raw(userdata as *const InstanceData) };
206 let new_state = match state {
207 0 => LLMCallState::Normal,
208 1 => LLMCallState::Waiting,
209 2 => LLMCallState::Finish,
210 3 => LLMCallState::Error,
211 4 => LLMCallState::GetLastHiddenLayer,
212 _ => panic!("Not expect LLMCallState"),
213 };
214
215 let new_result = if result.is_null() {
216 None
217 } else {
218 Some(RKLLMResult {
219 text: if (*result).text.is_null() {
220 String::new()
221 } else {
222 (unsafe { CStr::from_ptr((*result).text) })
223 .to_str()
224 .expect("Convert cstr failed")
225 .to_owned()
226 .clone()
227 },
228 token_id: (*result).token_id,
229 last_hidden_layer: (*result).last_hidden_layer,
230 })
231 };
232
233 instance_data.callback_handler.lock().unwrap().handle(new_result, new_state);
234 }
235
236 #[doc = " @brief Initializes the LLM with the given parameters.\n @param handle Pointer to the LLM handle.\n @param param Configuration parameters for the LLM.\n @param callback Callback function to handle LLM results.\n @return Status code (0 for success, non-zero for failure)."]
237 pub fn rkllm_init(
238 param: *mut super::RKLLMParam,
239 ) -> Result<LLMHandle, Box<dyn std::error::Error + Send + Sync>> {
240 let mut handle = LLMHandle {
241 handle: std::ptr::null_mut(),
242 };
243
244 let callback: Option<
245 unsafe extern "C" fn(
246 *mut super::RKLLMResult,
247 *mut ::std::os::raw::c_void,
248 super::LLMCallState,
249 ),
250 > = Some(callback_passtrough);
251 let ret = unsafe { super::rkllm_init(&mut handle.handle, param, callback) };
252 if ret == 0 {
253 return Ok(handle);
254 } else {
255 return Err(Box::new(std::io::Error::new(
256 std::io::ErrorKind::Other,
257 format!("rkllm_init ret non zero: {}", ret),
258 )));
259 }
260 }
261
262 pub enum RKLLMInput {
263 #[doc = "< Input is a text prompt."]
264 Prompt(String),
265 #[doc = "< Input is a sequence of tokens."]
266 Token(String),
267 #[doc = "< Input is an embedding vector."]
268 Embed(String),
269 #[doc = "< Input is multimodal (e.g., text and image)."]
270 Multimodal(String),
271 }
272}