llama_cpp_bindings/model/
params.rs1use crate::LlamaCppError;
4use crate::context::params::LlamaContextParams;
5use crate::error::{FitError, ModelParamsError};
6use crate::model::params::fit_result::FitResult;
7use crate::model::params::kv_overrides::KvOverrides;
8use crate::model::split_mode::{LlamaSplitMode, LlamaSplitModeParseError};
9use std::ffi::{CStr, c_char};
10use std::fmt::{Debug, Formatter};
11use std::pin::Pin;
12use std::ptr::null;
13
14pub mod fit_result;
15pub mod kv_overrides;
16pub mod param_override_value;
17
18pub const LLAMA_CPP_MAX_DEVICES: usize = 16;
23
24pub struct LlamaModelParams {
26 pub params: llama_cpp_bindings_sys::llama_model_params,
28 kv_overrides: Vec<llama_cpp_bindings_sys::llama_model_kv_override>,
29 buft_overrides: Vec<llama_cpp_bindings_sys::llama_model_tensor_buft_override>,
30 devices: Pin<Box<[llama_cpp_bindings_sys::ggml_backend_dev_t; LLAMA_CPP_MAX_DEVICES]>>,
31 tensor_split: Vec<f32>,
32}
33
34impl Debug for LlamaModelParams {
35 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("LlamaModelParams")
37 .field("n_gpu_layers", &self.params.n_gpu_layers)
38 .field("main_gpu", &self.params.main_gpu)
39 .field("vocab_only", &self.params.vocab_only)
40 .field("use_mmap", &self.params.use_mmap)
41 .field("use_mlock", &self.params.use_mlock)
42 .field("split_mode", &self.split_mode())
43 .field("devices", &self.devices)
44 .field("kv_overrides", &"vec of kv_overrides")
45 .finish_non_exhaustive()
46 }
47}
48
49impl LlamaModelParams {
50 #[must_use]
62 pub const fn kv_overrides(&self) -> KvOverrides<'_> {
63 KvOverrides::new(self)
64 }
65
66 pub fn append_kv_override(
92 mut self: Pin<&mut Self>,
93 key: &CStr,
94 value: param_override_value::ParamOverrideValue,
95 ) -> Result<(), ModelParamsError> {
96 let kv_override = self
97 .kv_overrides
98 .get_mut(0)
99 .ok_or(ModelParamsError::NoAvailableSlot)?;
100
101 if kv_override.key[0] != 0 {
102 return Err(ModelParamsError::SlotNotEmpty);
103 }
104
105 for (i, &byte) in key.to_bytes_with_nul().iter().enumerate() {
106 kv_override.key[i] = c_char::try_from(byte).map_err(|convert_error| {
107 ModelParamsError::InvalidCharacterInKey {
108 byte,
109 reason: convert_error.to_string(),
110 }
111 })?;
112 }
113
114 kv_override.tag = value.tag();
115 kv_override.__bindgen_anon_1 = value.value();
116
117 self.push_kv_override_terminator();
118
119 Ok(())
120 }
121
122 fn push_kv_override_terminator(mut self: Pin<&mut Self>) {
127 self.params.kv_overrides = null();
128
129 self.kv_overrides
130 .push(llama_cpp_bindings_sys::llama_model_kv_override {
131 key: [0; 128],
132 tag: 0,
133 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
134 val_i64: 0,
135 },
136 });
137
138 self.params.kv_overrides = self.kv_overrides.as_ptr();
139 }
140}
141
142impl LlamaModelParams {
143 pub fn add_cpu_moe_override(self: Pin<&mut Self>) -> Result<(), ModelParamsError> {
149 self.add_cpu_buft_override(c"\\.ffn_(up|down|gate)_(ch|)exps")
150 }
151
152 pub fn add_cpu_buft_override(
159 mut self: Pin<&mut Self>,
160 key: &CStr,
161 ) -> Result<(), ModelParamsError> {
162 let buft_override = self
163 .buft_overrides
164 .get_mut(0)
165 .ok_or(ModelParamsError::NoAvailableSlot)?;
166
167 if !buft_override.pattern.is_null() {
168 return Err(ModelParamsError::SlotNotEmpty);
169 }
170
171 for &byte in key.to_bytes_with_nul() {
172 c_char::try_from(byte).map_err(|convert_error| {
173 ModelParamsError::InvalidCharacterInKey {
174 byte,
175 reason: convert_error.to_string(),
176 }
177 })?;
178 }
179
180 buft_override.pattern = key.as_ptr();
181 buft_override.buft = unsafe { llama_cpp_bindings_sys::ggml_backend_cpu_buffer_type() };
182
183 self.push_buft_override_terminator();
184
185 Ok(())
186 }
187
188 fn push_buft_override_terminator(mut self: Pin<&mut Self>) {
193 self.params.tensor_buft_overrides = null();
194
195 self.buft_overrides
196 .push(llama_cpp_bindings_sys::llama_model_tensor_buft_override {
197 pattern: null(),
198 buft: std::ptr::null_mut(),
199 });
200
201 self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
202 }
203}
204
205impl LlamaModelParams {
206 #[must_use]
208 pub const fn n_gpu_layers(&self) -> i32 {
209 self.params.n_gpu_layers
210 }
211
212 #[must_use]
214 pub const fn main_gpu(&self) -> i32 {
215 self.params.main_gpu
216 }
217
218 #[must_use]
220 pub const fn vocab_only(&self) -> bool {
221 self.params.vocab_only
222 }
223
224 #[must_use]
226 pub const fn use_mmap(&self) -> bool {
227 self.params.use_mmap
228 }
229
230 #[must_use]
232 pub const fn use_mlock(&self) -> bool {
233 self.params.use_mlock
234 }
235
236 pub fn split_mode(&self) -> Result<LlamaSplitMode, LlamaSplitModeParseError> {
241 LlamaSplitMode::try_from(self.params.split_mode)
242 }
243
244 #[must_use]
246 pub fn devices(&self) -> Vec<usize> {
247 let mut backend_devices = Vec::new();
248 for i in 0..unsafe { llama_cpp_bindings_sys::ggml_backend_dev_count() } {
249 let dev = unsafe { llama_cpp_bindings_sys::ggml_backend_dev_get(i) };
250 backend_devices.push(dev);
251 }
252 let mut devices = Vec::new();
253 for &dev in self.devices.iter() {
254 if dev.is_null() {
255 break;
256 }
257 let matched_index = backend_devices
258 .iter()
259 .enumerate()
260 .find(|&(_i, &d)| d == dev)
261 .map(|(index, _)| index);
262
263 if let Some(index) = matched_index {
264 devices.push(index);
265 }
266 }
267 devices
268 }
269
270 #[must_use]
278 pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
279 let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
280 self.params.n_gpu_layers = n_gpu_layers;
281 self
282 }
283
284 #[must_use]
288 pub const fn with_main_gpu(mut self, main_gpu: i32) -> Self {
289 self.params.main_gpu = main_gpu;
290 self
291 }
292
293 #[must_use]
295 pub const fn with_vocab_only(mut self, vocab_only: bool) -> Self {
296 self.params.vocab_only = vocab_only;
297 self
298 }
299
300 #[must_use]
310 pub const fn with_use_mmap(mut self, use_mmap: bool) -> Self {
311 self.params.use_mmap = use_mmap;
312 self
313 }
314
315 #[must_use]
317 pub const fn no_alloc(&self) -> bool {
318 self.params.no_alloc
319 }
320
321 #[must_use]
333 pub const fn with_no_alloc(mut self, no_alloc: bool) -> Self {
334 self.params.no_alloc = no_alloc;
335 if no_alloc {
336 self.params.use_mmap = false;
337 }
338 self
339 }
340
341 #[must_use]
343 pub const fn with_use_mlock(mut self, use_mlock: bool) -> Self {
344 self.params.use_mlock = use_mlock;
345 self
346 }
347
348 #[must_use]
350 pub fn with_split_mode(mut self, split_mode: LlamaSplitMode) -> Self {
351 self.params.split_mode = split_mode.into();
352 self
353 }
354
355 pub fn with_devices(mut self, devices: &[usize]) -> Result<Self, LlamaCppError> {
366 for dev in self.devices.iter_mut() {
367 *dev = std::ptr::null_mut();
368 }
369 let max_devices = crate::max_devices().min(LLAMA_CPP_MAX_DEVICES);
370 if devices.len() > max_devices {
371 return Err(LlamaCppError::MaxDevicesExceeded(max_devices));
372 }
373 for (i, &dev) in devices.iter().enumerate() {
374 if dev >= unsafe { llama_cpp_bindings_sys::ggml_backend_dev_count() } {
375 return Err(LlamaCppError::BackendDeviceNotFound(dev));
376 }
377 let backend_dev = unsafe { llama_cpp_bindings_sys::ggml_backend_dev_get(dev) };
378 self.devices[i] = backend_dev;
379 }
380 self.params.devices = self.devices.as_mut_ptr();
381
382 Ok(self)
383 }
384}
385
386impl LlamaModelParams {
387 pub fn fit_params(
425 mut self: Pin<&mut Self>,
426 model_path: &CStr,
427 context_params: &mut LlamaContextParams,
428 margins: &mut [usize],
429 n_ctx_min: u32,
430 log_level: llama_cpp_bindings_sys::ggml_log_level,
431 ) -> Result<FitResult, FitError> {
432 let max_devices = unsafe { llama_cpp_bindings_sys::llama_max_devices() };
433 let max_buft = unsafe { llama_cpp_bindings_sys::llama_max_tensor_buft_overrides() };
434
435 self.tensor_split.clear();
436 self.tensor_split.resize(max_devices, 0.0);
437
438 self.buft_overrides.clear();
439 self.buft_overrides.resize(
440 max_buft + 1,
441 llama_cpp_bindings_sys::llama_model_tensor_buft_override {
442 pattern: null(),
443 buft: std::ptr::null_mut(),
444 },
445 );
446
447 self.params.tensor_split = null::<f32>();
448 self.params.tensor_buft_overrides = null();
449
450 let status = unsafe {
451 llama_cpp_bindings_sys::llama_rs_fit_params(
452 model_path.as_ptr(),
453 &raw mut self.params,
454 &raw mut context_params.context_params,
455 self.tensor_split.as_mut_ptr(),
456 self.buft_overrides.as_mut_ptr(),
457 margins.as_mut_ptr(),
458 n_ctx_min,
459 log_level,
460 )
461 };
462
463 match status {
464 llama_cpp_bindings_sys::LLAMA_RS_FIT_STATUS_SUCCESS => {}
465 llama_cpp_bindings_sys::LLAMA_RS_FIT_STATUS_FAILURE => return Err(FitError::Failure),
466 _ => return Err(FitError::Error),
467 }
468
469 self.params.tensor_split = self.tensor_split.as_ptr();
470 self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
471
472 Ok(FitResult {
473 n_ctx: context_params.context_params.n_ctx,
474 })
475 }
476}
477
478impl Default for LlamaModelParams {
492 fn default() -> Self {
493 let default_params = unsafe { llama_cpp_bindings_sys::llama_model_default_params() };
494 Self {
495 params: default_params,
496 kv_overrides: vec![llama_cpp_bindings_sys::llama_model_kv_override {
497 key: [0; 128],
498 tag: 0,
499 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
500 val_i64: 0,
501 },
502 }],
503 buft_overrides: vec![llama_cpp_bindings_sys::llama_model_tensor_buft_override {
504 pattern: null(),
505 buft: std::ptr::null_mut(),
506 }],
507 devices: Box::pin([std::ptr::null_mut(); 16]),
508 tensor_split: Vec::new(),
509 }
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use crate::model::split_mode::LlamaSplitMode;
516
517 use super::LlamaModelParams;
518
519 #[test]
520 fn default_params_have_expected_values() {
521 let params = LlamaModelParams::default();
522
523 assert_eq!(params.n_gpu_layers(), -1);
524 assert_eq!(params.main_gpu(), 0);
525 assert!(!params.vocab_only());
526 assert!(params.use_mmap());
527 assert!(!params.use_mlock());
528 assert_eq!(params.split_mode(), Ok(LlamaSplitMode::Layer));
529 assert!(params.devices().is_empty());
530 }
531
532 #[test]
533 fn n_gpu_layers_overflow_clamps_to_max() {
534 let params = LlamaModelParams::default().with_n_gpu_layers(u32::MAX);
535
536 assert_eq!(params.n_gpu_layers(), i32::MAX);
537 }
538
539 #[test]
540 fn with_n_gpu_layers_sets_value() {
541 let params = LlamaModelParams::default().with_n_gpu_layers(32);
542
543 assert_eq!(params.n_gpu_layers(), 32);
544 }
545
546 #[test]
547 fn with_main_gpu_sets_value() {
548 let params = LlamaModelParams::default().with_main_gpu(2);
549
550 assert_eq!(params.main_gpu(), 2);
551 }
552
553 #[test]
554 fn with_split_mode_none() {
555 let params = LlamaModelParams::default().with_split_mode(LlamaSplitMode::None);
556
557 assert_eq!(params.split_mode(), Ok(LlamaSplitMode::None));
558 }
559
560 #[test]
561 fn with_split_mode_row() {
562 let params = LlamaModelParams::default().with_split_mode(LlamaSplitMode::Row);
563
564 assert_eq!(params.split_mode(), Ok(LlamaSplitMode::Row));
565 }
566
567 #[test]
568 fn with_vocab_only_enables() {
569 let params = LlamaModelParams::default().with_vocab_only(true);
570
571 assert!(params.vocab_only());
572 }
573
574 #[test]
575 fn with_vocab_only_disables() {
576 let params = LlamaModelParams::default().with_vocab_only(false);
577
578 assert!(!params.vocab_only());
579 }
580
581 #[test]
582 fn with_use_mmap_enables() {
583 let params = LlamaModelParams::default().with_use_mmap(true);
584
585 assert!(params.use_mmap());
586 }
587
588 #[test]
589 fn with_use_mmap_disables() {
590 let params = LlamaModelParams::default().with_use_mmap(false);
591
592 assert!(!params.use_mmap());
593 }
594
595 #[test]
596 fn with_no_alloc_enables() {
597 let params = LlamaModelParams::default().with_no_alloc(true);
598
599 assert!(params.no_alloc());
600 }
601
602 #[test]
603 fn with_no_alloc_disables() {
604 let params = LlamaModelParams::default().with_no_alloc(false);
605
606 assert!(!params.no_alloc());
607 }
608
609 #[test]
610 fn with_no_alloc_true_disables_mmap() {
611 let params = LlamaModelParams::default()
612 .with_use_mmap(true)
613 .with_no_alloc(true);
614
615 assert!(params.no_alloc());
616 assert!(!params.use_mmap());
617 }
618
619 #[test]
620 fn default_no_alloc_is_false() {
621 let params = LlamaModelParams::default();
622
623 assert!(!params.no_alloc());
624 }
625
626 #[test]
627 fn with_use_mlock_enables() {
628 let params = LlamaModelParams::default().with_use_mlock(true);
629
630 assert!(params.use_mlock());
631 }
632
633 #[test]
634 fn with_use_mlock_disables() {
635 let params = LlamaModelParams::default().with_use_mlock(false);
636
637 assert!(!params.use_mlock());
638 }
639
640 #[test]
641 fn debug_format_contains_field_names() {
642 let params = LlamaModelParams::default();
643 let debug_output = format!("{params:?}");
644
645 assert!(debug_output.contains("n_gpu_layers"));
646 assert!(debug_output.contains("main_gpu"));
647 assert!(debug_output.contains("vocab_only"));
648 assert!(debug_output.contains("use_mmap"));
649 assert!(debug_output.contains("use_mlock"));
650 assert!(debug_output.contains("split_mode"));
651 }
652
653 #[test]
654 fn builder_chaining_preserves_all_values() {
655 let params = LlamaModelParams::default()
656 .with_n_gpu_layers(10)
657 .with_main_gpu(1)
658 .with_split_mode(LlamaSplitMode::Row)
659 .with_vocab_only(true)
660 .with_use_mlock(true);
661
662 assert_eq!(params.n_gpu_layers(), 10);
663 assert_eq!(params.main_gpu(), 1);
664 assert_eq!(params.split_mode(), Ok(LlamaSplitMode::Row));
665 assert!(params.vocab_only());
666 assert!(params.use_mlock());
667 }
668
669 #[test]
670 fn with_devices_empty_list_succeeds() {
671 let params = LlamaModelParams::default().with_devices(&[]);
672
673 assert!(params.is_ok());
674 assert!(params.unwrap().devices().is_empty());
675 }
676
677 #[test]
678 fn with_devices_invalid_index_returns_error() {
679 let result = LlamaModelParams::default().with_devices(&[999_999]);
680
681 assert_eq!(
682 result.unwrap_err(),
683 crate::LlamaCppError::BackendDeviceNotFound(999_999)
684 );
685 }
686
687 #[test]
688 fn add_cpu_buft_override_succeeds() {
689 let mut params = std::pin::pin!(LlamaModelParams::default());
690 let result = params.as_mut().add_cpu_buft_override(c"test_pattern");
691
692 assert!(result.is_ok());
693 }
694
695 #[test]
696 fn add_cpu_buft_override_twice_fails_with_slot_not_empty() {
697 let mut params = std::pin::pin!(LlamaModelParams::default());
698 params
699 .as_mut()
700 .add_cpu_buft_override(c"first_pattern")
701 .unwrap();
702 let result = params.as_mut().add_cpu_buft_override(c"second_pattern");
703
704 assert_eq!(
705 result.unwrap_err(),
706 crate::error::ModelParamsError::SlotNotEmpty
707 );
708 }
709
710 #[test]
711 fn add_cpu_moe_override_succeeds() {
712 let mut params = std::pin::pin!(LlamaModelParams::default());
713 let result = params.as_mut().add_cpu_moe_override();
714
715 assert!(result.is_ok());
716 }
717
718 #[test]
719 fn append_kv_override_twice_fails_with_slot_not_empty() {
720 use crate::model::params::param_override_value::ParamOverrideValue;
721 use std::ffi::CString;
722
723 let mut params = std::pin::pin!(LlamaModelParams::default());
724 let key = CString::new("first_key").unwrap();
725 params
726 .as_mut()
727 .append_kv_override(&key, ParamOverrideValue::Int(1))
728 .unwrap();
729
730 let key2 = CString::new("second_key").unwrap();
731 let result = params
732 .as_mut()
733 .append_kv_override(&key2, ParamOverrideValue::Int(2));
734
735 assert_eq!(
736 result.unwrap_err(),
737 crate::error::ModelParamsError::SlotNotEmpty
738 );
739 }
740
741 #[test]
742 fn with_devices_too_many_returns_max_exceeded() {
743 let too_many: Vec<usize> = (0..17).collect();
744 let result = LlamaModelParams::default().with_devices(&too_many);
745
746 assert!(
747 result
748 .unwrap_err()
749 .to_string()
750 .contains("Max devices exceeded")
751 );
752 }
753
754 #[test]
755 fn with_devices_sets_devices_when_available() {
756 #[cfg(feature = "dynamic-backends")]
757 crate::load_backends::load_backends().unwrap();
758
759 let dev_count = unsafe { llama_cpp_bindings_sys::ggml_backend_dev_count() };
760 assert!(dev_count > 0, "Test requires at least one backend device");
761
762 let params = LlamaModelParams::default().with_devices(&[0]).unwrap();
763
764 assert_eq!(params.devices().len(), 1);
765 assert_eq!(params.devices()[0], 0);
766 }
767
768 #[test]
769 fn with_devices_invalid_index_returns_not_found() {
770 let invalid_index = usize::MAX;
771 let result = LlamaModelParams::default().with_devices(&[invalid_index]);
772
773 assert!(result.unwrap_err().to_string().contains("Backend device"));
774 }
775
776 #[test]
777 #[cfg(not(target_os = "windows"))]
778 fn append_kv_override_with_high_byte_returns_invalid_character_error() {
779 use crate::model::params::param_override_value::ParamOverrideValue;
780
781 let key_bytes: &[u8] = b"\xff\0";
782 let key = std::ffi::CStr::from_bytes_with_nul(key_bytes).unwrap();
783 let mut params = std::pin::pin!(LlamaModelParams::default());
784 let result = params
785 .as_mut()
786 .append_kv_override(key, ParamOverrideValue::Int(1));
787
788 assert!(matches!(
789 result,
790 Err(crate::error::ModelParamsError::InvalidCharacterInKey { byte: 0xff, .. })
791 ));
792 }
793
794 #[test]
795 #[cfg(not(target_os = "windows"))]
796 fn add_cpu_buft_override_with_high_byte_returns_invalid_character_error() {
797 let key_bytes: &[u8] = b"\xff\0";
798 let key = std::ffi::CStr::from_bytes_with_nul(key_bytes).unwrap();
799 let mut params = std::pin::pin!(LlamaModelParams::default());
800 let result = params.as_mut().add_cpu_buft_override(key);
801
802 assert!(matches!(
803 result,
804 Err(crate::error::ModelParamsError::InvalidCharacterInKey { byte: 0xff, .. })
805 ));
806 }
807
808 #[test]
809 #[serial_test::serial]
810 fn fit_params_invalid_model_path_returns_error() {
811 use crate::context::params::LlamaContextParams;
812 use crate::error::FitError;
813 use crate::llama_backend::LlamaBackend;
814
815 let _backend = LlamaBackend::init();
816 let mut params = std::pin::pin!(LlamaModelParams::default());
817 let mut context_params = LlamaContextParams::default();
818 let mut margins = vec![0usize; crate::max_devices()];
819
820 let bogus_path = c"/nonexistent/path/to/model.gguf";
821 let result = params.as_mut().fit_params(
822 bogus_path,
823 &mut context_params,
824 &mut margins,
825 512,
826 llama_cpp_bindings_sys::GGML_LOG_LEVEL_NONE,
827 );
828
829 assert_eq!(result, Err(FitError::Error));
830 }
831}