blockless_sdk/
llm.rs

1use json::JsonValue;
2use std::{str::FromStr, string::ToString};
3
4type Handle = u32;
5type ExitCode = u8;
6
7#[link(wasm_import_module = "blockless_llm")]
8extern "C" {
9    fn llm_set_model_request(h: *mut Handle, model_ptr: *const u8, model_len: u8) -> ExitCode;
10    fn llm_get_model_response(
11        h: Handle,
12        buf: *mut u8,
13        buf_len: u8,
14        bytes_written: *mut u8,
15    ) -> ExitCode;
16    fn llm_set_model_options_request(
17        h: Handle,
18        options_ptr: *const u8,
19        options_len: u16,
20    ) -> ExitCode;
21    fn llm_get_model_options(
22        h: Handle,
23        buf: *mut u8,
24        buf_len: u16,
25        bytes_written: *mut u16,
26    ) -> ExitCode;
27    fn llm_prompt_request(h: Handle, prompt_ptr: *const u8, prompt_len: u16) -> ExitCode;
28    fn llm_read_prompt_response(
29        h: Handle,
30        buf: *mut u8,
31        buf_len: u16,
32        bytes_written: *mut u16,
33    ) -> ExitCode;
34    fn llm_close(h: Handle) -> ExitCode;
35}
36
37#[derive(Debug, Clone)]
38pub enum Models {
39    Llama321BInstruct(Option<String>),
40    Llama323BInstruct(Option<String>),
41    Mistral7BInstructV03(Option<String>),
42    Mixtral8x7BInstructV01(Option<String>),
43    Gemma22BInstruct(Option<String>),
44    Gemma27BInstruct(Option<String>),
45    Gemma29BInstruct(Option<String>),
46    Custom(String),
47}
48
49impl FromStr for Models {
50    type Err = String;
51    fn from_str(s: &str) -> Result<Self, Self::Err> {
52        match s {
53            // Llama 3.2 1B
54            "Llama-3.2-1B-Instruct" => Ok(Models::Llama321BInstruct(None)),
55            "Llama-3.2-1B-Instruct-Q6_K"
56            | "Llama-3.2-1B-Instruct_Q6_K"
57            | "Llama-3.2-1B-Instruct.Q6_K" => {
58                Ok(Models::Llama321BInstruct(Some("Q6_K".to_string())))
59            }
60            "Llama-3.2-1B-Instruct-q4f16_1" | "Llama-3.2-1B-Instruct.q4f16_1" => {
61                Ok(Models::Llama321BInstruct(Some("q4f16_1".to_string())))
62            }
63
64            // Llama 3.2 3B
65            "Llama-3.2-3B-Instruct" => Ok(Models::Llama323BInstruct(None)),
66            "Llama-3.2-3B-Instruct-Q6_K"
67            | "Llama-3.2-3B-Instruct_Q6_K"
68            | "Llama-3.2-3B-Instruct.Q6_K" => {
69                Ok(Models::Llama323BInstruct(Some("Q6_K".to_string())))
70            }
71            "Llama-3.2-3B-Instruct-q4f16_1" | "Llama-3.2-3B-Instruct.q4f16_1" => {
72                Ok(Models::Llama323BInstruct(Some("q4f16_1".to_string())))
73            }
74
75            // Mistral 7B
76            "Mistral-7B-Instruct-v0.3" => Ok(Models::Mistral7BInstructV03(None)),
77            "Mistral-7B-Instruct-v0.3-q4f16_1" | "Mistral-7B-Instruct-v0.3.q4f16_1" => {
78                Ok(Models::Mistral7BInstructV03(Some("q4f16_1".to_string())))
79            }
80
81            // Mixtral 8x7B
82            "Mixtral-8x7B-Instruct-v0.1" => Ok(Models::Mixtral8x7BInstructV01(None)),
83            "Mixtral-8x7B-Instruct-v0.1-q4f16_1" | "Mixtral-8x7B-Instruct-v0.1.q4f16_1" => {
84                Ok(Models::Mixtral8x7BInstructV01(Some("q4f16_1".to_string())))
85            }
86
87            // Gemma models
88            "gemma-2-2b-it" => Ok(Models::Gemma22BInstruct(None)),
89            "gemma-2-2b-it-q4f16_1" | "gemma-2-2b-it.q4f16_1" => {
90                Ok(Models::Gemma22BInstruct(Some("q4f16_1".to_string())))
91            }
92
93            "gemma-2-27b-it" => Ok(Models::Gemma27BInstruct(None)),
94            "gemma-2-27b-it-q4f16_1" | "gemma-2-27b-it.q4f16_1" => {
95                Ok(Models::Gemma27BInstruct(Some("q4f16_1".to_string())))
96            }
97
98            "gemma-2-9b-it" => Ok(Models::Gemma29BInstruct(None)),
99            "gemma-2-9b-it-q4f16_1" | "gemma-2-9b-it.q4f16_1" => {
100                Ok(Models::Gemma29BInstruct(Some("q4f16_1".to_string())))
101            }
102            _ => Ok(Models::Custom(s.to_string())),
103        }
104    }
105}
106
107impl std::fmt::Display for Models {
108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109        match self {
110            Models::Llama321BInstruct(_) => write!(f, "Llama-3.2-1B-Instruct"),
111            Models::Llama323BInstruct(_) => write!(f, "Llama-3.2-3B-Instruct"),
112            Models::Mistral7BInstructV03(_) => write!(f, "Mistral-7B-Instruct-v0.3"),
113            Models::Mixtral8x7BInstructV01(_) => write!(f, "Mixtral-8x7B-Instruct-v0.1"),
114            Models::Gemma22BInstruct(_) => write!(f, "gemma-2-2b-it"),
115            Models::Gemma27BInstruct(_) => write!(f, "gemma-2-27b-it"),
116            Models::Gemma29BInstruct(_) => write!(f, "gemma-2-9b-it"),
117            Models::Custom(s) => write!(f, "{}", s),
118        }
119    }
120}
121
122#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
123#[derive(Debug, Clone, Default)]
124pub struct BlocklessLlm {
125    inner: Handle,
126    model_name: String,
127    options: LlmOptions,
128}
129
130#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
131#[derive(Debug, Clone, Default, PartialEq)]
132pub struct LlmOptions {
133    pub system_message: Option<String>,
134    pub tools_sse_urls: Option<Vec<String>>,
135    pub temperature: Option<f32>,
136    pub top_p: Option<f32>,
137}
138
139impl LlmOptions {
140    pub fn with_system_message(mut self, system_message: String) -> Self {
141        self.system_message = Some(system_message);
142        self
143    }
144
145    pub fn with_tools_sse_urls(mut self, tools_sse_urls: Vec<String>) -> Self {
146        self.tools_sse_urls = Some(tools_sse_urls);
147        self
148    }
149
150    fn dump(&self) -> Vec<u8> {
151        let mut json = JsonValue::new_object();
152
153        if let Some(system_message) = &self.system_message {
154            json["system_message"] = system_message.clone().into();
155        }
156
157        if let Some(tools_sse_urls) = &self.tools_sse_urls {
158            json["tools_sse_urls"] = tools_sse_urls.clone().into();
159        }
160
161        if let Some(temperature) = self.temperature {
162            json["temperature"] = temperature.into();
163        }
164        if let Some(top_p) = self.top_p {
165            json["top_p"] = top_p.into();
166        }
167
168        // If json is empty, return an empty JSON object
169        if json.entries().count() == 0 {
170            return "{}".as_bytes().to_vec();
171        }
172
173        json.dump().into_bytes()
174    }
175}
176
177impl std::fmt::Display for LlmOptions {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        let bytes = self.dump();
180        match String::from_utf8(bytes) {
181            Ok(s) => write!(f, "{}", s),
182            Err(_) => write!(f, "<invalid utf8>"),
183        }
184    }
185}
186
187impl TryFrom<Vec<u8>> for LlmOptions {
188    type Error = LlmErrorKind;
189
190    fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
191        // Convert bytes to UTF-8 string
192        let json_str = String::from_utf8(bytes).map_err(|_| LlmErrorKind::Utf8Error)?;
193
194        // Parse the JSON string
195        let json = json::parse(&json_str).map_err(|_| LlmErrorKind::ModelOptionsNotSet)?;
196
197        // Extract system_message
198        let system_message = json["system_message"].as_str().map(|s| s.to_string());
199
200        // Extract tools_sse_urls - can be an array or a comma-separated string
201        let tools_sse_urls = if json["tools_sse_urls"].is_array() {
202            // Handle array format - native runtime
203            Some(
204                json["tools_sse_urls"]
205                    .members()
206                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
207                    .collect(),
208            )
209        } else {
210            json["tools_sse_urls"]
211                .as_str()
212                .map(|s| s.split(',').map(|s| s.trim().to_string()).collect())
213        };
214
215        Ok(LlmOptions {
216            system_message,
217            tools_sse_urls,
218            temperature: json["temperature"].as_f32(),
219            top_p: json["top_p"].as_f32(),
220        })
221    }
222}
223
224impl BlocklessLlm {
225    pub fn new(model: Models) -> Result<Self, LlmErrorKind> {
226        let model_name = model.to_string();
227        let mut llm: BlocklessLlm = Default::default();
228        llm.set_model(&model_name)?;
229        Ok(llm)
230    }
231
232    pub fn handle(&self) -> Handle {
233        self.inner
234    }
235
236    pub fn get_model(&self) -> Result<String, LlmErrorKind> {
237        let mut buf = [0u8; u8::MAX as usize];
238        let mut num_bytes: u8 = 0;
239        let code = unsafe {
240            llm_get_model_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
241        };
242        if code != 0 {
243            return Err(code.into());
244        }
245        let model = String::from_utf8(buf[0..num_bytes as _].to_vec()).unwrap();
246        Ok(model)
247    }
248
249    pub fn set_model(&mut self, model_name: &str) -> Result<(), LlmErrorKind> {
250        self.model_name = model_name.to_string();
251        let code = unsafe {
252            llm_set_model_request(&mut self.inner, model_name.as_ptr(), model_name.len() as _)
253        };
254        if code != 0 {
255            return Err(code.into());
256        }
257
258        // validate model is set correctly in host/runtime
259        let host_model = self.get_model()?;
260        if self.model_name != host_model {
261            eprintln!(
262                "Model not set correctly in host/runtime; model_name: {}, model_from_host: {}",
263                self.model_name, host_model
264            );
265            return Err(LlmErrorKind::ModelNotSet);
266        }
267        Ok(())
268    }
269
270    pub fn get_options(&self) -> Result<LlmOptions, LlmErrorKind> {
271        let mut buf = [0u8; u16::MAX as usize];
272        let mut num_bytes: u16 = 0;
273        let code = unsafe {
274            llm_get_model_options(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
275        };
276        if code != 0 {
277            return Err(code.into());
278        }
279
280        // Convert buffer slice to Vec<u8> and try to parse into LlmOptions
281        LlmOptions::try_from(buf[0..num_bytes as usize].to_vec())
282    }
283
284    pub fn set_options(&mut self, options: LlmOptions) -> Result<(), LlmErrorKind> {
285        let options_json = options.dump();
286        self.options = options;
287        let code = unsafe {
288            llm_set_model_options_request(
289                self.inner,
290                options_json.as_ptr(),
291                options_json.len() as _,
292            )
293        };
294        if code != 0 {
295            return Err(code.into());
296        }
297
298        // Verify options were set correctly
299        let host_options = self.get_options()?;
300        if self.options != host_options {
301            println!(
302                "Options not set correctly in host/runtime; options: {:?}, options_from_host: {:?}",
303                self.options, host_options
304            );
305            return Err(LlmErrorKind::ModelOptionsNotSet);
306        }
307        Ok(())
308    }
309
310    pub fn chat_request(&self, prompt: &str) -> Result<String, LlmErrorKind> {
311        // Perform the prompt request
312        let code = unsafe { llm_prompt_request(self.inner, prompt.as_ptr(), prompt.len() as _) };
313        if code != 0 {
314            return Err(code.into());
315        }
316
317        // Read the response
318        self.get_chat_response()
319    }
320
321    fn get_chat_response(&self) -> Result<String, LlmErrorKind> {
322        let mut buf = [0u8; u16::MAX as usize];
323        let mut num_bytes: u16 = 0;
324        let code = unsafe {
325            llm_read_prompt_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
326        };
327        if code != 0 {
328            return Err(code.into());
329        }
330
331        let response_vec = buf[0..num_bytes as usize].to_vec();
332        String::from_utf8(response_vec).map_err(|_| LlmErrorKind::Utf8Error)
333    }
334}
335
336impl Drop for BlocklessLlm {
337    fn drop(&mut self) {
338        let code = unsafe { llm_close(self.inner) };
339        if code != 0 {
340            eprintln!("Error closing LLM: {}", code);
341        }
342    }
343}
344
345#[derive(Debug)]
346pub enum LlmErrorKind {
347    ModelNotSet,               // 1
348    ModelNotSupported,         // 2
349    ModelInitializationFailed, // 3
350    ModelCompletionFailed,     // 4
351    ModelOptionsNotSet,        // 5
352    ModelShutdownFailed,       // 6
353    Utf8Error,                 // 7
354    RuntimeError,              // 8
355    MCPFunctionCallError,      // 9
356}
357
358impl From<u8> for LlmErrorKind {
359    fn from(code: u8) -> Self {
360        match code {
361            1 => LlmErrorKind::ModelNotSet,
362            2 => LlmErrorKind::ModelNotSupported,
363            3 => LlmErrorKind::ModelInitializationFailed,
364            4 => LlmErrorKind::ModelCompletionFailed,
365            5 => LlmErrorKind::ModelOptionsNotSet,
366            6 => LlmErrorKind::ModelShutdownFailed,
367            7 => LlmErrorKind::Utf8Error,
368            // 8 => LlmErrorKind::RuntimeError,
369            9 => LlmErrorKind::MCPFunctionCallError,
370            _ => LlmErrorKind::RuntimeError,
371        }
372    }
373}