1use crate::context::params::LlamaContextParams;
4use crate::model::params::kv_overrides::KvOverrides;
5use crate::LlamaCppError;
6use std::ffi::{c_char, CStr};
7use std::fmt::{Debug, Formatter};
8use std::pin::Pin;
9use std::ptr::null;
10
11pub mod kv_overrides;
12
13#[derive(Debug, Clone)]
15pub struct FitResult {
16 pub n_ctx: u32,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
22pub enum FitError {
23 #[error("could not find allocations that fit available memory")]
25 Failure,
26 #[error("hard error during parameter fitting")]
28 Error,
29}
30
31#[allow(clippy::cast_possible_wrap)]
32#[allow(clippy::cast_possible_truncation)]
33const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_NONE as i8;
34#[allow(clippy::cast_possible_wrap)]
35#[allow(clippy::cast_possible_truncation)]
36const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_LAYER as i8;
37#[allow(clippy::cast_possible_wrap)]
38#[allow(clippy::cast_possible_truncation)]
39const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_ROW as i8;
40#[allow(clippy::cast_possible_wrap)]
41#[allow(clippy::cast_possible_truncation)]
42const LLAMA_SPLIT_MODE_TENSOR: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR as i8;
43
44#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaSplitMode {
48 None = LLAMA_SPLIT_MODE_NONE,
50 Layer = LLAMA_SPLIT_MODE_LAYER,
52 Row = LLAMA_SPLIT_MODE_ROW,
54 Tensor = LLAMA_SPLIT_MODE_TENSOR,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub struct LlamaSplitModeParseError(pub i32);
61
62impl TryFrom<i32> for LlamaSplitMode {
67 type Error = LlamaSplitModeParseError;
68
69 fn try_from(value: i32) -> Result<Self, Self::Error> {
70 let i8_value = value
71 .try_into()
72 .map_err(|_| LlamaSplitModeParseError(value))?;
73 match i8_value {
74 LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
75 LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
76 LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
77 LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
78 _ => Err(LlamaSplitModeParseError(value)),
79 }
80 }
81}
82
83impl TryFrom<u32> for LlamaSplitMode {
88 type Error = LlamaSplitModeParseError;
89
90 fn try_from(value: u32) -> Result<Self, Self::Error> {
91 let i8_value = value
92 .try_into()
93 .map_err(|_| LlamaSplitModeParseError(value.try_into().unwrap_or(i32::MAX)))?;
94 match i8_value {
95 LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
96 LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
97 LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
98 LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
99 _ => Err(LlamaSplitModeParseError(
100 value.try_into().unwrap_or(i32::MAX),
101 )),
102 }
103 }
104}
105
106impl From<LlamaSplitMode> for i32 {
108 fn from(value: LlamaSplitMode) -> Self {
109 match value {
110 LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
111 LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
112 LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
113 LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR.into(),
114 }
115 }
116}
117
118impl From<LlamaSplitMode> for u32 {
120 fn from(value: LlamaSplitMode) -> Self {
121 match value {
122 LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as u32,
123 LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as u32,
124 LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as u32,
125 LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR as u32,
126 }
127 }
128}
129
130impl Default for LlamaSplitMode {
132 fn default() -> Self {
133 LlamaSplitMode::Layer
134 }
135}
136
137pub const LLAMA_CPP_MAX_DEVICES: usize = 16;
142
143#[allow(clippy::module_name_repetitions)]
145pub struct LlamaModelParams {
146 pub(crate) params: llama_cpp_sys_2::llama_model_params,
147 kv_overrides: Vec<llama_cpp_sys_2::llama_model_kv_override>,
148 buft_overrides: Vec<llama_cpp_sys_2::llama_model_tensor_buft_override>,
149 devices: Pin<Box<[llama_cpp_sys_2::ggml_backend_dev_t; LLAMA_CPP_MAX_DEVICES]>>,
150 tensor_split: Vec<f32>,
151}
152
153impl Debug for LlamaModelParams {
154 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
155 f.debug_struct("LlamaModelParams")
156 .field("n_gpu_layers", &self.params.n_gpu_layers)
157 .field("main_gpu", &self.params.main_gpu)
158 .field("vocab_only", &self.params.vocab_only)
159 .field("use_mmap", &self.params.use_mmap)
160 .field("use_mlock", &self.params.use_mlock)
161 .field("split_mode", &self.split_mode())
162 .field("devices", &self.devices)
163 .field("kv_overrides", &"vec of kv_overrides")
164 .finish()
165 }
166}
167
168impl LlamaModelParams {
169 #[must_use]
181 pub fn kv_overrides<'a>(&'a self) -> KvOverrides<'a> {
182 KvOverrides::new(self)
183 }
184
185 #[allow(clippy::missing_panics_doc)] pub fn append_kv_override(
208 mut self: Pin<&mut Self>,
209 key: &CStr,
210 value: kv_overrides::ParamOverrideValue,
211 ) {
212 let kv_override = self
213 .kv_overrides
214 .get_mut(0)
215 .expect("kv_overrides did not have a next allocated");
216
217 assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
218
219 for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
221 kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
222 }
223
224 kv_override.tag = value.tag();
225 kv_override.__bindgen_anon_1 = value.value();
226
227 self.params.kv_overrides = null();
229
230 self.kv_overrides
232 .push(llama_cpp_sys_2::llama_model_kv_override {
233 key: [0; 128],
234 tag: 0,
235 __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
236 val_i64: 0,
237 },
238 });
239
240 self.params.kv_overrides = self.kv_overrides.as_ptr();
242
243 eprintln!("saved ptr: {:?}", self.params.kv_overrides);
244 }
245}
246
247impl LlamaModelParams {
248 pub fn add_cpu_moe_override(self: Pin<&mut Self>) {
250 self.add_cpu_buft_override(c"\\.ffn_(up|down|gate)_(ch|)exps");
251 }
252
253 pub fn add_cpu_buft_override(mut self: Pin<&mut Self>, key: &CStr) {
256 let buft_override = self
257 .buft_overrides
258 .get_mut(0)
259 .expect("buft_overrides did not have a next allocated");
260
261 assert!(
262 buft_override.pattern.is_null(),
263 "last buft_override was not empty"
264 );
265
266 for &c in key.to_bytes_with_nul().iter() {
268 c_char::try_from(c).expect("invalid character in key");
269 }
270
271 buft_override.pattern = key.as_ptr();
272 buft_override.buft = unsafe { llama_cpp_sys_2::ggml_backend_cpu_buffer_type() };
273
274 self.params.tensor_buft_overrides = null();
276
277 self.buft_overrides
279 .push(llama_cpp_sys_2::llama_model_tensor_buft_override {
280 pattern: std::ptr::null(),
281 buft: std::ptr::null_mut(),
282 });
283
284 self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
286 }
287}
288
289impl LlamaModelParams {
290 pub fn fit_params(
327 mut self: Pin<&mut Self>,
328 model_path: &CStr,
329 cparams: &mut LlamaContextParams,
330 margins: &mut [usize],
331 n_ctx_min: u32,
332 log_level: llama_cpp_sys_2::ggml_log_level,
333 ) -> Result<FitResult, FitError> {
334 let max_devices = unsafe { llama_cpp_sys_2::llama_max_devices() };
335 let max_buft = unsafe { llama_cpp_sys_2::llama_max_tensor_buft_overrides() };
336
337 self.tensor_split.clear();
339 self.tensor_split.resize(max_devices, 0.0);
340
341 self.buft_overrides.clear();
343 self.buft_overrides.resize(
344 max_buft + 1,
345 llama_cpp_sys_2::llama_model_tensor_buft_override {
346 pattern: std::ptr::null(),
347 buft: std::ptr::null_mut(),
348 },
349 );
350
351 self.params.tensor_split = null::<f32>();
353 self.params.tensor_buft_overrides = null();
354
355 let status = unsafe {
356 llama_cpp_sys_2::llama_params_fit(
357 model_path.as_ptr(),
358 &raw mut self.params,
359 &raw mut cparams.context_params,
360 self.tensor_split.as_mut_ptr(),
361 self.buft_overrides.as_mut_ptr(),
362 margins.as_mut_ptr(),
363 n_ctx_min,
364 log_level,
365 )
366 };
367
368 match status {
369 llama_cpp_sys_2::LLAMA_PARAMS_FIT_STATUS_SUCCESS => {}
370 llama_cpp_sys_2::LLAMA_PARAMS_FIT_STATUS_FAILURE => return Err(FitError::Failure),
371 _ => return Err(FitError::Error),
372 }
373
374 self.params.tensor_split = self.tensor_split.as_ptr();
376 self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
377
378 Ok(FitResult {
379 n_ctx: cparams.context_params.n_ctx,
380 })
381 }
382}
383
384impl LlamaModelParams {
385 #[must_use]
387 pub fn n_gpu_layers(&self) -> i32 {
388 self.params.n_gpu_layers
389 }
390
391 #[must_use]
393 pub fn main_gpu(&self) -> i32 {
394 self.params.main_gpu
395 }
396
397 #[must_use]
399 pub fn vocab_only(&self) -> bool {
400 self.params.vocab_only
401 }
402
403 #[must_use]
405 pub fn use_mmap(&self) -> bool {
406 self.params.use_mmap
407 }
408
409 #[must_use]
411 pub fn use_mlock(&self) -> bool {
412 self.params.use_mlock
413 }
414
415 pub fn split_mode(&self) -> Result<LlamaSplitMode, LlamaSplitModeParseError> {
420 LlamaSplitMode::try_from(self.params.split_mode)
421 }
422
423 #[must_use]
425 pub fn devices(&self) -> Vec<usize> {
426 let mut backend_devices = Vec::new();
427 for i in 0..unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
428 let dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(i) };
429 backend_devices.push(dev);
430 }
431 let mut devices = Vec::new();
432 for &dev in self.devices.iter() {
433 if dev.is_null() {
434 break;
435 }
436 if let Some((index, _)) = backend_devices
437 .iter()
438 .enumerate()
439 .find(|&(_i, &d)| d == dev)
440 {
441 devices.push(index);
442 }
443 }
444 devices
445 }
446
447 #[must_use]
455 pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
456 let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
459 self.params.n_gpu_layers = n_gpu_layers;
460 self
461 }
462
463 #[must_use]
467 pub fn with_main_gpu(mut self, main_gpu: i32) -> Self {
468 self.params.main_gpu = main_gpu;
469 self
470 }
471
472 #[must_use]
474 pub fn with_vocab_only(mut self, vocab_only: bool) -> Self {
475 self.params.vocab_only = vocab_only;
476 self
477 }
478
479 #[must_use]
481 pub fn with_use_mmap(mut self, use_mmap: bool) -> Self {
482 self.params.use_mmap = use_mmap;
483 self
484 }
485
486 #[must_use]
488 pub fn with_use_mlock(mut self, use_mlock: bool) -> Self {
489 self.params.use_mlock = use_mlock;
490 self
491 }
492
493 #[must_use]
495 pub fn with_split_mode(mut self, split_mode: LlamaSplitMode) -> Self {
496 self.params.split_mode = split_mode.into();
497 self
498 }
499
500 pub fn with_devices(mut self, devices: &[usize]) -> Result<Self, LlamaCppError> {
511 for dev in self.devices.iter_mut() {
512 *dev = std::ptr::null_mut();
513 }
514 let max_devices = crate::max_devices().min(LLAMA_CPP_MAX_DEVICES);
516 if devices.len() > max_devices {
517 return Err(LlamaCppError::MaxDevicesExceeded(max_devices));
518 }
519 for (i, &dev) in devices.iter().enumerate() {
520 if dev >= unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
521 return Err(LlamaCppError::BackendDeviceNotFound(dev));
522 }
523 let backend_dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(dev) };
524 self.devices[i] = backend_dev;
525 }
526 if self.devices.is_empty() {
527 self.params.devices = std::ptr::null_mut();
528 } else {
529 self.params.devices = self.devices.as_mut_ptr();
530 }
531 Ok(self)
532 }
533
534 #[must_use]
540 pub fn with_no_alloc(mut self, no_alloc: bool) -> Self {
541 self.params.no_alloc = no_alloc;
542 if no_alloc {
543 self = self.with_use_mmap(false);
544 }
545 self
546 }
547
548 #[must_use]
552 pub fn no_alloc(&self) -> bool {
553 self.params.no_alloc
554 }
555}
556
557impl Default for LlamaModelParams {
572 fn default() -> Self {
573 let default_params = unsafe { llama_cpp_sys_2::llama_model_default_params() };
574 LlamaModelParams {
575 params: default_params,
576 kv_overrides: vec![llama_cpp_sys_2::llama_model_kv_override {
578 key: [0; 128],
579 tag: 0,
580 __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
581 val_i64: 0,
582 },
583 }],
584 buft_overrides: vec![llama_cpp_sys_2::llama_model_tensor_buft_override {
585 pattern: std::ptr::null(),
586 buft: std::ptr::null_mut(),
587 }],
588 devices: Box::pin([std::ptr::null_mut(); 16]),
589 tensor_split: Vec::new(),
590 }
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::LlamaSplitMode;
597
598 #[test]
599 fn tensor_split_mode_round_trips() {
600 assert_eq!(
601 LlamaSplitMode::try_from(llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR),
602 Ok(LlamaSplitMode::Tensor)
603 );
604 assert_eq!(
605 u32::from(LlamaSplitMode::Tensor),
606 llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR as u32
607 );
608 assert_eq!(
609 i32::from(LlamaSplitMode::Tensor),
610 llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR as i32
611 );
612 }
613}