1use llama_cpp_sys_v3::{LlamaLib, LoadError};
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5pub mod backend;
6pub mod downloader;
7
8pub use backend::Backend;
9
10#[derive(Debug, thiserror::Error)]
11pub enum LlamaError {
12 #[error("Failed to load DLL: {0}")]
13 DllLoad(#[from] LoadError),
14 #[error("Failed to download backend: {0}")]
15 Download(#[from] downloader::DownloadError),
16 #[error("Failed to initialize backend")]
17 BackendInit,
18 #[error("Failed to load model from file")]
19 ModelLoad,
20 #[error("Failed to create context")]
21 ContextCreate,
22 #[error("Decode error with status code {0}")]
23 Decode(i32),
24}
25
26pub struct LoadOptions<'a> {
27 pub backend: Backend,
28 pub app_name: &'a str,
29 pub version: Option<&'a str>, pub explicit_path: Option<&'a Path>, pub cache_dir: Option<PathBuf>, }
33
34pub struct LlamaBackend {
37 pub lib: Arc<LlamaLib>,
38}
39
40impl Drop for LlamaBackend {
41 fn drop(&mut self) {
42 if Arc::strong_count(&self.lib) == 1 {
45 unsafe {
46 (self.lib.symbols.llama_backend_free)();
47 }
48 }
49 }
50}
51
52impl LlamaBackend {
53 pub fn load(options: LoadOptions<'_>) -> Result<Self, LlamaError> {
55 let dll_path = if let Some(path) = options.explicit_path {
56 path.to_path_buf()
57 } else if let Ok(env_path) = std::env::var("LLAMA_DLL_PATH") {
58 PathBuf::from(env_path)
59 } else {
60 downloader::Downloader::ensure_dll(
62 options.backend,
63 options.app_name,
64 options.version,
65 options.cache_dir,
66 )?
67 };
68
69 if let Some(parent) = dll_path.parent() {
70 if let Some(path_ext) = std::env::var_os("PATH") {
71 let mut paths = std::env::split_paths(&path_ext).collect::<Vec<_>>();
72 let parent_buf = parent.to_path_buf();
73 if !paths.contains(&parent_buf) {
74 paths.insert(0, parent_buf);
75 if let Ok(new_path) = std::env::join_paths(paths) {
76 std::env::set_var("PATH", new_path);
77 }
78 }
79 }
80 }
81
82 let lib = LlamaLib::open(&dll_path)?;
83
84 if let Some(parent) = dll_path.parent() {
85 let parent_str = parent.to_string_lossy().to_string();
86 let c_parent = std::ffi::CString::new(parent_str).unwrap();
87 unsafe {
88 (lib.symbols.ggml_backend_load_all_from_path)(c_parent.as_ptr());
89 }
90 } else {
91 unsafe {
92 (lib.symbols.ggml_backend_load_all)();
93 }
94 }
95
96 unsafe {
97 (lib.symbols.llama_backend_init)();
98 }
99
100 Ok(Self { lib: Arc::new(lib) })
101 }
102}
103
104pub struct LlamaModel {
106 pub backend: Arc<LlamaLib>,
107 pub handle: *mut llama_cpp_sys_v3::llama_model,
108}
109
110impl Drop for LlamaModel {
111 fn drop(&mut self) {
112 unsafe {
113 (self.backend.symbols.llama_model_free)(self.handle);
114 }
115 }
116}
117
118unsafe impl Send for LlamaModel {}
119unsafe impl Sync for LlamaModel {}
120
121impl LlamaModel {
122 pub fn load_from_file(
123 backend: &LlamaBackend,
124 path: &str,
125 params: llama_cpp_sys_v3::llama_model_params,
126 ) -> Result<Self, LlamaError> {
127 let c_path = std::ffi::CString::new(path).map_err(|_| LlamaError::ModelLoad)?;
128 let handle =
129 unsafe { (backend.lib.symbols.llama_model_load_from_file)(c_path.as_ptr(), params) };
130
131 if handle.is_null() {
132 return Err(LlamaError::ModelLoad);
133 }
134
135 Ok(Self {
136 backend: backend.lib.clone(),
137 handle,
138 })
139 }
140
141 pub fn default_params(backend: &LlamaBackend) -> llama_cpp_sys_v3::llama_model_params {
142 unsafe { (backend.lib.symbols.llama_model_default_params)() }
143 }
144
145 pub fn get_vocab(&self) -> LlamaVocab {
146 let handle = unsafe { (self.backend.symbols.llama_model_get_vocab)(self.handle) };
147 LlamaVocab {
148 backend: self.backend.clone(),
149 handle,
150 }
151 }
152
153 pub fn tokenize(
154 &self,
155 text: &str,
156 add_special: bool,
157 parse_special: bool,
158 ) -> Result<Vec<llama_cpp_sys_v3::llama_token>, LlamaError> {
159 let vocab = self.get_vocab();
160 let c_text = std::ffi::CString::new(text).map_err(|_| LlamaError::ModelLoad)?;
161
162 let n_tokens = unsafe {
164 (self.backend.symbols.llama_tokenize)(
165 vocab.handle,
166 c_text.as_ptr(),
167 text.len() as i32,
168 std::ptr::null_mut(),
169 0,
170 add_special,
171 parse_special,
172 )
173 };
174
175 if n_tokens < 0 {
176 let mut tokens = vec![0; (-n_tokens) as usize];
177 let actual_tokens = unsafe {
178 (self.backend.symbols.llama_tokenize)(
179 vocab.handle,
180 c_text.as_ptr(),
181 text.len() as i32,
182 tokens.as_mut_ptr(),
183 tokens.len() as i32,
184 add_special,
185 parse_special,
186 )
187 };
188 if actual_tokens < 0 {
189 return Err(LlamaError::Decode(actual_tokens));
190 }
191 tokens.truncate(actual_tokens as usize);
192 Ok(tokens)
193 } else {
194 let mut tokens = vec![0; n_tokens as usize];
195 let actual_tokens = unsafe {
196 (self.backend.symbols.llama_tokenize)(
197 vocab.handle,
198 c_text.as_ptr(),
199 text.len() as i32,
200 tokens.as_mut_ptr(),
201 tokens.len() as i32,
202 add_special,
203 parse_special,
204 )
205 };
206 if actual_tokens < 0 {
207 return Err(LlamaError::Decode(actual_tokens));
208 }
209 tokens.truncate(actual_tokens as usize);
210 Ok(tokens)
211 }
212 }
213
214 pub fn token_to_piece(&self, token: llama_cpp_sys_v3::llama_token) -> String {
215 let vocab = self.get_vocab();
216 let mut buf = vec![0u8; 128];
217 let n = unsafe {
218 (self.backend.symbols.llama_token_to_piece)(
219 vocab.handle,
220 token,
221 buf.as_mut_ptr() as *mut std::ffi::c_char,
222 buf.len() as i32,
223 0,
224 true,
225 )
226 };
227
228 if n < 0 {
229 buf.resize((-n) as usize, 0);
230 unsafe {
231 (self.backend.symbols.llama_token_to_piece)(
232 vocab.handle,
233 token,
234 buf.as_mut_ptr() as *mut std::ffi::c_char,
235 buf.len() as i32,
236 0,
237 true,
238 );
239 }
240 } else {
241 buf.truncate(n as usize);
242 }
243
244 String::from_utf8_lossy(&buf).to_string()
245 }
246}
247
248pub struct LlamaVocab {
249 pub backend: Arc<LlamaLib>,
250 pub handle: *const llama_cpp_sys_v3::llama_vocab,
251}
252
253impl LlamaVocab {
254 pub fn bos(&self) -> llama_cpp_sys_v3::llama_token {
255 unsafe { (self.backend.symbols.llama_vocab_bos)(self.handle) }
256 }
257
258 pub fn eos(&self) -> llama_cpp_sys_v3::llama_token {
259 unsafe { (self.backend.symbols.llama_vocab_eos)(self.handle) }
260 }
261
262 pub fn is_eog(&self, token: llama_cpp_sys_v3::llama_token) -> bool {
263 unsafe { (self.backend.symbols.llama_vocab_is_eog)(self.handle, token) }
264 }
265}
266
267pub struct LlamaSampler {
268 pub backend: Arc<LlamaLib>,
269 pub handle: *mut llama_cpp_sys_v3::llama_sampler,
270}
271
272impl Drop for LlamaSampler {
273 fn drop(&mut self) {
274 unsafe {
275 (self.backend.symbols.llama_sampler_free)(self.handle);
276 }
277 }
278}
279
280impl LlamaSampler {
281 pub fn new_greedy(backend: Arc<LlamaLib>) -> Self {
282 let handle = unsafe { (backend.symbols.llama_sampler_init_greedy)() };
283 Self { backend, handle }
284 }
285
286 pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> llama_cpp_sys_v3::llama_token {
287 unsafe { (self.backend.symbols.llama_sampler_sample)(self.handle, ctx.handle, idx) }
288 }
289}
290
291pub struct LlamaContext {
293 pub backend: Arc<LlamaLib>,
294 pub handle: *mut llama_cpp_sys_v3::llama_context,
295}
296
297impl Drop for LlamaContext {
298 fn drop(&mut self) {
299 unsafe {
300 (self.backend.symbols.llama_free)(self.handle);
301 }
302 }
303}
304
305impl LlamaContext {
306 pub fn new(
307 model: &LlamaModel,
308 params: llama_cpp_sys_v3::llama_context_params,
309 ) -> Result<Self, LlamaError> {
310 let handle = unsafe { (model.backend.symbols.llama_init_from_model)(model.handle, params) };
311
312 if handle.is_null() {
313 return Err(LlamaError::ContextCreate);
314 }
315
316 Ok(Self {
317 backend: model.backend.clone(),
318 handle,
319 })
320 }
321
322 pub fn default_params(model: &LlamaModel) -> llama_cpp_sys_v3::llama_context_params {
323 unsafe { (model.backend.symbols.llama_context_default_params)() }
324 }
325
326 pub fn decode(&mut self, batch: llama_cpp_sys_v3::llama_batch) -> Result<(), LlamaError> {
327 let res = unsafe { (self.backend.symbols.llama_decode)(self.handle, batch) };
328 if res != 0 {
329 Err(LlamaError::Decode(res))
330 } else {
331 Ok(())
332 }
333 }
334}