1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
6
7pub mod prelude {
8 use std::ffi::CStr;
9 use std::ptr::null_mut;
10 use std::sync::Arc;
11 use std::sync::Mutex;
12
13 pub use super::RKLLMExtendParam;
14 pub use super::RKLLMLoraParam;
15 pub use super::RKLLMParam;
16 pub use super::RKLLMResultLastHiddenLayer;
17
18 #[derive(Debug, PartialEq, Eq)]
20 pub enum LLMCallState {
21 Normal = 0,
23 Waiting = 1,
25 Finish = 2,
27 Error = 3,
29 GetLastHiddenLayer = 4,
31 }
32
33 #[derive(Debug, Clone, Default)]
34 pub enum KeepHistory {
35 #[default]
36 NoKeepHistory = 0,
38 KeepHistory = 1,
40 }
41
42 #[derive(Debug, Clone, Default)]
44 pub struct RKLLMInferParam {
45 pub mode: RKLLMInferMode,
47 pub lora_params: Option<String>,
49 pub prompt_cache_params: Option<RKLLMPromptCacheParam>,
51 pub keep_history: KeepHistory,
52 }
53
54 #[derive(Debug, Copy, Clone, Default)]
56 pub enum RKLLMInferMode {
57 #[default]
59 InferGenerate = 0,
60 InferGetLastHiddenLayer = 1,
62 }
63
64 impl Into<u32> for RKLLMInferMode {
65 fn into(self) -> u32 {
67 self as u32
68 }
69 }
70
71 #[derive(Debug, Clone)]
73 pub struct RKLLMPromptCacheParam {
74 pub save_prompt_cache: bool,
76 pub prompt_cache_path: String,
78 }
79
80 impl Default for super::RKLLMParam {
81 fn default() -> Self {
83 unsafe { super::rkllm_createDefaultParam() }
84 }
85 }
86
87 #[derive(Debug, Clone)]
89 pub struct RKLLMResult {
90 pub text: String,
92 pub token_id: i32,
94 pub last_hidden_layer: RKLLMResultLastHiddenLayer,
96 }
97
98 #[derive(Debug, Clone)]
99 pub struct RKLLMLoraAdapter {
100 pub lora_adapter_path: String,
101 pub lora_adapter_name: String,
102 pub scale: f32,
103 }
104
105 #[derive(Clone, Debug, Copy)]
107 pub struct LLMHandle {
108 handle: super::LLMHandle,
109 }
110
111 unsafe impl Send for LLMHandle {} unsafe impl Sync for LLMHandle {} pub trait RkllmCallbackHandler {
116 fn handle(&mut self, result: Option<RKLLMResult>, state: LLMCallState);
118 }
119
120 pub struct InstanceData {
122 pub callback_handler: Arc<Mutex<dyn RkllmCallbackHandler + Send + Sync>>,
124 }
125
126 impl LLMHandle {
127 pub fn destroy(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
129 let ret = unsafe { super::rkllm_destroy(self.handle) };
130
131 if ret == 0 {
132 return Ok(());
133 } else {
134 return Err(Box::new(std::io::Error::new(
135 std::io::ErrorKind::Other,
136 format!("rkllm_run returned non-zero: {}", ret),
137 )));
138 }
139 }
140
141 pub fn run(
151 &self,
152 rkllm_input: RKLLMInput,
153 rkllm_infer_params: Option<RKLLMInferParam>,
154 user_data: impl RkllmCallbackHandler + Send + Sync + 'static,
155 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
156 let instance_data = Arc::new(InstanceData {
157 callback_handler: Arc::new(Mutex::new(user_data)),
158 });
159
160 let userdata_ptr = Arc::into_raw(instance_data) as *mut std::ffi::c_void;
161 let prompt_cstring;
162 let prompt_cstring_ptr;
163 let mut input = match rkllm_input {
164 RKLLMInput::Prompt(prompt) => {
165 prompt_cstring = std::ffi::CString::new(prompt).unwrap();
166 prompt_cstring_ptr = prompt_cstring.as_ptr() as *const std::os::raw::c_char;
167 super::RKLLMInput {
168 input_type: super::RKLLMInputType_RKLLM_INPUT_PROMPT,
169 __bindgen_anon_1: super::RKLLMInput__bindgen_ty_1 {
170 prompt_input: prompt_cstring_ptr,
171 },
172 }
173 }
174 RKLLMInput::Token(_) => todo!(),
175 RKLLMInput::Embed(_) => todo!(),
176 RKLLMInput::Multimodal(_) => todo!(),
177 };
178
179 let prompt_cache_cstring;
180 let prompt_cache_cstring_ptr;
181
182 let lora_adapter_name;
183 let lora_adapter_name_ptr;
184 let mut loraparam;
185
186 let new_rkllm_infer_params: *mut super::RKLLMInferParam =
187 if let Some(rkllm_infer_params) = rkllm_infer_params {
188 &mut super::RKLLMInferParam {
189 keep_history: rkllm_infer_params.keep_history as i32,
190 mode: rkllm_infer_params.mode.into(),
191 lora_params: match rkllm_infer_params.lora_params {
192 Some(a) => {
193 lora_adapter_name = a;
194 lora_adapter_name_ptr =
195 lora_adapter_name.as_ptr() as *const std::os::raw::c_char;
196 loraparam = RKLLMLoraParam {
197 lora_adapter_name: lora_adapter_name_ptr,
198 };
199 &mut loraparam
200 }
201 None => null_mut(),
202 },
203 prompt_cache_params: if let Some(cache_params) =
204 rkllm_infer_params.prompt_cache_params
205 {
206 prompt_cache_cstring =
207 std::ffi::CString::new(cache_params.prompt_cache_path).unwrap();
208 prompt_cache_cstring_ptr =
209 prompt_cache_cstring.as_ptr() as *const std::os::raw::c_char;
210
211 &mut super::RKLLMPromptCacheParam {
212 save_prompt_cache: if cache_params.save_prompt_cache {
213 1
214 } else {
215 0
216 },
217 prompt_cache_path: prompt_cache_cstring_ptr,
218 }
219 } else {
220 null_mut()
221 },
222 }
223 } else {
224 null_mut()
225 };
226
227 let ret = unsafe {
228 super::rkllm_run(
229 self.handle,
230 &mut input,
231 new_rkllm_infer_params,
232 userdata_ptr,
233 )
234 };
235 if ret == 0 {
236 return Ok(());
237 } else {
238 return Err(Box::new(std::io::Error::new(
239 std::io::ErrorKind::Other,
240 format!("rkllm_run returned non-zero: {}", ret),
241 )));
242 }
243 }
244
245 pub fn load_prompt_cache(
250 &self,
251 cache_path: &str,
252 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
253 let prompt_cache_path = std::ffi::CString::new(cache_path).unwrap();
254 let prompt_cache_path_ptr = prompt_cache_path.as_ptr() as *const std::os::raw::c_char;
255 let ret = unsafe { super::rkllm_load_prompt_cache(self.handle, prompt_cache_path_ptr) };
256 if ret == 0 {
257 return Ok(());
258 } else {
259 return Err(Box::new(std::io::Error::new(
260 std::io::ErrorKind::Other,
261 format!("rkllm_load_prompt_cache returned non-zero: {}", ret),
262 )));
263 }
264 }
265
266 pub fn release_prompt_cache(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
268 let ret = unsafe { super::rkllm_release_prompt_cache(self.handle) };
269 if ret == 0 {
270 return Ok(());
271 } else {
272 return Err(Box::new(std::io::Error::new(
273 std::io::ErrorKind::Other,
274 format!("rkllm_release_prompt_cache returned non-zero: {}", ret),
275 )));
276 }
277 }
278
279 pub fn abort(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
281 let ret = unsafe { super::rkllm_abort(self.handle) };
282 if ret == 0 {
283 return Ok(());
284 } else {
285 return Err(Box::new(std::io::Error::new(
286 std::io::ErrorKind::Other,
287 format!("rkllm_abort returned non-zero: {}", ret),
288 )));
289 }
290 }
291
292 pub fn is_running(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
294 let ret = unsafe { super::rkllm_is_running(self.handle) };
295 if ret == 0 {
296 return Ok(());
297 } else {
298 return Err(Box::new(std::io::Error::new(
299 std::io::ErrorKind::Other,
300 format!("rkllm_is_running returned non-zero: {}", ret),
301 )));
302 }
303 }
304
305 pub fn load_lora(
307 &self,
308 lora_cfg: &RKLLMLoraAdapter,
309 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
310 let lora_adapter_name_cstring =
311 std::ffi::CString::new(lora_cfg.lora_adapter_name.clone()).unwrap();
312 let lora_adapter_name_cstring_ptr =
313 lora_adapter_name_cstring.as_ptr() as *const std::os::raw::c_char;
314 let lora_adapter_path_cstring =
315 std::ffi::CString::new(lora_cfg.lora_adapter_path.clone()).unwrap();
316 let lora_adapter_path_cstring_ptr =
317 lora_adapter_path_cstring.as_ptr() as *const std::os::raw::c_char;
318 let mut param = super::RKLLMLoraAdapter {
319 lora_adapter_path: lora_adapter_path_cstring_ptr,
320 lora_adapter_name: lora_adapter_name_cstring_ptr,
321 scale: lora_cfg.scale,
322 };
323 let ret = unsafe { super::rkllm_load_lora(self.handle, &mut param) };
324 if ret == 0 {
325 return Ok(());
326 } else {
327 return Err(Box::new(std::io::Error::new(
328 std::io::ErrorKind::Other,
329 format!("rkllm_load_lora returned non-zero: {}", ret),
330 )));
331 }
332 }
333 }
334
335 unsafe extern "C" fn callback_passtrough(
337 result: *mut super::RKLLMResult,
338 userdata: *mut ::std::os::raw::c_void,
339 state: super::LLMCallState,
340 ) {
341 Arc::increment_strong_count(userdata); let instance_data = unsafe { Arc::from_raw(userdata as *const InstanceData) };
343 let new_state = match state {
344 0 => LLMCallState::Normal,
345 1 => LLMCallState::Waiting,
346 2 => LLMCallState::Finish,
347 3 => LLMCallState::Error,
348 4 => LLMCallState::GetLastHiddenLayer,
349 _ => panic!("Unexpected LLMCallState"),
350 };
351
352 let new_result = if result.is_null() {
353 None
354 } else {
355 Some(RKLLMResult {
356 text: if (*result).text.is_null() {
357 String::new()
358 } else {
359 (unsafe { CStr::from_ptr((*result).text) })
360 .to_str()
361 .expect("Failed to convert C string")
362 .to_owned()
363 .clone()
364 },
365 token_id: (*result).token_id,
366 last_hidden_layer: (*result).last_hidden_layer,
367 })
368 };
369
370 instance_data
371 .callback_handler
372 .lock()
373 .unwrap()
374 .handle(new_result, new_state);
375 }
376
377 pub fn rkllm_init(
385 param: *mut super::RKLLMParam,
386 ) -> Result<LLMHandle, Box<dyn std::error::Error + Send + Sync>> {
387 let mut handle = LLMHandle {
388 handle: std::ptr::null_mut(),
389 };
390
391 let callback: Option<
392 unsafe extern "C" fn(
393 *mut super::RKLLMResult,
394 *mut ::std::os::raw::c_void,
395 super::LLMCallState,
396 ),
397 > = Some(callback_passtrough);
398 let ret = unsafe { super::rkllm_init(&mut handle.handle, param, callback) };
399 if ret == 0 {
400 return Ok(handle);
401 } else {
402 return Err(Box::new(std::io::Error::new(
403 std::io::ErrorKind::Other,
404 format!("rkllm_init returned non-zero: {}", ret),
405 )));
406 }
407 }
408
409 pub enum RKLLMInput {
411 Prompt(String),
413 Token(String),
415 Embed(String),
417 Multimodal(String),
419 }
420}