1use llama_cpp_sys_v3::{LlamaLib, LoadError};
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5pub mod backend;
6pub mod downloader;
7
8pub use backend::Backend;
9
10#[derive(Debug, thiserror::Error)]
11pub enum LlamaError {
12 #[error("Failed to load DLL: {0}")]
13 DllLoad(#[from] LoadError),
14 #[error("Failed to download backend: {0}")]
15 Download(#[from] downloader::DownloadError),
16 #[error("Failed to initialize backend")]
17 BackendInit,
18 #[error("Failed to load model from file")]
19 ModelLoad,
20 #[error("Failed to create context")]
21 ContextCreate,
22 #[error("Decode error with status code {0}")]
23 Decode(i32),
24 #[error("Missing or empty chat template")]
25 MissingChatTemplate,
26 #[error("Invalid string (contains internal null byte)")]
27 InvalidString,
28 #[error("Failed to apply chat template (check template syntax)")]
29 TemplateApply,
30}
31
32pub struct LoadOptions<'a> {
33 pub backend: Backend,
34 pub app_name: &'a str,
35 pub version: Option<&'a str>, pub explicit_path: Option<&'a Path>, pub cache_dir: Option<PathBuf>, }
39
40pub struct LlamaBackend {
43 pub lib: Arc<LlamaLib>,
44}
45
46impl Drop for LlamaBackend {
47 fn drop(&mut self) {
48 if Arc::strong_count(&self.lib) == 1 {
51 unsafe {
52 (self.lib.symbols.llama_backend_free)();
53 }
54 }
55 }
56}
57
58impl LlamaBackend {
59 pub fn load(options: LoadOptions<'_>) -> Result<Self, LlamaError> {
61 let dll_path = if let Some(path) = options.explicit_path {
62 path.to_path_buf()
63 } else if let Ok(env_path) = std::env::var("LLAMA_DLL_PATH") {
64 PathBuf::from(env_path)
65 } else {
66 downloader::Downloader::ensure_dll(
68 options.backend,
69 options.app_name,
70 options.version,
71 options.cache_dir,
72 )?
73 };
74
75 if let Some(parent) = dll_path.parent() {
76 if let Some(path_ext) = std::env::var_os("PATH") {
77 let mut paths = std::env::split_paths(&path_ext).collect::<Vec<_>>();
78 let parent_buf = parent.to_path_buf();
79 if !paths.contains(&parent_buf) {
80 paths.insert(0, parent_buf);
81 if let Ok(new_path) = std::env::join_paths(paths) {
82 std::env::set_var("PATH", new_path);
83 }
84 }
85 }
86 }
87
88 let lib = LlamaLib::open(&dll_path)?;
89
90 if let Some(parent) = dll_path.parent() {
91 let parent_str = parent.to_string_lossy().to_string();
92 let c_parent = std::ffi::CString::new(parent_str).unwrap();
93 unsafe {
94 (lib.symbols.ggml_backend_load_all_from_path)(c_parent.as_ptr());
95 }
96 } else {
97 unsafe {
98 (lib.symbols.ggml_backend_load_all)();
99 }
100 }
101
102 unsafe {
103 (lib.symbols.llama_backend_init)();
104 }
105
106 Ok(Self { lib: Arc::new(lib) })
107 }
108}
109
110pub struct LlamaModel {
112 pub backend: Arc<LlamaLib>,
113 pub handle: *mut llama_cpp_sys_v3::llama_model,
114}
115
116impl Drop for LlamaModel {
117 fn drop(&mut self) {
118 unsafe {
119 (self.backend.symbols.llama_model_free)(self.handle);
120 }
121 }
122}
123
124unsafe impl Send for LlamaModel {}
125unsafe impl Sync for LlamaModel {}
126
127impl LlamaModel {
128 pub fn load_from_file(
129 backend: &LlamaBackend,
130 path: &str,
131 params: llama_cpp_sys_v3::llama_model_params,
132 ) -> Result<Self, LlamaError> {
133 let c_path = std::ffi::CString::new(path).map_err(|_| LlamaError::InvalidString)?;
134 let handle =
135 unsafe { (backend.lib.symbols.llama_model_load_from_file)(c_path.as_ptr(), params) };
136
137 if handle.is_null() {
138 return Err(LlamaError::ModelLoad);
139 }
140
141 Ok(Self {
142 backend: backend.lib.clone(),
143 handle,
144 })
145 }
146
147 pub fn default_params(backend: &LlamaBackend) -> llama_cpp_sys_v3::llama_model_params {
148 unsafe { (backend.lib.symbols.llama_model_default_params)() }
149 }
150
151 pub fn get_vocab(&self) -> LlamaVocab {
152 let handle = unsafe { (self.backend.symbols.llama_model_get_vocab)(self.handle) };
153 LlamaVocab {
154 backend: self.backend.clone(),
155 handle,
156 }
157 }
158
159 pub fn tokenize(
160 &self,
161 text: &str,
162 add_special: bool,
163 parse_special: bool,
164 ) -> Result<Vec<llama_cpp_sys_v3::llama_token>, LlamaError> {
165 let vocab = self.get_vocab();
166 let c_text = std::ffi::CString::new(text).map_err(|_| LlamaError::InvalidString)?;
167
168 let n_tokens = unsafe {
170 (self.backend.symbols.llama_tokenize)(
171 vocab.handle,
172 c_text.as_ptr(),
173 text.len() as i32,
174 std::ptr::null_mut(),
175 0,
176 add_special,
177 parse_special,
178 )
179 };
180
181 if n_tokens < 0 {
182 let mut tokens = vec![0; (-n_tokens) as usize];
183 let actual_tokens = unsafe {
184 (self.backend.symbols.llama_tokenize)(
185 vocab.handle,
186 c_text.as_ptr(),
187 text.len() as i32,
188 tokens.as_mut_ptr(),
189 tokens.len() as i32,
190 add_special,
191 parse_special,
192 )
193 };
194 if actual_tokens < 0 {
195 return Err(LlamaError::Decode(actual_tokens));
196 }
197 tokens.truncate(actual_tokens as usize);
198 Ok(tokens)
199 } else {
200 let mut tokens = vec![0; n_tokens as usize];
201 let actual_tokens = unsafe {
202 (self.backend.symbols.llama_tokenize)(
203 vocab.handle,
204 c_text.as_ptr(),
205 text.len() as i32,
206 tokens.as_mut_ptr(),
207 tokens.len() as i32,
208 add_special,
209 parse_special,
210 )
211 };
212 if actual_tokens < 0 {
213 return Err(LlamaError::Decode(actual_tokens));
214 }
215 tokens.truncate(actual_tokens as usize);
216 Ok(tokens)
217 }
218 }
219
220 pub fn token_to_piece(&self, token: llama_cpp_sys_v3::llama_token) -> String {
221 let vocab = self.get_vocab();
222 let mut buf = vec![0u8; 128];
223 let n = unsafe {
224 (self.backend.symbols.llama_token_to_piece)(
225 vocab.handle,
226 token,
227 buf.as_mut_ptr() as *mut std::ffi::c_char,
228 buf.len() as i32,
229 0,
230 true,
231 )
232 };
233
234 if n < 0 {
235 buf.resize((-n) as usize, 0);
236 unsafe {
237 (self.backend.symbols.llama_token_to_piece)(
238 vocab.handle,
239 token,
240 buf.as_mut_ptr() as *mut std::ffi::c_char,
241 buf.len() as i32,
242 0,
243 true,
244 );
245 }
246 } else {
247 buf.truncate(n as usize);
248 }
249
250 String::from_utf8_lossy(&buf).to_string()
251 }
252
253 pub fn apply_chat_template(
254 &self,
255 tmpl: Option<&str>,
256 messages: &[ChatMessage],
257 add_ass: bool,
258 ) -> Result<String, LlamaError> {
259 let resolved_tmpl = match tmpl {
260 Some(s) => s.to_string(),
261 None => self
262 .get_chat_template(None)
263 .ok_or(LlamaError::MissingChatTemplate)?,
264 };
265
266 if resolved_tmpl.trim().is_empty() {
267 return Err(LlamaError::MissingChatTemplate);
268 }
269
270 let c_tmpl = std::ffi::CString::new(resolved_tmpl).map_err(|_| LlamaError::InvalidString)?;
271
272 let mut c_messages = Vec::with_capacity(messages.len());
273 let mut c_strings = Vec::with_capacity(messages.len() * 2);
274
275 for msg in messages {
276 let role =
277 std::ffi::CString::new(msg.role.as_str()).map_err(|_| LlamaError::InvalidString)?;
278 let content =
279 std::ffi::CString::new(msg.content.as_str()).map_err(|_| LlamaError::InvalidString)?;
280
281 let msg_struct = llama_cpp_sys_v3::llama_chat_message {
282 role: role.as_ptr(),
283 content: content.as_ptr(),
284 };
285
286 c_messages.push(msg_struct);
287 c_strings.push(role);
288 c_strings.push(content);
289 }
290
291 let n_chars = unsafe {
293 (self.backend.symbols.llama_chat_apply_template)(
294 c_tmpl.as_ptr(),
295 c_messages.as_ptr(),
296 c_messages.len(),
297 add_ass,
298 std::ptr::null_mut(),
299 0,
300 )
301 };
302
303 if n_chars < 0 {
304 return Err(LlamaError::Decode(n_chars));
305 }
306
307 let mut buf = vec![0u8; n_chars as usize + 1];
308 let actual_chars = unsafe {
309 (self.backend.symbols.llama_chat_apply_template)(
310 c_tmpl.as_ptr(),
311 c_messages.as_ptr(),
312 c_messages.len(),
313 add_ass,
314 buf.as_mut_ptr() as *mut std::ffi::c_char,
315 buf.len() as i32,
316 )
317 };
318
319 if actual_chars < 0 {
320 return Err(LlamaError::Decode(actual_chars));
321 }
322
323 buf.truncate(actual_chars as usize);
324 Ok(String::from_utf8_lossy(&buf).to_string())
325 }
326
327 pub fn get_chat_template(&self, name: Option<&str>) -> Option<String> {
328 let c_name = name.map(|s| std::ffi::CString::new(s).ok()).flatten();
329 let name_ptr = c_name
330 .as_ref()
331 .map(|c| c.as_ptr())
332 .unwrap_or(std::ptr::null());
333
334 let mut buf = vec![0u8; 1024];
335 let n = unsafe {
336 (self.backend.symbols.llama_model_chat_template)(
337 self.handle,
338 name_ptr,
339 buf.as_mut_ptr() as *mut std::ffi::c_char,
340 buf.len(),
341 )
342 };
343
344 if n < 0 {
345 return None;
346 }
347
348 if n as usize >= buf.len() {
349 buf.resize(n as usize + 1, 0);
350 unsafe {
351 (self.backend.symbols.llama_model_chat_template)(
352 self.handle,
353 name_ptr,
354 buf.as_mut_ptr() as *mut std::ffi::c_char,
355 buf.len(),
356 );
357 }
358 }
359
360 buf.truncate(n as usize);
361 Some(String::from_utf8_lossy(&buf).to_string())
362 }
363}
364
365pub struct ChatMessage {
366 pub role: String,
367 pub content: String,
368}
369
370pub struct LlamaVocab {
371 pub backend: Arc<LlamaLib>,
372 pub handle: *const llama_cpp_sys_v3::llama_vocab,
373}
374
375impl LlamaVocab {
376 pub fn bos(&self) -> llama_cpp_sys_v3::llama_token {
377 unsafe { (self.backend.symbols.llama_vocab_bos)(self.handle) }
378 }
379
380 pub fn eos(&self) -> llama_cpp_sys_v3::llama_token {
381 unsafe { (self.backend.symbols.llama_vocab_eos)(self.handle) }
382 }
383
384 pub fn is_eog(&self, token: llama_cpp_sys_v3::llama_token) -> bool {
385 unsafe { (self.backend.symbols.llama_vocab_is_eog)(self.handle, token) }
386 }
387}
388
389pub struct LlamaSampler {
390 pub backend: Arc<LlamaLib>,
391 pub handle: *mut llama_cpp_sys_v3::llama_sampler,
392}
393
394impl Drop for LlamaSampler {
395 fn drop(&mut self) {
396 unsafe {
397 (self.backend.symbols.llama_sampler_free)(self.handle);
398 }
399 }
400}
401
402impl LlamaSampler {
403 pub fn new_chain(backend: Arc<LlamaLib>, no_perf: bool) -> Self {
404 let params = llama_cpp_sys_v3::llama_sampler_chain_params { no_perf };
405 let handle = unsafe { (backend.symbols.llama_sampler_chain_init)(params) };
406 Self { backend, handle }
407 }
408
409 pub fn new_greedy(backend: Arc<LlamaLib>) -> Self {
410 let handle = unsafe { (backend.symbols.llama_sampler_init_greedy)() };
411 Self { backend, handle }
412 }
413
414 pub fn new_temp(backend: Arc<LlamaLib>, temp: f32) -> Self {
415 let handle = unsafe { (backend.symbols.llama_sampler_init_temp)(temp) };
416 Self { backend, handle }
417 }
418
419 pub fn new_top_k(backend: Arc<LlamaLib>, k: i32) -> Self {
420 let handle = unsafe { (backend.symbols.llama_sampler_init_top_k)(k) };
421 Self { backend, handle }
422 }
423
424 pub fn new_top_p(backend: Arc<LlamaLib>, p: f32, min_keep: usize) -> Self {
425 let handle = unsafe { (backend.symbols.llama_sampler_init_top_p)(p, min_keep) };
426 Self { backend, handle }
427 }
428
429 pub fn new_min_p(backend: Arc<LlamaLib>, p: f32, min_keep: usize) -> Self {
430 let handle = unsafe { (backend.symbols.llama_sampler_init_min_p)(p, min_keep) };
431 Self { backend, handle }
432 }
433
434 pub fn new_typical(backend: Arc<LlamaLib>, p: f32, min_keep: usize) -> Self {
435 let handle = unsafe { (backend.symbols.llama_sampler_init_typical)(p, min_keep) };
436 Self { backend, handle }
437 }
438
439 pub fn new_mirostat_v2(backend: Arc<LlamaLib>, seed: u32, tau: f32, eta: f32) -> Self {
440 let handle = unsafe { (backend.symbols.llama_sampler_init_mirostat_v2)(seed, tau, eta) };
441 Self { backend, handle }
442 }
443
444 pub fn new_penalties(
445 backend: Arc<LlamaLib>,
446 last_n: i32,
447 repeat: f32,
448 freq: f32,
449 present: f32,
450 ) -> Self {
451 let handle = unsafe {
452 (backend.symbols.llama_sampler_init_penalties)(last_n, repeat, freq, present)
453 };
454 Self { backend, handle }
455 }
456
457 pub fn new_dist(backend: Arc<LlamaLib>, seed: u32) -> Self {
458 let handle = unsafe { (backend.symbols.llama_sampler_init_dist)(seed) };
459 Self { backend, handle }
460 }
461
462 pub fn add(&mut self, other: LlamaSampler) {
463 unsafe {
464 (self.backend.symbols.llama_sampler_chain_add)(self.handle, other.handle);
465 }
466 std::mem::forget(other);
467 }
468
469 pub fn sample(&self, ctx: &LlamaContext, idx: i32) -> llama_cpp_sys_v3::llama_token {
470 unsafe { (self.backend.symbols.llama_sampler_sample)(self.handle, ctx.handle, idx) }
471 }
472
473 pub fn accept(&self, token: llama_cpp_sys_v3::llama_token) {
474 unsafe {
475 (self.backend.symbols.llama_sampler_accept)(self.handle, token);
476 }
477 }
478}
479
480pub struct LlamaContext {
482 pub backend: Arc<LlamaLib>,
483 pub handle: *mut llama_cpp_sys_v3::llama_context,
484}
485
486impl Drop for LlamaContext {
487 fn drop(&mut self) {
488 unsafe {
489 (self.backend.symbols.llama_free)(self.handle);
490 }
491 }
492}
493
494unsafe impl Send for LlamaContext {}
495unsafe impl Sync for LlamaContext {}
496
497impl LlamaContext {
498 pub fn new(
499 model: &LlamaModel,
500 params: llama_cpp_sys_v3::llama_context_params,
501 ) -> Result<Self, LlamaError> {
502 let handle = unsafe { (model.backend.symbols.llama_init_from_model)(model.handle, params) };
503
504 if handle.is_null() {
505 return Err(LlamaError::ContextCreate);
506 }
507
508 Ok(Self {
509 backend: model.backend.clone(),
510 handle,
511 })
512 }
513
514 pub fn default_params(model: &LlamaModel) -> llama_cpp_sys_v3::llama_context_params {
515 unsafe { (model.backend.symbols.llama_context_default_params)() }
516 }
517
518 pub fn decode(&mut self, batch: &LlamaBatch) -> Result<(), LlamaError> {
519 let res = unsafe { (self.backend.symbols.llama_decode)(self.handle, batch.handle) };
520 if res != 0 {
521 Err(LlamaError::Decode(res))
522 } else {
523 Ok(())
524 }
525 }
526
527 pub fn kv_cache_clear(&mut self) {
531 unsafe {
532 let memory = (self.backend.symbols.llama_get_memory)(self.handle);
533 (self.backend.symbols.llama_memory_clear)(memory, true);
534 }
535 }
536
537 pub fn kv_cache_seq_rm(
546 &mut self,
547 seq_id: llama_cpp_sys_v3::llama_seq_id,
548 p0: llama_cpp_sys_v3::llama_pos,
549 p1: llama_cpp_sys_v3::llama_pos,
550 ) -> bool {
551 unsafe {
552 let memory = (self.backend.symbols.llama_get_memory)(self.handle);
553 (self.backend.symbols.llama_memory_seq_rm)(memory, seq_id, p0, p1)
554 }
555 }
556}
557
558pub struct LlamaBatch {
559 pub backend: Arc<LlamaLib>,
560 pub handle: llama_cpp_sys_v3::llama_batch,
561}
562
563impl Drop for LlamaBatch {
564 fn drop(&mut self) {
565 unsafe {
566 (self.backend.symbols.llama_batch_free)(self.handle);
567 }
568 }
569}
570
571impl LlamaBatch {
572 pub fn new(backend: Arc<LlamaLib>, n_tokens: i32, embd: i32, n_seq_max: i32) -> Self {
573 let handle = unsafe { (backend.symbols.llama_batch_init)(n_tokens, embd, n_seq_max) };
574 Self { backend, handle }
575 }
576
577 pub fn clear(&mut self) {
578 self.handle.n_tokens = 0;
579 }
580
581 pub fn add(
582 &mut self,
583 token: llama_cpp_sys_v3::llama_token,
584 pos: llama_cpp_sys_v3::llama_pos,
585 seq_ids: &[i32],
586 logits: bool,
587 ) {
588 let n = self.handle.n_tokens as usize;
589 unsafe {
590 *self.handle.token.add(n) = token;
591 *self.handle.pos.add(n) = pos;
592 *self.handle.n_seq_id.add(n) = seq_ids.len() as i32;
593 for (j, &seq_id) in seq_ids.iter().enumerate() {
594 *(*self.handle.seq_id.add(n)).add(j) = seq_id;
595 }
596 *self.handle.logits.add(n) = if logits { 1 } else { 0 };
597 }
598 self.handle.n_tokens += 1;
599 }
600}