Skip to main content

llama_cpp_v3/
lib.rs

1use llama_cpp_sys_v3::{LlamaLib, LoadError};
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5pub mod backend;
6pub mod downloader;
7
8pub use backend::Backend;
9
10#[derive(Debug, thiserror::Error)]
11pub enum LlamaError {
12    #[error("Failed to load DLL: {0}")]
13    DllLoad(#[from] LoadError),
14    #[error("Failed to download backend: {0}")]
15    Download(#[from] downloader::DownloadError),
16    #[error("Failed to initialize backend")]
17    BackendInit,
18    #[error("Failed to load model from file")]
19    ModelLoad,
20    #[error("Failed to create context")]
21    ContextCreate,
22    #[error("Decode error with status code {0}")]
23    Decode(i32),
24    #[error("Missing or empty chat template")]
25    MissingChatTemplate,
26    #[error("Invalid string (contains internal null byte)")]
27    InvalidString,
28    #[error("Failed to apply chat template (check template syntax)")]
29    TemplateApply,
30}
31
32pub struct LoadOptions<'a> {
33    pub backend: Backend,
34    pub app_name: &'a str,
35    pub version: Option<&'a str>,        // None = latest
36    pub explicit_path: Option<&'a Path>, // Absolute bypass of all resolution
37    pub cache_dir: Option<PathBuf>,      // Exact directory to save/load downloaded DLLs
38}
39
40/// The initialized Llama capabilities backend.
41/// Holds the DLL handle alive.
42pub struct LlamaBackend {
43    pub lib: Arc<LlamaLib>,
44}
45
46impl Drop for LlamaBackend {
47    fn drop(&mut self) {
48        // Technically backend free is global in llama.cpp
49        // We only free if we are the very last reference to the library object.
50        if Arc::strong_count(&self.lib) == 1 {
51            unsafe {
52                (self.lib.symbols.llama_backend_free)();
53            }
54        }
55    }
56}
57
58impl LlamaBackend {
59    /// Load the specified backend DLL, downloading it from GitHub releases if necessary.
60    pub fn load(options: LoadOptions<'_>) -> Result<Self, LlamaError> {
61        let dll_path = if let Some(path) = options.explicit_path {
62            path.to_path_buf()
63        } else if let Ok(env_path) = std::env::var("LLAMA_DLL_PATH") {
64            PathBuf::from(env_path)
65        } else {
66            // Attempt auto-download
67            downloader::Downloader::ensure_dll(
68                options.backend,
69                options.app_name,
70                options.version,
71                options.cache_dir,
72            )?
73        };
74
75        if let Some(parent) = dll_path.parent() {
76            if let Some(path_ext) = std::env::var_os("PATH") {
77                let mut paths = std::env::split_paths(&path_ext).collect::<Vec<_>>();
78                let parent_buf = parent.to_path_buf();
79                if !paths.contains(&parent_buf) {
80                    paths.insert(0, parent_buf);
81                    if let Ok(new_path) = std::env::join_paths(paths) {
82                        std::env::set_var("PATH", new_path);
83                    }
84                }
85            }
86        }
87
88        let lib = LlamaLib::open(&dll_path)?;
89
90        if let Some(parent) = dll_path.parent() {
91            let parent_str = parent.to_string_lossy().to_string();
92            let c_parent = std::ffi::CString::new(parent_str).unwrap();
93            unsafe {
94                (lib.symbols.ggml_backend_load_all_from_path)(c_parent.as_ptr());
95            }
96        } else {
97            unsafe {
98                (lib.symbols.ggml_backend_load_all)();
99            }
100        }
101
102        unsafe {
103            (lib.symbols.llama_backend_init)();
104        }
105
106        Ok(Self { lib: Arc::new(lib) })
107    }
108}
109
110/// A loaded GGUF model
111pub struct LlamaModel {
112    pub backend: Arc<LlamaLib>,
113    pub handle: *mut llama_cpp_sys_v3::llama_model,
114}
115
116impl Drop for LlamaModel {
117    fn drop(&mut self) {
118        unsafe {
119            (self.backend.symbols.llama_model_free)(self.handle);
120        }
121    }
122}
123
124unsafe impl Send for LlamaModel {}
125unsafe impl Sync for LlamaModel {}
126
127impl LlamaModel {
128    pub fn load_from_file(
129        backend: &LlamaBackend,
130        path: &str,
131        params: llama_cpp_sys_v3::llama_model_params,
132    ) -> Result<Self, LlamaError> {
133        let c_path = std::ffi::CString::new(path).map_err(|_| LlamaError::InvalidString)?;
134        let handle =
135            unsafe { (backend.lib.symbols.llama_model_load_from_file)(c_path.as_ptr(), params) };
136
137        if handle.is_null() {
138            return Err(LlamaError::ModelLoad);
139        }
140
141        Ok(Self {
142            backend: backend.lib.clone(),
143            handle,
144        })
145    }
146
147    pub fn default_params(backend: &LlamaBackend) -> llama_cpp_sys_v3::llama_model_params {
148        unsafe { (backend.lib.symbols.llama_model_default_params)() }
149    }
150
151    pub fn get_vocab(&self) -> LlamaVocab {
152        let handle = unsafe { (self.backend.symbols.llama_model_get_vocab)(self.handle) };
153        LlamaVocab {
154            backend: self.backend.clone(),
155            handle,
156        }
157    }
158
159    pub fn tokenize(
160        &self,
161        text: &str,
162        add_special: bool,
163        parse_special: bool,
164    ) -> Result<Vec<llama_cpp_sys_v3::llama_token>, LlamaError> {
165        let vocab = self.get_vocab();
166        let c_text = std::ffi::CString::new(text).map_err(|_| LlamaError::InvalidString)?;
167
168        // First call to get required size
169        let n_tokens = unsafe {
170            (self.backend.symbols.llama_tokenize)(
171                vocab.handle,
172                c_text.as_ptr(),
173                text.len() as i32,
174                std::ptr::null_mut(),
175                0,
176                add_special,
177                parse_special,
178            )
179        };
180
181        if n_tokens < 0 {
182            let mut tokens = vec![0; (-n_tokens) as usize];
183            let actual_tokens = unsafe {
184                (self.backend.symbols.llama_tokenize)(
185                    vocab.handle,
186                    c_text.as_ptr(),
187                    text.len() as i32,
188                    tokens.as_mut_ptr(),
189                    tokens.len() as i32,
190                    add_special,
191                    parse_special,
192                )
193            };
194            if actual_tokens < 0 {
195                return Err(LlamaError::Decode(actual_tokens));
196            }
197            tokens.truncate(actual_tokens as usize);
198            Ok(tokens)
199        } else {
200            let mut tokens = vec![0; n_tokens as usize];
201            let actual_tokens = unsafe {
202                (self.backend.symbols.llama_tokenize)(
203                    vocab.handle,
204                    c_text.as_ptr(),
205                    text.len() as i32,
206                    tokens.as_mut_ptr(),
207                    tokens.len() as i32,
208                    add_special,
209                    parse_special,
210                )
211            };
212            if actual_tokens < 0 {
213                return Err(LlamaError::Decode(actual_tokens));
214            }
215            tokens.truncate(actual_tokens as usize);
216            Ok(tokens)
217        }
218    }
219
220    pub fn token_to_piece(&self, token: llama_cpp_sys_v3::llama_token) -> String {
221        let vocab = self.get_vocab();
222        let mut buf = vec![0u8; 128];
223        let n = unsafe {
224            (self.backend.symbols.llama_token_to_piece)(
225                vocab.handle,
226                token,
227                buf.as_mut_ptr() as *mut std::ffi::c_char,
228                buf.len() as i32,
229                0,
230                true,
231            )
232        };
233
234        if n < 0 {
235            buf.resize((-n) as usize, 0);
236            unsafe {
237                (self.backend.symbols.llama_token_to_piece)(
238                    vocab.handle,
239                    token,
240                    buf.as_mut_ptr() as *mut std::ffi::c_char,
241                    buf.len() as i32,
242                    0,
243                    true,
244                );
245            }
246        } else {
247            buf.truncate(n as usize);
248        }
249
250        String::from_utf8_lossy(&buf).to_string()
251    }
252
253    pub fn apply_chat_template(
254        &self,
255        tmpl: Option<&str>,
256        messages: &[ChatMessage],
257        add_ass: bool,
258    ) -> Result<String, LlamaError> {
259        let resolved_tmpl = match tmpl {
260            Some(s) => s.to_string(),
261            None => self
262                .get_chat_template(None)
263                .ok_or(LlamaError::MissingChatTemplate)?,
264        };
265
266        if resolved_tmpl.trim().is_empty() {
267            return Err(LlamaError::MissingChatTemplate);
268        }
269
270        let c_tmpl = std::ffi::CString::new(resolved_tmpl).map_err(|_| LlamaError::InvalidString)?;
271
272        let mut c_messages = Vec::with_capacity(messages.len());
273        let mut c_strings = Vec::with_capacity(messages.len() * 2);
274
275        for msg in messages {
276            let role =
277                std::ffi::CString::new(msg.role.as_str()).map_err(|_| LlamaError::InvalidString)?;
278            let content =
279                std::ffi::CString::new(msg.content.as_str()).map_err(|_| LlamaError::InvalidString)?;
280
281            let msg_struct = llama_cpp_sys_v3::llama_chat_message {
282                role: role.as_ptr(),
283                content: content.as_ptr(),
284            };
285
286            c_messages.push(msg_struct);
287            c_strings.push(role);
288            c_strings.push(content);
289        }
290
291        // First call to get required size
292        let n_chars = unsafe {
293            (self.backend.symbols.llama_chat_apply_template)(
294                c_tmpl.as_ptr(),
295                c_messages.as_ptr(),
296                c_messages.len(),
297                add_ass,
298                std::ptr::null_mut(),
299                0,
300            )
301        };
302
303        if n_chars < 0 {
304            return Err(LlamaError::Decode(n_chars));
305        }
306
307        let mut buf = vec![0u8; n_chars as usize + 1];
308        let actual_chars = unsafe {
309            (self.backend.symbols.llama_chat_apply_template)(
310                c_tmpl.as_ptr(),
311                c_messages.as_ptr(),
312                c_messages.len(),
313                add_ass,
314                buf.as_mut_ptr() as *mut std::ffi::c_char,
315                buf.len() as i32,
316            )
317        };
318
319        if actual_chars < 0 {
320            return Err(LlamaError::Decode(actual_chars));
321        }
322
323        buf.truncate(actual_chars as usize);
324        Ok(String::from_utf8_lossy(&buf).to_string())
325    }
326
327    pub fn get_chat_template(&self, name: Option<&str>) -> Option<String> {
328        let c_name = name.map(|s| std::ffi::CString::new(s).ok()).flatten();
329        let name_ptr = c_name
330            .as_ref()
331            .map(|c| c.as_ptr())
332            .unwrap_or(std::ptr::null());
333
334        let mut buf = vec![0u8; 1024];
335        let n = unsafe {
336            (self.backend.symbols.llama_model_chat_template)(
337                self.handle,
338                name_ptr,
339                buf.as_mut_ptr() as *mut std::ffi::c_char,
340                buf.len(),
341            )
342        };
343
344        if n < 0 {
345            return None;
346        }
347
348        if n as usize >= buf.len() {
349            buf.resize(n as usize + 1, 0);
350            unsafe {
351                (self.backend.symbols.llama_model_chat_template)(
352                    self.handle,
353                    name_ptr,
354                    buf.as_mut_ptr() as *mut std::ffi::c_char,
355                    buf.len(),
356                );
357            }
358        }
359
360        buf.truncate(n as usize);
361        Some(String::from_utf8_lossy(&buf).to_string())
362    }
363}
364
365pub struct ChatMessage {
366    pub role: String,
367    pub content: String,
368}
369
370pub struct LlamaVocab {
371    pub backend: Arc<LlamaLib>,
372    pub handle: *const llama_cpp_sys_v3::llama_vocab,
373}
374
375impl LlamaVocab {
376    pub fn bos(&self) -> llama_cpp_sys_v3::llama_token {
377        unsafe { (self.backend.symbols.llama_vocab_bos)(self.handle) }
378    }
379
380    pub fn eos(&self) -> llama_cpp_sys_v3::llama_token {
381        unsafe { (self.backend.symbols.llama_vocab_eos)(self.handle) }
382    }
383
384    pub fn is_eog(&self, token: llama_cpp_sys_v3::llama_token) -> bool {
385        unsafe { (self.backend.symbols.llama_vocab_is_eog)(self.handle, token) }
386    }
387}
388
389pub struct LlamaSampler {
390    pub backend: Arc<LlamaLib>,
391    pub handle: *mut llama_cpp_sys_v3::llama_sampler,
392}
393
394impl Drop for LlamaSampler {
395    fn drop(&mut self) {
396        unsafe {
397            (self.backend.symbols.llama_sampler_free)(self.handle);
398        }
399    }
400}
401
402impl LlamaSampler {
403    pub fn new_chain(backend: Arc<LlamaLib>, no_perf: bool) -> Self {
404        let params = llama_cpp_sys_v3::llama_sampler_chain_params { no_perf };
405        let handle = unsafe { (backend.symbols.llama_sampler_chain_init)(params) };
406        Self { backend, handle }
407    }
408
409    pub fn new_greedy(backend: Arc<LlamaLib>) -> Self {
410        let handle = unsafe { (backend.symbols.llama_sampler_init_greedy)() };
411        Self { backend, handle }
412    }
413
414    pub fn new_temp(backend: Arc<LlamaLib>, temp: f32) -> Self {
415        let handle = unsafe { (backend.symbols.llama_sampler_init_temp)(temp) };
416        Self { backend, handle }
417    }
418
419    pub fn new_top_k(backend: Arc<LlamaLib>, k: i32) -> Self {
420        let handle = unsafe { (backend.symbols.llama_sampler_init_top_k)(k) };
421        Self { backend, handle }
422    }
423
424    pub fn new_top_p(backend: Arc<LlamaLib>, p: f32, min_keep: usize) -> Self {
425        let handle = unsafe { (backend.symbols.llama_sampler_init_top_p)(p, min_keep) };
426        Self { backend, handle }
427    }
428
429    pub fn new_min_p(backend: Arc<LlamaLib>, p: f32, min_keep: usize) -> Self {
430        let handle = unsafe { (backend.symbols.llama_sampler_init_min_p)(p, min_keep) };
431        Self { backend, handle }
432    }
433
434    pub fn new_typical(backend: Arc<LlamaLib>, p: f32, min_keep: usize) -> Self {
435        let handle = unsafe { (backend.symbols.llama_sampler_init_typical)(p, min_keep) };
436        Self { backend, handle }
437    }
438
439    pub fn new_mirostat_v2(backend: Arc<LlamaLib>, seed: u32, tau: f32, eta: f32) -> Self {
440        let handle = unsafe { (backend.symbols.llama_sampler_init_mirostat_v2)(seed, tau, eta) };
441        Self { backend, handle }
442    }
443
444    pub fn new_penalties(
445        backend: Arc<LlamaLib>,
446        last_n: i32,
447        repeat: f32,
448        freq: f32,
449        present: f32,
450    ) -> Self {
451        let handle = unsafe {
452            (backend.symbols.llama_sampler_init_penalties)(last_n, repeat, freq, present)
453        };
454        Self { backend, handle }
455    }
456
457    pub fn new_dist(backend: Arc<LlamaLib>, seed: u32) -> Self {
458        let handle = unsafe { (backend.symbols.llama_sampler_init_dist)(seed) };
459        Self { backend, handle }
460    }
461
462    pub fn add(&mut self, other: LlamaSampler) {
463        unsafe {
464            (self.backend.symbols.llama_sampler_chain_add)(self.handle, other.handle);
465        }
466        std::mem::forget(other);
467    }
468
469    pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> llama_cpp_sys_v3::llama_token {
470        unsafe { (self.backend.symbols.llama_sampler_sample)(self.handle, ctx.handle, idx) }
471    }
472
473    pub fn accept(&self, token: llama_cpp_sys_v3::llama_token) {
474        unsafe {
475            (self.backend.symbols.llama_sampler_accept)(self.handle, token);
476        }
477    }
478}
479
480/// Inference context attached to a model
481pub struct LlamaContext {
482    pub backend: Arc<LlamaLib>,
483    pub handle: *mut llama_cpp_sys_v3::llama_context,
484}
485
486impl Drop for LlamaContext {
487    fn drop(&mut self) {
488        unsafe {
489            (self.backend.symbols.llama_free)(self.handle);
490        }
491    }
492}
493
494unsafe impl Send for LlamaContext {}
495unsafe impl Sync for LlamaContext {}
496
497impl LlamaContext {
498    pub fn new(
499        model: &LlamaModel,
500        params: llama_cpp_sys_v3::llama_context_params,
501    ) -> Result<Self, LlamaError> {
502        let handle = unsafe { (model.backend.symbols.llama_init_from_model)(model.handle, params) };
503
504        if handle.is_null() {
505            return Err(LlamaError::ContextCreate);
506        }
507
508        Ok(Self {
509            backend: model.backend.clone(),
510            handle,
511        })
512    }
513
514    pub fn default_params(model: &LlamaModel) -> llama_cpp_sys_v3::llama_context_params {
515        unsafe { (model.backend.symbols.llama_context_default_params)() }
516    }
517
518    pub fn decode(&mut self, batch: &LlamaBatch) -> Result<(), LlamaError> {
519        let res = unsafe { (self.backend.symbols.llama_decode)(self.handle, batch.handle) };
520        if res != 0 {
521            Err(LlamaError::Decode(res))
522        } else {
523            Ok(())
524        }
525    }
526
527    /// Clear the KV cache for this context.
528    /// Resets all cached key/value state, allowing the context to be reused
529    /// for a fresh generation without reallocating.
530    pub fn kv_cache_clear(&mut self) {
531        unsafe {
532            let memory = (self.backend.symbols.llama_get_memory)(self.handle);
533            (self.backend.symbols.llama_memory_clear)(memory, true);
534        }
535    }
536
537    /// Remove KV cache entries for sequence `seq_id` in position range `[p0, p1)`.
538    ///
539    /// If `p0 < 0`, removes from the beginning. If `p1 < 0`, removes to the end.
540    /// Returns `true` if the operation succeeded.
541    ///
542    /// This is used for incremental prompt encoding: when the conversation
543    /// diverges from the cached prefix, only the divergent suffix needs to
544    /// be removed and re-decoded, avoiding a full KV cache clear.
545    pub fn kv_cache_seq_rm(
546        &mut self,
547        seq_id: llama_cpp_sys_v3::llama_seq_id,
548        p0: llama_cpp_sys_v3::llama_pos,
549        p1: llama_cpp_sys_v3::llama_pos,
550    ) -> bool {
551        unsafe {
552            let memory = (self.backend.symbols.llama_get_memory)(self.handle);
553            (self.backend.symbols.llama_memory_seq_rm)(memory, seq_id, p0, p1)
554        }
555    }
556}
557
558pub struct LlamaBatch {
559    pub backend: Arc<LlamaLib>,
560    pub handle: llama_cpp_sys_v3::llama_batch,
561}
562
563impl Drop for LlamaBatch {
564    fn drop(&mut self) {
565        unsafe {
566            (self.backend.symbols.llama_batch_free)(self.handle);
567        }
568    }
569}
570
571impl LlamaBatch {
572    pub fn new(backend: Arc<LlamaLib>, n_tokens: i32, embd: i32, n_seq_max: i32) -> Self {
573        let handle = unsafe { (backend.symbols.llama_batch_init)(n_tokens, embd, n_seq_max) };
574        Self { backend, handle }
575    }
576
577    pub fn clear(&mut self) {
578        self.handle.n_tokens = 0;
579    }
580
581    pub fn add(
582        &mut self,
583        token: llama_cpp_sys_v3::llama_token,
584        pos: llama_cpp_sys_v3::llama_pos,
585        seq_ids: &[i32],
586        logits: bool,
587    ) {
588        let n = self.handle.n_tokens as usize;
589        unsafe {
590            *self.handle.token.add(n) = token;
591            *self.handle.pos.add(n) = pos;
592            *self.handle.n_seq_id.add(n) = seq_ids.len() as i32;
593            for (j, &seq_id) in seq_ids.iter().enumerate() {
594                *(*self.handle.seq_id.add(n)).add(j) = seq_id;
595            }
596            *self.handle.logits.add(n) = if logits { 1 } else { 0 };
597        }
598        self.handle.n_tokens += 1;
599    }
600}