Skip to main content

llama_cpp_v3/
lib.rs

1use llama_cpp_sys_v3::{LlamaLib, LoadError};
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5pub mod backend;
6pub mod downloader;
7
8pub use backend::Backend;
9
10#[derive(Debug, thiserror::Error)]
11pub enum LlamaError {
12    #[error("Failed to load DLL: {0}")]
13    DllLoad(#[from] LoadError),
14    #[error("Failed to download backend: {0}")]
15    Download(#[from] downloader::DownloadError),
16    #[error("Failed to initialize backend")]
17    BackendInit,
18    #[error("Failed to load model from file")]
19    ModelLoad,
20    #[error("Failed to create context")]
21    ContextCreate,
22    #[error("Decode error with status code {0}")]
23    Decode(i32),
24}
25
26pub struct LoadOptions<'a> {
27    pub backend: Backend,
28    pub app_name: &'a str,
29    pub version: Option<&'a str>,        // None = latest
30    pub explicit_path: Option<&'a Path>, // Absolute bypass of all resolution
31    pub cache_dir: Option<PathBuf>,      // Exact directory to save/load downloaded DLLs
32}
33
34/// The initialized Llama capabilities backend.
35/// Holds the DLL handle alive.
36pub struct LlamaBackend {
37    pub lib: Arc<LlamaLib>,
38}
39
40impl Drop for LlamaBackend {
41    fn drop(&mut self) {
42        // Technically backend free is global in llama.cpp
43        // We only free if we are the very last reference to the library object.
44        if Arc::strong_count(&self.lib) == 1 {
45            unsafe {
46                (self.lib.symbols.llama_backend_free)();
47            }
48        }
49    }
50}
51
52impl LlamaBackend {
53    /// Load the specified backend DLL, downloading it from GitHub releases if necessary.
54    pub fn load(options: LoadOptions<'_>) -> Result<Self, LlamaError> {
55        let dll_path = if let Some(path) = options.explicit_path {
56            path.to_path_buf()
57        } else if let Ok(env_path) = std::env::var("LLAMA_DLL_PATH") {
58            PathBuf::from(env_path)
59        } else {
60            // Attempt auto-download
61            downloader::Downloader::ensure_dll(
62                options.backend,
63                options.app_name,
64                options.version,
65                options.cache_dir,
66            )?
67        };
68
69        if let Some(parent) = dll_path.parent() {
70            if let Some(path_ext) = std::env::var_os("PATH") {
71                let mut paths = std::env::split_paths(&path_ext).collect::<Vec<_>>();
72                let parent_buf = parent.to_path_buf();
73                if !paths.contains(&parent_buf) {
74                    paths.insert(0, parent_buf);
75                    if let Ok(new_path) = std::env::join_paths(paths) {
76                        std::env::set_var("PATH", new_path);
77                    }
78                }
79            }
80        }
81
82        let lib = LlamaLib::open(&dll_path)?;
83
84        if let Some(parent) = dll_path.parent() {
85            let parent_str = parent.to_string_lossy().to_string();
86            let c_parent = std::ffi::CString::new(parent_str).unwrap();
87            unsafe {
88                (lib.symbols.ggml_backend_load_all_from_path)(c_parent.as_ptr());
89            }
90        } else {
91            unsafe {
92                (lib.symbols.ggml_backend_load_all)();
93            }
94        }
95
96        unsafe {
97            (lib.symbols.llama_backend_init)();
98        }
99
100        Ok(Self { lib: Arc::new(lib) })
101    }
102}
103
104/// A loaded GGUF model
105pub struct LlamaModel {
106    pub backend: Arc<LlamaLib>,
107    pub handle: *mut llama_cpp_sys_v3::llama_model,
108}
109
110impl Drop for LlamaModel {
111    fn drop(&mut self) {
112        unsafe {
113            (self.backend.symbols.llama_model_free)(self.handle);
114        }
115    }
116}
117
118unsafe impl Send for LlamaModel {}
119unsafe impl Sync for LlamaModel {}
120
121impl LlamaModel {
122    pub fn load_from_file(
123        backend: &LlamaBackend,
124        path: &str,
125        params: llama_cpp_sys_v3::llama_model_params,
126    ) -> Result<Self, LlamaError> {
127        let c_path = std::ffi::CString::new(path).map_err(|_| LlamaError::ModelLoad)?;
128        let handle =
129            unsafe { (backend.lib.symbols.llama_model_load_from_file)(c_path.as_ptr(), params) };
130
131        if handle.is_null() {
132            return Err(LlamaError::ModelLoad);
133        }
134
135        Ok(Self {
136            backend: backend.lib.clone(),
137            handle,
138        })
139    }
140
141    pub fn default_params(backend: &LlamaBackend) -> llama_cpp_sys_v3::llama_model_params {
142        unsafe { (backend.lib.symbols.llama_model_default_params)() }
143    }
144
145    pub fn get_vocab(&self) -> LlamaVocab {
146        let handle = unsafe { (self.backend.symbols.llama_model_get_vocab)(self.handle) };
147        LlamaVocab {
148            backend: self.backend.clone(),
149            handle,
150        }
151    }
152
153    pub fn tokenize(
154        &self,
155        text: &str,
156        add_special: bool,
157        parse_special: bool,
158    ) -> Result<Vec<llama_cpp_sys_v3::llama_token>, LlamaError> {
159        let vocab = self.get_vocab();
160        let c_text = std::ffi::CString::new(text).map_err(|_| LlamaError::ModelLoad)?;
161
162        // First call to get required size
163        let n_tokens = unsafe {
164            (self.backend.symbols.llama_tokenize)(
165                vocab.handle,
166                c_text.as_ptr(),
167                text.len() as i32,
168                std::ptr::null_mut(),
169                0,
170                add_special,
171                parse_special,
172            )
173        };
174
175        if n_tokens < 0 {
176            let mut tokens = vec![0; (-n_tokens) as usize];
177            let actual_tokens = unsafe {
178                (self.backend.symbols.llama_tokenize)(
179                    vocab.handle,
180                    c_text.as_ptr(),
181                    text.len() as i32,
182                    tokens.as_mut_ptr(),
183                    tokens.len() as i32,
184                    add_special,
185                    parse_special,
186                )
187            };
188            if actual_tokens < 0 {
189                return Err(LlamaError::Decode(actual_tokens));
190            }
191            tokens.truncate(actual_tokens as usize);
192            Ok(tokens)
193        } else {
194            let mut tokens = vec![0; n_tokens as usize];
195            let actual_tokens = unsafe {
196                (self.backend.symbols.llama_tokenize)(
197                    vocab.handle,
198                    c_text.as_ptr(),
199                    text.len() as i32,
200                    tokens.as_mut_ptr(),
201                    tokens.len() as i32,
202                    add_special,
203                    parse_special,
204                )
205            };
206            if actual_tokens < 0 {
207                return Err(LlamaError::Decode(actual_tokens));
208            }
209            tokens.truncate(actual_tokens as usize);
210            Ok(tokens)
211        }
212    }
213
214    pub fn token_to_piece(&self, token: llama_cpp_sys_v3::llama_token) -> String {
215        let vocab = self.get_vocab();
216        let mut buf = vec![0u8; 128];
217        let n = unsafe {
218            (self.backend.symbols.llama_token_to_piece)(
219                vocab.handle,
220                token,
221                buf.as_mut_ptr() as *mut std::ffi::c_char,
222                buf.len() as i32,
223                0,
224                true,
225            )
226        };
227
228        if n < 0 {
229            buf.resize((-n) as usize, 0);
230            unsafe {
231                (self.backend.symbols.llama_token_to_piece)(
232                    vocab.handle,
233                    token,
234                    buf.as_mut_ptr() as *mut std::ffi::c_char,
235                    buf.len() as i32,
236                    0,
237                    true,
238                );
239            }
240        } else {
241            buf.truncate(n as usize);
242        }
243
244        String::from_utf8_lossy(&buf).to_string()
245    }
246}
247
248pub struct LlamaVocab {
249    pub backend: Arc<LlamaLib>,
250    pub handle: *const llama_cpp_sys_v3::llama_vocab,
251}
252
253impl LlamaVocab {
254    pub fn bos(&self) -> llama_cpp_sys_v3::llama_token {
255        unsafe { (self.backend.symbols.llama_vocab_bos)(self.handle) }
256    }
257
258    pub fn eos(&self) -> llama_cpp_sys_v3::llama_token {
259        unsafe { (self.backend.symbols.llama_vocab_eos)(self.handle) }
260    }
261
262    pub fn is_eog(&self, token: llama_cpp_sys_v3::llama_token) -> bool {
263        unsafe { (self.backend.symbols.llama_vocab_is_eog)(self.handle, token) }
264    }
265}
266
267pub struct LlamaSampler {
268    pub backend: Arc<LlamaLib>,
269    pub handle: *mut llama_cpp_sys_v3::llama_sampler,
270}
271
272impl Drop for LlamaSampler {
273    fn drop(&mut self) {
274        unsafe {
275            (self.backend.symbols.llama_sampler_free)(self.handle);
276        }
277    }
278}
279
280impl LlamaSampler {
281    pub fn new_greedy(backend: Arc<LlamaLib>) -> Self {
282        let handle = unsafe { (backend.symbols.llama_sampler_init_greedy)() };
283        Self { backend, handle }
284    }
285
286    pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> llama_cpp_sys_v3::llama_token {
287        unsafe { (self.backend.symbols.llama_sampler_sample)(self.handle, ctx.handle, idx) }
288    }
289}
290
291/// Inference context attached to a model
292pub struct LlamaContext {
293    pub backend: Arc<LlamaLib>,
294    pub handle: *mut llama_cpp_sys_v3::llama_context,
295}
296
297impl Drop for LlamaContext {
298    fn drop(&mut self) {
299        unsafe {
300            (self.backend.symbols.llama_free)(self.handle);
301        }
302    }
303}
304
305impl LlamaContext {
306    pub fn new(
307        model: &LlamaModel,
308        params: llama_cpp_sys_v3::llama_context_params,
309    ) -> Result<Self, LlamaError> {
310        let handle = unsafe { (model.backend.symbols.llama_init_from_model)(model.handle, params) };
311
312        if handle.is_null() {
313            return Err(LlamaError::ContextCreate);
314        }
315
316        Ok(Self {
317            backend: model.backend.clone(),
318            handle,
319        })
320    }
321
322    pub fn default_params(model: &LlamaModel) -> llama_cpp_sys_v3::llama_context_params {
323        unsafe { (model.backend.symbols.llama_context_default_params)() }
324    }
325
326    pub fn decode(&mut self, batch: llama_cpp_sys_v3::llama_batch) -> Result<(), LlamaError> {
327        let res = unsafe { (self.backend.symbols.llama_decode)(self.handle, batch) };
328        if res != 0 {
329            Err(LlamaError::Decode(res))
330        } else {
331            Ok(())
332        }
333    }
334}