blockless_sdk/
llm.rs

1use json::JsonValue;
2use std::{str::FromStr, string::ToString};
3
4type Handle = u32;
5type ExitCode = u8;
6
7#[cfg(not(feature = "mock-ffi"))]
8#[link(wasm_import_module = "blockless_llm")]
9extern "C" {
10    fn llm_set_model_request(h: *mut Handle, model_ptr: *const u8, model_len: u8) -> ExitCode;
11    fn llm_get_model_response(
12        h: Handle,
13        buf: *mut u8,
14        buf_len: u8,
15        bytes_written: *mut u8,
16    ) -> ExitCode;
17    fn llm_set_model_options_request(
18        h: Handle,
19        options_ptr: *const u8,
20        options_len: u16,
21    ) -> ExitCode;
22    fn llm_get_model_options(
23        h: Handle,
24        buf: *mut u8,
25        buf_len: u16,
26        bytes_written: *mut u16,
27    ) -> ExitCode;
28    fn llm_prompt_request(h: Handle, prompt_ptr: *const u8, prompt_len: u16) -> ExitCode;
29    fn llm_read_prompt_response(
30        h: Handle,
31        buf: *mut u8,
32        buf_len: u16,
33        bytes_written: *mut u16,
34    ) -> ExitCode;
35    fn llm_close(h: Handle) -> ExitCode;
36}
37
38#[cfg(feature = "mock-ffi")]
39#[allow(unused_variables)]
40mod mock_ffi {
41    use super::*;
42
43    pub unsafe fn llm_set_model_request(
44        h: *mut Handle,
45        _model_ptr: *const u8,
46        _model_len: u8,
47    ) -> ExitCode {
48        unimplemented!()
49    }
50
51    pub unsafe fn llm_get_model_response(
52        _h: Handle,
53        buf: *mut u8,
54        buf_len: u8,
55        bytes_written: *mut u8,
56    ) -> ExitCode {
57        unimplemented!()
58    }
59
60    pub unsafe fn llm_set_model_options_request(
61        _h: Handle,
62        _options_ptr: *const u8,
63        _options_len: u16,
64    ) -> ExitCode {
65        unimplemented!()
66    }
67
68    pub unsafe fn llm_get_model_options(
69        _h: Handle,
70        buf: *mut u8,
71        buf_len: u16,
72        bytes_written: *mut u16,
73    ) -> ExitCode {
74        unimplemented!()
75    }
76
77    pub unsafe fn llm_prompt_request(
78        _h: Handle,
79        _prompt_ptr: *const u8,
80        _prompt_len: u16,
81    ) -> ExitCode {
82        unimplemented!()
83    }
84
85    pub unsafe fn llm_read_prompt_response(
86        _h: Handle,
87        buf: *mut u8,
88        buf_len: u16,
89        bytes_written: *mut u16,
90    ) -> ExitCode {
91        unimplemented!()
92    }
93
94    pub unsafe fn llm_close(_h: Handle) -> ExitCode {
95        unimplemented!()
96    }
97}
98
99#[cfg(feature = "mock-ffi")]
100use mock_ffi::*;
101
102#[derive(Debug, Clone)]
103pub enum Models {
104    Llama321BInstruct(Option<String>),
105    Llama323BInstruct(Option<String>),
106    Mistral7BInstructV03(Option<String>),
107    Mixtral8x7BInstructV01(Option<String>),
108    Gemma22BInstruct(Option<String>),
109    Gemma27BInstruct(Option<String>),
110    Gemma29BInstruct(Option<String>),
111    Custom(String),
112}
113
114impl FromStr for Models {
115    type Err = String;
116    fn from_str(s: &str) -> Result<Self, Self::Err> {
117        match s {
118            // Llama 3.2 1B
119            "Llama-3.2-1B-Instruct" => Ok(Models::Llama321BInstruct(None)),
120            "Llama-3.2-1B-Instruct-Q6_K"
121            | "Llama-3.2-1B-Instruct_Q6_K"
122            | "Llama-3.2-1B-Instruct.Q6_K" => {
123                Ok(Models::Llama321BInstruct(Some("Q6_K".to_string())))
124            }
125            "Llama-3.2-1B-Instruct-q4f16_1" | "Llama-3.2-1B-Instruct.q4f16_1" => {
126                Ok(Models::Llama321BInstruct(Some("q4f16_1".to_string())))
127            }
128
129            // Llama 3.2 3B
130            "Llama-3.2-3B-Instruct" => Ok(Models::Llama323BInstruct(None)),
131            "Llama-3.2-3B-Instruct-Q6_K"
132            | "Llama-3.2-3B-Instruct_Q6_K"
133            | "Llama-3.2-3B-Instruct.Q6_K" => {
134                Ok(Models::Llama323BInstruct(Some("Q6_K".to_string())))
135            }
136            "Llama-3.2-3B-Instruct-q4f16_1" | "Llama-3.2-3B-Instruct.q4f16_1" => {
137                Ok(Models::Llama323BInstruct(Some("q4f16_1".to_string())))
138            }
139
140            // Mistral 7B
141            "Mistral-7B-Instruct-v0.3" => Ok(Models::Mistral7BInstructV03(None)),
142            "Mistral-7B-Instruct-v0.3-q4f16_1" | "Mistral-7B-Instruct-v0.3.q4f16_1" => {
143                Ok(Models::Mistral7BInstructV03(Some("q4f16_1".to_string())))
144            }
145
146            // Mixtral 8x7B
147            "Mixtral-8x7B-Instruct-v0.1" => Ok(Models::Mixtral8x7BInstructV01(None)),
148            "Mixtral-8x7B-Instruct-v0.1-q4f16_1" | "Mixtral-8x7B-Instruct-v0.1.q4f16_1" => {
149                Ok(Models::Mixtral8x7BInstructV01(Some("q4f16_1".to_string())))
150            }
151
152            // Gemma models
153            "gemma-2-2b-it" => Ok(Models::Gemma22BInstruct(None)),
154            "gemma-2-2b-it-q4f16_1" | "gemma-2-2b-it.q4f16_1" => {
155                Ok(Models::Gemma22BInstruct(Some("q4f16_1".to_string())))
156            }
157
158            "gemma-2-27b-it" => Ok(Models::Gemma27BInstruct(None)),
159            "gemma-2-27b-it-q4f16_1" | "gemma-2-27b-it.q4f16_1" => {
160                Ok(Models::Gemma27BInstruct(Some("q4f16_1".to_string())))
161            }
162
163            "gemma-2-9b-it" => Ok(Models::Gemma29BInstruct(None)),
164            "gemma-2-9b-it-q4f16_1" | "gemma-2-9b-it.q4f16_1" => {
165                Ok(Models::Gemma29BInstruct(Some("q4f16_1".to_string())))
166            }
167            _ => Ok(Models::Custom(s.to_string())),
168        }
169    }
170}
171
172impl std::fmt::Display for Models {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        match self {
175            Models::Llama321BInstruct(_) => write!(f, "Llama-3.2-1B-Instruct"),
176            Models::Llama323BInstruct(_) => write!(f, "Llama-3.2-3B-Instruct"),
177            Models::Mistral7BInstructV03(_) => write!(f, "Mistral-7B-Instruct-v0.3"),
178            Models::Mixtral8x7BInstructV01(_) => write!(f, "Mixtral-8x7B-Instruct-v0.1"),
179            Models::Gemma22BInstruct(_) => write!(f, "gemma-2-2b-it"),
180            Models::Gemma27BInstruct(_) => write!(f, "gemma-2-27b-it"),
181            Models::Gemma29BInstruct(_) => write!(f, "gemma-2-9b-it"),
182            Models::Custom(s) => write!(f, "{}", s),
183        }
184    }
185}
186
187#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
188pub struct BlocklessLlm {
189    inner: Handle,
190    model_name: String,
191    options: LlmOptions,
192}
193
194#[derive(Debug, Clone, Default, PartialEq, serde::Serialize, serde::Deserialize)]
195pub struct LlmOptions {
196    pub system_message: Option<String>,
197    pub tools_sse_urls: Option<Vec<String>>,
198    pub temperature: Option<f32>,
199    pub top_p: Option<f32>,
200}
201
202impl LlmOptions {
203    pub fn with_system_message(mut self, system_message: String) -> Self {
204        self.system_message = Some(system_message);
205        self
206    }
207
208    pub fn with_tools_sse_urls(mut self, tools_sse_urls: Vec<String>) -> Self {
209        self.tools_sse_urls = Some(tools_sse_urls);
210        self
211    }
212
213    fn dump(&self) -> Vec<u8> {
214        let mut json = JsonValue::new_object();
215
216        if let Some(system_message) = &self.system_message {
217            json["system_message"] = system_message.clone().into();
218        }
219
220        if let Some(tools_sse_urls) = &self.tools_sse_urls {
221            json["tools_sse_urls"] = tools_sse_urls.clone().into();
222        }
223
224        if let Some(temperature) = self.temperature {
225            json["temperature"] = temperature.into();
226        }
227        if let Some(top_p) = self.top_p {
228            json["top_p"] = top_p.into();
229        }
230
231        // If json is empty, return an empty JSON object
232        if json.entries().count() == 0 {
233            return "{}".as_bytes().to_vec();
234        }
235
236        json.dump().into_bytes()
237    }
238}
239
240impl std::fmt::Display for LlmOptions {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        let bytes = self.dump();
243        match String::from_utf8(bytes) {
244            Ok(s) => write!(f, "{}", s),
245            Err(_) => write!(f, "<invalid utf8>"),
246        }
247    }
248}
249
250impl TryFrom<Vec<u8>> for LlmOptions {
251    type Error = LlmErrorKind;
252
253    fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
254        // Convert bytes to UTF-8 string
255        let json_str = String::from_utf8(bytes).map_err(|_| LlmErrorKind::Utf8Error)?;
256
257        // Parse the JSON string
258        let json = json::parse(&json_str).map_err(|_| LlmErrorKind::ModelOptionsNotSet)?;
259
260        // Extract system_message
261        let system_message = json["system_message"].as_str().map(|s| s.to_string());
262
263        // Extract tools_sse_urls - can be an array or a comma-separated string
264        let tools_sse_urls = if json["tools_sse_urls"].is_array() {
265            // Handle array format - native runtime
266            Some(
267                json["tools_sse_urls"]
268                    .members()
269                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
270                    .collect(),
271            )
272        } else {
273            json["tools_sse_urls"]
274                .as_str()
275                .map(|s| s.split(',').map(|s| s.trim().to_string()).collect())
276        };
277
278        Ok(LlmOptions {
279            system_message,
280            tools_sse_urls,
281            temperature: json["temperature"].as_f32(),
282            top_p: json["top_p"].as_f32(),
283        })
284    }
285}
286
287impl BlocklessLlm {
288    pub fn new(model: Models) -> Result<Self, LlmErrorKind> {
289        let model_name = model.to_string();
290        let mut llm: BlocklessLlm = Default::default();
291        llm.set_model(&model_name)?;
292        Ok(llm)
293    }
294
295    pub fn handle(&self) -> Handle {
296        self.inner
297    }
298
299    pub fn get_model(&self) -> Result<String, LlmErrorKind> {
300        let mut buf = [0u8; u8::MAX as usize];
301        let mut num_bytes: u8 = 0;
302        let code = unsafe {
303            llm_get_model_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
304        };
305        if code != 0 {
306            return Err(code.into());
307        }
308        let model = String::from_utf8(buf[0..num_bytes as _].to_vec()).unwrap();
309        Ok(model)
310    }
311
312    pub fn set_model(&mut self, model_name: &str) -> Result<(), LlmErrorKind> {
313        self.model_name = model_name.to_string();
314        let code = unsafe {
315            llm_set_model_request(&mut self.inner, model_name.as_ptr(), model_name.len() as _)
316        };
317        if code != 0 {
318            return Err(code.into());
319        }
320
321        // validate model is set correctly in host/runtime
322        let host_model = self.get_model()?;
323        if self.model_name != host_model {
324            eprintln!(
325                "Model not set correctly in host/runtime; model_name: {}, model_from_host: {}",
326                self.model_name, host_model
327            );
328            return Err(LlmErrorKind::ModelNotSet);
329        }
330        Ok(())
331    }
332
333    pub fn get_options(&self) -> Result<LlmOptions, LlmErrorKind> {
334        let mut buf = [0u8; u16::MAX as usize];
335        let mut num_bytes: u16 = 0;
336        let code = unsafe {
337            llm_get_model_options(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
338        };
339        if code != 0 {
340            return Err(code.into());
341        }
342
343        // Convert buffer slice to Vec<u8> and try to parse into LlmOptions
344        LlmOptions::try_from(buf[0..num_bytes as usize].to_vec())
345    }
346
347    pub fn set_options(&mut self, options: LlmOptions) -> Result<(), LlmErrorKind> {
348        let options_json = options.dump();
349        self.options = options;
350        let code = unsafe {
351            llm_set_model_options_request(
352                self.inner,
353                options_json.as_ptr(),
354                options_json.len() as _,
355            )
356        };
357        if code != 0 {
358            return Err(code.into());
359        }
360
361        // Verify options were set correctly
362        let host_options = self.get_options()?;
363        if self.options != host_options {
364            println!(
365                "Options not set correctly in host/runtime; options: {:?}, options_from_host: {:?}",
366                self.options, host_options
367            );
368            return Err(LlmErrorKind::ModelOptionsNotSet);
369        }
370        Ok(())
371    }
372
373    pub fn chat_request(&self, prompt: &str) -> Result<String, LlmErrorKind> {
374        // Perform the prompt request
375        let code = unsafe { llm_prompt_request(self.inner, prompt.as_ptr(), prompt.len() as _) };
376        if code != 0 {
377            return Err(code.into());
378        }
379
380        // Read the response
381        self.get_chat_response()
382    }
383
384    fn get_chat_response(&self) -> Result<String, LlmErrorKind> {
385        let mut buf = [0u8; u16::MAX as usize];
386        let mut num_bytes: u16 = 0;
387        let code = unsafe {
388            llm_read_prompt_response(self.inner, buf.as_mut_ptr(), buf.len() as _, &mut num_bytes)
389        };
390        if code != 0 {
391            return Err(code.into());
392        }
393
394        let response_vec = buf[0..num_bytes as usize].to_vec();
395        String::from_utf8(response_vec).map_err(|_| LlmErrorKind::Utf8Error)
396    }
397}
398
399impl Drop for BlocklessLlm {
400    fn drop(&mut self) {
401        let code = unsafe { llm_close(self.inner) };
402        if code != 0 {
403            eprintln!("Error closing LLM: {}", code);
404        }
405    }
406}
407
408#[derive(Debug)]
409pub enum LlmErrorKind {
410    ModelNotSet,               // 1
411    ModelNotSupported,         // 2
412    ModelInitializationFailed, // 3
413    ModelCompletionFailed,     // 4
414    ModelOptionsNotSet,        // 5
415    ModelShutdownFailed,       // 6
416    Utf8Error,                 // 7
417    RuntimeError,              // 8
418    MCPFunctionCallError,      // 9
419}
420
421impl From<u8> for LlmErrorKind {
422    fn from(code: u8) -> Self {
423        match code {
424            1 => LlmErrorKind::ModelNotSet,
425            2 => LlmErrorKind::ModelNotSupported,
426            3 => LlmErrorKind::ModelInitializationFailed,
427            4 => LlmErrorKind::ModelCompletionFailed,
428            5 => LlmErrorKind::ModelOptionsNotSet,
429            6 => LlmErrorKind::ModelShutdownFailed,
430            7 => LlmErrorKind::Utf8Error,
431            // 8 => LlmErrorKind::RuntimeError,
432            9 => LlmErrorKind::MCPFunctionCallError,
433            _ => LlmErrorKind::RuntimeError,
434        }
435    }
436}