1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
6
7pub mod prelude {
8
9 use std::ffi::CStr;
10 use std::ptr::null_mut;
11
12 pub use super::RKLLMExtendParam;
13 pub use super::RKLLMLoraParam;
14 pub use super::RKLLMParam;
15 pub use super::RKLLMResultLastHiddenLayer;
16
17 type rkllm_callback =
18 fn(result: Option<RKLLMResult>, userdata: *mut ::std::os::raw::c_void, state: LLMCallState);
19
20 #[derive(Debug, PartialEq, Eq)]
21 pub enum LLMCallState {
22 #[doc = "< The LLM call is in a normal running state."]
23 Normal = 0,
24 #[doc = "< The LLM call is waiting for complete UTF-8 encoded character."]
25 Waiting = 1,
26 #[doc = "< The LLM call has finished execution."]
27 Finish = 2,
28 #[doc = "< An error occurred during the LLM call."]
29 Error = 3,
30 #[doc = "< Retrieve the last hidden layer during inference."]
31 GetLastHiddenLayer = 4,
32 }
33
34 #[doc = " @struct RKLLMInferParam\n @brief Structure for defining parameters during inference."]
35 #[derive(Debug, Clone)]
36 pub struct RKLLMInferParam {
37 #[doc = "< Inference mode (e.g., generate or get last hidden layer)."]
38 pub mode: RKLLMInferMode,
39 #[doc = "< Pointer to Lora adapter parameters."]
40 pub lora_params: Option<*mut RKLLMLoraParam>,
41 #[doc = "< Pointer to prompt cache parameters."]
42 pub prompt_cache_params: Option<RKLLMPromptCacheParam>,
43 }
44
45 #[derive(Debug, Copy, Clone)]
46 pub enum RKLLMInferMode {
47 #[doc = "< The LLM generates text based on input."]
48 InferGenerate = 0,
49 #[doc = "< The LLM retrieves the last hidden layer for further processing."]
50 InferGetLastHiddenLayer = 1,
51 }
52 impl Into<u32> for RKLLMInferMode {
53 fn into(self) -> u32 {
54 self as u32
55 }
56 }
57
58 #[doc = " @struct RKLLMPromptCacheParam\n @brief Structure to define parameters for caching prompts."]
59 #[derive(Debug, Clone)]
60 pub struct RKLLMPromptCacheParam {
61 #[doc = "< Flag to indicate whether to save the prompt cache (0 = don't save, 1 = save)."]
62 pub save_prompt_cache: bool,
63 #[doc = "< Path to the prompt cache file."]
64 pub prompt_cache_path: String,
65 }
66
67 impl Default for super::RKLLMParam {
68 fn default() -> Self {
69 unsafe { super::rkllm_createDefaultParam() }
70 }
71 }
72
73 #[doc = " @struct RKLLMResult\n @brief Structure to represent the result of LLM inference."]
74 #[derive(Debug, Clone)]
75 pub struct RKLLMResult {
76 #[doc = "< Generated text result."]
77 pub text: String,
78 #[doc = "< ID of the generated token."]
79 pub token_id: i32,
80 #[doc = "< Hidden states of the last layer (if requested)."]
81 pub last_hidden_layer: RKLLMResultLastHiddenLayer,
82 }
83
84 #[doc = " @struct LLMHandle\n @brief LLMHandle."]
85 pub struct LLMHandle {
86 handle: super::LLMHandle,
87 }
88
89 impl LLMHandle {
90 #[doc = " @brief Destroys the LLM instance and releases resources.\n @param handle LLM handle.\n @return Status code (0 for success, non-zero for failure)."]
91 pub fn destroy(&self) -> i32 {
92 unsafe { super::rkllm_destroy(self.handle) }
93 }
94
95 #[doc = " @brief Runs an LLM inference task asynchronously.\n @param handle LLM handle.\n @param rkllm_input Input data for the LLM.\n @param rkllm_infer_params Parameters for the inference task.\n @param userdata Pointer to user data for the callback.\n @return Status code (0 for success, non-zero for failure)."]
96 pub fn run(
97 &self,
98 rkllm_input: RKLLMInput,
99 rkllm_infer_params: Option<RKLLMInferParam>,
100 userdata: *mut ::std::os::raw::c_void,
101 ) {
102 let prompt_cstring;
103 let prompt_cstring_ptr;
104 let mut input = match rkllm_input {
105 RKLLMInput::Prompt(prompt) => {
106 prompt_cstring = std::ffi::CString::new(prompt).unwrap();
107 prompt_cstring_ptr = prompt_cstring.as_ptr();
108 super::RKLLMInput {
109 input_type: super::RKLLMInputType_RKLLM_INPUT_PROMPT,
110 __bindgen_anon_1: super::RKLLMInput__bindgen_ty_1 {
111 prompt_input: prompt_cstring_ptr,
112 },
113 }
114 }
115 RKLLMInput::Token(_) => todo!(),
116 RKLLMInput::Embed(_) => todo!(),
117 RKLLMInput::Multimodal(_) => todo!(),
118 };
119
120 let prompt_cache_cstring;
121 let prompt_cache_cstring_ptr;
122
123 let new_rkllm_infer_params: *mut super::RKLLMInferParam =
124 if let Some(rkllm_infer_params) = rkllm_infer_params {
125 &mut super::RKLLMInferParam {
126 mode: rkllm_infer_params.mode.into(),
127 lora_params: match rkllm_infer_params.lora_params {
128 Some(a) => a,
129 None => null_mut(),
130 },
131 prompt_cache_params: if let Some(cache_params) =
132 rkllm_infer_params.prompt_cache_params
133 {
134 prompt_cache_cstring =
135 std::ffi::CString::new(cache_params.prompt_cache_path).unwrap();
136 prompt_cache_cstring_ptr = prompt_cache_cstring.as_ptr();
137
138 &mut super::RKLLMPromptCacheParam {
139 save_prompt_cache: if cache_params.save_prompt_cache {
140 1
141 } else {
142 0
143 },
144 prompt_cache_path: prompt_cache_cstring_ptr,
145 }
146 } else {
147 null_mut()
148 },
149 }
150 } else {
151 null_mut()
152 };
153
154 unsafe { super::rkllm_run(self.handle, &mut input, new_rkllm_infer_params, userdata) };
155 }
156
157 #[doc = " @brief Loads a prompt cache from a file.\n @param handle LLM handle.\n @param prompt_cache_path Path to the prompt cache file.\n @return Status code (0 for success, non-zero for failure)."]
158 pub fn load_prompt_cache(&self, cache_path: &str) {
159 let prompt_cache_path = std::ffi::CString::new(cache_path).unwrap();
160 let prompt_cache_path_ptr = prompt_cache_path.as_ptr();
161 unsafe { super::rkllm_load_prompt_cache(self.handle, prompt_cache_path_ptr) };
162 }
163 }
164
165 static mut CALLBACK: Option<rkllm_callback> = None;
166
167 unsafe extern "C" fn callback_passtrough(
168 result: *mut super::RKLLMResult,
169 userdata: *mut ::std::os::raw::c_void,
170 state: super::LLMCallState,
171 ) {
172 let new_state = match state {
173 0 => LLMCallState::Normal,
174 1 => LLMCallState::Waiting,
175 2 => LLMCallState::Finish,
176 3 => LLMCallState::Error,
177 4 => LLMCallState::GetLastHiddenLayer,
178 _ => panic!("Not expect LLMCallState"),
179 };
180
181 let new_result = if result.is_null() {
182 None
183 } else {
184 Some(RKLLMResult {
185 text: if (*result).text.is_null() {
186 String::new()
187 } else {
188 (unsafe { CStr::from_ptr((*result).text) })
189 .to_str()
190 .expect("Convert cstr failed")
191 .to_owned()
192 .clone()
193 },
194 token_id: (*result).token_id,
195 last_hidden_layer: (*result).last_hidden_layer,
196 })
197 };
198
199 if let Some(callback) = CALLBACK {
200 callback(new_result, userdata, new_state);
201 }
202 }
203
204 #[doc = " @brief Initializes the LLM with the given parameters.\n @param handle Pointer to the LLM handle.\n @param param Configuration parameters for the LLM.\n @param callback Callback function to handle LLM results.\n @return Status code (0 for success, non-zero for failure)."]
205 pub fn rkllm_init(
206 param: *mut super::RKLLMParam,
207 callback: rkllm_callback,
208 ) -> Result<LLMHandle, i32> {
209 let mut handle = LLMHandle {
210 handle: std::ptr::null_mut(),
211 };
212 unsafe { CALLBACK = Some(callback) };
213 let callback: Option<
214 unsafe extern "C" fn(
215 *mut super::RKLLMResult,
216 *mut ::std::os::raw::c_void,
217 super::LLMCallState,
218 ),
219 > = Some(callback_passtrough);
220 let ret = unsafe { super::rkllm_init(&mut handle.handle, param, callback) };
221 if ret == 0 {
222 return Ok(handle);
223 } else {
224 return Err(ret);
225 }
226 }
227
228 pub enum RKLLMInput {
229 #[doc = "< Input is a text prompt."]
230 Prompt(String),
231 #[doc = "< Input is a sequence of tokens."]
232 Token(String),
233 #[doc = "< Input is an embedding vector."]
234 Embed(String),
235 #[doc = "< Input is multimodal (e.g., text and image)."]
236 Multimodal(String),
237 }
238}