Skip to main content

llama_cpp_sys_v3/
lib.rs

1use libloading::Library;
2use std::path::{Path, PathBuf};
3
4pub mod types;
5pub use types::*;
6
7#[derive(Debug, thiserror::Error)]
8pub enum LoadError {
9    #[error("DLL not found: {0}")]
10    NotFound(PathBuf),
11    #[error("Failed to load DLL: {0}")]
12    LoadFailed(#[from] libloading::Error),
13    #[error("Symbol not found: {0}")]
14    SymbolMissing(&'static str),
15}
16
17/// A loaded instance of the llama.cpp dynamic library.
18/// This struct holds the library handle and all resolved function pointers.
19pub struct LlamaLib {
20    // We must keep the libraries alive as long as the functions are used.
21    _libs: Vec<Library>,
22    pub symbols: LlamaSymbols,
23}
24
25macro_rules! resolve_symbols {
26    ($libs:expr, { $( $name:ident : $type:ty ),* $(,)? }) => {
27        LlamaSymbols {
28            $(
29                $name: {
30                    let mut found = None;
31                    for lib in $libs.iter() {
32                        if let Ok(sym) = unsafe { lib.get::<$type>(stringify!($name).as_bytes()) } {
33                            found = Some(*sym);
34                            break;
35                        }
36                    }
37                    found.ok_or(LoadError::SymbolMissing(stringify!($name)))?
38                },
39            )*
40        }
41    };
42}
43
44impl LlamaLib {
45    /// Attempt to load the llama.cpp library from the given path.
46    pub fn open(path: &Path) -> Result<Self, LoadError> {
47        if !path.exists() {
48            return Err(LoadError::NotFound(path.to_path_buf()));
49        }
50
51        let mut libs = Vec::new();
52
53        // Try to load ggml.dll first if it's in the same directory
54        if let Some(parent) = path.parent() {
55            let ggml_path = parent.join("ggml.dll");
56            if ggml_path.exists() {
57                #[cfg(target_os = "windows")]
58                let lib = unsafe {
59                    libloading::os::windows::Library::load_with_flags(
60                        &ggml_path,
61                        libloading::os::windows::LOAD_WITH_ALTERED_SEARCH_PATH,
62                    )?
63                };
64                #[cfg(not(target_os = "windows"))]
65                let lib = unsafe { libloading::Library::new(&ggml_path)? };
66
67                libs.push(libloading::Library::from(lib));
68            }
69        }
70
71        #[cfg(target_os = "windows")]
72        let main_lib = unsafe {
73            libloading::os::windows::Library::load_with_flags(
74                path,
75                libloading::os::windows::LOAD_WITH_ALTERED_SEARCH_PATH,
76            )?
77        };
78        #[cfg(not(target_os = "windows"))]
79        let main_lib = unsafe { Library::new(path)? };
80
81        libs.push(libloading::Library::from(main_lib));
82
83        // Resolve all required symbols here
84        let symbols = resolve_symbols!(libs, {
85            llama_backend_init: unsafe extern "C" fn(),
86            llama_backend_free: unsafe extern "C" fn(),
87            ggml_backend_load_all: unsafe extern "C" fn(),
88            ggml_backend_load_all_from_path: unsafe extern "C" fn(*const std::ffi::c_char),
89
90            llama_model_default_params: unsafe extern "C" fn() -> llama_model_params,
91            llama_model_load_from_file: unsafe extern "C" fn(*const std::ffi::c_char, llama_model_params) -> *mut llama_model,
92            llama_model_free: unsafe extern "C" fn(*mut llama_model),
93
94            llama_context_default_params: unsafe extern "C" fn() -> llama_context_params,
95            llama_init_from_model: unsafe extern "C" fn(*mut llama_model, llama_context_params) -> *mut llama_context,
96            llama_free: unsafe extern "C" fn(*mut llama_context),
97
98            llama_batch_get_one: unsafe extern "C" fn(*mut llama_token, i32) -> llama_batch,
99            llama_batch_init: unsafe extern "C" fn(i32, i32, i32) -> llama_batch,
100            llama_batch_free: unsafe extern "C" fn(llama_batch),
101
102            llama_decode: unsafe extern "C" fn(*mut llama_context, llama_batch) -> i32,
103            llama_get_memory: unsafe extern "C" fn(*const llama_context) -> *mut llama_memory,
104            llama_memory_clear: unsafe extern "C" fn(*mut llama_memory, bool),
105            llama_memory_seq_rm: unsafe extern "C" fn(*mut llama_memory, llama_seq_id, llama_pos, llama_pos) -> bool,
106
107            llama_set_n_threads: unsafe extern "C" fn(*mut llama_context, u32, u32),
108            llama_model_get_vocab: unsafe extern "C" fn(*const llama_model) -> *const llama_vocab,
109            llama_vocab_n_tokens: unsafe extern "C" fn(*const llama_vocab) -> i32,
110            llama_n_vocab: unsafe extern "C" fn(*const llama_vocab) -> i32,
111            llama_n_ctx: unsafe extern "C" fn(*const llama_context) -> u32,
112
113            llama_get_logits: unsafe extern "C" fn(*mut llama_context) -> *mut f32,
114            llama_get_logits_ith: unsafe extern "C" fn(*mut llama_context, i32) -> *mut f32,
115
116            llama_token_get_text: unsafe extern "C" fn(*const llama_vocab, llama_token) -> *const std::ffi::c_char,
117            llama_tokenize: unsafe extern "C" fn(*const llama_vocab, *const std::ffi::c_char, i32, *mut llama_token, i32, bool, bool) -> i32,
118            llama_token_to_piece: unsafe extern "C" fn(*const llama_vocab, llama_token, *mut std::ffi::c_char, i32, i32, bool) -> i32,
119
120            llama_vocab_bos: unsafe extern "C" fn(*const llama_vocab) -> llama_token,
121            llama_vocab_eos: unsafe extern "C" fn(*const llama_vocab) -> llama_token,
122            llama_vocab_nl: unsafe extern "C" fn(*const llama_vocab) -> llama_token,
123            llama_vocab_is_eog: unsafe extern "C" fn(*const llama_vocab, llama_token) -> bool,
124
125            llama_print_system_info: unsafe extern "C" fn() -> *const std::ffi::c_char,
126
127            // Sampler API
128            llama_sampler_chain_init: unsafe extern "C" fn(llama_sampler_chain_params) -> *mut llama_sampler,
129            llama_sampler_init_greedy: unsafe extern "C" fn() -> *mut llama_sampler,
130            llama_sampler_free: unsafe extern "C" fn(*mut llama_sampler),
131            llama_sampler_init_temp: unsafe extern "C" fn(f32) -> *mut llama_sampler,
132            llama_sampler_init_top_k: unsafe extern "C" fn(i32) -> *mut llama_sampler,
133            llama_sampler_init_top_p: unsafe extern "C" fn(f32, usize) -> *mut llama_sampler,
134            llama_sampler_init_dist: unsafe extern "C" fn(u32) -> *mut llama_sampler,
135            llama_sampler_init_min_p: unsafe extern "C" fn(f32, usize) -> *mut llama_sampler,
136            llama_sampler_init_typical: unsafe extern "C" fn(f32, usize) -> *mut llama_sampler,
137            llama_sampler_init_mirostat_v2: unsafe extern "C" fn(u32, f32, f32) -> *mut llama_sampler,
138            llama_sampler_init_penalties: unsafe extern "C" fn(i32, f32, f32, f32) -> *mut llama_sampler,
139            llama_sampler_chain_add: unsafe extern "C" fn(*mut llama_sampler, *mut llama_sampler),
140            llama_sampler_sample: unsafe extern "C" fn(*mut llama_sampler, *mut llama_context, i32) -> llama_token,
141            llama_sampler_accept: unsafe extern "C" fn(*mut llama_sampler, llama_token),
142            llama_chat_apply_template: unsafe extern "C" fn(*const std::ffi::c_char, *const llama_chat_message, usize, bool, *mut std::ffi::c_char, i32) -> i32,
143            llama_model_chat_template: unsafe extern "C" fn(*const llama_model, *const std::ffi::c_char, *mut std::ffi::c_char, usize) -> i32,
144        });
145
146        Ok(Self {
147            _libs: libs,
148            symbols,
149        })
150    }
151}
152
153pub struct LlamaSymbols {
154    pub llama_backend_init: unsafe extern "C" fn(),
155    pub llama_backend_free: unsafe extern "C" fn(),
156    pub ggml_backend_load_all: unsafe extern "C" fn(),
157    pub ggml_backend_load_all_from_path: unsafe extern "C" fn(*const std::ffi::c_char),
158
159    pub llama_model_default_params: unsafe extern "C" fn() -> llama_model_params,
160    pub llama_model_load_from_file:
161        unsafe extern "C" fn(*const std::ffi::c_char, llama_model_params) -> *mut llama_model,
162    pub llama_model_free: unsafe extern "C" fn(*mut llama_model),
163
164    pub llama_context_default_params: unsafe extern "C" fn() -> llama_context_params,
165    pub llama_init_from_model:
166        unsafe extern "C" fn(*mut llama_model, llama_context_params) -> *mut llama_context,
167    pub llama_free: unsafe extern "C" fn(*mut llama_context),
168
169    pub llama_batch_get_one: unsafe extern "C" fn(*mut llama_token, i32) -> llama_batch,
170    pub llama_batch_init: unsafe extern "C" fn(i32, i32, i32) -> llama_batch,
171    pub llama_batch_free: unsafe extern "C" fn(llama_batch),
172
173    pub llama_decode: unsafe extern "C" fn(*mut llama_context, llama_batch) -> i32,
174    pub llama_get_memory: unsafe extern "C" fn(*const llama_context) -> *mut llama_memory,
175    pub llama_memory_clear: unsafe extern "C" fn(*mut llama_memory, bool),
176    pub llama_memory_seq_rm:
177        unsafe extern "C" fn(*mut llama_memory, llama_seq_id, llama_pos, llama_pos) -> bool,
178
179    pub llama_set_n_threads: unsafe extern "C" fn(*mut llama_context, u32, u32),
180    pub llama_model_get_vocab: unsafe extern "C" fn(*const llama_model) -> *const llama_vocab,
181    pub llama_vocab_n_tokens: unsafe extern "C" fn(*const llama_vocab) -> i32,
182    pub llama_n_vocab: unsafe extern "C" fn(*const llama_vocab) -> i32,
183    pub llama_n_ctx: unsafe extern "C" fn(*const llama_context) -> u32,
184
185    pub llama_get_logits: unsafe extern "C" fn(*mut llama_context) -> *mut f32,
186    pub llama_get_logits_ith: unsafe extern "C" fn(*mut llama_context, i32) -> *mut f32,
187
188    pub llama_token_get_text:
189        unsafe extern "C" fn(*const llama_vocab, llama_token) -> *const std::ffi::c_char,
190    pub llama_tokenize: unsafe extern "C" fn(
191        *const llama_vocab,
192        *const std::ffi::c_char,
193        i32,
194        *mut llama_token,
195        i32,
196        bool,
197        bool,
198    ) -> i32,
199    pub llama_token_to_piece: unsafe extern "C" fn(
200        *const llama_vocab,
201        llama_token,
202        *mut std::ffi::c_char,
203        i32,
204        i32,
205        bool,
206    ) -> i32,
207
208    pub llama_vocab_bos: unsafe extern "C" fn(*const llama_vocab) -> llama_token,
209    pub llama_vocab_eos: unsafe extern "C" fn(*const llama_vocab) -> llama_token,
210    pub llama_vocab_nl: unsafe extern "C" fn(*const llama_vocab) -> llama_token,
211    pub llama_vocab_is_eog: unsafe extern "C" fn(*const llama_vocab, llama_token) -> bool,
212
213    pub llama_print_system_info: unsafe extern "C" fn() -> *const std::ffi::c_char,
214
215    pub llama_sampler_chain_init:
216        unsafe extern "C" fn(llama_sampler_chain_params) -> *mut llama_sampler,
217    pub llama_sampler_init_greedy: unsafe extern "C" fn() -> *mut llama_sampler,
218    pub llama_sampler_free: unsafe extern "C" fn(*mut llama_sampler),
219    pub llama_sampler_init_temp: unsafe extern "C" fn(f32) -> *mut llama_sampler,
220    pub llama_sampler_init_top_k: unsafe extern "C" fn(i32) -> *mut llama_sampler,
221    pub llama_sampler_init_top_p: unsafe extern "C" fn(f32, usize) -> *mut llama_sampler,
222    pub llama_sampler_init_dist: unsafe extern "C" fn(u32) -> *mut llama_sampler,
223    pub llama_sampler_init_min_p: unsafe extern "C" fn(f32, usize) -> *mut llama_sampler,
224    pub llama_sampler_init_typical: unsafe extern "C" fn(f32, usize) -> *mut llama_sampler,
225    pub llama_sampler_init_mirostat_v2: unsafe extern "C" fn(u32, f32, f32) -> *mut llama_sampler,
226    pub llama_sampler_init_penalties:
227        unsafe extern "C" fn(i32, f32, f32, f32) -> *mut llama_sampler,
228    pub llama_sampler_chain_add: unsafe extern "C" fn(*mut llama_sampler, *mut llama_sampler),
229    pub llama_sampler_sample:
230        unsafe extern "C" fn(*mut llama_sampler, *mut llama_context, i32) -> llama_token,
231    pub llama_sampler_accept: unsafe extern "C" fn(*mut llama_sampler, llama_token),
232    pub llama_chat_apply_template: unsafe extern "C" fn(
233        *const std::ffi::c_char,
234        *const llama_chat_message,
235        usize,
236        bool,
237        *mut std::ffi::c_char,
238        i32,
239    ) -> i32,
240    pub llama_model_chat_template: unsafe extern "C" fn(
241        *const llama_model,
242        *const std::ffi::c_char,
243        *mut std::ffi::c_char,
244        usize,
245    ) -> i32,
246}