1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
6
7pub mod prelude {
8 use std::ffi::CStr;
9 use std::ptr::null_mut;
10 use std::sync::Arc;
11 use std::sync::Mutex;
12
13 pub use super::RKLLMExtendParam;
14 pub use super::RKLLMLoraParam;
15 pub use super::RKLLMParam;
16 pub use super::RKLLMResultLastHiddenLayer;
17
18 #[derive(Debug, PartialEq, Eq)]
20 pub enum LLMCallState {
21 Normal = 0,
23 Waiting = 1,
25 Finish = 2,
27 Error = 3,
29 GetLastHiddenLayer = 4,
31 }
32
33 #[derive(Debug, Clone, Default)]
35 pub struct RKLLMInferParam {
36 pub mode: RKLLMInferMode,
38 pub lora_params: Option<String>,
40 pub prompt_cache_params: Option<RKLLMPromptCacheParam>,
42 }
43
44 #[derive(Debug, Copy, Clone, Default)]
46 pub enum RKLLMInferMode {
47 #[default]
49 InferGenerate = 0,
50 InferGetLastHiddenLayer = 1,
52 }
53
54 impl Into<u32> for RKLLMInferMode {
55 fn into(self) -> u32 {
57 self as u32
58 }
59 }
60
61 #[derive(Debug, Clone)]
63 pub struct RKLLMPromptCacheParam {
64 pub save_prompt_cache: bool,
66 pub prompt_cache_path: String,
68 }
69
70 impl Default for super::RKLLMParam {
71 fn default() -> Self {
73 unsafe { super::rkllm_createDefaultParam() }
74 }
75 }
76
77 #[derive(Debug, Clone)]
79 pub struct RKLLMResult {
80 pub text: String,
82 pub token_id: i32,
84 pub last_hidden_layer: RKLLMResultLastHiddenLayer,
86 }
87
88 #[derive(Debug, Clone)]
89 pub struct RKLLMLoraAdapter {
90 pub lora_adapter_path: String,
91 pub lora_adapter_name: String,
92 pub scale: f32,
93 }
94
95 #[derive(Clone, Debug, Copy)]
97 pub struct LLMHandle {
98 handle: super::LLMHandle,
99 }
100
101 unsafe impl Send for LLMHandle {} unsafe impl Sync for LLMHandle {} pub trait RkllmCallbackHandler {
106 fn handle(&mut self, result: Option<RKLLMResult>, state: LLMCallState);
108 }
109
110 pub struct InstanceData {
112 pub callback_handler: Arc<Mutex<dyn RkllmCallbackHandler + Send + Sync>>,
114 }
115
116 impl LLMHandle {
117 pub fn destroy(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
119 let ret = unsafe { super::rkllm_destroy(self.handle) };
120
121 if ret == 0 {
122 return Ok(());
123 } else {
124 return Err(Box::new(std::io::Error::new(
125 std::io::ErrorKind::Other,
126 format!("rkllm_run returned non-zero: {}", ret),
127 )));
128 }
129 }
130
131 pub fn run(
141 &self,
142 rkllm_input: RKLLMInput,
143 rkllm_infer_params: Option<RKLLMInferParam>,
144 user_data: impl RkllmCallbackHandler + Send + Sync + 'static,
145 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
146 let instance_data = Arc::new(InstanceData {
147 callback_handler: Arc::new(Mutex::new(user_data)),
148 });
149
150 let userdata_ptr = Arc::into_raw(instance_data) as *mut std::ffi::c_void;
151 let prompt_cstring;
152 let prompt_cstring_ptr;
153 let mut input = match rkllm_input {
154 RKLLMInput::Prompt(prompt) => {
155 prompt_cstring = std::ffi::CString::new(prompt).unwrap();
156 prompt_cstring_ptr = prompt_cstring.as_ptr() as *const std::os::raw::c_char;
157 super::RKLLMInput {
158 input_type: super::RKLLMInputType_RKLLM_INPUT_PROMPT,
159 __bindgen_anon_1: super::RKLLMInput__bindgen_ty_1 {
160 prompt_input: prompt_cstring_ptr,
161 },
162 }
163 }
164 RKLLMInput::Token(_) => todo!(),
165 RKLLMInput::Embed(_) => todo!(),
166 RKLLMInput::Multimodal(_) => todo!(),
167 };
168
169 let prompt_cache_cstring;
170 let prompt_cache_cstring_ptr;
171
172 let lora_adapter_name;
173 let lora_adapter_name_ptr;
174 let mut loraparam;
175
176 let new_rkllm_infer_params: *mut super::RKLLMInferParam =
177 if let Some(rkllm_infer_params) = rkllm_infer_params {
178 &mut super::RKLLMInferParam {
179 mode: rkllm_infer_params.mode.into(),
180 lora_params: match rkllm_infer_params.lora_params {
181 Some(a) => {
182 lora_adapter_name = a;
183 lora_adapter_name_ptr =
184 lora_adapter_name.as_ptr() as *const std::os::raw::c_char;
185 loraparam = RKLLMLoraParam {
186 lora_adapter_name: lora_adapter_name_ptr,
187 };
188 &mut loraparam
189 }
190 None => null_mut(),
191 },
192 prompt_cache_params: if let Some(cache_params) =
193 rkllm_infer_params.prompt_cache_params
194 {
195 prompt_cache_cstring =
196 std::ffi::CString::new(cache_params.prompt_cache_path).unwrap();
197 prompt_cache_cstring_ptr =
198 prompt_cache_cstring.as_ptr() as *const std::os::raw::c_char;
199
200 &mut super::RKLLMPromptCacheParam {
201 save_prompt_cache: if cache_params.save_prompt_cache {
202 1
203 } else {
204 0
205 },
206 prompt_cache_path: prompt_cache_cstring_ptr,
207 }
208 } else {
209 null_mut()
210 },
211 }
212 } else {
213 null_mut()
214 };
215
216 let ret = unsafe {
217 super::rkllm_run(
218 self.handle,
219 &mut input,
220 new_rkllm_infer_params,
221 userdata_ptr,
222 )
223 };
224 if ret == 0 {
225 return Ok(());
226 } else {
227 return Err(Box::new(std::io::Error::new(
228 std::io::ErrorKind::Other,
229 format!("rkllm_run returned non-zero: {}", ret),
230 )));
231 }
232 }
233
234 pub fn load_prompt_cache(
239 &self,
240 cache_path: &str,
241 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
242 let prompt_cache_path = std::ffi::CString::new(cache_path).unwrap();
243 let prompt_cache_path_ptr = prompt_cache_path.as_ptr() as *const std::os::raw::c_char;
244 let ret = unsafe { super::rkllm_load_prompt_cache(self.handle, prompt_cache_path_ptr) };
245 if ret == 0 {
246 return Ok(());
247 } else {
248 return Err(Box::new(std::io::Error::new(
249 std::io::ErrorKind::Other,
250 format!("rkllm_load_prompt_cache returned non-zero: {}", ret),
251 )));
252 }
253 }
254
255 pub fn release_prompt_cache(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
257 let ret = unsafe { super::rkllm_release_prompt_cache(self.handle) };
258 if ret == 0 {
259 return Ok(());
260 } else {
261 return Err(Box::new(std::io::Error::new(
262 std::io::ErrorKind::Other,
263 format!("rkllm_release_prompt_cache returned non-zero: {}", ret),
264 )));
265 }
266 }
267
268 pub fn abort(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
270 let ret = unsafe { super::rkllm_abort(self.handle) };
271 if ret == 0 {
272 return Ok(());
273 } else {
274 return Err(Box::new(std::io::Error::new(
275 std::io::ErrorKind::Other,
276 format!("rkllm_abort returned non-zero: {}", ret),
277 )));
278 }
279 }
280
281 pub fn is_running(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
283 let ret = unsafe { super::rkllm_is_running(self.handle) };
284 if ret == 0 {
285 return Ok(());
286 } else {
287 return Err(Box::new(std::io::Error::new(
288 std::io::ErrorKind::Other,
289 format!("rkllm_is_running returned non-zero: {}", ret),
290 )));
291 }
292 }
293
294 pub fn load_lora(
296 &self,
297 lora_cfg: &RKLLMLoraAdapter,
298 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
299 let lora_adapter_name_cstring =
300 std::ffi::CString::new(lora_cfg.lora_adapter_name.clone()).unwrap();
301 let lora_adapter_name_cstring_ptr =
302 lora_adapter_name_cstring.as_ptr() as *const std::os::raw::c_char;
303 let lora_adapter_path_cstring =
304 std::ffi::CString::new(lora_cfg.lora_adapter_path.clone()).unwrap();
305 let lora_adapter_path_cstring_ptr =
306 lora_adapter_path_cstring.as_ptr() as *const std::os::raw::c_char;
307 let mut param = super::RKLLMLoraAdapter {
308 lora_adapter_path: lora_adapter_path_cstring_ptr,
309 lora_adapter_name: lora_adapter_name_cstring_ptr,
310 scale: lora_cfg.scale,
311 };
312 let ret = unsafe { super::rkllm_load_lora(self.handle, &mut param) };
313 if ret == 0 {
314 return Ok(());
315 } else {
316 return Err(Box::new(std::io::Error::new(
317 std::io::ErrorKind::Other,
318 format!("rkllm_load_lora returned non-zero: {}", ret),
319 )));
320 }
321 }
322 }
323
324 unsafe extern "C" fn callback_passtrough(
326 result: *mut super::RKLLMResult,
327 userdata: *mut ::std::os::raw::c_void,
328 state: super::LLMCallState,
329 ) {
330 Arc::increment_strong_count(userdata); let instance_data = unsafe { Arc::from_raw(userdata as *const InstanceData) };
332 let new_state = match state {
333 0 => LLMCallState::Normal,
334 1 => LLMCallState::Waiting,
335 2 => LLMCallState::Finish,
336 3 => LLMCallState::Error,
337 4 => LLMCallState::GetLastHiddenLayer,
338 _ => panic!("Unexpected LLMCallState"),
339 };
340
341 let new_result = if result.is_null() {
342 None
343 } else {
344 Some(RKLLMResult {
345 text: if (*result).text.is_null() {
346 String::new()
347 } else {
348 (unsafe { CStr::from_ptr((*result).text) })
349 .to_str()
350 .expect("Failed to convert C string")
351 .to_owned()
352 .clone()
353 },
354 token_id: (*result).token_id,
355 last_hidden_layer: (*result).last_hidden_layer,
356 })
357 };
358
359 instance_data
360 .callback_handler
361 .lock()
362 .unwrap()
363 .handle(new_result, new_state);
364 }
365
366 pub fn rkllm_init(
374 param: *mut super::RKLLMParam,
375 ) -> Result<LLMHandle, Box<dyn std::error::Error + Send + Sync>> {
376 let mut handle = LLMHandle {
377 handle: std::ptr::null_mut(),
378 };
379
380 let callback: Option<
381 unsafe extern "C" fn(
382 *mut super::RKLLMResult,
383 *mut ::std::os::raw::c_void,
384 super::LLMCallState,
385 ),
386 > = Some(callback_passtrough);
387 let ret = unsafe { super::rkllm_init(&mut handle.handle, param, callback) };
388 if ret == 0 {
389 return Ok(handle);
390 } else {
391 return Err(Box::new(std::io::Error::new(
392 std::io::ErrorKind::Other,
393 format!("rkllm_init returned non-zero: {}", ret),
394 )));
395 }
396 }
397
398 pub enum RKLLMInput {
400 Prompt(String),
402 Token(String),
404 Embed(String),
406 Multimodal(String),
408 }
409}