Skip to main content

llama_cpp_v3/
lib.rs

1use llama_cpp_sys_v3::{LlamaLib, LoadError};
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5pub mod backend;
6pub mod downloader;
7
8pub use backend::Backend;
9
10#[derive(Debug, thiserror::Error)]
11pub enum LlamaError {
12    #[error("Failed to load DLL: {0}")]
13    DllLoad(#[from] LoadError),
14    #[error("Failed to download backend: {0}")]
15    Download(#[from] downloader::DownloadError),
16    #[error("Failed to initialize backend")]
17    BackendInit,
18    #[error("Failed to load model from file")]
19    ModelLoad,
20    #[error("Failed to create context")]
21    ContextCreate,
22    #[error("Decode error with status code {0}")]
23    Decode(i32),
24    #[error("Missing or empty chat template")]
25    MissingChatTemplate,
26}
27
28pub struct LoadOptions<'a> {
29    pub backend: Backend,
30    pub app_name: &'a str,
31    pub version: Option<&'a str>,        // None = latest
32    pub explicit_path: Option<&'a Path>, // Absolute bypass of all resolution
33    pub cache_dir: Option<PathBuf>,      // Exact directory to save/load downloaded DLLs
34}
35
36/// The initialized Llama capabilities backend.
37/// Holds the DLL handle alive.
38pub struct LlamaBackend {
39    pub lib: Arc<LlamaLib>,
40}
41
42impl Drop for LlamaBackend {
43    fn drop(&mut self) {
44        // Technically backend free is global in llama.cpp
45        // We only free if we are the very last reference to the library object.
46        if Arc::strong_count(&self.lib) == 1 {
47            unsafe {
48                (self.lib.symbols.llama_backend_free)();
49            }
50        }
51    }
52}
53
54impl LlamaBackend {
55    /// Load the specified backend DLL, downloading it from GitHub releases if necessary.
56    pub fn load(options: LoadOptions<'_>) -> Result<Self, LlamaError> {
57        let dll_path = if let Some(path) = options.explicit_path {
58            path.to_path_buf()
59        } else if let Ok(env_path) = std::env::var("LLAMA_DLL_PATH") {
60            PathBuf::from(env_path)
61        } else {
62            // Attempt auto-download
63            downloader::Downloader::ensure_dll(
64                options.backend,
65                options.app_name,
66                options.version,
67                options.cache_dir,
68            )?
69        };
70
71        if let Some(parent) = dll_path.parent() {
72            if let Some(path_ext) = std::env::var_os("PATH") {
73                let mut paths = std::env::split_paths(&path_ext).collect::<Vec<_>>();
74                let parent_buf = parent.to_path_buf();
75                if !paths.contains(&parent_buf) {
76                    paths.insert(0, parent_buf);
77                    if let Ok(new_path) = std::env::join_paths(paths) {
78                        std::env::set_var("PATH", new_path);
79                    }
80                }
81            }
82        }
83
84        let lib = LlamaLib::open(&dll_path)?;
85
86        if let Some(parent) = dll_path.parent() {
87            let parent_str = parent.to_string_lossy().to_string();
88            let c_parent = std::ffi::CString::new(parent_str).unwrap();
89            unsafe {
90                (lib.symbols.ggml_backend_load_all_from_path)(c_parent.as_ptr());
91            }
92        } else {
93            unsafe {
94                (lib.symbols.ggml_backend_load_all)();
95            }
96        }
97
98        unsafe {
99            (lib.symbols.llama_backend_init)();
100        }
101
102        Ok(Self { lib: Arc::new(lib) })
103    }
104}
105
106/// A loaded GGUF model
107pub struct LlamaModel {
108    pub backend: Arc<LlamaLib>,
109    pub handle: *mut llama_cpp_sys_v3::llama_model,
110}
111
112impl Drop for LlamaModel {
113    fn drop(&mut self) {
114        unsafe {
115            (self.backend.symbols.llama_model_free)(self.handle);
116        }
117    }
118}
119
120unsafe impl Send for LlamaModel {}
121unsafe impl Sync for LlamaModel {}
122
123impl LlamaModel {
124    pub fn load_from_file(
125        backend: &LlamaBackend,
126        path: &str,
127        params: llama_cpp_sys_v3::llama_model_params,
128    ) -> Result<Self, LlamaError> {
129        let c_path = std::ffi::CString::new(path).map_err(|_| LlamaError::ModelLoad)?;
130        let handle =
131            unsafe { (backend.lib.symbols.llama_model_load_from_file)(c_path.as_ptr(), params) };
132
133        if handle.is_null() {
134            return Err(LlamaError::ModelLoad);
135        }
136
137        Ok(Self {
138            backend: backend.lib.clone(),
139            handle,
140        })
141    }
142
143    pub fn default_params(backend: &LlamaBackend) -> llama_cpp_sys_v3::llama_model_params {
144        unsafe { (backend.lib.symbols.llama_model_default_params)() }
145    }
146
147    pub fn get_vocab(&self) -> LlamaVocab {
148        let handle = unsafe { (self.backend.symbols.llama_model_get_vocab)(self.handle) };
149        LlamaVocab {
150            backend: self.backend.clone(),
151            handle,
152        }
153    }
154
155    pub fn tokenize(
156        &self,
157        text: &str,
158        add_special: bool,
159        parse_special: bool,
160    ) -> Result<Vec<llama_cpp_sys_v3::llama_token>, LlamaError> {
161        let vocab = self.get_vocab();
162        let c_text = std::ffi::CString::new(text).map_err(|_| LlamaError::ModelLoad)?;
163
164        // First call to get required size
165        let n_tokens = unsafe {
166            (self.backend.symbols.llama_tokenize)(
167                vocab.handle,
168                c_text.as_ptr(),
169                text.len() as i32,
170                std::ptr::null_mut(),
171                0,
172                add_special,
173                parse_special,
174            )
175        };
176
177        if n_tokens < 0 {
178            let mut tokens = vec![0; (-n_tokens) as usize];
179            let actual_tokens = unsafe {
180                (self.backend.symbols.llama_tokenize)(
181                    vocab.handle,
182                    c_text.as_ptr(),
183                    text.len() as i32,
184                    tokens.as_mut_ptr(),
185                    tokens.len() as i32,
186                    add_special,
187                    parse_special,
188                )
189            };
190            if actual_tokens < 0 {
191                return Err(LlamaError::Decode(actual_tokens));
192            }
193            tokens.truncate(actual_tokens as usize);
194            Ok(tokens)
195        } else {
196            let mut tokens = vec![0; n_tokens as usize];
197            let actual_tokens = unsafe {
198                (self.backend.symbols.llama_tokenize)(
199                    vocab.handle,
200                    c_text.as_ptr(),
201                    text.len() as i32,
202                    tokens.as_mut_ptr(),
203                    tokens.len() as i32,
204                    add_special,
205                    parse_special,
206                )
207            };
208            if actual_tokens < 0 {
209                return Err(LlamaError::Decode(actual_tokens));
210            }
211            tokens.truncate(actual_tokens as usize);
212            Ok(tokens)
213        }
214    }
215
216    pub fn token_to_piece(&self, token: llama_cpp_sys_v3::llama_token) -> String {
217        let vocab = self.get_vocab();
218        let mut buf = vec![0u8; 128];
219        let n = unsafe {
220            (self.backend.symbols.llama_token_to_piece)(
221                vocab.handle,
222                token,
223                buf.as_mut_ptr() as *mut std::ffi::c_char,
224                buf.len() as i32,
225                0,
226                true,
227            )
228        };
229
230        if n < 0 {
231            buf.resize((-n) as usize, 0);
232            unsafe {
233                (self.backend.symbols.llama_token_to_piece)(
234                    vocab.handle,
235                    token,
236                    buf.as_mut_ptr() as *mut std::ffi::c_char,
237                    buf.len() as i32,
238                    0,
239                    true,
240                );
241            }
242        } else {
243            buf.truncate(n as usize);
244        }
245
246        String::from_utf8_lossy(&buf).to_string()
247    }
248
249    pub fn apply_chat_template(
250        &self,
251        tmpl: Option<&str>,
252        messages: &[ChatMessage],
253        add_ass: bool,
254    ) -> Result<String, LlamaError> {
255        let resolved_tmpl = match tmpl {
256            Some(s) => s.to_string(),
257            None => self
258                .get_chat_template(None)
259                .ok_or(LlamaError::MissingChatTemplate)?,
260        };
261
262        if resolved_tmpl.trim().is_empty() {
263            return Err(LlamaError::MissingChatTemplate);
264        }
265
266        let c_tmpl = std::ffi::CString::new(resolved_tmpl).map_err(|_| LlamaError::ModelLoad)?;
267
268        let mut c_messages = Vec::with_capacity(messages.len());
269        let mut c_strings = Vec::with_capacity(messages.len() * 2);
270
271        for msg in messages {
272            let role =
273                std::ffi::CString::new(msg.role.as_str()).map_err(|_| LlamaError::ModelLoad)?;
274            let content =
275                std::ffi::CString::new(msg.content.as_str()).map_err(|_| LlamaError::ModelLoad)?;
276
277            let msg_struct = llama_cpp_sys_v3::llama_chat_message {
278                role: role.as_ptr(),
279                content: content.as_ptr(),
280            };
281
282            c_messages.push(msg_struct);
283            c_strings.push(role);
284            c_strings.push(content);
285        }
286
287        // First call to get required size
288        let n_chars = unsafe {
289            (self.backend.symbols.llama_chat_apply_template)(
290                c_tmpl.as_ptr(),
291                c_messages.as_ptr(),
292                c_messages.len(),
293                add_ass,
294                std::ptr::null_mut(),
295                0,
296            )
297        };
298
299        if n_chars < 0 {
300            return Err(LlamaError::Decode(n_chars));
301        }
302
303        let mut buf = vec![0u8; n_chars as usize + 1];
304        let actual_chars = unsafe {
305            (self.backend.symbols.llama_chat_apply_template)(
306                c_tmpl.as_ptr(),
307                c_messages.as_ptr(),
308                c_messages.len(),
309                add_ass,
310                buf.as_mut_ptr() as *mut std::ffi::c_char,
311                buf.len() as i32,
312            )
313        };
314
315        if actual_chars < 0 {
316            return Err(LlamaError::Decode(actual_chars));
317        }
318
319        buf.truncate(actual_chars as usize);
320        Ok(String::from_utf8_lossy(&buf).to_string())
321    }
322
323    pub fn get_chat_template(&self, name: Option<&str>) -> Option<String> {
324        let c_name = name.map(|s| std::ffi::CString::new(s).ok()).flatten();
325        let name_ptr = c_name
326            .as_ref()
327            .map(|c| c.as_ptr())
328            .unwrap_or(std::ptr::null());
329
330        let mut buf = vec![0u8; 1024];
331        let n = unsafe {
332            (self.backend.symbols.llama_model_chat_template)(
333                self.handle,
334                name_ptr,
335                buf.as_mut_ptr() as *mut std::ffi::c_char,
336                buf.len(),
337            )
338        };
339
340        if n < 0 {
341            return None;
342        }
343
344        if n as usize >= buf.len() {
345            buf.resize(n as usize + 1, 0);
346            unsafe {
347                (self.backend.symbols.llama_model_chat_template)(
348                    self.handle,
349                    name_ptr,
350                    buf.as_mut_ptr() as *mut std::ffi::c_char,
351                    buf.len(),
352                );
353            }
354        }
355
356        buf.truncate(n as usize);
357        Some(String::from_utf8_lossy(&buf).to_string())
358    }
359}
360
361pub struct ChatMessage {
362    pub role: String,
363    pub content: String,
364}
365
366pub struct LlamaVocab {
367    pub backend: Arc<LlamaLib>,
368    pub handle: *const llama_cpp_sys_v3::llama_vocab,
369}
370
371impl LlamaVocab {
372    pub fn bos(&self) -> llama_cpp_sys_v3::llama_token {
373        unsafe { (self.backend.symbols.llama_vocab_bos)(self.handle) }
374    }
375
376    pub fn eos(&self) -> llama_cpp_sys_v3::llama_token {
377        unsafe { (self.backend.symbols.llama_vocab_eos)(self.handle) }
378    }
379
380    pub fn is_eog(&self, token: llama_cpp_sys_v3::llama_token) -> bool {
381        unsafe { (self.backend.symbols.llama_vocab_is_eog)(self.handle, token) }
382    }
383}
384
385pub struct LlamaSampler {
386    pub backend: Arc<LlamaLib>,
387    pub handle: *mut llama_cpp_sys_v3::llama_sampler,
388}
389
390impl Drop for LlamaSampler {
391    fn drop(&mut self) {
392        unsafe {
393            (self.backend.symbols.llama_sampler_free)(self.handle);
394        }
395    }
396}
397
398impl LlamaSampler {
399    pub fn new_greedy(backend: Arc<LlamaLib>) -> Self {
400        let handle = unsafe { (backend.symbols.llama_sampler_init_greedy)() };
401        Self { backend, handle }
402    }
403
404    pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> llama_cpp_sys_v3::llama_token {
405        unsafe { (self.backend.symbols.llama_sampler_sample)(self.handle, ctx.handle, idx) }
406    }
407}
408
409/// Inference context attached to a model
410pub struct LlamaContext {
411    pub backend: Arc<LlamaLib>,
412    pub handle: *mut llama_cpp_sys_v3::llama_context,
413}
414
415impl Drop for LlamaContext {
416    fn drop(&mut self) {
417        unsafe {
418            (self.backend.symbols.llama_free)(self.handle);
419        }
420    }
421}
422
423impl LlamaContext {
424    pub fn new(
425        model: &LlamaModel,
426        params: llama_cpp_sys_v3::llama_context_params,
427    ) -> Result<Self, LlamaError> {
428        let handle = unsafe { (model.backend.symbols.llama_init_from_model)(model.handle, params) };
429
430        if handle.is_null() {
431            return Err(LlamaError::ContextCreate);
432        }
433
434        Ok(Self {
435            backend: model.backend.clone(),
436            handle,
437        })
438    }
439
440    pub fn default_params(model: &LlamaModel) -> llama_cpp_sys_v3::llama_context_params {
441        unsafe { (model.backend.symbols.llama_context_default_params)() }
442    }
443
444    pub fn decode(&mut self, batch: &LlamaBatch) -> Result<(), LlamaError> {
445        let res = unsafe { (self.backend.symbols.llama_decode)(self.handle, batch.handle) };
446        if res != 0 {
447            Err(LlamaError::Decode(res))
448        } else {
449            Ok(())
450        }
451    }
452
453    /// Clear the KV cache for this context.
454    /// Resets all cached key/value state, allowing the context to be reused
455    /// for a fresh generation without reallocating.
456    pub fn kv_cache_clear(&mut self) {
457        unsafe { (self.backend.symbols.llama_kv_cache_clear)(self.handle) }
458    }
459}
460
461pub struct LlamaBatch {
462    pub backend: Arc<LlamaLib>,
463    pub handle: llama_cpp_sys_v3::llama_batch,
464}
465
466impl Drop for LlamaBatch {
467    fn drop(&mut self) {
468        unsafe {
469            (self.backend.symbols.llama_batch_free)(self.handle);
470        }
471    }
472}
473
474impl LlamaBatch {
475    pub fn new(backend: Arc<LlamaLib>, n_tokens: i32, embd: i32, n_seq_max: i32) -> Self {
476        let handle = unsafe { (backend.symbols.llama_batch_init)(n_tokens, embd, n_seq_max) };
477        Self { backend, handle }
478    }
479
480    pub fn clear(&mut self) {
481        self.handle.n_tokens = 0;
482    }
483
484    pub fn add(
485        &mut self,
486        token: llama_cpp_sys_v3::llama_token,
487        pos: llama_cpp_sys_v3::llama_pos,
488        seq_ids: &[i32],
489        logits: bool,
490    ) {
491        let n = self.handle.n_tokens as usize;
492        unsafe {
493            *self.handle.token.add(n) = token;
494            *self.handle.pos.add(n) = pos;
495            *self.handle.n_seq_id.add(n) = seq_ids.len() as i32;
496            for (j, &seq_id) in seq_ids.iter().enumerate() {
497                *(*self.handle.seq_id.add(n)).add(j) = seq_id;
498            }
499            *self.handle.logits.add(n) = if logits { 1 } else { 0 };
500        }
501        self.handle.n_tokens += 1;
502    }
503}