1use crate::model::params::kv_overrides::KvOverrides;
4use crate::LlamaCppError;
5use std::ffi::{c_char, CStr};
6use std::fmt::{Debug, Formatter};
7use std::pin::Pin;
8use std::ptr::null;
9
10pub mod kv_overrides;
11
12#[allow(clippy::cast_possible_wrap)]
13#[allow(clippy::cast_possible_truncation)]
14const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_NONE as i8;
15#[allow(clippy::cast_possible_wrap)]
16#[allow(clippy::cast_possible_truncation)]
17const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_LAYER as i8;
18#[allow(clippy::cast_possible_wrap)]
19#[allow(clippy::cast_possible_truncation)]
20const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_ROW as i8;
21#[allow(clippy::cast_possible_wrap)]
22#[allow(clippy::cast_possible_truncation)]
23const LLAMA_SPLIT_MODE_TENSOR: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR as i8;
24
25#[repr(i8)]
27#[derive(Copy, Clone, Debug, PartialEq, Eq)]
28pub enum LlamaSplitMode {
29 None = LLAMA_SPLIT_MODE_NONE,
31 Layer = LLAMA_SPLIT_MODE_LAYER,
33 Row = LLAMA_SPLIT_MODE_ROW,
35 Tensor = LLAMA_SPLIT_MODE_TENSOR,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub struct LlamaSplitModeParseError(pub i32);
42
43impl TryFrom<i32> for LlamaSplitMode {
48 type Error = LlamaSplitModeParseError;
49
50 fn try_from(value: i32) -> Result<Self, Self::Error> {
51 let i8_value = value
52 .try_into()
53 .map_err(|_| LlamaSplitModeParseError(value))?;
54 match i8_value {
55 LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
56 LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
57 LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
58 LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
59 _ => Err(LlamaSplitModeParseError(value)),
60 }
61 }
62}
63
64impl TryFrom<u32> for LlamaSplitMode {
69 type Error = LlamaSplitModeParseError;
70
71 fn try_from(value: u32) -> Result<Self, Self::Error> {
72 let i8_value = value
73 .try_into()
74 .map_err(|_| LlamaSplitModeParseError(value.try_into().unwrap_or(i32::MAX)))?;
75 match i8_value {
76 LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
77 LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
78 LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
79 LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
80 _ => Err(LlamaSplitModeParseError(
81 value.try_into().unwrap_or(i32::MAX),
82 )),
83 }
84 }
85}
86
87impl From<LlamaSplitMode> for i32 {
89 fn from(value: LlamaSplitMode) -> Self {
90 match value {
91 LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
92 LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
93 LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
94 LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR.into(),
95 }
96 }
97}
98
99impl From<LlamaSplitMode> for u32 {
101 fn from(value: LlamaSplitMode) -> Self {
102 match value {
103 LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as u32,
104 LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as u32,
105 LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as u32,
106 LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR as u32,
107 }
108 }
109}
110
111impl Default for LlamaSplitMode {
113 fn default() -> Self {
114 LlamaSplitMode::Layer
115 }
116}
117
118pub const LLAMA_CPP_MAX_DEVICES: usize = 16;
123
124#[allow(clippy::module_name_repetitions)]
126pub struct LlamaModelParams {
127 pub(crate) params: llama_cpp_sys_2::llama_model_params,
128 kv_overrides: Vec<llama_cpp_sys_2::llama_model_kv_override>,
129 buft_overrides: Vec<llama_cpp_sys_2::llama_model_tensor_buft_override>,
130 devices: Pin<Box<[llama_cpp_sys_2::ggml_backend_dev_t; LLAMA_CPP_MAX_DEVICES]>>,
131}
132
133impl Debug for LlamaModelParams {
134 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
135 f.debug_struct("LlamaModelParams")
136 .field("n_gpu_layers", &self.params.n_gpu_layers)
137 .field("main_gpu", &self.params.main_gpu)
138 .field("vocab_only", &self.params.vocab_only)
139 .field("use_mmap", &self.params.use_mmap)
140 .field("use_mlock", &self.params.use_mlock)
141 .field("split_mode", &self.split_mode())
142 .field("devices", &self.devices)
143 .field("kv_overrides", &"vec of kv_overrides")
144 .finish()
145 }
146}
147
148impl LlamaModelParams {
149 #[must_use]
161 pub fn kv_overrides<'a>(&'a self) -> KvOverrides<'a> {
162 KvOverrides::new(self)
163 }
164
165 #[allow(clippy::missing_panics_doc)] pub fn append_kv_override(
188 mut self: Pin<&mut Self>,
189 key: &CStr,
190 value: kv_overrides::ParamOverrideValue,
191 ) {
192 let kv_override = self
193 .kv_overrides
194 .get_mut(0)
195 .expect("kv_overrides did not have a next allocated");
196
197 assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
198
199 for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
201 kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
202 }
203
204 kv_override.tag = value.tag();
205 kv_override.__bindgen_anon_1 = value.value();
206
207 self.params.kv_overrides = null();
209
210 self.kv_overrides
212 .push(llama_cpp_sys_2::llama_model_kv_override {
213 key: [0; 128],
214 tag: 0,
215 __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
216 val_i64: 0,
217 },
218 });
219
220 self.params.kv_overrides = self.kv_overrides.as_ptr();
222
223 eprintln!("saved ptr: {:?}", self.params.kv_overrides);
224 }
225}
226
227impl LlamaModelParams {
228 pub fn add_cpu_moe_override(self: Pin<&mut Self>) {
230 self.add_cpu_buft_override(c"\\.ffn_(up|down|gate)_(ch|)exps");
231 }
232
233 pub fn add_cpu_buft_override(mut self: Pin<&mut Self>, key: &CStr) {
236 let buft_override = self
237 .buft_overrides
238 .get_mut(0)
239 .expect("buft_overrides did not have a next allocated");
240
241 assert!(
242 buft_override.pattern.is_null(),
243 "last buft_override was not empty"
244 );
245
246 for &c in key.to_bytes_with_nul().iter() {
248 c_char::try_from(c).expect("invalid character in key");
249 }
250
251 buft_override.pattern = key.as_ptr();
252 buft_override.buft = unsafe { llama_cpp_sys_2::ggml_backend_cpu_buffer_type() };
253
254 self.params.tensor_buft_overrides = null();
256
257 self.buft_overrides
259 .push(llama_cpp_sys_2::llama_model_tensor_buft_override {
260 pattern: std::ptr::null(),
261 buft: std::ptr::null_mut(),
262 });
263
264 self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
266 }
267}
268
269impl LlamaModelParams {
270 #[must_use]
272 pub fn n_gpu_layers(&self) -> i32 {
273 self.params.n_gpu_layers
274 }
275
276 #[must_use]
278 pub fn main_gpu(&self) -> i32 {
279 self.params.main_gpu
280 }
281
282 #[must_use]
284 pub fn vocab_only(&self) -> bool {
285 self.params.vocab_only
286 }
287
288 #[must_use]
290 pub fn use_mmap(&self) -> bool {
291 self.params.use_mmap
292 }
293
294 #[must_use]
296 pub fn use_mlock(&self) -> bool {
297 self.params.use_mlock
298 }
299
300 pub fn split_mode(&self) -> Result<LlamaSplitMode, LlamaSplitModeParseError> {
305 LlamaSplitMode::try_from(self.params.split_mode)
306 }
307
308 #[must_use]
310 pub fn devices(&self) -> Vec<usize> {
311 let mut backend_devices = Vec::new();
312 for i in 0..unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
313 let dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(i) };
314 backend_devices.push(dev);
315 }
316 let mut devices = Vec::new();
317 for &dev in self.devices.iter() {
318 if dev.is_null() {
319 break;
320 }
321 if let Some((index, _)) = backend_devices
322 .iter()
323 .enumerate()
324 .find(|&(_i, &d)| d == dev)
325 {
326 devices.push(index);
327 }
328 }
329 devices
330 }
331
332 #[must_use]
340 pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
341 let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
344 self.params.n_gpu_layers = n_gpu_layers;
345 self
346 }
347
348 #[must_use]
352 pub fn with_main_gpu(mut self, main_gpu: i32) -> Self {
353 self.params.main_gpu = main_gpu;
354 self
355 }
356
357 #[must_use]
359 pub fn with_vocab_only(mut self, vocab_only: bool) -> Self {
360 self.params.vocab_only = vocab_only;
361 self
362 }
363
364 #[must_use]
366 pub fn with_use_mmap(mut self, use_mmap: bool) -> Self {
367 self.params.use_mmap = use_mmap;
368 self
369 }
370
371 #[must_use]
373 pub fn with_use_mlock(mut self, use_mlock: bool) -> Self {
374 self.params.use_mlock = use_mlock;
375 self
376 }
377
378 #[must_use]
380 pub fn with_split_mode(mut self, split_mode: LlamaSplitMode) -> Self {
381 self.params.split_mode = split_mode.into();
382 self
383 }
384
385 pub fn with_devices(mut self, devices: &[usize]) -> Result<Self, LlamaCppError> {
396 for dev in self.devices.iter_mut() {
397 *dev = std::ptr::null_mut();
398 }
399 let max_devices = crate::max_devices().min(LLAMA_CPP_MAX_DEVICES);
401 if devices.len() > max_devices {
402 return Err(LlamaCppError::MaxDevicesExceeded(max_devices));
403 }
404 for (i, &dev) in devices.iter().enumerate() {
405 if dev >= unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
406 return Err(LlamaCppError::BackendDeviceNotFound(dev));
407 }
408 let backend_dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(dev) };
409 self.devices[i] = backend_dev;
410 }
411 if self.devices.is_empty() {
412 self.params.devices = std::ptr::null_mut();
413 } else {
414 self.params.devices = self.devices.as_mut_ptr();
415 }
416 Ok(self)
417 }
418
419 #[must_use]
425 pub fn with_no_alloc(mut self, no_alloc: bool) -> Self {
426 self.params.no_alloc = no_alloc;
427 if no_alloc {
428 self = self.with_use_mmap(false);
429 }
430 self
431 }
432
433 #[must_use]
437 pub fn no_alloc(&self) -> bool {
438 self.params.no_alloc
439 }
440}
441
442impl Default for LlamaModelParams {
457 fn default() -> Self {
458 let default_params = unsafe { llama_cpp_sys_2::llama_model_default_params() };
459 LlamaModelParams {
460 params: default_params,
461 kv_overrides: vec![llama_cpp_sys_2::llama_model_kv_override {
463 key: [0; 128],
464 tag: 0,
465 __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
466 val_i64: 0,
467 },
468 }],
469 buft_overrides: vec![llama_cpp_sys_2::llama_model_tensor_buft_override {
470 pattern: std::ptr::null(),
471 buft: std::ptr::null_mut(),
472 }],
473 devices: Box::pin([std::ptr::null_mut(); 16]),
474 }
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::LlamaSplitMode;
481
482 #[test]
483 fn tensor_split_mode_round_trips() {
484 assert_eq!(
485 LlamaSplitMode::try_from(llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR),
486 Ok(LlamaSplitMode::Tensor)
487 );
488 assert_eq!(
489 u32::from(LlamaSplitMode::Tensor),
490 llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR as u32
491 );
492 assert_eq!(
493 i32::from(LlamaSplitMode::Tensor),
494 llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR as i32
495 );
496 }
497}