1use crate::context::params::LlamaContextParams;
4use crate::model::params::kv_overrides::KvOverrides;
5use crate::LlamaCppError;
6use std::ffi::{c_char, CStr};
7use std::fmt::{Debug, Formatter};
8use std::pin::Pin;
9use std::ptr::null;
10
11pub mod kv_overrides;
12
13#[cfg(feature = "common")]
15#[derive(Debug, Clone)]
16pub struct FitResult {
17 pub n_ctx: u32,
19}
20
21#[cfg(feature = "common")]
23#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
24pub enum FitError {
25 #[error("could not find allocations that fit available memory")]
27 Failure,
28 #[error("hard error during parameter fitting")]
30 Error,
31}
32
33#[allow(clippy::cast_possible_wrap)]
34#[allow(clippy::cast_possible_truncation)]
35const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_NONE as i8;
36#[allow(clippy::cast_possible_wrap)]
37#[allow(clippy::cast_possible_truncation)]
38const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_LAYER as i8;
39#[allow(clippy::cast_possible_wrap)]
40#[allow(clippy::cast_possible_truncation)]
41const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_ROW as i8;
42#[allow(clippy::cast_possible_wrap)]
43#[allow(clippy::cast_possible_truncation)]
44const LLAMA_SPLIT_MODE_TENSOR: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR as i8;
45
46#[repr(i8)]
48#[derive(Copy, Clone, Debug, PartialEq, Eq)]
49pub enum LlamaSplitMode {
50 None = LLAMA_SPLIT_MODE_NONE,
52 Layer = LLAMA_SPLIT_MODE_LAYER,
54 Row = LLAMA_SPLIT_MODE_ROW,
56 Tensor = LLAMA_SPLIT_MODE_TENSOR,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub struct LlamaSplitModeParseError(pub i32);
63
64impl TryFrom<i32> for LlamaSplitMode {
69 type Error = LlamaSplitModeParseError;
70
71 fn try_from(value: i32) -> Result<Self, Self::Error> {
72 let i8_value = value
73 .try_into()
74 .map_err(|_| LlamaSplitModeParseError(value))?;
75 match i8_value {
76 LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
77 LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
78 LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
79 LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
80 _ => Err(LlamaSplitModeParseError(value)),
81 }
82 }
83}
84
85impl TryFrom<u32> for LlamaSplitMode {
90 type Error = LlamaSplitModeParseError;
91
92 fn try_from(value: u32) -> Result<Self, Self::Error> {
93 let i8_value = value
94 .try_into()
95 .map_err(|_| LlamaSplitModeParseError(value.try_into().unwrap_or(i32::MAX)))?;
96 match i8_value {
97 LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
98 LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
99 LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
100 LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
101 _ => Err(LlamaSplitModeParseError(
102 value.try_into().unwrap_or(i32::MAX),
103 )),
104 }
105 }
106}
107
108impl From<LlamaSplitMode> for i32 {
110 fn from(value: LlamaSplitMode) -> Self {
111 match value {
112 LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
113 LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
114 LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
115 LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR.into(),
116 }
117 }
118}
119
120impl From<LlamaSplitMode> for u32 {
122 fn from(value: LlamaSplitMode) -> Self {
123 match value {
124 LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as u32,
125 LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as u32,
126 LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as u32,
127 LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR as u32,
128 }
129 }
130}
131
132impl Default for LlamaSplitMode {
134 fn default() -> Self {
135 LlamaSplitMode::Layer
136 }
137}
138
139pub const LLAMA_CPP_MAX_DEVICES: usize = 16;
144
145#[allow(clippy::module_name_repetitions)]
147pub struct LlamaModelParams {
148 pub(crate) params: llama_cpp_sys_2::llama_model_params,
149 kv_overrides: Vec<llama_cpp_sys_2::llama_model_kv_override>,
150 buft_overrides: Vec<llama_cpp_sys_2::llama_model_tensor_buft_override>,
151 devices: Pin<Box<[llama_cpp_sys_2::ggml_backend_dev_t; LLAMA_CPP_MAX_DEVICES]>>,
152 tensor_split: Vec<f32>,
153}
154
155impl Debug for LlamaModelParams {
156 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
157 f.debug_struct("LlamaModelParams")
158 .field("n_gpu_layers", &self.params.n_gpu_layers)
159 .field("main_gpu", &self.params.main_gpu)
160 .field("vocab_only", &self.params.vocab_only)
161 .field("use_mmap", &self.params.use_mmap)
162 .field("use_mlock", &self.params.use_mlock)
163 .field("split_mode", &self.split_mode())
164 .field("devices", &self.devices)
165 .field("kv_overrides", &"vec of kv_overrides")
166 .finish()
167 }
168}
169
170impl LlamaModelParams {
171 #[must_use]
183 pub fn kv_overrides<'a>(&'a self) -> KvOverrides<'a> {
184 KvOverrides::new(self)
185 }
186
187 #[allow(clippy::missing_panics_doc)] pub fn append_kv_override(
210 mut self: Pin<&mut Self>,
211 key: &CStr,
212 value: kv_overrides::ParamOverrideValue,
213 ) {
214 let kv_override = self
215 .kv_overrides
216 .get_mut(0)
217 .expect("kv_overrides did not have a next allocated");
218
219 assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
220
221 for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
223 kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
224 }
225
226 kv_override.tag = value.tag();
227 kv_override.__bindgen_anon_1 = value.value();
228
229 self.params.kv_overrides = null();
231
232 self.kv_overrides
234 .push(llama_cpp_sys_2::llama_model_kv_override {
235 key: [0; 128],
236 tag: 0,
237 __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
238 val_i64: 0,
239 },
240 });
241
242 self.params.kv_overrides = self.kv_overrides.as_ptr();
244
245 eprintln!("saved ptr: {:?}", self.params.kv_overrides);
246 }
247}
248
249impl LlamaModelParams {
250 pub fn add_cpu_moe_override(self: Pin<&mut Self>) {
252 self.add_cpu_buft_override(c"\\.ffn_(up|down|gate)_(ch|)exps");
253 }
254
255 pub fn add_cpu_buft_override(mut self: Pin<&mut Self>, key: &CStr) {
258 let buft_override = self
259 .buft_overrides
260 .get_mut(0)
261 .expect("buft_overrides did not have a next allocated");
262
263 assert!(
264 buft_override.pattern.is_null(),
265 "last buft_override was not empty"
266 );
267
268 for &c in key.to_bytes_with_nul().iter() {
270 c_char::try_from(c).expect("invalid character in key");
271 }
272
273 buft_override.pattern = key.as_ptr();
274 buft_override.buft = unsafe { llama_cpp_sys_2::ggml_backend_cpu_buffer_type() };
275
276 self.params.tensor_buft_overrides = null();
278
279 self.buft_overrides
281 .push(llama_cpp_sys_2::llama_model_tensor_buft_override {
282 pattern: std::ptr::null(),
283 buft: std::ptr::null_mut(),
284 });
285
286 self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
288 }
289}
290
291#[cfg(feature = "common")]
292impl LlamaModelParams {
293 pub fn fit_params(
330 mut self: Pin<&mut Self>,
331 model_path: &CStr,
332 cparams: &mut LlamaContextParams,
333 margins: &mut [usize],
334 n_ctx_min: u32,
335 log_level: llama_cpp_sys_2::ggml_log_level,
336 ) -> Result<FitResult, FitError> {
337 let max_devices = unsafe { llama_cpp_sys_2::llama_max_devices() };
338 let max_buft = unsafe { llama_cpp_sys_2::llama_max_tensor_buft_overrides() };
339
340 self.tensor_split.clear();
342 self.tensor_split.resize(max_devices, 0.0);
343
344 self.buft_overrides.clear();
346 self.buft_overrides.resize(
347 max_buft + 1,
348 llama_cpp_sys_2::llama_model_tensor_buft_override {
349 pattern: std::ptr::null(),
350 buft: std::ptr::null_mut(),
351 },
352 );
353
354 self.params.tensor_split = null::<f32>();
356 self.params.tensor_buft_overrides = null();
357
358 let status = unsafe {
359 llama_cpp_sys_2::llama_rs_fit_params(
360 model_path.as_ptr(),
361 &raw mut self.params,
362 &raw mut cparams.context_params,
363 self.tensor_split.as_mut_ptr(),
364 self.buft_overrides.as_mut_ptr(),
365 margins.as_mut_ptr(),
366 n_ctx_min,
367 log_level,
368 )
369 };
370
371 match status {
373 0 => {}
374 1 => return Err(FitError::Failure),
375 _ => return Err(FitError::Error),
376 }
377
378 self.params.tensor_split = self.tensor_split.as_ptr();
380 self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
381
382 Ok(FitResult {
383 n_ctx: cparams.context_params.n_ctx,
384 })
385 }
386}
387
388impl LlamaModelParams {
389 #[must_use]
391 pub fn n_gpu_layers(&self) -> i32 {
392 self.params.n_gpu_layers
393 }
394
395 #[must_use]
397 pub fn main_gpu(&self) -> i32 {
398 self.params.main_gpu
399 }
400
401 #[must_use]
403 pub fn vocab_only(&self) -> bool {
404 self.params.vocab_only
405 }
406
407 #[must_use]
409 pub fn use_mmap(&self) -> bool {
410 self.params.use_mmap
411 }
412
413 #[must_use]
415 pub fn use_mlock(&self) -> bool {
416 self.params.use_mlock
417 }
418
419 pub fn split_mode(&self) -> Result<LlamaSplitMode, LlamaSplitModeParseError> {
424 LlamaSplitMode::try_from(self.params.split_mode)
425 }
426
427 #[must_use]
429 pub fn devices(&self) -> Vec<usize> {
430 let mut backend_devices = Vec::new();
431 for i in 0..unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
432 let dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(i) };
433 backend_devices.push(dev);
434 }
435 let mut devices = Vec::new();
436 for &dev in self.devices.iter() {
437 if dev.is_null() {
438 break;
439 }
440 if let Some((index, _)) = backend_devices
441 .iter()
442 .enumerate()
443 .find(|&(_i, &d)| d == dev)
444 {
445 devices.push(index);
446 }
447 }
448 devices
449 }
450
451 #[must_use]
459 pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
460 let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
463 self.params.n_gpu_layers = n_gpu_layers;
464 self
465 }
466
467 #[must_use]
471 pub fn with_main_gpu(mut self, main_gpu: i32) -> Self {
472 self.params.main_gpu = main_gpu;
473 self
474 }
475
476 #[must_use]
478 pub fn with_vocab_only(mut self, vocab_only: bool) -> Self {
479 self.params.vocab_only = vocab_only;
480 self
481 }
482
483 #[must_use]
485 pub fn with_use_mmap(mut self, use_mmap: bool) -> Self {
486 self.params.use_mmap = use_mmap;
487 self
488 }
489
490 #[must_use]
492 pub fn with_use_mlock(mut self, use_mlock: bool) -> Self {
493 self.params.use_mlock = use_mlock;
494 self
495 }
496
497 #[must_use]
499 pub fn with_split_mode(mut self, split_mode: LlamaSplitMode) -> Self {
500 self.params.split_mode = split_mode.into();
501 self
502 }
503
504 pub fn with_devices(mut self, devices: &[usize]) -> Result<Self, LlamaCppError> {
515 for dev in self.devices.iter_mut() {
516 *dev = std::ptr::null_mut();
517 }
518 let max_devices = crate::max_devices().min(LLAMA_CPP_MAX_DEVICES);
520 if devices.len() > max_devices {
521 return Err(LlamaCppError::MaxDevicesExceeded(max_devices));
522 }
523 for (i, &dev) in devices.iter().enumerate() {
524 if dev >= unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
525 return Err(LlamaCppError::BackendDeviceNotFound(dev));
526 }
527 let backend_dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(dev) };
528 self.devices[i] = backend_dev;
529 }
530 if self.devices.is_empty() {
531 self.params.devices = std::ptr::null_mut();
532 } else {
533 self.params.devices = self.devices.as_mut_ptr();
534 }
535 Ok(self)
536 }
537
538 #[must_use]
544 pub fn with_no_alloc(mut self, no_alloc: bool) -> Self {
545 self.params.no_alloc = no_alloc;
546 if no_alloc {
547 self = self.with_use_mmap(false);
548 }
549 self
550 }
551
552 #[must_use]
556 pub fn no_alloc(&self) -> bool {
557 self.params.no_alloc
558 }
559}
560
561impl Default for LlamaModelParams {
576 fn default() -> Self {
577 let default_params = unsafe { llama_cpp_sys_2::llama_model_default_params() };
578 LlamaModelParams {
579 params: default_params,
580 kv_overrides: vec![llama_cpp_sys_2::llama_model_kv_override {
582 key: [0; 128],
583 tag: 0,
584 __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
585 val_i64: 0,
586 },
587 }],
588 buft_overrides: vec![llama_cpp_sys_2::llama_model_tensor_buft_override {
589 pattern: std::ptr::null(),
590 buft: std::ptr::null_mut(),
591 }],
592 devices: Box::pin([std::ptr::null_mut(); 16]),
593 tensor_split: Vec::new(),
594 }
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::LlamaSplitMode;
601
602 #[test]
603 fn tensor_split_mode_round_trips() {
604 assert_eq!(
605 LlamaSplitMode::try_from(llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR),
606 Ok(LlamaSplitMode::Tensor)
607 );
608 assert_eq!(
609 u32::from(LlamaSplitMode::Tensor),
610 llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR as u32
611 );
612 assert_eq!(
613 i32::from(LlamaSplitMode::Tensor),
614 llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR as i32
615 );
616 }
617}