1use llama_cpp_sys_v3::{LlamaLib, LoadError};
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5pub mod backend;
6pub mod downloader;
7
8pub use backend::Backend;
9
10#[derive(Debug, thiserror::Error)]
11pub enum LlamaError {
12 #[error("Failed to load DLL: {0}")]
13 DllLoad(#[from] LoadError),
14 #[error("Failed to download backend: {0}")]
15 Download(#[from] downloader::DownloadError),
16 #[error("Failed to initialize backend")]
17 BackendInit,
18 #[error("Failed to load model from file")]
19 ModelLoad,
20 #[error("Failed to create context")]
21 ContextCreate,
22 #[error("Decode error with status code {0}")]
23 Decode(i32),
24 #[error("Missing or empty chat template")]
25 MissingChatTemplate,
26}
27
28pub struct LoadOptions<'a> {
29 pub backend: Backend,
30 pub app_name: &'a str,
31 pub version: Option<&'a str>, pub explicit_path: Option<&'a Path>, pub cache_dir: Option<PathBuf>, }
35
36pub struct LlamaBackend {
39 pub lib: Arc<LlamaLib>,
40}
41
42impl Drop for LlamaBackend {
43 fn drop(&mut self) {
44 if Arc::strong_count(&self.lib) == 1 {
47 unsafe {
48 (self.lib.symbols.llama_backend_free)();
49 }
50 }
51 }
52}
53
54impl LlamaBackend {
55 pub fn load(options: LoadOptions<'_>) -> Result<Self, LlamaError> {
57 let dll_path = if let Some(path) = options.explicit_path {
58 path.to_path_buf()
59 } else if let Ok(env_path) = std::env::var("LLAMA_DLL_PATH") {
60 PathBuf::from(env_path)
61 } else {
62 downloader::Downloader::ensure_dll(
64 options.backend,
65 options.app_name,
66 options.version,
67 options.cache_dir,
68 )?
69 };
70
71 if let Some(parent) = dll_path.parent() {
72 if let Some(path_ext) = std::env::var_os("PATH") {
73 let mut paths = std::env::split_paths(&path_ext).collect::<Vec<_>>();
74 let parent_buf = parent.to_path_buf();
75 if !paths.contains(&parent_buf) {
76 paths.insert(0, parent_buf);
77 if let Ok(new_path) = std::env::join_paths(paths) {
78 std::env::set_var("PATH", new_path);
79 }
80 }
81 }
82 }
83
84 let lib = LlamaLib::open(&dll_path)?;
85
86 if let Some(parent) = dll_path.parent() {
87 let parent_str = parent.to_string_lossy().to_string();
88 let c_parent = std::ffi::CString::new(parent_str).unwrap();
89 unsafe {
90 (lib.symbols.ggml_backend_load_all_from_path)(c_parent.as_ptr());
91 }
92 } else {
93 unsafe {
94 (lib.symbols.ggml_backend_load_all)();
95 }
96 }
97
98 unsafe {
99 (lib.symbols.llama_backend_init)();
100 }
101
102 Ok(Self { lib: Arc::new(lib) })
103 }
104}
105
106pub struct LlamaModel {
108 pub backend: Arc<LlamaLib>,
109 pub handle: *mut llama_cpp_sys_v3::llama_model,
110}
111
112impl Drop for LlamaModel {
113 fn drop(&mut self) {
114 unsafe {
115 (self.backend.symbols.llama_model_free)(self.handle);
116 }
117 }
118}
119
120unsafe impl Send for LlamaModel {}
121unsafe impl Sync for LlamaModel {}
122
123impl LlamaModel {
124 pub fn load_from_file(
125 backend: &LlamaBackend,
126 path: &str,
127 params: llama_cpp_sys_v3::llama_model_params,
128 ) -> Result<Self, LlamaError> {
129 let c_path = std::ffi::CString::new(path).map_err(|_| LlamaError::ModelLoad)?;
130 let handle =
131 unsafe { (backend.lib.symbols.llama_model_load_from_file)(c_path.as_ptr(), params) };
132
133 if handle.is_null() {
134 return Err(LlamaError::ModelLoad);
135 }
136
137 Ok(Self {
138 backend: backend.lib.clone(),
139 handle,
140 })
141 }
142
143 pub fn default_params(backend: &LlamaBackend) -> llama_cpp_sys_v3::llama_model_params {
144 unsafe { (backend.lib.symbols.llama_model_default_params)() }
145 }
146
147 pub fn get_vocab(&self) -> LlamaVocab {
148 let handle = unsafe { (self.backend.symbols.llama_model_get_vocab)(self.handle) };
149 LlamaVocab {
150 backend: self.backend.clone(),
151 handle,
152 }
153 }
154
155 pub fn tokenize(
156 &self,
157 text: &str,
158 add_special: bool,
159 parse_special: bool,
160 ) -> Result<Vec<llama_cpp_sys_v3::llama_token>, LlamaError> {
161 let vocab = self.get_vocab();
162 let c_text = std::ffi::CString::new(text).map_err(|_| LlamaError::ModelLoad)?;
163
164 let n_tokens = unsafe {
166 (self.backend.symbols.llama_tokenize)(
167 vocab.handle,
168 c_text.as_ptr(),
169 text.len() as i32,
170 std::ptr::null_mut(),
171 0,
172 add_special,
173 parse_special,
174 )
175 };
176
177 if n_tokens < 0 {
178 let mut tokens = vec![0; (-n_tokens) as usize];
179 let actual_tokens = unsafe {
180 (self.backend.symbols.llama_tokenize)(
181 vocab.handle,
182 c_text.as_ptr(),
183 text.len() as i32,
184 tokens.as_mut_ptr(),
185 tokens.len() as i32,
186 add_special,
187 parse_special,
188 )
189 };
190 if actual_tokens < 0 {
191 return Err(LlamaError::Decode(actual_tokens));
192 }
193 tokens.truncate(actual_tokens as usize);
194 Ok(tokens)
195 } else {
196 let mut tokens = vec![0; n_tokens as usize];
197 let actual_tokens = unsafe {
198 (self.backend.symbols.llama_tokenize)(
199 vocab.handle,
200 c_text.as_ptr(),
201 text.len() as i32,
202 tokens.as_mut_ptr(),
203 tokens.len() as i32,
204 add_special,
205 parse_special,
206 )
207 };
208 if actual_tokens < 0 {
209 return Err(LlamaError::Decode(actual_tokens));
210 }
211 tokens.truncate(actual_tokens as usize);
212 Ok(tokens)
213 }
214 }
215
216 pub fn token_to_piece(&self, token: llama_cpp_sys_v3::llama_token) -> String {
217 let vocab = self.get_vocab();
218 let mut buf = vec![0u8; 128];
219 let n = unsafe {
220 (self.backend.symbols.llama_token_to_piece)(
221 vocab.handle,
222 token,
223 buf.as_mut_ptr() as *mut std::ffi::c_char,
224 buf.len() as i32,
225 0,
226 true,
227 )
228 };
229
230 if n < 0 {
231 buf.resize((-n) as usize, 0);
232 unsafe {
233 (self.backend.symbols.llama_token_to_piece)(
234 vocab.handle,
235 token,
236 buf.as_mut_ptr() as *mut std::ffi::c_char,
237 buf.len() as i32,
238 0,
239 true,
240 );
241 }
242 } else {
243 buf.truncate(n as usize);
244 }
245
246 String::from_utf8_lossy(&buf).to_string()
247 }
248
249 pub fn apply_chat_template(
250 &self,
251 tmpl: Option<&str>,
252 messages: &[ChatMessage],
253 add_ass: bool,
254 ) -> Result<String, LlamaError> {
255 let resolved_tmpl = match tmpl {
256 Some(s) => s.to_string(),
257 None => self
258 .get_chat_template(None)
259 .ok_or(LlamaError::MissingChatTemplate)?,
260 };
261
262 if resolved_tmpl.trim().is_empty() {
263 return Err(LlamaError::MissingChatTemplate);
264 }
265
266 let c_tmpl = std::ffi::CString::new(resolved_tmpl).map_err(|_| LlamaError::ModelLoad)?;
267
268 let mut c_messages = Vec::with_capacity(messages.len());
269 let mut c_strings = Vec::with_capacity(messages.len() * 2);
270
271 for msg in messages {
272 let role =
273 std::ffi::CString::new(msg.role.as_str()).map_err(|_| LlamaError::ModelLoad)?;
274 let content =
275 std::ffi::CString::new(msg.content.as_str()).map_err(|_| LlamaError::ModelLoad)?;
276
277 let msg_struct = llama_cpp_sys_v3::llama_chat_message {
278 role: role.as_ptr(),
279 content: content.as_ptr(),
280 };
281
282 c_messages.push(msg_struct);
283 c_strings.push(role);
284 c_strings.push(content);
285 }
286
287 let n_chars = unsafe {
289 (self.backend.symbols.llama_chat_apply_template)(
290 c_tmpl.as_ptr(),
291 c_messages.as_ptr(),
292 c_messages.len(),
293 add_ass,
294 std::ptr::null_mut(),
295 0,
296 )
297 };
298
299 if n_chars < 0 {
300 return Err(LlamaError::Decode(n_chars));
301 }
302
303 let mut buf = vec![0u8; n_chars as usize + 1];
304 let actual_chars = unsafe {
305 (self.backend.symbols.llama_chat_apply_template)(
306 c_tmpl.as_ptr(),
307 c_messages.as_ptr(),
308 c_messages.len(),
309 add_ass,
310 buf.as_mut_ptr() as *mut std::ffi::c_char,
311 buf.len() as i32,
312 )
313 };
314
315 if actual_chars < 0 {
316 return Err(LlamaError::Decode(actual_chars));
317 }
318
319 buf.truncate(actual_chars as usize);
320 Ok(String::from_utf8_lossy(&buf).to_string())
321 }
322
323 pub fn get_chat_template(&self, name: Option<&str>) -> Option<String> {
324 let c_name = name.map(|s| std::ffi::CString::new(s).ok()).flatten();
325 let name_ptr = c_name
326 .as_ref()
327 .map(|c| c.as_ptr())
328 .unwrap_or(std::ptr::null());
329
330 let mut buf = vec![0u8; 1024];
331 let n = unsafe {
332 (self.backend.symbols.llama_model_chat_template)(
333 self.handle,
334 name_ptr,
335 buf.as_mut_ptr() as *mut std::ffi::c_char,
336 buf.len(),
337 )
338 };
339
340 if n < 0 {
341 return None;
342 }
343
344 if n as usize >= buf.len() {
345 buf.resize(n as usize + 1, 0);
346 unsafe {
347 (self.backend.symbols.llama_model_chat_template)(
348 self.handle,
349 name_ptr,
350 buf.as_mut_ptr() as *mut std::ffi::c_char,
351 buf.len(),
352 );
353 }
354 }
355
356 buf.truncate(n as usize);
357 Some(String::from_utf8_lossy(&buf).to_string())
358 }
359}
360
361pub struct ChatMessage {
362 pub role: String,
363 pub content: String,
364}
365
366pub struct LlamaVocab {
367 pub backend: Arc<LlamaLib>,
368 pub handle: *const llama_cpp_sys_v3::llama_vocab,
369}
370
371impl LlamaVocab {
372 pub fn bos(&self) -> llama_cpp_sys_v3::llama_token {
373 unsafe { (self.backend.symbols.llama_vocab_bos)(self.handle) }
374 }
375
376 pub fn eos(&self) -> llama_cpp_sys_v3::llama_token {
377 unsafe { (self.backend.symbols.llama_vocab_eos)(self.handle) }
378 }
379
380 pub fn is_eog(&self, token: llama_cpp_sys_v3::llama_token) -> bool {
381 unsafe { (self.backend.symbols.llama_vocab_is_eog)(self.handle, token) }
382 }
383}
384
385pub struct LlamaSampler {
386 pub backend: Arc<LlamaLib>,
387 pub handle: *mut llama_cpp_sys_v3::llama_sampler,
388}
389
390impl Drop for LlamaSampler {
391 fn drop(&mut self) {
392 unsafe {
393 (self.backend.symbols.llama_sampler_free)(self.handle);
394 }
395 }
396}
397
398impl LlamaSampler {
399 pub fn new_greedy(backend: Arc<LlamaLib>) -> Self {
400 let handle = unsafe { (backend.symbols.llama_sampler_init_greedy)() };
401 Self { backend, handle }
402 }
403
404 pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> llama_cpp_sys_v3::llama_token {
405 unsafe { (self.backend.symbols.llama_sampler_sample)(self.handle, ctx.handle, idx) }
406 }
407}
408
409pub struct LlamaContext {
411 pub backend: Arc<LlamaLib>,
412 pub handle: *mut llama_cpp_sys_v3::llama_context,
413}
414
415impl Drop for LlamaContext {
416 fn drop(&mut self) {
417 unsafe {
418 (self.backend.symbols.llama_free)(self.handle);
419 }
420 }
421}
422
423impl LlamaContext {
424 pub fn new(
425 model: &LlamaModel,
426 params: llama_cpp_sys_v3::llama_context_params,
427 ) -> Result<Self, LlamaError> {
428 let handle = unsafe { (model.backend.symbols.llama_init_from_model)(model.handle, params) };
429
430 if handle.is_null() {
431 return Err(LlamaError::ContextCreate);
432 }
433
434 Ok(Self {
435 backend: model.backend.clone(),
436 handle,
437 })
438 }
439
440 pub fn default_params(model: &LlamaModel) -> llama_cpp_sys_v3::llama_context_params {
441 unsafe { (model.backend.symbols.llama_context_default_params)() }
442 }
443
444 pub fn decode(&mut self, batch: &LlamaBatch) -> Result<(), LlamaError> {
445 let res = unsafe { (self.backend.symbols.llama_decode)(self.handle, batch.handle) };
446 if res != 0 {
447 Err(LlamaError::Decode(res))
448 } else {
449 Ok(())
450 }
451 }
452
453 pub fn kv_cache_clear(&mut self) {
457 unsafe {
458 let memory = (self.backend.symbols.llama_get_memory)(self.handle);
459 (self.backend.symbols.llama_memory_clear)(memory, true);
460 }
461 }
462}
463
464pub struct LlamaBatch {
465 pub backend: Arc<LlamaLib>,
466 pub handle: llama_cpp_sys_v3::llama_batch,
467}
468
469impl Drop for LlamaBatch {
470 fn drop(&mut self) {
471 unsafe {
472 (self.backend.symbols.llama_batch_free)(self.handle);
473 }
474 }
475}
476
477impl LlamaBatch {
478 pub fn new(backend: Arc<LlamaLib>, n_tokens: i32, embd: i32, n_seq_max: i32) -> Self {
479 let handle = unsafe { (backend.symbols.llama_batch_init)(n_tokens, embd, n_seq_max) };
480 Self { backend, handle }
481 }
482
483 pub fn clear(&mut self) {
484 self.handle.n_tokens = 0;
485 }
486
487 pub fn add(
488 &mut self,
489 token: llama_cpp_sys_v3::llama_token,
490 pos: llama_cpp_sys_v3::llama_pos,
491 seq_ids: &[i32],
492 logits: bool,
493 ) {
494 let n = self.handle.n_tokens as usize;
495 unsafe {
496 *self.handle.token.add(n) = token;
497 *self.handle.pos.add(n) = pos;
498 *self.handle.n_seq_id.add(n) = seq_ids.len() as i32;
499 for (j, &seq_id) in seq_ids.iter().enumerate() {
500 *(*self.handle.seq_id.add(n)).add(j) = seq_id;
501 }
502 *self.handle.logits.add(n) = if logits { 1 } else { 0 };
503 }
504 self.handle.n_tokens += 1;
505 }
506}