llama_cpp_bindings/model/
params.rs1use crate::LlamaCppError;
4use crate::model::params::kv_overrides::KvOverrides;
5use std::ffi::{CStr, c_char};
6use std::fmt::{Debug, Formatter};
7use std::pin::Pin;
8use std::ptr::null;
9
10pub mod kv_overrides;
11
12const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_NONE as i8;
13const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_LAYER as i8;
14const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_ROW as i8;
15
16#[repr(i8)]
18#[derive(Copy, Clone, Debug, PartialEq, Eq)]
19pub enum LlamaSplitMode {
20 None = LLAMA_SPLIT_MODE_NONE,
22 Layer = LLAMA_SPLIT_MODE_LAYER,
24 Row = LLAMA_SPLIT_MODE_ROW,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct LlamaSplitModeParseError(pub i32);
31
32impl TryFrom<i32> for LlamaSplitMode {
37 type Error = LlamaSplitModeParseError;
38
39 fn try_from(value: i32) -> Result<Self, Self::Error> {
40 let i8_value = value
41 .try_into()
42 .map_err(|_| LlamaSplitModeParseError(value))?;
43 match i8_value {
44 LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
45 LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
46 LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
47 _ => Err(LlamaSplitModeParseError(value)),
48 }
49 }
50}
51
52impl TryFrom<u32> for LlamaSplitMode {
57 type Error = LlamaSplitModeParseError;
58
59 fn try_from(value: u32) -> Result<Self, Self::Error> {
60 let i8_value = value
61 .try_into()
62 .map_err(|_| LlamaSplitModeParseError(value.try_into().unwrap_or(i32::MAX)))?;
63 match i8_value {
64 LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
65 LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
66 LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
67 _ => Err(LlamaSplitModeParseError(
68 value.try_into().unwrap_or(i32::MAX),
69 )),
70 }
71 }
72}
73
74impl From<LlamaSplitMode> for i32 {
76 fn from(value: LlamaSplitMode) -> Self {
77 match value {
78 LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
79 LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
80 LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
81 }
82 }
83}
84
85impl From<LlamaSplitMode> for u32 {
87 fn from(value: LlamaSplitMode) -> Self {
88 match value {
89 LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as u32,
90 LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as u32,
91 LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as u32,
92 }
93 }
94}
95
96impl Default for LlamaSplitMode {
98 fn default() -> Self {
99 LlamaSplitMode::Layer
100 }
101}
102
103pub const LLAMA_CPP_MAX_DEVICES: usize = 16;
108
109pub struct LlamaModelParams {
111 pub params: llama_cpp_bindings_sys::llama_model_params,
113 kv_overrides: Vec<llama_cpp_bindings_sys::llama_model_kv_override>,
114 buft_overrides: Vec<llama_cpp_bindings_sys::llama_model_tensor_buft_override>,
115 devices: Pin<Box<[llama_cpp_bindings_sys::ggml_backend_dev_t; LLAMA_CPP_MAX_DEVICES]>>,
116}
117
118impl Debug for LlamaModelParams {
119 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
120 f.debug_struct("LlamaModelParams")
121 .field("n_gpu_layers", &self.params.n_gpu_layers)
122 .field("main_gpu", &self.params.main_gpu)
123 .field("vocab_only", &self.params.vocab_only)
124 .field("use_mmap", &self.params.use_mmap)
125 .field("use_mlock", &self.params.use_mlock)
126 .field("split_mode", &self.split_mode())
127 .field("devices", &self.devices)
128 .field("kv_overrides", &"vec of kv_overrides")
129 .finish()
130 }
131}
132
133impl LlamaModelParams {
134 #[must_use]
146 pub fn kv_overrides(&self) -> KvOverrides<'_> {
147 KvOverrides::new(self)
148 }
149
150 pub fn append_kv_override(
172 mut self: Pin<&mut Self>,
173 key: &CStr,
174 value: kv_overrides::ParamOverrideValue,
175 ) {
176 let kv_override = self
177 .kv_overrides
178 .get_mut(0)
179 .expect("kv_overrides did not have a next allocated");
180
181 assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
182
183 for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
185 kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
186 }
187
188 kv_override.tag = value.tag();
189 kv_override.__bindgen_anon_1 = value.value();
190
191 self.params.kv_overrides = null();
193
194 self.kv_overrides
196 .push(llama_cpp_bindings_sys::llama_model_kv_override {
197 key: [0; 128],
198 tag: 0,
199 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
200 val_i64: 0,
201 },
202 });
203
204 self.params.kv_overrides = self.kv_overrides.as_ptr();
206
207 eprintln!("saved ptr: {:?}", self.params.kv_overrides);
208 }
209}
210
211impl LlamaModelParams {
212 pub fn add_cpu_moe_override(self: Pin<&mut Self>) {
214 self.add_cpu_buft_override(c"\\.ffn_(up|down|gate)_(ch|)exps");
215 }
216
217 pub fn add_cpu_buft_override(mut self: Pin<&mut Self>, key: &CStr) {
223 let buft_override = self
224 .buft_overrides
225 .get_mut(0)
226 .expect("buft_overrides did not have a next allocated");
227
228 assert!(
229 buft_override.pattern.is_null(),
230 "last buft_override was not empty"
231 );
232
233 for &c in key.to_bytes_with_nul() {
235 c_char::try_from(c).expect("invalid character in key");
236 }
237
238 buft_override.pattern = key.as_ptr();
239 buft_override.buft = unsafe { llama_cpp_bindings_sys::ggml_backend_cpu_buffer_type() };
240
241 self.params.tensor_buft_overrides = null();
243
244 self.buft_overrides
246 .push(llama_cpp_bindings_sys::llama_model_tensor_buft_override {
247 pattern: std::ptr::null(),
248 buft: std::ptr::null_mut(),
249 });
250
251 self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
253 }
254}
255
256impl LlamaModelParams {
257 #[must_use]
259 pub fn n_gpu_layers(&self) -> i32 {
260 self.params.n_gpu_layers
261 }
262
263 #[must_use]
265 pub fn main_gpu(&self) -> i32 {
266 self.params.main_gpu
267 }
268
269 #[must_use]
271 pub fn vocab_only(&self) -> bool {
272 self.params.vocab_only
273 }
274
275 #[must_use]
277 pub fn use_mmap(&self) -> bool {
278 self.params.use_mmap
279 }
280
281 #[must_use]
283 pub fn use_mlock(&self) -> bool {
284 self.params.use_mlock
285 }
286
287 pub fn split_mode(&self) -> Result<LlamaSplitMode, LlamaSplitModeParseError> {
292 LlamaSplitMode::try_from(self.params.split_mode)
293 }
294
295 #[must_use]
297 pub fn devices(&self) -> Vec<usize> {
298 let mut backend_devices = Vec::new();
299 for i in 0..unsafe { llama_cpp_bindings_sys::ggml_backend_dev_count() } {
300 let dev = unsafe { llama_cpp_bindings_sys::ggml_backend_dev_get(i) };
301 backend_devices.push(dev);
302 }
303 let mut devices = Vec::new();
304 for &dev in self.devices.iter() {
305 if dev.is_null() {
306 break;
307 }
308 if let Some((index, _)) = backend_devices
309 .iter()
310 .enumerate()
311 .find(|&(_i, &d)| d == dev)
312 {
313 devices.push(index);
314 }
315 }
316 devices
317 }
318
319 #[must_use]
327 pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
328 let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
331 self.params.n_gpu_layers = n_gpu_layers;
332 self
333 }
334
335 #[must_use]
339 pub fn with_main_gpu(mut self, main_gpu: i32) -> Self {
340 self.params.main_gpu = main_gpu;
341 self
342 }
343
344 #[must_use]
346 pub fn with_vocab_only(mut self, vocab_only: bool) -> Self {
347 self.params.vocab_only = vocab_only;
348 self
349 }
350
351 #[must_use]
353 pub fn with_use_mlock(mut self, use_mlock: bool) -> Self {
354 self.params.use_mlock = use_mlock;
355 self
356 }
357
358 #[must_use]
360 pub fn with_split_mode(mut self, split_mode: LlamaSplitMode) -> Self {
361 self.params.split_mode = split_mode.into();
362 self
363 }
364
365 pub fn with_devices(mut self, devices: &[usize]) -> Result<Self, LlamaCppError> {
376 for dev in self.devices.iter_mut() {
377 *dev = std::ptr::null_mut();
378 }
379 let max_devices = crate::max_devices().min(LLAMA_CPP_MAX_DEVICES);
381 if devices.len() > max_devices {
382 return Err(LlamaCppError::MaxDevicesExceeded(max_devices));
383 }
384 for (i, &dev) in devices.iter().enumerate() {
385 if dev >= unsafe { llama_cpp_bindings_sys::ggml_backend_dev_count() } {
386 return Err(LlamaCppError::BackendDeviceNotFound(dev));
387 }
388 let backend_dev = unsafe { llama_cpp_bindings_sys::ggml_backend_dev_get(dev) };
389 self.devices[i] = backend_dev;
390 }
391 if self.devices.is_empty() {
392 self.params.devices = std::ptr::null_mut();
393 } else {
394 self.params.devices = self.devices.as_mut_ptr();
395 }
396 Ok(self)
397 }
398}
399
400impl Default for LlamaModelParams {
414 fn default() -> Self {
415 let default_params = unsafe { llama_cpp_bindings_sys::llama_model_default_params() };
416 LlamaModelParams {
417 params: default_params,
418 kv_overrides: vec![llama_cpp_bindings_sys::llama_model_kv_override {
420 key: [0; 128],
421 tag: 0,
422 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
423 val_i64: 0,
424 },
425 }],
426 buft_overrides: vec![llama_cpp_bindings_sys::llama_model_tensor_buft_override {
427 pattern: std::ptr::null(),
428 buft: std::ptr::null_mut(),
429 }],
430 devices: Box::pin([std::ptr::null_mut(); 16]),
431 }
432 }
433}