1use json::JsonValue;
2use std::{str::FromStr, string::ToString};
3
4type Handle = u32;
5type ExitCode = u8;
6
7#[link(wasm_import_module = "blockless_llm")]
8extern "C" {
9 fn llm_set_model_request(h: *mut Handle, model_ptr: *const u8, model_len: u8) -> ExitCode;
10 fn llm_get_model_response(
11 h: Handle,
12 buf: *mut u8,
13 buf_len: u8,
14 bytes_written: *mut u8,
15 ) -> ExitCode;
16 fn llm_set_model_options_request(
17 h: Handle,
18 options_ptr: *const u8,
19 options_len: u16,
20 ) -> ExitCode;
21 fn llm_get_model_options(
22 h: Handle,
23 buf: *mut u8,
24 buf_len: u16,
25 bytes_written: *mut u16,
26 ) -> ExitCode;
27 fn llm_prompt_request(h: Handle, prompt_ptr: *const u8, prompt_len: u16) -> ExitCode;
28 fn llm_read_prompt_response(
29 h: Handle,
30 buf: *mut u8,
31 buf_len: u16,
32 bytes_written: *mut u16,
33 ) -> ExitCode;
34 fn llm_close(h: Handle) -> ExitCode;
35}
36
37#[derive(Debug, Clone)]
38pub enum SupportedModels {
39 Llama321BInstruct(Option<String>),
40 Llama323BInstruct(Option<String>),
41 Mistral7BInstructV03(Option<String>),
42 Mixtral8x7BInstructV01(Option<String>),
43 Gemma22BInstruct(Option<String>),
44 Gemma27BInstruct(Option<String>),
45 Gemma29BInstruct(Option<String>),
46}
47
48impl FromStr for SupportedModels {
49 type Err = String;
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
51 match s {
52 "Llama-3.2-1B-Instruct" => Ok(SupportedModels::Llama321BInstruct(None)),
54 "Llama-3.2-1B-Instruct-Q6_K"
55 | "Llama-3.2-1B-Instruct_Q6_K"
56 | "Llama-3.2-1B-Instruct.Q6_K" => {
57 Ok(SupportedModels::Llama321BInstruct(Some("Q6_K".to_string())))
58 }
59 "Llama-3.2-1B-Instruct-q4f16_1" | "Llama-3.2-1B-Instruct.q4f16_1" => Ok(
60 SupportedModels::Llama321BInstruct(Some("q4f16_1".to_string())),
61 ),
62
63 "Llama-3.2-3B-Instruct" => Ok(SupportedModels::Llama323BInstruct(None)),
65 "Llama-3.2-3B-Instruct-Q6_K"
66 | "Llama-3.2-3B-Instruct_Q6_K"
67 | "Llama-3.2-3B-Instruct.Q6_K" => {
68 Ok(SupportedModels::Llama323BInstruct(Some("Q6_K".to_string())))
69 }
70 "Llama-3.2-3B-Instruct-q4f16_1" | "Llama-3.2-3B-Instruct.q4f16_1" => Ok(
71 SupportedModels::Llama323BInstruct(Some("q4f16_1".to_string())),
72 ),
73
74 "Mistral-7B-Instruct-v0.3" => Ok(SupportedModels::Mistral7BInstructV03(None)),
76 "Mistral-7B-Instruct-v0.3-q4f16_1" | "Mistral-7B-Instruct-v0.3.q4f16_1" => Ok(
77 SupportedModels::Mistral7BInstructV03(Some("q4f16_1".to_string())),
78 ),
79
80 "Mixtral-8x7B-Instruct-v0.1" => Ok(SupportedModels::Mixtral8x7BInstructV01(None)),
82 "Mixtral-8x7B-Instruct-v0.1-q4f16_1" | "Mixtral-8x7B-Instruct-v0.1.q4f16_1" => Ok(
83 SupportedModels::Mixtral8x7BInstructV01(Some("q4f16_1".to_string())),
84 ),
85
86 "gemma-2-2b-it" => Ok(SupportedModels::Gemma22BInstruct(None)),
88 "gemma-2-2b-it-q4f16_1" | "gemma-2-2b-it.q4f16_1" => Ok(
89 SupportedModels::Gemma22BInstruct(Some("q4f16_1".to_string())),
90 ),
91
92 "gemma-2-27b-it" => Ok(SupportedModels::Gemma27BInstruct(None)),
93 "gemma-2-27b-it-q4f16_1" | "gemma-2-27b-it.q4f16_1" => Ok(
94 SupportedModels::Gemma27BInstruct(Some("q4f16_1".to_string())),
95 ),
96
97 "gemma-2-9b-it" => Ok(SupportedModels::Gemma29BInstruct(None)),
98 "gemma-2-9b-it-q4f16_1" | "gemma-2-9b-it.q4f16_1" => Ok(
99 SupportedModels::Gemma29BInstruct(Some("q4f16_1".to_string())),
100 ),
101
102 _ => Err(format!("Unsupported model: {}", s)),
103 }
104 }
105}
106
107impl ToString for SupportedModels {
108 fn to_string(&self) -> String {
109 match self {
110 SupportedModels::Llama321BInstruct(_) => "Llama-3.2-1B-Instruct".to_string(),
111 SupportedModels::Llama323BInstruct(_) => "Llama-3.2-3B-Instruct".to_string(),
112 SupportedModels::Mistral7BInstructV03(_) => "Mistral-7B-Instruct-v0.3".to_string(),
113 SupportedModels::Mixtral8x7BInstructV01(_) => "Mixtral-8x7B-Instruct-v0.1".to_string(),
114 SupportedModels::Gemma22BInstruct(_) => "gemma-2-2b-it".to_string(),
115 SupportedModels::Gemma27BInstruct(_) => "gemma-2-27b-it".to_string(),
116 SupportedModels::Gemma29BInstruct(_) => "gemma-2-9b-it".to_string(),
117 }
118 }
119}
120
121#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
122#[derive(Debug, Clone, Default)]
123pub struct BlocklessLlm {
124 inner: Handle,
125 model_name: String,
126 options: LlmOptions,
127}
128
129#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
130#[derive(Debug, Clone, Default, PartialEq)]
131pub struct LlmOptions {
132 pub system_message: String,
133 pub temperature: Option<f32>,
135 pub top_p: Option<f32>,
136 }
139
140impl LlmOptions {
141 pub fn new() -> Self {
142 Self::default()
143 }
144 pub fn dump(&self) -> Vec<u8> {
145 let mut json = JsonValue::new_object();
146 json["system_message"] = self.system_message.clone().into();
147 if let Some(temperature) = self.temperature {
148 json["temperature"] = temperature.into();
149 }
150 if let Some(top_p) = self.top_p {
151 json["top_p"] = top_p.into();
152 }
153 json.dump().into_bytes()
154 }
155}
156
157impl std::fmt::Display for LlmOptions {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 let bytes = self.dump();
160 match String::from_utf8(bytes) {
161 Ok(s) => write!(f, "{}", s),
162 Err(_) => write!(f, "<invalid utf8>"),
163 }
164 }
165}
166
167impl TryFrom<Vec<u8>> for LlmOptions {
168 type Error = LlmErrorKind;
169
170 fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
171 let json_str = String::from_utf8(bytes).map_err(|_| LlmErrorKind::Utf8Error)?;
173
174 let json = json::parse(&json_str).map_err(|_| LlmErrorKind::ModelOptionsNotSet)?;
176
177 let system_message = json["system_message"]
179 .as_str()
180 .ok_or(LlmErrorKind::ModelOptionsNotSet)?
181 .to_string();
182
183 Ok(LlmOptions {
184 system_message,
185 temperature: json["temperature"].as_f32(),
186 top_p: json["top_p"].as_f32(),
187 })
188 }
189}
190
191impl BlocklessLlm {
192 pub fn new(model: SupportedModels) -> Result<Self, LlmErrorKind> {
193 let model_name = model.to_string();
194 let mut llm: BlocklessLlm = Default::default();
195 llm.set_model(&model_name)?;
196 Ok(llm)
197 }
198
199 pub fn handle(&self) -> Handle {
200 self.inner
201 }
202
203 pub fn get_model(&self) -> Result<String, LlmErrorKind> {
204 let mut buf = [0u8; u8::MAX as usize];
205 let mut num_bytes: u8 = 0;
206 let code = unsafe {
207 llm_get_model_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
208 };
209 if code != 0 {
210 return Err(code.into());
211 }
212 let model = String::from_utf8(buf[0..num_bytes as _].to_vec()).unwrap();
213 Ok(model)
214 }
215
216 pub fn set_model(&mut self, model_name: &str) -> Result<(), LlmErrorKind> {
217 self.model_name = model_name.to_string();
218 let code = unsafe {
219 llm_set_model_request(&mut self.inner, model_name.as_ptr(), model_name.len() as _)
220 };
221 if code != 0 {
222 return Err(code.into());
223 }
224
225 let host_model = self.get_model()?;
227 if self.model_name != host_model {
228 eprintln!(
229 "Model not set correctly in host/runtime; model_name: {}, model_from_host: {}",
230 self.model_name, host_model
231 );
232 return Err(LlmErrorKind::ModelNotSet);
233 }
234 Ok(())
235 }
236
237 pub fn get_options(&self) -> Result<LlmOptions, LlmErrorKind> {
238 let mut buf = [0u8; u16::MAX as usize];
239 let mut num_bytes: u16 = 0;
240 let code = unsafe {
241 llm_get_model_options(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
242 };
243 if code != 0 {
244 return Err(code.into());
245 }
246
247 LlmOptions::try_from(buf[0..num_bytes as usize].to_vec())
249 }
250
251 pub fn set_options(&mut self, options: LlmOptions) -> Result<(), LlmErrorKind> {
252 let options_json = options.dump();
253 self.options = options;
254 let code = unsafe {
255 llm_set_model_options_request(
256 self.inner,
257 options_json.as_ptr(),
258 options_json.len() as _,
259 )
260 };
261 if code != 0 {
262 return Err(code.into());
263 }
264
265 let host_options = self.get_options()?;
267 if self.options != host_options {
268 println!(
269 "Options not set correctly in host/runtime; options: {:?}, options_from_host: {:?}",
270 self.options, host_options
271 );
272 return Err(LlmErrorKind::ModelOptionsNotSet);
273 }
274 Ok(())
275 }
276
277 pub fn chat_request(&self, prompt: &str) -> Result<String, LlmErrorKind> {
278 let code = unsafe { llm_prompt_request(self.inner, prompt.as_ptr(), prompt.len() as _) };
280 if code != 0 {
281 return Err(code.into());
282 }
283
284 self.get_chat_response()
286 }
287
288 fn get_chat_response(&self) -> Result<String, LlmErrorKind> {
289 let mut buf = [0u8; u16::MAX as usize];
290 let mut num_bytes: u16 = 0;
291 let code = unsafe {
292 llm_read_prompt_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
293 };
294 if code != 0 {
295 return Err(code.into());
296 }
297
298 let response_vec = buf[0..num_bytes as usize].to_vec();
299 String::from_utf8(response_vec).map_err(|_| LlmErrorKind::Utf8Error)
300 }
301}
302
303impl Drop for BlocklessLlm {
304 fn drop(&mut self) {
305 let code = unsafe { llm_close(self.inner) };
306 if code != 0 {
307 eprintln!("Error closing LLM: {}", code);
308 }
309 }
310}
311
312#[derive(Debug)]
313pub enum LlmErrorKind {
314 ModelNotSet, ModelNotSupported, ModelInitializationFailed, ModelCompletionFailed, ModelOptionsNotSet, ModelShutdownFailed, Utf8Error, RuntimeError, }
323
324impl From<u8> for LlmErrorKind {
325 fn from(code: u8) -> Self {
326 match code {
327 1 => LlmErrorKind::ModelNotSet,
328 2 => LlmErrorKind::ModelNotSupported,
329 3 => LlmErrorKind::ModelInitializationFailed,
330 4 => LlmErrorKind::ModelCompletionFailed,
331 5 => LlmErrorKind::ModelOptionsNotSet,
332 6 => LlmErrorKind::ModelShutdownFailed,
333 7 => LlmErrorKind::Utf8Error,
334 _ => LlmErrorKind::RuntimeError,
335 }
336 }
337}