blockless_sdk/
llm.rs

1use json::JsonValue;
2use std::{str::FromStr, string::ToString};
3
4type Handle = u32;
5type ExitCode = u8;
6
7#[cfg(not(feature = "mock-ffi"))]
8#[link(wasm_import_module = "blockless_llm")]
9extern "C" {
10    fn llm_set_model_request(h: *mut Handle, model_ptr: *const u8, model_len: u8) -> ExitCode;
11    fn llm_get_model_response(
12        h: Handle,
13        buf: *mut u8,
14        buf_len: u8,
15        bytes_written: *mut u8,
16    ) -> ExitCode;
17    fn llm_set_model_options_request(
18        h: Handle,
19        options_ptr: *const u8,
20        options_len: u16,
21    ) -> ExitCode;
22    fn llm_get_model_options(
23        h: Handle,
24        buf: *mut u8,
25        buf_len: u16,
26        bytes_written: *mut u16,
27    ) -> ExitCode;
28    fn llm_prompt_request(h: Handle, prompt_ptr: *const u8, prompt_len: u16) -> ExitCode;
29    fn llm_read_prompt_response(
30        h: Handle,
31        buf: *mut u8,
32        buf_len: u16,
33        bytes_written: *mut u16,
34    ) -> ExitCode;
35    fn llm_close(h: Handle) -> ExitCode;
36}
37
38#[cfg(feature = "mock-ffi")]
39#[allow(unused_variables)]
40mod mock_ffi {
41    use super::*;
42
43    pub unsafe fn llm_set_model_request(
44        h: *mut Handle,
45        _model_ptr: *const u8,
46        _model_len: u8,
47    ) -> ExitCode {
48        unimplemented!()
49    }
50
51    pub unsafe fn llm_get_model_response(
52        _h: Handle,
53        buf: *mut u8,
54        buf_len: u8,
55        bytes_written: *mut u8,
56    ) -> ExitCode {
57        unimplemented!()
58    }
59
60    pub unsafe fn llm_set_model_options_request(
61        _h: Handle,
62        _options_ptr: *const u8,
63        _options_len: u16,
64    ) -> ExitCode {
65        unimplemented!()
66    }
67
68    pub unsafe fn llm_get_model_options(
69        _h: Handle,
70        buf: *mut u8,
71        buf_len: u16,
72        bytes_written: *mut u16,
73    ) -> ExitCode {
74        unimplemented!()
75    }
76
77    pub unsafe fn llm_prompt_request(
78        _h: Handle,
79        _prompt_ptr: *const u8,
80        _prompt_len: u16,
81    ) -> ExitCode {
82        unimplemented!()
83    }
84
85    pub unsafe fn llm_read_prompt_response(
86        _h: Handle,
87        buf: *mut u8,
88        buf_len: u16,
89        bytes_written: *mut u16,
90    ) -> ExitCode {
91        unimplemented!()
92    }
93
94    pub unsafe fn llm_close(_h: Handle) -> ExitCode {
95        unimplemented!()
96    }
97}
98
99#[cfg(feature = "mock-ffi")]
100use mock_ffi::*;
101
102#[derive(Debug, Clone)]
103pub enum Models {
104    Llama321BInstruct(Option<String>),
105    Llama323BInstruct(Option<String>),
106    Mistral7BInstructV03(Option<String>),
107    Mixtral8x7BInstructV01(Option<String>),
108    Gemma22BInstruct(Option<String>),
109    Gemma27BInstruct(Option<String>),
110    Gemma29BInstruct(Option<String>),
111    Custom(String),
112}
113
114impl FromStr for Models {
115    type Err = String;
116    fn from_str(s: &str) -> Result<Self, Self::Err> {
117        match s {
118            // Llama 3.2 1B
119            "Llama-3.2-1B-Instruct" => Ok(Models::Llama321BInstruct(None)),
120            "Llama-3.2-1B-Instruct-Q6_K"
121            | "Llama-3.2-1B-Instruct_Q6_K"
122            | "Llama-3.2-1B-Instruct.Q6_K" => {
123                Ok(Models::Llama321BInstruct(Some("Q6_K".to_string())))
124            }
125            "Llama-3.2-1B-Instruct-q4f16_1" | "Llama-3.2-1B-Instruct.q4f16_1" => {
126                Ok(Models::Llama321BInstruct(Some("q4f16_1".to_string())))
127            }
128
129            // Llama 3.2 3B
130            "Llama-3.2-3B-Instruct" => Ok(Models::Llama323BInstruct(None)),
131            "Llama-3.2-3B-Instruct-Q6_K"
132            | "Llama-3.2-3B-Instruct_Q6_K"
133            | "Llama-3.2-3B-Instruct.Q6_K" => {
134                Ok(Models::Llama323BInstruct(Some("Q6_K".to_string())))
135            }
136            "Llama-3.2-3B-Instruct-q4f16_1" | "Llama-3.2-3B-Instruct.q4f16_1" => {
137                Ok(Models::Llama323BInstruct(Some("q4f16_1".to_string())))
138            }
139
140            // Mistral 7B
141            "Mistral-7B-Instruct-v0.3" => Ok(Models::Mistral7BInstructV03(None)),
142            "Mistral-7B-Instruct-v0.3-q4f16_1" | "Mistral-7B-Instruct-v0.3.q4f16_1" => {
143                Ok(Models::Mistral7BInstructV03(Some("q4f16_1".to_string())))
144            }
145
146            // Mixtral 8x7B
147            "Mixtral-8x7B-Instruct-v0.1" => Ok(Models::Mixtral8x7BInstructV01(None)),
148            "Mixtral-8x7B-Instruct-v0.1-q4f16_1" | "Mixtral-8x7B-Instruct-v0.1.q4f16_1" => {
149                Ok(Models::Mixtral8x7BInstructV01(Some("q4f16_1".to_string())))
150            }
151
152            // Gemma models
153            "gemma-2-2b-it" => Ok(Models::Gemma22BInstruct(None)),
154            "gemma-2-2b-it-q4f16_1" | "gemma-2-2b-it.q4f16_1" => {
155                Ok(Models::Gemma22BInstruct(Some("q4f16_1".to_string())))
156            }
157
158            "gemma-2-27b-it" => Ok(Models::Gemma27BInstruct(None)),
159            "gemma-2-27b-it-q4f16_1" | "gemma-2-27b-it.q4f16_1" => {
160                Ok(Models::Gemma27BInstruct(Some("q4f16_1".to_string())))
161            }
162
163            "gemma-2-9b-it" => Ok(Models::Gemma29BInstruct(None)),
164            "gemma-2-9b-it-q4f16_1" | "gemma-2-9b-it.q4f16_1" => {
165                Ok(Models::Gemma29BInstruct(Some("q4f16_1".to_string())))
166            }
167            _ => Ok(Models::Custom(s.to_string())),
168        }
169    }
170}
171
172impl std::fmt::Display for Models {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        match self {
175            Models::Llama321BInstruct(_) => write!(f, "Llama-3.2-1B-Instruct"),
176            Models::Llama323BInstruct(_) => write!(f, "Llama-3.2-3B-Instruct"),
177            Models::Mistral7BInstructV03(_) => write!(f, "Mistral-7B-Instruct-v0.3"),
178            Models::Mixtral8x7BInstructV01(_) => write!(f, "Mixtral-8x7B-Instruct-v0.1"),
179            Models::Gemma22BInstruct(_) => write!(f, "gemma-2-2b-it"),
180            Models::Gemma27BInstruct(_) => write!(f, "gemma-2-27b-it"),
181            Models::Gemma29BInstruct(_) => write!(f, "gemma-2-9b-it"),
182            Models::Custom(s) => write!(f, "{}", s),
183        }
184    }
185}
186
187#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
188#[derive(Debug, Clone, Default)]
189pub struct BlocklessLlm {
190    inner: Handle,
191    model_name: String,
192    options: LlmOptions,
193}
194
195#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
196#[derive(Debug, Clone, Default, PartialEq)]
197pub struct LlmOptions {
198    pub system_message: Option<String>,
199    pub tools_sse_urls: Option<Vec<String>>,
200    pub temperature: Option<f32>,
201    pub top_p: Option<f32>,
202}
203
204impl LlmOptions {
205    pub fn with_system_message(mut self, system_message: String) -> Self {
206        self.system_message = Some(system_message);
207        self
208    }
209
210    pub fn with_tools_sse_urls(mut self, tools_sse_urls: Vec<String>) -> Self {
211        self.tools_sse_urls = Some(tools_sse_urls);
212        self
213    }
214
215    fn dump(&self) -> Vec<u8> {
216        let mut json = JsonValue::new_object();
217
218        if let Some(system_message) = &self.system_message {
219            json["system_message"] = system_message.clone().into();
220        }
221
222        if let Some(tools_sse_urls) = &self.tools_sse_urls {
223            json["tools_sse_urls"] = tools_sse_urls.clone().into();
224        }
225
226        if let Some(temperature) = self.temperature {
227            json["temperature"] = temperature.into();
228        }
229        if let Some(top_p) = self.top_p {
230            json["top_p"] = top_p.into();
231        }
232
233        // If json is empty, return an empty JSON object
234        if json.entries().count() == 0 {
235            return "{}".as_bytes().to_vec();
236        }
237
238        json.dump().into_bytes()
239    }
240}
241
242impl std::fmt::Display for LlmOptions {
243    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244        let bytes = self.dump();
245        match String::from_utf8(bytes) {
246            Ok(s) => write!(f, "{}", s),
247            Err(_) => write!(f, "<invalid utf8>"),
248        }
249    }
250}
251
252impl TryFrom<Vec<u8>> for LlmOptions {
253    type Error = LlmErrorKind;
254
255    fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
256        // Convert bytes to UTF-8 string
257        let json_str = String::from_utf8(bytes).map_err(|_| LlmErrorKind::Utf8Error)?;
258
259        // Parse the JSON string
260        let json = json::parse(&json_str).map_err(|_| LlmErrorKind::ModelOptionsNotSet)?;
261
262        // Extract system_message
263        let system_message = json["system_message"].as_str().map(|s| s.to_string());
264
265        // Extract tools_sse_urls - can be an array or a comma-separated string
266        let tools_sse_urls = if json["tools_sse_urls"].is_array() {
267            // Handle array format - native runtime
268            Some(
269                json["tools_sse_urls"]
270                    .members()
271                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
272                    .collect(),
273            )
274        } else {
275            json["tools_sse_urls"]
276                .as_str()
277                .map(|s| s.split(',').map(|s| s.trim().to_string()).collect())
278        };
279
280        Ok(LlmOptions {
281            system_message,
282            tools_sse_urls,
283            temperature: json["temperature"].as_f32(),
284            top_p: json["top_p"].as_f32(),
285        })
286    }
287}
288
289impl BlocklessLlm {
290    pub fn new(model: Models) -> Result<Self, LlmErrorKind> {
291        let model_name = model.to_string();
292        let mut llm: BlocklessLlm = Default::default();
293        llm.set_model(&model_name)?;
294        Ok(llm)
295    }
296
297    pub fn handle(&self) -> Handle {
298        self.inner
299    }
300
301    pub fn get_model(&self) -> Result<String, LlmErrorKind> {
302        let mut buf = [0u8; u8::MAX as usize];
303        let mut num_bytes: u8 = 0;
304        let code = unsafe {
305            llm_get_model_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
306        };
307        if code != 0 {
308            return Err(code.into());
309        }
310        let model = String::from_utf8(buf[0..num_bytes as _].to_vec()).unwrap();
311        Ok(model)
312    }
313
314    pub fn set_model(&mut self, model_name: &str) -> Result<(), LlmErrorKind> {
315        self.model_name = model_name.to_string();
316        let code = unsafe {
317            llm_set_model_request(&mut self.inner, model_name.as_ptr(), model_name.len() as _)
318        };
319        if code != 0 {
320            return Err(code.into());
321        }
322
323        // validate model is set correctly in host/runtime
324        let host_model = self.get_model()?;
325        if self.model_name != host_model {
326            eprintln!(
327                "Model not set correctly in host/runtime; model_name: {}, model_from_host: {}",
328                self.model_name, host_model
329            );
330            return Err(LlmErrorKind::ModelNotSet);
331        }
332        Ok(())
333    }
334
335    pub fn get_options(&self) -> Result<LlmOptions, LlmErrorKind> {
336        let mut buf = [0u8; u16::MAX as usize];
337        let mut num_bytes: u16 = 0;
338        let code = unsafe {
339            llm_get_model_options(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
340        };
341        if code != 0 {
342            return Err(code.into());
343        }
344
345        // Convert buffer slice to Vec<u8> and try to parse into LlmOptions
346        LlmOptions::try_from(buf[0..num_bytes as usize].to_vec())
347    }
348
349    pub fn set_options(&mut self, options: LlmOptions) -> Result<(), LlmErrorKind> {
350        let options_json = options.dump();
351        self.options = options;
352        let code = unsafe {
353            llm_set_model_options_request(
354                self.inner,
355                options_json.as_ptr(),
356                options_json.len() as _,
357            )
358        };
359        if code != 0 {
360            return Err(code.into());
361        }
362
363        // Verify options were set correctly
364        let host_options = self.get_options()?;
365        if self.options != host_options {
366            println!(
367                "Options not set correctly in host/runtime; options: {:?}, options_from_host: {:?}",
368                self.options, host_options
369            );
370            return Err(LlmErrorKind::ModelOptionsNotSet);
371        }
372        Ok(())
373    }
374
375    pub fn chat_request(&self, prompt: &str) -> Result<String, LlmErrorKind> {
376        // Perform the prompt request
377        let code = unsafe { llm_prompt_request(self.inner, prompt.as_ptr(), prompt.len() as _) };
378        if code != 0 {
379            return Err(code.into());
380        }
381
382        // Read the response
383        self.get_chat_response()
384    }
385
386    fn get_chat_response(&self) -> Result<String, LlmErrorKind> {
387        let mut buf = [0u8; u16::MAX as usize];
388        let mut num_bytes: u16 = 0;
389        let code = unsafe {
390            llm_read_prompt_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
391        };
392        if code != 0 {
393            return Err(code.into());
394        }
395
396        let response_vec = buf[0..num_bytes as usize].to_vec();
397        String::from_utf8(response_vec).map_err(|_| LlmErrorKind::Utf8Error)
398    }
399}
400
401impl Drop for BlocklessLlm {
402    fn drop(&mut self) {
403        let code = unsafe { llm_close(self.inner) };
404        if code != 0 {
405            eprintln!("Error closing LLM: {}", code);
406        }
407    }
408}
409
410#[derive(Debug)]
411pub enum LlmErrorKind {
412    ModelNotSet,               // 1
413    ModelNotSupported,         // 2
414    ModelInitializationFailed, // 3
415    ModelCompletionFailed,     // 4
416    ModelOptionsNotSet,        // 5
417    ModelShutdownFailed,       // 6
418    Utf8Error,                 // 7
419    RuntimeError,              // 8
420    MCPFunctionCallError,      // 9
421}
422
423impl From<u8> for LlmErrorKind {
424    fn from(code: u8) -> Self {
425        match code {
426            1 => LlmErrorKind::ModelNotSet,
427            2 => LlmErrorKind::ModelNotSupported,
428            3 => LlmErrorKind::ModelInitializationFailed,
429            4 => LlmErrorKind::ModelCompletionFailed,
430            5 => LlmErrorKind::ModelOptionsNotSet,
431            6 => LlmErrorKind::ModelShutdownFailed,
432            7 => LlmErrorKind::Utf8Error,
433            // 8 => LlmErrorKind::RuntimeError,
434            9 => LlmErrorKind::MCPFunctionCallError,
435            _ => LlmErrorKind::RuntimeError,
436        }
437    }
438}