1use json::JsonValue;
2use std::{str::FromStr, string::ToString};
3
4type Handle = u32;
5type ExitCode = u8;
6
7#[cfg(not(feature = "mock-ffi"))]
8#[link(wasm_import_module = "blockless_llm")]
9extern "C" {
10 fn llm_set_model_request(h: *mut Handle, model_ptr: *const u8, model_len: u8) -> ExitCode;
11 fn llm_get_model_response(
12 h: Handle,
13 buf: *mut u8,
14 buf_len: u8,
15 bytes_written: *mut u8,
16 ) -> ExitCode;
17 fn llm_set_model_options_request(
18 h: Handle,
19 options_ptr: *const u8,
20 options_len: u16,
21 ) -> ExitCode;
22 fn llm_get_model_options(
23 h: Handle,
24 buf: *mut u8,
25 buf_len: u16,
26 bytes_written: *mut u16,
27 ) -> ExitCode;
28 fn llm_prompt_request(h: Handle, prompt_ptr: *const u8, prompt_len: u16) -> ExitCode;
29 fn llm_read_prompt_response(
30 h: Handle,
31 buf: *mut u8,
32 buf_len: u16,
33 bytes_written: *mut u16,
34 ) -> ExitCode;
35 fn llm_close(h: Handle) -> ExitCode;
36}
37
38#[cfg(feature = "mock-ffi")]
39#[allow(unused_variables)]
40mod mock_ffi {
41 use super::*;
42
43 pub unsafe fn llm_set_model_request(
44 h: *mut Handle,
45 _model_ptr: *const u8,
46 _model_len: u8,
47 ) -> ExitCode {
48 unimplemented!()
49 }
50
51 pub unsafe fn llm_get_model_response(
52 _h: Handle,
53 buf: *mut u8,
54 buf_len: u8,
55 bytes_written: *mut u8,
56 ) -> ExitCode {
57 unimplemented!()
58 }
59
60 pub unsafe fn llm_set_model_options_request(
61 _h: Handle,
62 _options_ptr: *const u8,
63 _options_len: u16,
64 ) -> ExitCode {
65 unimplemented!()
66 }
67
68 pub unsafe fn llm_get_model_options(
69 _h: Handle,
70 buf: *mut u8,
71 buf_len: u16,
72 bytes_written: *mut u16,
73 ) -> ExitCode {
74 unimplemented!()
75 }
76
77 pub unsafe fn llm_prompt_request(
78 _h: Handle,
79 _prompt_ptr: *const u8,
80 _prompt_len: u16,
81 ) -> ExitCode {
82 unimplemented!()
83 }
84
85 pub unsafe fn llm_read_prompt_response(
86 _h: Handle,
87 buf: *mut u8,
88 buf_len: u16,
89 bytes_written: *mut u16,
90 ) -> ExitCode {
91 unimplemented!()
92 }
93
94 pub unsafe fn llm_close(_h: Handle) -> ExitCode {
95 unimplemented!()
96 }
97}
98
99#[cfg(feature = "mock-ffi")]
100use mock_ffi::*;
101
102#[derive(Debug, Clone)]
103pub enum Models {
104 Llama321BInstruct(Option<String>),
105 Llama323BInstruct(Option<String>),
106 Mistral7BInstructV03(Option<String>),
107 Mixtral8x7BInstructV01(Option<String>),
108 Gemma22BInstruct(Option<String>),
109 Gemma27BInstruct(Option<String>),
110 Gemma29BInstruct(Option<String>),
111 Custom(String),
112}
113
114impl FromStr for Models {
115 type Err = String;
116 fn from_str(s: &str) -> Result<Self, Self::Err> {
117 match s {
118 "Llama-3.2-1B-Instruct" => Ok(Models::Llama321BInstruct(None)),
120 "Llama-3.2-1B-Instruct-Q6_K"
121 | "Llama-3.2-1B-Instruct_Q6_K"
122 | "Llama-3.2-1B-Instruct.Q6_K" => {
123 Ok(Models::Llama321BInstruct(Some("Q6_K".to_string())))
124 }
125 "Llama-3.2-1B-Instruct-q4f16_1" | "Llama-3.2-1B-Instruct.q4f16_1" => {
126 Ok(Models::Llama321BInstruct(Some("q4f16_1".to_string())))
127 }
128
129 "Llama-3.2-3B-Instruct" => Ok(Models::Llama323BInstruct(None)),
131 "Llama-3.2-3B-Instruct-Q6_K"
132 | "Llama-3.2-3B-Instruct_Q6_K"
133 | "Llama-3.2-3B-Instruct.Q6_K" => {
134 Ok(Models::Llama323BInstruct(Some("Q6_K".to_string())))
135 }
136 "Llama-3.2-3B-Instruct-q4f16_1" | "Llama-3.2-3B-Instruct.q4f16_1" => {
137 Ok(Models::Llama323BInstruct(Some("q4f16_1".to_string())))
138 }
139
140 "Mistral-7B-Instruct-v0.3" => Ok(Models::Mistral7BInstructV03(None)),
142 "Mistral-7B-Instruct-v0.3-q4f16_1" | "Mistral-7B-Instruct-v0.3.q4f16_1" => {
143 Ok(Models::Mistral7BInstructV03(Some("q4f16_1".to_string())))
144 }
145
146 "Mixtral-8x7B-Instruct-v0.1" => Ok(Models::Mixtral8x7BInstructV01(None)),
148 "Mixtral-8x7B-Instruct-v0.1-q4f16_1" | "Mixtral-8x7B-Instruct-v0.1.q4f16_1" => {
149 Ok(Models::Mixtral8x7BInstructV01(Some("q4f16_1".to_string())))
150 }
151
152 "gemma-2-2b-it" => Ok(Models::Gemma22BInstruct(None)),
154 "gemma-2-2b-it-q4f16_1" | "gemma-2-2b-it.q4f16_1" => {
155 Ok(Models::Gemma22BInstruct(Some("q4f16_1".to_string())))
156 }
157
158 "gemma-2-27b-it" => Ok(Models::Gemma27BInstruct(None)),
159 "gemma-2-27b-it-q4f16_1" | "gemma-2-27b-it.q4f16_1" => {
160 Ok(Models::Gemma27BInstruct(Some("q4f16_1".to_string())))
161 }
162
163 "gemma-2-9b-it" => Ok(Models::Gemma29BInstruct(None)),
164 "gemma-2-9b-it-q4f16_1" | "gemma-2-9b-it.q4f16_1" => {
165 Ok(Models::Gemma29BInstruct(Some("q4f16_1".to_string())))
166 }
167 _ => Ok(Models::Custom(s.to_string())),
168 }
169 }
170}
171
172impl std::fmt::Display for Models {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 match self {
175 Models::Llama321BInstruct(_) => write!(f, "Llama-3.2-1B-Instruct"),
176 Models::Llama323BInstruct(_) => write!(f, "Llama-3.2-3B-Instruct"),
177 Models::Mistral7BInstructV03(_) => write!(f, "Mistral-7B-Instruct-v0.3"),
178 Models::Mixtral8x7BInstructV01(_) => write!(f, "Mixtral-8x7B-Instruct-v0.1"),
179 Models::Gemma22BInstruct(_) => write!(f, "gemma-2-2b-it"),
180 Models::Gemma27BInstruct(_) => write!(f, "gemma-2-27b-it"),
181 Models::Gemma29BInstruct(_) => write!(f, "gemma-2-9b-it"),
182 Models::Custom(s) => write!(f, "{}", s),
183 }
184 }
185}
186
187#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
188pub struct BlocklessLlm {
189 inner: Handle,
190 model_name: String,
191 options: LlmOptions,
192}
193
194#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
195pub struct LlmOptions {
196 pub system_message: Option<String>,
197 pub tools_sse_urls: Option<Vec<String>>,
198 pub temperature: Option<f32>,
199 pub top_p: Option<f32>,
200}
201
202impl LlmOptions {
203 pub fn with_system_message(mut self, system_message: String) -> Self {
204 self.system_message = Some(system_message);
205 self
206 }
207
208 pub fn with_tools_sse_urls(mut self, tools_sse_urls: Vec<String>) -> Self {
209 self.tools_sse_urls = Some(tools_sse_urls);
210 self
211 }
212
213 fn dump(&self) -> Vec<u8> {
214 let mut json = JsonValue::new_object();
215
216 if let Some(system_message) = &self.system_message {
217 json["system_message"] = system_message.clone().into();
218 }
219
220 if let Some(tools_sse_urls) = &self.tools_sse_urls {
221 json["tools_sse_urls"] = tools_sse_urls.clone().into();
222 }
223
224 if let Some(temperature) = self.temperature {
225 json["temperature"] = temperature.into();
226 }
227 if let Some(top_p) = self.top_p {
228 json["top_p"] = top_p.into();
229 }
230
231 if json.entries().count() == 0 {
233 return "{}".as_bytes().to_vec();
234 }
235
236 json.dump().into_bytes()
237 }
238}
239
240impl std::fmt::Display for LlmOptions {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 let bytes = self.dump();
243 match String::from_utf8(bytes) {
244 Ok(s) => write!(f, "{}", s),
245 Err(_) => write!(f, "<invalid utf8>"),
246 }
247 }
248}
249
250impl TryFrom<Vec<u8>> for LlmOptions {
251 type Error = LlmErrorKind;
252
253 fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
254 let json_str = String::from_utf8(bytes).map_err(|_| LlmErrorKind::Utf8Error)?;
256
257 let json = json::parse(&json_str).map_err(|_| LlmErrorKind::ModelOptionsNotSet)?;
259
260 let system_message = json["system_message"].as_str().map(|s| s.to_string());
262
263 let tools_sse_urls = if json["tools_sse_urls"].is_array() {
265 Some(
267 json["tools_sse_urls"]
268 .members()
269 .filter_map(|v| v.as_str().map(|s| s.to_string()))
270 .collect(),
271 )
272 } else {
273 json["tools_sse_urls"]
274 .as_str()
275 .map(|s| s.split(',').map(|s| s.trim().to_string()).collect())
276 };
277
278 Ok(LlmOptions {
279 system_message,
280 tools_sse_urls,
281 temperature: json["temperature"].as_f32(),
282 top_p: json["top_p"].as_f32(),
283 })
284 }
285}
286
287impl BlocklessLlm {
288 pub fn new(model: Models) -> Result<Self, LlmErrorKind> {
289 let model_name = model.to_string();
290 let mut llm: BlocklessLlm = Default::default();
291 llm.set_model(&model_name)?;
292 Ok(llm)
293 }
294
295 pub fn handle(&self) -> Handle {
296 self.inner
297 }
298
299 pub fn get_model(&self) -> Result<String, LlmErrorKind> {
300 let mut buf = [0u8; u8::MAX as usize];
301 let mut num_bytes: u8 = 0;
302 let code = unsafe {
303 llm_get_model_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
304 };
305 if code != 0 {
306 return Err(code.into());
307 }
308 let model = String::from_utf8(buf[0..num_bytes as _].to_vec()).unwrap();
309 Ok(model)
310 }
311
312 pub fn set_model(&mut self, model_name: &str) -> Result<(), LlmErrorKind> {
313 self.model_name = model_name.to_string();
314 let code = unsafe {
315 llm_set_model_request(&mut self.inner, model_name.as_ptr(), model_name.len() as _)
316 };
317 if code != 0 {
318 return Err(code.into());
319 }
320
321 let host_model = self.get_model()?;
323 if self.model_name != host_model {
324 eprintln!(
325 "Model not set correctly in host/runtime; model_name: {}, model_from_host: {}",
326 self.model_name, host_model
327 );
328 return Err(LlmErrorKind::ModelNotSet);
329 }
330 Ok(())
331 }
332
333 pub fn get_options(&self) -> Result<LlmOptions, LlmErrorKind> {
334 let mut buf = [0u8; u16::MAX as usize];
335 let mut num_bytes: u16 = 0;
336 let code = unsafe {
337 llm_get_model_options(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
338 };
339 if code != 0 {
340 return Err(code.into());
341 }
342
343 LlmOptions::try_from(buf[0..num_bytes as usize].to_vec())
345 }
346
347 pub fn set_options(&mut self, options: LlmOptions) -> Result<(), LlmErrorKind> {
348 let options_json = options.dump();
349 self.options = options;
350 let code = unsafe {
351 llm_set_model_options_request(
352 self.inner,
353 options_json.as_ptr(),
354 options_json.len() as _,
355 )
356 };
357 if code != 0 {
358 return Err(code.into());
359 }
360
361 let host_options = self.get_options()?;
363 if self.options != host_options {
364 println!(
365 "Options not set correctly in host/runtime; options: {:?}, options_from_host: {:?}",
366 self.options, host_options
367 );
368 return Err(LlmErrorKind::ModelOptionsNotSet);
369 }
370 Ok(())
371 }
372
373 pub fn chat_request(&self, prompt: &str) -> Result<String, LlmErrorKind> {
374 let code = unsafe { llm_prompt_request(self.inner, prompt.as_ptr(), prompt.len() as _) };
376 if code != 0 {
377 return Err(code.into());
378 }
379
380 self.get_chat_response()
382 }
383
384 fn get_chat_response(&self) -> Result<String, LlmErrorKind> {
385 let mut buf = [0u8; u16::MAX as usize];
386 let mut num_bytes: u16 = 0;
387 let code = unsafe {
388 llm_read_prompt_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
389 };
390 if code != 0 {
391 return Err(code.into());
392 }
393
394 let response_vec = buf[0..num_bytes as usize].to_vec();
395 String::from_utf8(response_vec).map_err(|_| LlmErrorKind::Utf8Error)
396 }
397}
398
399impl Drop for BlocklessLlm {
400 fn drop(&mut self) {
401 let code = unsafe { llm_close(self.inner) };
402 if code != 0 {
403 eprintln!("Error closing LLM: {}", code);
404 }
405 }
406}
407
408#[derive(Debug)]
409pub enum LlmErrorKind {
410 ModelNotSet, ModelNotSupported, ModelInitializationFailed, ModelCompletionFailed, ModelOptionsNotSet, ModelShutdownFailed, Utf8Error, RuntimeError, MCPFunctionCallError, }
420
421impl From<u8> for LlmErrorKind {
422 fn from(code: u8) -> Self {
423 match code {
424 1 => LlmErrorKind::ModelNotSet,
425 2 => LlmErrorKind::ModelNotSupported,
426 3 => LlmErrorKind::ModelInitializationFailed,
427 4 => LlmErrorKind::ModelCompletionFailed,
428 5 => LlmErrorKind::ModelOptionsNotSet,
429 6 => LlmErrorKind::ModelShutdownFailed,
430 7 => LlmErrorKind::Utf8Error,
431 9 => LlmErrorKind::MCPFunctionCallError,
433 _ => LlmErrorKind::RuntimeError,
434 }
435 }
436}