1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
6
7pub mod prelude {
8 use std::ffi::CStr;
9 use std::ptr::null_mut;
10 use std::sync::Arc;
11 use std::sync::Mutex;
12
13 pub use super::RKLLMExtendParam;
14 pub use super::RKLLMLoraParam;
15 pub use super::RKLLMParam;
16 pub use super::RKLLMResultLastHiddenLayer;
17
18 #[derive(Debug, PartialEq, Eq)]
20 pub enum LLMCallState {
21 Normal = 0,
23 Waiting = 1,
25 Finish = 2,
27 Error = 3,
29 GetLastHiddenLayer = 4,
31 }
32
33 #[derive(Debug, Clone, Default)]
34 pub enum KeepHistory {
35 #[default]
36 NoKeepHistory = 0,
38 KeepHistory = 1,
40 }
41
42 #[derive(Debug, Clone, Default)]
44 pub struct RKLLMInferParam {
45 pub mode: RKLLMInferMode,
47 pub lora_params: Option<String>,
49 pub prompt_cache_params: Option<RKLLMPromptCacheParam>,
51 pub keep_history: KeepHistory,
52 }
53
54 #[derive(Debug, Copy, Clone, Default)]
56 pub enum RKLLMInferMode {
57 #[default]
59 InferGenerate = 0,
60 InferGetLastHiddenLayer = 1,
62 }
63
64 impl Into<u32> for RKLLMInferMode {
65 fn into(self) -> u32 {
67 self as u32
68 }
69 }
70
71 #[derive(Debug, Clone)]
73 pub struct RKLLMPromptCacheParam {
74 pub save_prompt_cache: bool,
76 pub prompt_cache_path: String,
78 }
79
80 impl Default for super::RKLLMParam {
81 fn default() -> Self {
83 unsafe { super::rkllm_createDefaultParam() }
84 }
85 }
86
87 #[derive(Debug, Clone)]
89 pub struct RKLLMResult {
90 pub text: String,
92 pub token_id: i32,
94 pub last_hidden_layer: RKLLMResultLastHiddenLayer,
96 }
97
98 #[derive(Debug, Clone)]
99 pub struct RKLLMLoraAdapter {
100 pub lora_adapter_path: String,
101 pub lora_adapter_name: String,
102 pub scale: f32,
103 }
104
105 #[derive(Clone, Debug, Copy)]
107 pub struct LLMHandle {
108 handle: super::LLMHandle,
109 }
110
111 unsafe impl Send for LLMHandle {} unsafe impl Sync for LLMHandle {} pub trait RkllmCallbackHandler {
116 fn handle(&mut self, result: Option<RKLLMResult>, state: LLMCallState);
118 }
119
120 pub struct InstanceData {
122 pub callback_handler: Arc<Mutex<dyn RkllmCallbackHandler + Send + Sync>>,
124 }
125
126 impl LLMHandle {
127 pub fn destroy(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
129 let ret = unsafe { super::rkllm_destroy(self.handle) };
130
131 if ret == 0 {
132 return Ok(());
133 } else {
134 return Err(Box::new(std::io::Error::new(
135 std::io::ErrorKind::Other,
136 format!("rkllm_run returned non-zero: {}", ret),
137 )));
138 }
139 }
140
141 pub fn run(
151 &self,
152 rkllm_input: RKLLMInput,
153 rkllm_infer_params: Option<RKLLMInferParam>,
154 user_data: impl RkllmCallbackHandler + Send + Sync + 'static,
155 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
156 let instance_data = Arc::new(InstanceData {
157 callback_handler: Arc::new(Mutex::new(user_data)),
158 });
159
160 let userdata_ptr = Arc::into_raw(instance_data) as *mut std::ffi::c_void;
161 let prompt_cstring;
162 let prompt_cstring_ptr;
163 let role_text;
164 let role_text_ptr;
165 let mut input = match rkllm_input.input_type {
166 RKLLMInputType::Prompt(prompt) => {
167 prompt_cstring = std::ffi::CString::new(prompt).unwrap();
168 prompt_cstring_ptr = prompt_cstring.as_ptr() as *const std::os::raw::c_char;
169
170 role_text = match rkllm_input.role {
171 RKLLMInputRole::User => "user",
172 RKLLMInputRole::Tool => "tool",
173 };
174 role_text_ptr = role_text.as_ptr() as *const std::os::raw::c_char;
175
176 super::RKLLMInput {
177 input_type: super::RKLLMInputType_RKLLM_INPUT_PROMPT,
178 enable_thinking: rkllm_input.enable_thinking,
179 role: role_text_ptr,
180 __bindgen_anon_1: super::RKLLMInput__bindgen_ty_1 {
181 prompt_input: prompt_cstring_ptr,
182 },
183 }
184 }
185 RKLLMInputType::Token(_) => todo!(),
186 RKLLMInputType::Embed(_) => todo!(),
187 RKLLMInputType::Multimodal(_) => todo!(),
188 };
189
190 let prompt_cache_cstring;
191 let prompt_cache_cstring_ptr;
192
193 let lora_adapter_name;
194 let lora_adapter_name_ptr;
195 let mut loraparam;
196
197 let new_rkllm_infer_params: *mut super::RKLLMInferParam =
198 if let Some(rkllm_infer_params) = rkllm_infer_params {
199 &mut super::RKLLMInferParam {
200 keep_history: rkllm_infer_params.keep_history as i32,
201 mode: rkllm_infer_params.mode.into(),
202 lora_params: match rkllm_infer_params.lora_params {
203 Some(a) => {
204 lora_adapter_name = a;
205 lora_adapter_name_ptr =
206 lora_adapter_name.as_ptr() as *const std::os::raw::c_char;
207 loraparam = RKLLMLoraParam {
208 lora_adapter_name: lora_adapter_name_ptr,
209 };
210 &mut loraparam
211 }
212 None => null_mut(),
213 },
214 prompt_cache_params: if let Some(cache_params) =
215 rkllm_infer_params.prompt_cache_params
216 {
217 prompt_cache_cstring =
218 std::ffi::CString::new(cache_params.prompt_cache_path).unwrap();
219 prompt_cache_cstring_ptr =
220 prompt_cache_cstring.as_ptr() as *const std::os::raw::c_char;
221
222 &mut super::RKLLMPromptCacheParam {
223 save_prompt_cache: if cache_params.save_prompt_cache {
224 1
225 } else {
226 0
227 },
228 prompt_cache_path: prompt_cache_cstring_ptr,
229 }
230 } else {
231 null_mut()
232 },
233 }
234 } else {
235 null_mut()
236 };
237
238 let ret = unsafe {
239 super::rkllm_run(
240 self.handle,
241 &mut input,
242 new_rkllm_infer_params,
243 userdata_ptr,
244 )
245 };
246 if ret == 0 {
247 return Ok(());
248 } else {
249 return Err(Box::new(std::io::Error::new(
250 std::io::ErrorKind::Other,
251 format!("rkllm_run returned non-zero: {}", ret),
252 )));
253 }
254 }
255
256 pub fn load_prompt_cache(
261 &self,
262 cache_path: &str,
263 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
264 let prompt_cache_path = std::ffi::CString::new(cache_path).unwrap();
265 let prompt_cache_path_ptr = prompt_cache_path.as_ptr() as *const std::os::raw::c_char;
266 let ret = unsafe { super::rkllm_load_prompt_cache(self.handle, prompt_cache_path_ptr) };
267 if ret == 0 {
268 return Ok(());
269 } else {
270 return Err(Box::new(std::io::Error::new(
271 std::io::ErrorKind::Other,
272 format!("rkllm_load_prompt_cache returned non-zero: {}", ret),
273 )));
274 }
275 }
276
277 pub fn release_prompt_cache(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
279 let ret = unsafe { super::rkllm_release_prompt_cache(self.handle) };
280 if ret == 0 {
281 return Ok(());
282 } else {
283 return Err(Box::new(std::io::Error::new(
284 std::io::ErrorKind::Other,
285 format!("rkllm_release_prompt_cache returned non-zero: {}", ret),
286 )));
287 }
288 }
289
290 pub fn abort(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
292 let ret = unsafe { super::rkllm_abort(self.handle) };
293 if ret == 0 {
294 return Ok(());
295 } else {
296 return Err(Box::new(std::io::Error::new(
297 std::io::ErrorKind::Other,
298 format!("rkllm_abort returned non-zero: {}", ret),
299 )));
300 }
301 }
302
303 pub fn is_running(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
305 let ret = unsafe { super::rkllm_is_running(self.handle) };
306 if ret == 0 {
307 return Ok(());
308 } else {
309 return Err(Box::new(std::io::Error::new(
310 std::io::ErrorKind::Other,
311 format!("rkllm_is_running returned non-zero: {}", ret),
312 )));
313 }
314 }
315
316 pub fn load_lora(
318 &self,
319 lora_cfg: &RKLLMLoraAdapter,
320 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
321 let lora_adapter_name_cstring =
322 std::ffi::CString::new(lora_cfg.lora_adapter_name.clone()).unwrap();
323 let lora_adapter_name_cstring_ptr =
324 lora_adapter_name_cstring.as_ptr() as *const std::os::raw::c_char;
325 let lora_adapter_path_cstring =
326 std::ffi::CString::new(lora_cfg.lora_adapter_path.clone()).unwrap();
327 let lora_adapter_path_cstring_ptr =
328 lora_adapter_path_cstring.as_ptr() as *const std::os::raw::c_char;
329 let mut param = super::RKLLMLoraAdapter {
330 lora_adapter_path: lora_adapter_path_cstring_ptr,
331 lora_adapter_name: lora_adapter_name_cstring_ptr,
332 scale: lora_cfg.scale,
333 };
334 let ret = unsafe { super::rkllm_load_lora(self.handle, &mut param) };
335 if ret == 0 {
336 return Ok(());
337 } else {
338 return Err(Box::new(std::io::Error::new(
339 std::io::ErrorKind::Other,
340 format!("rkllm_load_lora returned non-zero: {}", ret),
341 )));
342 }
343 }
344 }
345
346 unsafe extern "C" fn callback_passtrough(
348 result: *mut super::RKLLMResult,
349 userdata: *mut ::std::os::raw::c_void,
350 state: super::LLMCallState,
351 ) -> i32 {
352 Arc::increment_strong_count(userdata); let instance_data = unsafe { Arc::from_raw(userdata as *const InstanceData) };
354 let new_state = match state {
355 0 => LLMCallState::Normal,
356 1 => LLMCallState::Waiting,
357 2 => LLMCallState::Finish,
358 3 => LLMCallState::Error,
359 4 => LLMCallState::GetLastHiddenLayer,
360 _ => panic!("Unexpected LLMCallState"),
361 };
362
363 let new_result = if result.is_null() {
364 None
365 } else {
366 Some(RKLLMResult {
367 text: if (*result).text.is_null() {
368 String::new()
369 } else {
370 (unsafe { CStr::from_ptr((*result).text) })
371 .to_str()
372 .expect("Failed to convert C string")
373 .to_owned()
374 .clone()
375 },
376 token_id: (*result).token_id,
377 last_hidden_layer: (*result).last_hidden_layer,
378 })
379 };
380
381 instance_data
382 .callback_handler
383 .lock()
384 .unwrap()
385 .handle(new_result, new_state);
386 0
387 }
388
389 pub fn rkllm_init(
397 param: *mut super::RKLLMParam,
398 ) -> Result<LLMHandle, Box<dyn std::error::Error + Send + Sync>> {
399 let mut handle = LLMHandle {
400 handle: std::ptr::null_mut(),
401 };
402
403 let callback: Option<
404 unsafe extern "C" fn(
405 *mut super::RKLLMResult,
406 *mut ::std::os::raw::c_void,
407 super::LLMCallState,
408 ) -> i32,
409 > = Some(callback_passtrough);
410 let ret = unsafe { super::rkllm_init(&mut handle.handle, param, callback) };
411 if ret == 0 {
412 return Ok(handle);
413 } else {
414 return Err(Box::new(std::io::Error::new(
415 std::io::ErrorKind::Other,
416 format!("rkllm_init returned non-zero: {}", ret),
417 )));
418 }
419 }
420
421 pub struct RKLLMInput {
423 pub input_type: RKLLMInputType,
425 pub enable_thinking: bool,
427 pub role: RKLLMInputRole,
429 }
430
431 pub enum RKLLMInputType {
433 Prompt(String),
435 Token(String),
437 Embed(String),
439 Multimodal(String),
441 }
442
443 pub enum RKLLMInputRole {
445 User,
447 Tool,
449 }
450}