1use json::JsonValue;
2use std::{str::FromStr, string::ToString};
3
4type Handle = u32;
5type ExitCode = u8;
6
7#[cfg(not(feature = "mock-ffi"))]
8#[link(wasm_import_module = "blockless_llm")]
9extern "C" {
10 fn llm_set_model_request(h: *mut Handle, model_ptr: *const u8, model_len: u8) -> ExitCode;
11 fn llm_get_model_response(
12 h: Handle,
13 buf: *mut u8,
14 buf_len: u8,
15 bytes_written: *mut u8,
16 ) -> ExitCode;
17 fn llm_set_model_options_request(
18 h: Handle,
19 options_ptr: *const u8,
20 options_len: u16,
21 ) -> ExitCode;
22 fn llm_get_model_options(
23 h: Handle,
24 buf: *mut u8,
25 buf_len: u16,
26 bytes_written: *mut u16,
27 ) -> ExitCode;
28 fn llm_prompt_request(h: Handle, prompt_ptr: *const u8, prompt_len: u16) -> ExitCode;
29 fn llm_read_prompt_response(
30 h: Handle,
31 buf: *mut u8,
32 buf_len: u16,
33 bytes_written: *mut u16,
34 ) -> ExitCode;
35 fn llm_close(h: Handle) -> ExitCode;
36}
37
38#[cfg(feature = "mock-ffi")]
39#[allow(unused_variables)]
40mod mock_ffi {
41 use super::*;
42
43 pub unsafe fn llm_set_model_request(
44 h: *mut Handle,
45 _model_ptr: *const u8,
46 _model_len: u8,
47 ) -> ExitCode {
48 unimplemented!()
49 }
50
51 pub unsafe fn llm_get_model_response(
52 _h: Handle,
53 buf: *mut u8,
54 buf_len: u8,
55 bytes_written: *mut u8,
56 ) -> ExitCode {
57 unimplemented!()
58 }
59
60 pub unsafe fn llm_set_model_options_request(
61 _h: Handle,
62 _options_ptr: *const u8,
63 _options_len: u16,
64 ) -> ExitCode {
65 unimplemented!()
66 }
67
68 pub unsafe fn llm_get_model_options(
69 _h: Handle,
70 buf: *mut u8,
71 buf_len: u16,
72 bytes_written: *mut u16,
73 ) -> ExitCode {
74 unimplemented!()
75 }
76
77 pub unsafe fn llm_prompt_request(
78 _h: Handle,
79 _prompt_ptr: *const u8,
80 _prompt_len: u16,
81 ) -> ExitCode {
82 unimplemented!()
83 }
84
85 pub unsafe fn llm_read_prompt_response(
86 _h: Handle,
87 buf: *mut u8,
88 buf_len: u16,
89 bytes_written: *mut u16,
90 ) -> ExitCode {
91 unimplemented!()
92 }
93
94 pub unsafe fn llm_close(_h: Handle) -> ExitCode {
95 unimplemented!()
96 }
97}
98
99#[cfg(feature = "mock-ffi")]
100use mock_ffi::*;
101
102#[derive(Debug, Clone)]
103pub enum Models {
104 Llama321BInstruct(Option<String>),
105 Llama323BInstruct(Option<String>),
106 Mistral7BInstructV03(Option<String>),
107 Mixtral8x7BInstructV01(Option<String>),
108 Gemma22BInstruct(Option<String>),
109 Gemma27BInstruct(Option<String>),
110 Gemma29BInstruct(Option<String>),
111 Custom(String),
112}
113
114impl FromStr for Models {
115 type Err = String;
116 fn from_str(s: &str) -> Result<Self, Self::Err> {
117 match s {
118 "Llama-3.2-1B-Instruct" => Ok(Models::Llama321BInstruct(None)),
120 "Llama-3.2-1B-Instruct-Q6_K"
121 | "Llama-3.2-1B-Instruct_Q6_K"
122 | "Llama-3.2-1B-Instruct.Q6_K" => {
123 Ok(Models::Llama321BInstruct(Some("Q6_K".to_string())))
124 }
125 "Llama-3.2-1B-Instruct-q4f16_1" | "Llama-3.2-1B-Instruct.q4f16_1" => {
126 Ok(Models::Llama321BInstruct(Some("q4f16_1".to_string())))
127 }
128
129 "Llama-3.2-3B-Instruct" => Ok(Models::Llama323BInstruct(None)),
131 "Llama-3.2-3B-Instruct-Q6_K"
132 | "Llama-3.2-3B-Instruct_Q6_K"
133 | "Llama-3.2-3B-Instruct.Q6_K" => {
134 Ok(Models::Llama323BInstruct(Some("Q6_K".to_string())))
135 }
136 "Llama-3.2-3B-Instruct-q4f16_1" | "Llama-3.2-3B-Instruct.q4f16_1" => {
137 Ok(Models::Llama323BInstruct(Some("q4f16_1".to_string())))
138 }
139
140 "Mistral-7B-Instruct-v0.3" => Ok(Models::Mistral7BInstructV03(None)),
142 "Mistral-7B-Instruct-v0.3-q4f16_1" | "Mistral-7B-Instruct-v0.3.q4f16_1" => {
143 Ok(Models::Mistral7BInstructV03(Some("q4f16_1".to_string())))
144 }
145
146 "Mixtral-8x7B-Instruct-v0.1" => Ok(Models::Mixtral8x7BInstructV01(None)),
148 "Mixtral-8x7B-Instruct-v0.1-q4f16_1" | "Mixtral-8x7B-Instruct-v0.1.q4f16_1" => {
149 Ok(Models::Mixtral8x7BInstructV01(Some("q4f16_1".to_string())))
150 }
151
152 "gemma-2-2b-it" => Ok(Models::Gemma22BInstruct(None)),
154 "gemma-2-2b-it-q4f16_1" | "gemma-2-2b-it.q4f16_1" => {
155 Ok(Models::Gemma22BInstruct(Some("q4f16_1".to_string())))
156 }
157
158 "gemma-2-27b-it" => Ok(Models::Gemma27BInstruct(None)),
159 "gemma-2-27b-it-q4f16_1" | "gemma-2-27b-it.q4f16_1" => {
160 Ok(Models::Gemma27BInstruct(Some("q4f16_1".to_string())))
161 }
162
163 "gemma-2-9b-it" => Ok(Models::Gemma29BInstruct(None)),
164 "gemma-2-9b-it-q4f16_1" | "gemma-2-9b-it.q4f16_1" => {
165 Ok(Models::Gemma29BInstruct(Some("q4f16_1".to_string())))
166 }
167 _ => Ok(Models::Custom(s.to_string())),
168 }
169 }
170}
171
172impl std::fmt::Display for Models {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 match self {
175 Models::Llama321BInstruct(_) => write!(f, "Llama-3.2-1B-Instruct"),
176 Models::Llama323BInstruct(_) => write!(f, "Llama-3.2-3B-Instruct"),
177 Models::Mistral7BInstructV03(_) => write!(f, "Mistral-7B-Instruct-v0.3"),
178 Models::Mixtral8x7BInstructV01(_) => write!(f, "Mixtral-8x7B-Instruct-v0.1"),
179 Models::Gemma22BInstruct(_) => write!(f, "gemma-2-2b-it"),
180 Models::Gemma27BInstruct(_) => write!(f, "gemma-2-27b-it"),
181 Models::Gemma29BInstruct(_) => write!(f, "gemma-2-9b-it"),
182 Models::Custom(s) => write!(f, "{}", s),
183 }
184 }
185}
186
187#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
188#[derive(Debug, Clone, Default)]
189pub struct BlocklessLlm {
190 inner: Handle,
191 model_name: String,
192 options: LlmOptions,
193}
194
195#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
196#[derive(Debug, Clone, Default, PartialEq)]
197pub struct LlmOptions {
198 pub system_message: Option<String>,
199 pub tools_sse_urls: Option<Vec<String>>,
200 pub temperature: Option<f32>,
201 pub top_p: Option<f32>,
202}
203
204impl LlmOptions {
205 pub fn with_system_message(mut self, system_message: String) -> Self {
206 self.system_message = Some(system_message);
207 self
208 }
209
210 pub fn with_tools_sse_urls(mut self, tools_sse_urls: Vec<String>) -> Self {
211 self.tools_sse_urls = Some(tools_sse_urls);
212 self
213 }
214
215 fn dump(&self) -> Vec<u8> {
216 let mut json = JsonValue::new_object();
217
218 if let Some(system_message) = &self.system_message {
219 json["system_message"] = system_message.clone().into();
220 }
221
222 if let Some(tools_sse_urls) = &self.tools_sse_urls {
223 json["tools_sse_urls"] = tools_sse_urls.clone().into();
224 }
225
226 if let Some(temperature) = self.temperature {
227 json["temperature"] = temperature.into();
228 }
229 if let Some(top_p) = self.top_p {
230 json["top_p"] = top_p.into();
231 }
232
233 if json.entries().count() == 0 {
235 return "{}".as_bytes().to_vec();
236 }
237
238 json.dump().into_bytes()
239 }
240}
241
242impl std::fmt::Display for LlmOptions {
243 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 let bytes = self.dump();
245 match String::from_utf8(bytes) {
246 Ok(s) => write!(f, "{}", s),
247 Err(_) => write!(f, "<invalid utf8>"),
248 }
249 }
250}
251
252impl TryFrom<Vec<u8>> for LlmOptions {
253 type Error = LlmErrorKind;
254
255 fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
256 let json_str = String::from_utf8(bytes).map_err(|_| LlmErrorKind::Utf8Error)?;
258
259 let json = json::parse(&json_str).map_err(|_| LlmErrorKind::ModelOptionsNotSet)?;
261
262 let system_message = json["system_message"].as_str().map(|s| s.to_string());
264
265 let tools_sse_urls = if json["tools_sse_urls"].is_array() {
267 Some(
269 json["tools_sse_urls"]
270 .members()
271 .filter_map(|v| v.as_str().map(|s| s.to_string()))
272 .collect(),
273 )
274 } else {
275 json["tools_sse_urls"]
276 .as_str()
277 .map(|s| s.split(',').map(|s| s.trim().to_string()).collect())
278 };
279
280 Ok(LlmOptions {
281 system_message,
282 tools_sse_urls,
283 temperature: json["temperature"].as_f32(),
284 top_p: json["top_p"].as_f32(),
285 })
286 }
287}
288
289impl BlocklessLlm {
290 pub fn new(model: Models) -> Result<Self, LlmErrorKind> {
291 let model_name = model.to_string();
292 let mut llm: BlocklessLlm = Default::default();
293 llm.set_model(&model_name)?;
294 Ok(llm)
295 }
296
297 pub fn handle(&self) -> Handle {
298 self.inner
299 }
300
301 pub fn get_model(&self) -> Result<String, LlmErrorKind> {
302 let mut buf = [0u8; u8::MAX as usize];
303 let mut num_bytes: u8 = 0;
304 let code = unsafe {
305 llm_get_model_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
306 };
307 if code != 0 {
308 return Err(code.into());
309 }
310 let model = String::from_utf8(buf[0..num_bytes as _].to_vec()).unwrap();
311 Ok(model)
312 }
313
314 pub fn set_model(&mut self, model_name: &str) -> Result<(), LlmErrorKind> {
315 self.model_name = model_name.to_string();
316 let code = unsafe {
317 llm_set_model_request(&mut self.inner, model_name.as_ptr(), model_name.len() as _)
318 };
319 if code != 0 {
320 return Err(code.into());
321 }
322
323 let host_model = self.get_model()?;
325 if self.model_name != host_model {
326 eprintln!(
327 "Model not set correctly in host/runtime; model_name: {}, model_from_host: {}",
328 self.model_name, host_model
329 );
330 return Err(LlmErrorKind::ModelNotSet);
331 }
332 Ok(())
333 }
334
335 pub fn get_options(&self) -> Result<LlmOptions, LlmErrorKind> {
336 let mut buf = [0u8; u16::MAX as usize];
337 let mut num_bytes: u16 = 0;
338 let code = unsafe {
339 llm_get_model_options(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
340 };
341 if code != 0 {
342 return Err(code.into());
343 }
344
345 LlmOptions::try_from(buf[0..num_bytes as usize].to_vec())
347 }
348
349 pub fn set_options(&mut self, options: LlmOptions) -> Result<(), LlmErrorKind> {
350 let options_json = options.dump();
351 self.options = options;
352 let code = unsafe {
353 llm_set_model_options_request(
354 self.inner,
355 options_json.as_ptr(),
356 options_json.len() as _,
357 )
358 };
359 if code != 0 {
360 return Err(code.into());
361 }
362
363 let host_options = self.get_options()?;
365 if self.options != host_options {
366 println!(
367 "Options not set correctly in host/runtime; options: {:?}, options_from_host: {:?}",
368 self.options, host_options
369 );
370 return Err(LlmErrorKind::ModelOptionsNotSet);
371 }
372 Ok(())
373 }
374
375 pub fn chat_request(&self, prompt: &str) -> Result<String, LlmErrorKind> {
376 let code = unsafe { llm_prompt_request(self.inner, prompt.as_ptr(), prompt.len() as _) };
378 if code != 0 {
379 return Err(code.into());
380 }
381
382 self.get_chat_response()
384 }
385
386 fn get_chat_response(&self) -> Result<String, LlmErrorKind> {
387 let mut buf = [0u8; u16::MAX as usize];
388 let mut num_bytes: u16 = 0;
389 let code = unsafe {
390 llm_read_prompt_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
391 };
392 if code != 0 {
393 return Err(code.into());
394 }
395
396 let response_vec = buf[0..num_bytes as usize].to_vec();
397 String::from_utf8(response_vec).map_err(|_| LlmErrorKind::Utf8Error)
398 }
399}
400
401impl Drop for BlocklessLlm {
402 fn drop(&mut self) {
403 let code = unsafe { llm_close(self.inner) };
404 if code != 0 {
405 eprintln!("Error closing LLM: {}", code);
406 }
407 }
408}
409
410#[derive(Debug)]
411pub enum LlmErrorKind {
412 ModelNotSet, ModelNotSupported, ModelInitializationFailed, ModelCompletionFailed, ModelOptionsNotSet, ModelShutdownFailed, Utf8Error, RuntimeError, MCPFunctionCallError, }
422
423impl From<u8> for LlmErrorKind {
424 fn from(code: u8) -> Self {
425 match code {
426 1 => LlmErrorKind::ModelNotSet,
427 2 => LlmErrorKind::ModelNotSupported,
428 3 => LlmErrorKind::ModelInitializationFailed,
429 4 => LlmErrorKind::ModelCompletionFailed,
430 5 => LlmErrorKind::ModelOptionsNotSet,
431 6 => LlmErrorKind::ModelShutdownFailed,
432 7 => LlmErrorKind::Utf8Error,
433 9 => LlmErrorKind::MCPFunctionCallError,
435 _ => LlmErrorKind::RuntimeError,
436 }
437 }
438}