blockless_sdk/
llm.rs

1use json::JsonValue;
2use std::{str::FromStr, string::ToString};
3
4type Handle = u32;
5type ExitCode = u8;
6
7#[link(wasm_import_module = "blockless_llm")]
8extern "C" {
9    fn llm_set_model_request(h: *mut Handle, model_ptr: *const u8, model_len: u8) -> ExitCode;
10    fn llm_get_model_response(
11        h: Handle,
12        buf: *mut u8,
13        buf_len: u8,
14        bytes_written: *mut u8,
15    ) -> ExitCode;
16    fn llm_set_model_options_request(
17        h: Handle,
18        options_ptr: *const u8,
19        options_len: u16,
20    ) -> ExitCode;
21    fn llm_get_model_options(
22        h: Handle,
23        buf: *mut u8,
24        buf_len: u16,
25        bytes_written: *mut u16,
26    ) -> ExitCode;
27    fn llm_prompt_request(h: Handle, prompt_ptr: *const u8, prompt_len: u16) -> ExitCode;
28    fn llm_read_prompt_response(
29        h: Handle,
30        buf: *mut u8,
31        buf_len: u16,
32        bytes_written: *mut u16,
33    ) -> ExitCode;
34    fn llm_close(h: Handle) -> ExitCode;
35}
36
37#[derive(Debug, Clone)]
38pub enum SupportedModels {
39    Llama321BInstruct(Option<String>),
40    Llama323BInstruct(Option<String>),
41    Mistral7BInstructV03(Option<String>),
42    Mixtral8x7BInstructV01(Option<String>),
43    Gemma22BInstruct(Option<String>),
44    Gemma27BInstruct(Option<String>),
45    Gemma29BInstruct(Option<String>),
46}
47
48impl FromStr for SupportedModels {
49    type Err = String;
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        match s {
52            // Llama 3.2 1B
53            "Llama-3.2-1B-Instruct" => Ok(SupportedModels::Llama321BInstruct(None)),
54            "Llama-3.2-1B-Instruct-Q6_K"
55            | "Llama-3.2-1B-Instruct_Q6_K"
56            | "Llama-3.2-1B-Instruct.Q6_K" => {
57                Ok(SupportedModels::Llama321BInstruct(Some("Q6_K".to_string())))
58            }
59            "Llama-3.2-1B-Instruct-q4f16_1" | "Llama-3.2-1B-Instruct.q4f16_1" => Ok(
60                SupportedModels::Llama321BInstruct(Some("q4f16_1".to_string())),
61            ),
62
63            // Llama 3.2 3B
64            "Llama-3.2-3B-Instruct" => Ok(SupportedModels::Llama323BInstruct(None)),
65            "Llama-3.2-3B-Instruct-Q6_K"
66            | "Llama-3.2-3B-Instruct_Q6_K"
67            | "Llama-3.2-3B-Instruct.Q6_K" => {
68                Ok(SupportedModels::Llama323BInstruct(Some("Q6_K".to_string())))
69            }
70            "Llama-3.2-3B-Instruct-q4f16_1" | "Llama-3.2-3B-Instruct.q4f16_1" => Ok(
71                SupportedModels::Llama323BInstruct(Some("q4f16_1".to_string())),
72            ),
73
74            // Mistral 7B
75            "Mistral-7B-Instruct-v0.3" => Ok(SupportedModels::Mistral7BInstructV03(None)),
76            "Mistral-7B-Instruct-v0.3-q4f16_1" | "Mistral-7B-Instruct-v0.3.q4f16_1" => Ok(
77                SupportedModels::Mistral7BInstructV03(Some("q4f16_1".to_string())),
78            ),
79
80            // Mixtral 8x7B
81            "Mixtral-8x7B-Instruct-v0.1" => Ok(SupportedModels::Mixtral8x7BInstructV01(None)),
82            "Mixtral-8x7B-Instruct-v0.1-q4f16_1" | "Mixtral-8x7B-Instruct-v0.1.q4f16_1" => Ok(
83                SupportedModels::Mixtral8x7BInstructV01(Some("q4f16_1".to_string())),
84            ),
85
86            // Gemma models
87            "gemma-2-2b-it" => Ok(SupportedModels::Gemma22BInstruct(None)),
88            "gemma-2-2b-it-q4f16_1" | "gemma-2-2b-it.q4f16_1" => Ok(
89                SupportedModels::Gemma22BInstruct(Some("q4f16_1".to_string())),
90            ),
91
92            "gemma-2-27b-it" => Ok(SupportedModels::Gemma27BInstruct(None)),
93            "gemma-2-27b-it-q4f16_1" | "gemma-2-27b-it.q4f16_1" => Ok(
94                SupportedModels::Gemma27BInstruct(Some("q4f16_1".to_string())),
95            ),
96
97            "gemma-2-9b-it" => Ok(SupportedModels::Gemma29BInstruct(None)),
98            "gemma-2-9b-it-q4f16_1" | "gemma-2-9b-it.q4f16_1" => Ok(
99                SupportedModels::Gemma29BInstruct(Some("q4f16_1".to_string())),
100            ),
101
102            _ => Err(format!("Unsupported model: {}", s)),
103        }
104    }
105}
106
107impl ToString for SupportedModels {
108    fn to_string(&self) -> String {
109        match self {
110            SupportedModels::Llama321BInstruct(_) => "Llama-3.2-1B-Instruct".to_string(),
111            SupportedModels::Llama323BInstruct(_) => "Llama-3.2-3B-Instruct".to_string(),
112            SupportedModels::Mistral7BInstructV03(_) => "Mistral-7B-Instruct-v0.3".to_string(),
113            SupportedModels::Mixtral8x7BInstructV01(_) => "Mixtral-8x7B-Instruct-v0.1".to_string(),
114            SupportedModels::Gemma22BInstruct(_) => "gemma-2-2b-it".to_string(),
115            SupportedModels::Gemma27BInstruct(_) => "gemma-2-27b-it".to_string(),
116            SupportedModels::Gemma29BInstruct(_) => "gemma-2-9b-it".to_string(),
117        }
118    }
119}
120
121#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
122#[derive(Debug, Clone, Default)]
123pub struct BlocklessLlm {
124    inner: Handle,
125    model_name: String,
126    options: LlmOptions,
127}
128
129#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
130#[derive(Debug, Clone, Default, PartialEq)]
131pub struct LlmOptions {
132    pub system_message: String,
133    // pub max_tokens: u32,
134    pub temperature: Option<f32>,
135    pub top_p: Option<f32>,
136    // pub frequency_penalty: f32,
137    // pub presence_penalty: f32,
138}
139
140impl LlmOptions {
141    pub fn new() -> Self {
142        Self::default()
143    }
144    pub fn dump(&self) -> Vec<u8> {
145        let mut json = JsonValue::new_object();
146        json["system_message"] = self.system_message.clone().into();
147        if let Some(temperature) = self.temperature {
148            json["temperature"] = temperature.into();
149        }
150        if let Some(top_p) = self.top_p {
151            json["top_p"] = top_p.into();
152        }
153        json.dump().into_bytes()
154    }
155}
156
157impl std::fmt::Display for LlmOptions {
158    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159        let bytes = self.dump();
160        match String::from_utf8(bytes) {
161            Ok(s) => write!(f, "{}", s),
162            Err(_) => write!(f, "<invalid utf8>"),
163        }
164    }
165}
166
167impl TryFrom<Vec<u8>> for LlmOptions {
168    type Error = LlmErrorKind;
169
170    fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
171        // Convert bytes to UTF-8 string
172        let json_str = String::from_utf8(bytes).map_err(|_| LlmErrorKind::Utf8Error)?;
173
174        // Parse the JSON string
175        let json = json::parse(&json_str).map_err(|_| LlmErrorKind::ModelOptionsNotSet)?;
176
177        // Extract system_message
178        let system_message = json["system_message"]
179            .as_str()
180            .ok_or(LlmErrorKind::ModelOptionsNotSet)?
181            .to_string();
182
183        Ok(LlmOptions {
184            system_message,
185            temperature: json["temperature"].as_f32(),
186            top_p: json["top_p"].as_f32(),
187        })
188    }
189}
190
191impl BlocklessLlm {
192    pub fn new(model: SupportedModels) -> Result<Self, LlmErrorKind> {
193        let model_name = model.to_string();
194        let mut llm: BlocklessLlm = Default::default();
195        llm.set_model(&model_name)?;
196        Ok(llm)
197    }
198
199    pub fn handle(&self) -> Handle {
200        self.inner
201    }
202
203    pub fn get_model(&self) -> Result<String, LlmErrorKind> {
204        let mut buf = [0u8; u8::MAX as usize];
205        let mut num_bytes: u8 = 0;
206        let code = unsafe {
207            llm_get_model_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
208        };
209        if code != 0 {
210            return Err(code.into());
211        }
212        let model = String::from_utf8(buf[0..num_bytes as _].to_vec()).unwrap();
213        Ok(model)
214    }
215
216    pub fn set_model(&mut self, model_name: &str) -> Result<(), LlmErrorKind> {
217        self.model_name = model_name.to_string();
218        let code = unsafe {
219            llm_set_model_request(&mut self.inner, model_name.as_ptr(), model_name.len() as _)
220        };
221        if code != 0 {
222            return Err(code.into());
223        }
224
225        // validate model is set correctly in host/runtime
226        let host_model = self.get_model()?;
227        if self.model_name != host_model {
228            eprintln!(
229                "Model not set correctly in host/runtime; model_name: {}, model_from_host: {}",
230                self.model_name, host_model
231            );
232            return Err(LlmErrorKind::ModelNotSet);
233        }
234        Ok(())
235    }
236
237    pub fn get_options(&self) -> Result<LlmOptions, LlmErrorKind> {
238        let mut buf = [0u8; u16::MAX as usize];
239        let mut num_bytes: u16 = 0;
240        let code = unsafe {
241            llm_get_model_options(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
242        };
243        if code != 0 {
244            return Err(code.into());
245        }
246
247        // Convert buffer slice to Vec<u8> and try to parse into LlmOptions
248        LlmOptions::try_from(buf[0..num_bytes as usize].to_vec())
249    }
250
251    pub fn set_options(&mut self, options: LlmOptions) -> Result<(), LlmErrorKind> {
252        let options_json = options.dump();
253        self.options = options;
254        let code = unsafe {
255            llm_set_model_options_request(
256                self.inner,
257                options_json.as_ptr(),
258                options_json.len() as _,
259            )
260        };
261        if code != 0 {
262            return Err(code.into());
263        }
264
265        // Verify options were set correctly
266        let host_options = self.get_options()?;
267        if self.options != host_options {
268            println!(
269                "Options not set correctly in host/runtime; options: {:?}, options_from_host: {:?}",
270                self.options, host_options
271            );
272            return Err(LlmErrorKind::ModelOptionsNotSet);
273        }
274        Ok(())
275    }
276
277    pub fn chat_request(&self, prompt: &str) -> Result<String, LlmErrorKind> {
278        // Perform the prompt request
279        let code = unsafe { llm_prompt_request(self.inner, prompt.as_ptr(), prompt.len() as _) };
280        if code != 0 {
281            return Err(code.into());
282        }
283
284        // Read the response
285        self.get_chat_response()
286    }
287
288    fn get_chat_response(&self) -> Result<String, LlmErrorKind> {
289        let mut buf = [0u8; u16::MAX as usize];
290        let mut num_bytes: u16 = 0;
291        let code = unsafe {
292            llm_read_prompt_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
293        };
294        if code != 0 {
295            return Err(code.into());
296        }
297
298        let response_vec = buf[0..num_bytes as usize].to_vec();
299        String::from_utf8(response_vec).map_err(|_| LlmErrorKind::Utf8Error)
300    }
301}
302
303impl Drop for BlocklessLlm {
304    fn drop(&mut self) {
305        let code = unsafe { llm_close(self.inner) };
306        if code != 0 {
307            eprintln!("Error closing LLM: {}", code);
308        }
309    }
310}
311
312#[derive(Debug)]
313pub enum LlmErrorKind {
314    ModelNotSet,               // 1
315    ModelNotSupported,         // 2
316    ModelInitializationFailed, // 3
317    ModelCompletionFailed,     // 4
318    ModelOptionsNotSet,        // 5
319    ModelShutdownFailed,       // 6
320    Utf8Error,                 // 7
321    RuntimeError,              // 8
322}
323
324impl From<u8> for LlmErrorKind {
325    fn from(code: u8) -> Self {
326        match code {
327            1 => LlmErrorKind::ModelNotSet,
328            2 => LlmErrorKind::ModelNotSupported,
329            3 => LlmErrorKind::ModelInitializationFailed,
330            4 => LlmErrorKind::ModelCompletionFailed,
331            5 => LlmErrorKind::ModelOptionsNotSet,
332            6 => LlmErrorKind::ModelShutdownFailed,
333            7 => LlmErrorKind::Utf8Error,
334            _ => LlmErrorKind::RuntimeError,
335        }
336    }
337}