1use libloading::Library;
2use std::path::{Path, PathBuf};
3
4pub mod types;
5pub use types::*;
6
7#[derive(Debug, thiserror::Error)]
8pub enum LoadError {
9 #[error("DLL not found: {0}")]
10 NotFound(PathBuf),
11 #[error("Failed to load DLL: {0}")]
12 LoadFailed(#[from] libloading::Error),
13 #[error("Symbol not found: {0}")]
14 SymbolMissing(&'static str),
15}
16
17pub struct LlamaLib {
20 _libs: Vec<Library>,
22 pub symbols: LlamaSymbols,
23}
24
25macro_rules! resolve_symbols {
26 ($libs:expr, { $( $name:ident : $type:ty ),* $(,)? }) => {
27 LlamaSymbols {
28 $(
29 $name: {
30 let mut found = None;
31 for lib in $libs.iter() {
32 if let Ok(sym) = unsafe { lib.get::<$type>(stringify!($name).as_bytes()) } {
33 found = Some(*sym);
34 break;
35 }
36 }
37 found.ok_or(LoadError::SymbolMissing(stringify!($name)))?
38 },
39 )*
40 }
41 };
42}
43
44impl LlamaLib {
45 pub fn open(path: &Path) -> Result<Self, LoadError> {
47 if !path.exists() {
48 return Err(LoadError::NotFound(path.to_path_buf()));
49 }
50
51 let mut libs = Vec::new();
52
53 if let Some(parent) = path.parent() {
55 let ggml_path = parent.join("ggml.dll");
56 if ggml_path.exists() {
57 #[cfg(target_os = "windows")]
58 let lib = unsafe {
59 libloading::os::windows::Library::load_with_flags(
60 &ggml_path,
61 libloading::os::windows::LOAD_WITH_ALTERED_SEARCH_PATH,
62 )?
63 };
64 #[cfg(not(target_os = "windows"))]
65 let lib = unsafe { libloading::Library::new(&ggml_path)? };
66
67 libs.push(libloading::Library::from(lib));
68 }
69 }
70
71 #[cfg(target_os = "windows")]
72 let main_lib = unsafe {
73 libloading::os::windows::Library::load_with_flags(
74 path,
75 libloading::os::windows::LOAD_WITH_ALTERED_SEARCH_PATH,
76 )?
77 };
78 #[cfg(not(target_os = "windows"))]
79 let main_lib = unsafe { Library::new(path)? };
80
81 libs.push(libloading::Library::from(main_lib));
82
83 let symbols = resolve_symbols!(libs, {
85 llama_backend_init: unsafe extern "C" fn(),
86 llama_backend_free: unsafe extern "C" fn(),
87 ggml_backend_load_all: unsafe extern "C" fn(),
88 ggml_backend_load_all_from_path: unsafe extern "C" fn(*const std::ffi::c_char),
89
90 llama_model_default_params: unsafe extern "C" fn() -> llama_model_params,
91 llama_model_load_from_file: unsafe extern "C" fn(*const std::ffi::c_char, llama_model_params) -> *mut llama_model,
92 llama_model_free: unsafe extern "C" fn(*mut llama_model),
93
94 llama_context_default_params: unsafe extern "C" fn() -> llama_context_params,
95 llama_init_from_model: unsafe extern "C" fn(*mut llama_model, llama_context_params) -> *mut llama_context,
96 llama_free: unsafe extern "C" fn(*mut llama_context),
97
98 llama_batch_get_one: unsafe extern "C" fn(*mut llama_token, i32) -> llama_batch,
99 llama_batch_init: unsafe extern "C" fn(i32, i32, i32) -> llama_batch,
100 llama_batch_free: unsafe extern "C" fn(llama_batch),
101
102 llama_decode: unsafe extern "C" fn(*mut llama_context, llama_batch) -> i32,
103 llama_get_memory: unsafe extern "C" fn(*const llama_context) -> *mut llama_memory,
104 llama_memory_clear: unsafe extern "C" fn(*mut llama_memory, bool),
105 llama_memory_seq_rm: unsafe extern "C" fn(*mut llama_memory, llama_seq_id, llama_pos, llama_pos) -> bool,
106
107 llama_set_n_threads: unsafe extern "C" fn(*mut llama_context, u32, u32),
108 llama_model_get_vocab: unsafe extern "C" fn(*const llama_model) -> *const llama_vocab,
109 llama_vocab_n_tokens: unsafe extern "C" fn(*const llama_vocab) -> i32,
110 llama_n_vocab: unsafe extern "C" fn(*const llama_vocab) -> i32,
111 llama_n_ctx: unsafe extern "C" fn(*const llama_context) -> u32,
112
113 llama_get_logits: unsafe extern "C" fn(*mut llama_context) -> *mut f32,
114 llama_get_logits_ith: unsafe extern "C" fn(*mut llama_context, i32) -> *mut f32,
115
116 llama_token_get_text: unsafe extern "C" fn(*const llama_vocab, llama_token) -> *const std::ffi::c_char,
117 llama_tokenize: unsafe extern "C" fn(*const llama_vocab, *const std::ffi::c_char, i32, *mut llama_token, i32, bool, bool) -> i32,
118 llama_token_to_piece: unsafe extern "C" fn(*const llama_vocab, llama_token, *mut std::ffi::c_char, i32, i32, bool) -> i32,
119
120 llama_vocab_bos: unsafe extern "C" fn(*const llama_vocab) -> llama_token,
121 llama_vocab_eos: unsafe extern "C" fn(*const llama_vocab) -> llama_token,
122 llama_vocab_nl: unsafe extern "C" fn(*const llama_vocab) -> llama_token,
123 llama_vocab_is_eog: unsafe extern "C" fn(*const llama_vocab, llama_token) -> bool,
124
125 llama_print_system_info: unsafe extern "C" fn() -> *const std::ffi::c_char,
126
127 llama_sampler_chain_init: unsafe extern "C" fn(llama_sampler_chain_params) -> *mut llama_sampler,
129 llama_sampler_init_greedy: unsafe extern "C" fn() -> *mut llama_sampler,
130 llama_sampler_free: unsafe extern "C" fn(*mut llama_sampler),
131 llama_sampler_init_temp: unsafe extern "C" fn(f32) -> *mut llama_sampler,
132 llama_sampler_init_top_k: unsafe extern "C" fn(i32) -> *mut llama_sampler,
133 llama_sampler_init_top_p: unsafe extern "C" fn(f32, usize) -> *mut llama_sampler,
134 llama_sampler_init_dist: unsafe extern "C" fn(u32) -> *mut llama_sampler,
135 llama_sampler_init_min_p: unsafe extern "C" fn(f32, usize) -> *mut llama_sampler,
136 llama_sampler_init_typical: unsafe extern "C" fn(f32, usize) -> *mut llama_sampler,
137 llama_sampler_init_mirostat_v2: unsafe extern "C" fn(u32, f32, f32) -> *mut llama_sampler,
138 llama_sampler_init_penalties: unsafe extern "C" fn(i32, f32, f32, f32) -> *mut llama_sampler,
139 llama_sampler_chain_add: unsafe extern "C" fn(*mut llama_sampler, *mut llama_sampler),
140 llama_sampler_sample: unsafe extern "C" fn(*mut llama_sampler, *mut llama_context, i32) -> llama_token,
141 llama_sampler_accept: unsafe extern "C" fn(*mut llama_sampler, llama_token),
142 llama_chat_apply_template: unsafe extern "C" fn(*const std::ffi::c_char, *const llama_chat_message, usize, bool, *mut std::ffi::c_char, i32) -> i32,
143 llama_model_chat_template: unsafe extern "C" fn(*const llama_model, *const std::ffi::c_char, *mut std::ffi::c_char, usize) -> i32,
144 });
145
146 Ok(Self {
147 _libs: libs,
148 symbols,
149 })
150 }
151}
152
153pub struct LlamaSymbols {
154 pub llama_backend_init: unsafe extern "C" fn(),
155 pub llama_backend_free: unsafe extern "C" fn(),
156 pub ggml_backend_load_all: unsafe extern "C" fn(),
157 pub ggml_backend_load_all_from_path: unsafe extern "C" fn(*const std::ffi::c_char),
158
159 pub llama_model_default_params: unsafe extern "C" fn() -> llama_model_params,
160 pub llama_model_load_from_file:
161 unsafe extern "C" fn(*const std::ffi::c_char, llama_model_params) -> *mut llama_model,
162 pub llama_model_free: unsafe extern "C" fn(*mut llama_model),
163
164 pub llama_context_default_params: unsafe extern "C" fn() -> llama_context_params,
165 pub llama_init_from_model:
166 unsafe extern "C" fn(*mut llama_model, llama_context_params) -> *mut llama_context,
167 pub llama_free: unsafe extern "C" fn(*mut llama_context),
168
169 pub llama_batch_get_one: unsafe extern "C" fn(*mut llama_token, i32) -> llama_batch,
170 pub llama_batch_init: unsafe extern "C" fn(i32, i32, i32) -> llama_batch,
171 pub llama_batch_free: unsafe extern "C" fn(llama_batch),
172
173 pub llama_decode: unsafe extern "C" fn(*mut llama_context, llama_batch) -> i32,
174 pub llama_get_memory: unsafe extern "C" fn(*const llama_context) -> *mut llama_memory,
175 pub llama_memory_clear: unsafe extern "C" fn(*mut llama_memory, bool),
176 pub llama_memory_seq_rm:
177 unsafe extern "C" fn(*mut llama_memory, llama_seq_id, llama_pos, llama_pos) -> bool,
178
179 pub llama_set_n_threads: unsafe extern "C" fn(*mut llama_context, u32, u32),
180 pub llama_model_get_vocab: unsafe extern "C" fn(*const llama_model) -> *const llama_vocab,
181 pub llama_vocab_n_tokens: unsafe extern "C" fn(*const llama_vocab) -> i32,
182 pub llama_n_vocab: unsafe extern "C" fn(*const llama_vocab) -> i32,
183 pub llama_n_ctx: unsafe extern "C" fn(*const llama_context) -> u32,
184
185 pub llama_get_logits: unsafe extern "C" fn(*mut llama_context) -> *mut f32,
186 pub llama_get_logits_ith: unsafe extern "C" fn(*mut llama_context, i32) -> *mut f32,
187
188 pub llama_token_get_text:
189 unsafe extern "C" fn(*const llama_vocab, llama_token) -> *const std::ffi::c_char,
190 pub llama_tokenize: unsafe extern "C" fn(
191 *const llama_vocab,
192 *const std::ffi::c_char,
193 i32,
194 *mut llama_token,
195 i32,
196 bool,
197 bool,
198 ) -> i32,
199 pub llama_token_to_piece: unsafe extern "C" fn(
200 *const llama_vocab,
201 llama_token,
202 *mut std::ffi::c_char,
203 i32,
204 i32,
205 bool,
206 ) -> i32,
207
208 pub llama_vocab_bos: unsafe extern "C" fn(*const llama_vocab) -> llama_token,
209 pub llama_vocab_eos: unsafe extern "C" fn(*const llama_vocab) -> llama_token,
210 pub llama_vocab_nl: unsafe extern "C" fn(*const llama_vocab) -> llama_token,
211 pub llama_vocab_is_eog: unsafe extern "C" fn(*const llama_vocab, llama_token) -> bool,
212
213 pub llama_print_system_info: unsafe extern "C" fn() -> *const std::ffi::c_char,
214
215 pub llama_sampler_chain_init:
216 unsafe extern "C" fn(llama_sampler_chain_params) -> *mut llama_sampler,
217 pub llama_sampler_init_greedy: unsafe extern "C" fn() -> *mut llama_sampler,
218 pub llama_sampler_free: unsafe extern "C" fn(*mut llama_sampler),
219 pub llama_sampler_init_temp: unsafe extern "C" fn(f32) -> *mut llama_sampler,
220 pub llama_sampler_init_top_k: unsafe extern "C" fn(i32) -> *mut llama_sampler,
221 pub llama_sampler_init_top_p: unsafe extern "C" fn(f32, usize) -> *mut llama_sampler,
222 pub llama_sampler_init_dist: unsafe extern "C" fn(u32) -> *mut llama_sampler,
223 pub llama_sampler_init_min_p: unsafe extern "C" fn(f32, usize) -> *mut llama_sampler,
224 pub llama_sampler_init_typical: unsafe extern "C" fn(f32, usize) -> *mut llama_sampler,
225 pub llama_sampler_init_mirostat_v2: unsafe extern "C" fn(u32, f32, f32) -> *mut llama_sampler,
226 pub llama_sampler_init_penalties:
227 unsafe extern "C" fn(i32, f32, f32, f32) -> *mut llama_sampler,
228 pub llama_sampler_chain_add: unsafe extern "C" fn(*mut llama_sampler, *mut llama_sampler),
229 pub llama_sampler_sample:
230 unsafe extern "C" fn(*mut llama_sampler, *mut llama_context, i32) -> llama_token,
231 pub llama_sampler_accept: unsafe extern "C" fn(*mut llama_sampler, llama_token),
232 pub llama_chat_apply_template: unsafe extern "C" fn(
233 *const std::ffi::c_char,
234 *const llama_chat_message,
235 usize,
236 bool,
237 *mut std::ffi::c_char,
238 i32,
239 ) -> i32,
240 pub llama_model_chat_template: unsafe extern "C" fn(
241 *const llama_model,
242 *const std::ffi::c_char,
243 *mut std::ffi::c_char,
244 usize,
245 ) -> i32,
246}