llama_cpp_2/model/
params.rs1use crate::model::params::kv_overrides::KvOverrides;
4use crate::LLamaCppError;
5use std::ffi::{c_char, CStr};
6use std::fmt::{Debug, Formatter};
7use std::pin::Pin;
8use std::ptr::null;
9
10pub mod kv_overrides;
11
12#[allow(clippy::cast_possible_wrap)]
13#[allow(clippy::cast_possible_truncation)]
14const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_NONE as i8;
15#[allow(clippy::cast_possible_wrap)]
16#[allow(clippy::cast_possible_truncation)]
17const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_LAYER as i8;
18#[allow(clippy::cast_possible_wrap)]
19#[allow(clippy::cast_possible_truncation)]
20const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_ROW as i8;
21
22#[repr(i8)]
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25pub enum LlamaSplitMode {
26 None = LLAMA_SPLIT_MODE_NONE,
28 Layer = LLAMA_SPLIT_MODE_LAYER,
30 Row = LLAMA_SPLIT_MODE_ROW,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub struct LlamaSplitModeParseError(pub i32);
37
38impl TryFrom<i32> for LlamaSplitMode {
43 type Error = LlamaSplitModeParseError;
44
45 fn try_from(value: i32) -> Result<Self, Self::Error> {
46 let i8_value = value
47 .try_into()
48 .map_err(|_| LlamaSplitModeParseError(value))?;
49 match i8_value {
50 LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
51 LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
52 LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
53 _ => Err(LlamaSplitModeParseError(value)),
54 }
55 }
56}
57
58impl TryFrom<u32> for LlamaSplitMode {
63 type Error = LlamaSplitModeParseError;
64
65 fn try_from(value: u32) -> Result<Self, Self::Error> {
66 let i8_value = value
67 .try_into()
68 .map_err(|_| LlamaSplitModeParseError(value.try_into().unwrap_or(i32::MAX)))?;
69 match i8_value {
70 LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
71 LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
72 LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
73 _ => Err(LlamaSplitModeParseError(
74 value.try_into().unwrap_or(i32::MAX),
75 )),
76 }
77 }
78}
79
80impl From<LlamaSplitMode> for i32 {
82 fn from(value: LlamaSplitMode) -> Self {
83 match value {
84 LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
85 LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
86 LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
87 }
88 }
89}
90
91impl From<LlamaSplitMode> for u32 {
93 fn from(value: LlamaSplitMode) -> Self {
94 match value {
95 LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as u32,
96 LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as u32,
97 LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as u32,
98 }
99 }
100}
101
102impl Default for LlamaSplitMode {
104 fn default() -> Self {
105 LlamaSplitMode::Layer
106 }
107}
108
109pub const LLAMA_CPP_MAX_DEVICES: usize = 16;
114
115#[allow(clippy::module_name_repetitions)]
117pub struct LlamaModelParams {
118 pub(crate) params: llama_cpp_sys_2::llama_model_params,
119 kv_overrides: Vec<llama_cpp_sys_2::llama_model_kv_override>,
120 buft_overrides: Vec<llama_cpp_sys_2::llama_model_tensor_buft_override>,
121 devices: Pin<Box<[llama_cpp_sys_2::ggml_backend_dev_t; LLAMA_CPP_MAX_DEVICES]>>,
122}
123
124impl Debug for LlamaModelParams {
125 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
126 f.debug_struct("LlamaModelParams")
127 .field("n_gpu_layers", &self.params.n_gpu_layers)
128 .field("main_gpu", &self.params.main_gpu)
129 .field("vocab_only", &self.params.vocab_only)
130 .field("use_mmap", &self.params.use_mmap)
131 .field("use_mlock", &self.params.use_mlock)
132 .field("split_mode", &self.split_mode())
133 .field("devices", &self.devices)
134 .field("kv_overrides", &"vec of kv_overrides")
135 .finish()
136 }
137}
138
139impl LlamaModelParams {
140 #[must_use]
152 pub fn kv_overrides(&self) -> KvOverrides {
153 KvOverrides::new(self)
154 }
155
156 #[allow(clippy::missing_panics_doc)] pub fn append_kv_override(
179 mut self: Pin<&mut Self>,
180 key: &CStr,
181 value: kv_overrides::ParamOverrideValue,
182 ) {
183 let kv_override = self
184 .kv_overrides
185 .get_mut(0)
186 .expect("kv_overrides did not have a next allocated");
187
188 assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
189
190 for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
192 kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
193 }
194
195 kv_override.tag = value.tag();
196 kv_override.__bindgen_anon_1 = value.value();
197
198 self.params.kv_overrides = null();
200
201 self.kv_overrides
203 .push(llama_cpp_sys_2::llama_model_kv_override {
204 key: [0; 128],
205 tag: 0,
206 __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
207 val_i64: 0,
208 },
209 });
210
211 self.params.kv_overrides = self.kv_overrides.as_ptr();
213
214 eprintln!("saved ptr: {:?}", self.params.kv_overrides);
215 }
216}
217
218impl LlamaModelParams {
219 pub fn add_cpu_moe_override(self: Pin<&mut Self>) {
221 self.add_cpu_buft_override(c"\\.ffn_(up|down|gate)_(ch|)exps");
222 }
223
224 pub fn add_cpu_buft_override(mut self: Pin<&mut Self>, key: &CStr) {
227 let buft_override = self
228 .buft_overrides
229 .get_mut(0)
230 .expect("buft_overrides did not have a next allocated");
231
232 assert!(
233 buft_override.pattern.is_null(),
234 "last buft_override was not empty"
235 );
236
237 for (_i, &c) in key.to_bytes_with_nul().iter().enumerate() {
239 c_char::try_from(c).expect("invalid character in key");
240 }
241
242 buft_override.pattern = key.as_ptr();
243 buft_override.buft = unsafe { llama_cpp_sys_2::ggml_backend_cpu_buffer_type() };
244
245 self.params.tensor_buft_overrides = null();
247
248 self.buft_overrides
250 .push(llama_cpp_sys_2::llama_model_tensor_buft_override {
251 pattern: std::ptr::null(),
252 buft: std::ptr::null_mut(),
253 });
254
255 self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
257 }
258}
259
260impl LlamaModelParams {
261 #[must_use]
263 pub fn n_gpu_layers(&self) -> i32 {
264 self.params.n_gpu_layers
265 }
266
267 #[must_use]
269 pub fn main_gpu(&self) -> i32 {
270 self.params.main_gpu
271 }
272
273 #[must_use]
275 pub fn vocab_only(&self) -> bool {
276 self.params.vocab_only
277 }
278
279 #[must_use]
281 pub fn use_mmap(&self) -> bool {
282 self.params.use_mmap
283 }
284
285 #[must_use]
287 pub fn use_mlock(&self) -> bool {
288 self.params.use_mlock
289 }
290
291 pub fn split_mode(&self) -> Result<LlamaSplitMode, LlamaSplitModeParseError> {
296 LlamaSplitMode::try_from(self.params.split_mode)
297 }
298
299 #[must_use]
301 pub fn devices(&self) -> Vec<usize> {
302 let mut backend_devices = Vec::new();
303 for i in 0..unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
304 let dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(i) };
305 backend_devices.push(dev);
306 }
307 let mut devices = Vec::new();
308 for &dev in self.devices.iter() {
309 if dev.is_null() {
310 break;
311 }
312 if let Some((index, _)) = backend_devices
313 .iter()
314 .enumerate()
315 .find(|&(_i, &d)| d == dev)
316 {
317 devices.push(index);
318 }
319 }
320 devices
321 }
322
323 #[must_use]
331 pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
332 let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
335 self.params.n_gpu_layers = n_gpu_layers;
336 self
337 }
338
339 #[must_use]
343 pub fn with_main_gpu(mut self, main_gpu: i32) -> Self {
344 self.params.main_gpu = main_gpu;
345 self
346 }
347
348 #[must_use]
350 pub fn with_vocab_only(mut self, vocab_only: bool) -> Self {
351 self.params.vocab_only = vocab_only;
352 self
353 }
354
355 #[must_use]
357 pub fn with_use_mlock(mut self, use_mlock: bool) -> Self {
358 self.params.use_mlock = use_mlock;
359 self
360 }
361
362 #[must_use]
364 pub fn with_split_mode(mut self, split_mode: LlamaSplitMode) -> Self {
365 self.params.split_mode = split_mode.into();
366 self
367 }
368
369 pub fn with_devices(mut self, devices: &[usize]) -> Result<Self, LLamaCppError> {
380 for dev in self.devices.iter_mut() {
381 *dev = std::ptr::null_mut();
382 }
383 let max_devices = crate::max_devices().min(LLAMA_CPP_MAX_DEVICES);
385 if devices.len() > max_devices {
386 return Err(LLamaCppError::MaxDevicesExceeded(max_devices));
387 }
388 for (i, &dev) in devices.iter().enumerate() {
389 if dev >= unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
390 return Err(LLamaCppError::BackendDeviceNotFound(dev));
391 }
392 let backend_dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(dev) };
393 self.devices[i] = backend_dev;
394 }
395 if self.devices.is_empty() {
396 self.params.devices = std::ptr::null_mut();
397 } else {
398 self.params.devices = self.devices.as_mut_ptr();
399 }
400 Ok(self)
401 }
402}
403
404impl Default for LlamaModelParams {
418 fn default() -> Self {
419 let default_params = unsafe { llama_cpp_sys_2::llama_model_default_params() };
420 LlamaModelParams {
421 params: default_params,
422 kv_overrides: vec![llama_cpp_sys_2::llama_model_kv_override {
424 key: [0; 128],
425 tag: 0,
426 __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
427 val_i64: 0,
428 },
429 }],
430 buft_overrides: vec![llama_cpp_sys_2::llama_model_tensor_buft_override {
431 pattern: std::ptr::null(),
432 buft: std::ptr::null_mut(),
433 }],
434 devices: Box::pin([std::ptr::null_mut(); 16]),
435 }
436 }
437}