llama_cpp_2/model/
params.rs

1//! A safe wrapper around `llama_model_params`.
2
3use crate::model::params::kv_overrides::KvOverrides;
4use crate::LLamaCppError;
5use std::ffi::{c_char, CStr};
6use std::fmt::{Debug, Formatter};
7use std::pin::Pin;
8use std::ptr::null;
9
10pub mod kv_overrides;
11
12#[allow(clippy::cast_possible_wrap)]
13#[allow(clippy::cast_possible_truncation)]
14const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_NONE as i8;
15#[allow(clippy::cast_possible_wrap)]
16#[allow(clippy::cast_possible_truncation)]
17const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_LAYER as i8;
18#[allow(clippy::cast_possible_wrap)]
19#[allow(clippy::cast_possible_truncation)]
20const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_ROW as i8;
21
22/// A rusty wrapper around `llama_split_mode`.
23#[repr(i8)]
24#[derive(Copy, Clone, Debug, PartialEq, Eq)]
25pub enum LlamaSplitMode {
26    /// Single GPU
27    None = LLAMA_SPLIT_MODE_NONE,
28    /// Split layers and KV across GPUs
29    Layer = LLAMA_SPLIT_MODE_LAYER,
30    /// Split layers and KV across GPUs, use tensor parallelism if supported
31    Row = LLAMA_SPLIT_MODE_ROW,
32}
33
34/// An error that occurs when unknown split mode is encountered.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub struct LlamaSplitModeParseError(pub i32);
37
38/// Create a `LlamaSplitMode` from a `i32`.
39///
40/// # Errors
41/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
42impl TryFrom<i32> for LlamaSplitMode {
43    type Error = LlamaSplitModeParseError;
44
45    fn try_from(value: i32) -> Result<Self, Self::Error> {
46        let i8_value = value
47            .try_into()
48            .map_err(|_| LlamaSplitModeParseError(value))?;
49        match i8_value {
50            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
51            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
52            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
53            _ => Err(LlamaSplitModeParseError(value)),
54        }
55    }
56}
57
58/// Create a `LlamaSplitMode` from a `u32`.
59///
60/// # Errors
61/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
62impl TryFrom<u32> for LlamaSplitMode {
63    type Error = LlamaSplitModeParseError;
64
65    fn try_from(value: u32) -> Result<Self, Self::Error> {
66        let i8_value = value
67            .try_into()
68            .map_err(|_| LlamaSplitModeParseError(value.try_into().unwrap_or(i32::MAX)))?;
69        match i8_value {
70            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
71            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
72            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
73            _ => Err(LlamaSplitModeParseError(
74                value.try_into().unwrap_or(i32::MAX),
75            )),
76        }
77    }
78}
79
80/// Create a `i32` from a `LlamaSplitMode`.
81impl From<LlamaSplitMode> for i32 {
82    fn from(value: LlamaSplitMode) -> Self {
83        match value {
84            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
85            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
86            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
87        }
88    }
89}
90
91/// Create a `u32` from a `LlamaSplitMode`.
92impl From<LlamaSplitMode> for u32 {
93    fn from(value: LlamaSplitMode) -> Self {
94        match value {
95            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as u32,
96            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as u32,
97            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as u32,
98        }
99    }
100}
101
102/// The default split mode is `Layer` in llama.cpp.
103impl Default for LlamaSplitMode {
104    fn default() -> Self {
105        LlamaSplitMode::Layer
106    }
107}
108
109/// The maximum number of devices supported.
110///
111/// The real maximum number of devices is the lesser one of this value and the value returned by
112/// `llama_cpp_2::max_devices()`.
113pub const LLAMA_CPP_MAX_DEVICES: usize = 16;
114
115/// A safe wrapper around `llama_model_params`.
116#[allow(clippy::module_name_repetitions)]
117pub struct LlamaModelParams {
118    pub(crate) params: llama_cpp_sys_2::llama_model_params,
119    kv_overrides: Vec<llama_cpp_sys_2::llama_model_kv_override>,
120    buft_overrides: Vec<llama_cpp_sys_2::llama_model_tensor_buft_override>,
121    devices: Pin<Box<[llama_cpp_sys_2::ggml_backend_dev_t; LLAMA_CPP_MAX_DEVICES]>>,
122}
123
124impl Debug for LlamaModelParams {
125    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
126        f.debug_struct("LlamaModelParams")
127            .field("n_gpu_layers", &self.params.n_gpu_layers)
128            .field("main_gpu", &self.params.main_gpu)
129            .field("vocab_only", &self.params.vocab_only)
130            .field("use_mmap", &self.params.use_mmap)
131            .field("use_mlock", &self.params.use_mlock)
132            .field("split_mode", &self.split_mode())
133            .field("devices", &self.devices)
134            .field("kv_overrides", &"vec of kv_overrides")
135            .finish()
136    }
137}
138
139impl LlamaModelParams {
140    /// See [`KvOverrides`]
141    ///
142    /// # Examples
143    ///
144    /// ```rust
145    /// # use llama_cpp_2::model::params::LlamaModelParams;
146    /// let params = Box::pin(LlamaModelParams::default());
147    /// let kv_overrides = params.kv_overrides();
148    /// let count = kv_overrides.into_iter().count();
149    /// assert_eq!(count, 0);
150    /// ```
151    #[must_use]
152    pub fn kv_overrides(&self) -> KvOverrides {
153        KvOverrides::new(self)
154    }
155
156    /// Appends a key-value override to the model parameters. It must be pinned as this creates a self-referential struct.
157    ///
158    /// # Examples
159    ///
160    /// ```rust
161    /// # use std::ffi::{CStr, CString};
162    /// use std::pin::pin;
163    /// # use llama_cpp_2::model::params::LlamaModelParams;
164    /// # use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
165    /// let mut params = pin!(LlamaModelParams::default());
166    /// let key = CString::new("key").expect("CString::new failed");
167    /// params.as_mut().append_kv_override(&key, ParamOverrideValue::Int(50));
168    ///
169    /// let kv_overrides = params.kv_overrides().into_iter().collect::<Vec<_>>();
170    /// assert_eq!(kv_overrides.len(), 1);
171    ///
172    /// let (k, v) = &kv_overrides[0];
173    /// assert_eq!(v, &ParamOverrideValue::Int(50));
174    ///
175    /// assert_eq!(k.to_bytes(), b"key", "expected key to be 'key', was {:?}", k);
176    /// ```
177    #[allow(clippy::missing_panics_doc)] // panics are just to enforce internal invariants, not user errors
178    pub fn append_kv_override(
179        mut self: Pin<&mut Self>,
180        key: &CStr,
181        value: kv_overrides::ParamOverrideValue,
182    ) {
183        let kv_override = self
184            .kv_overrides
185            .get_mut(0)
186            .expect("kv_overrides did not have a next allocated");
187
188        assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
189
190        // There should be some way to do this without iterating over everything.
191        for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
192            kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
193        }
194
195        kv_override.tag = value.tag();
196        kv_override.__bindgen_anon_1 = value.value();
197
198        // set to null pointer for panic safety (as push may move the vector, invalidating the pointer)
199        self.params.kv_overrides = null();
200
201        // push the next one to ensure we maintain the iterator invariant of ending with a 0
202        self.kv_overrides
203            .push(llama_cpp_sys_2::llama_model_kv_override {
204                key: [0; 128],
205                tag: 0,
206                __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
207                    val_i64: 0,
208                },
209            });
210
211        // set the pointer to the (potentially) new vector
212        self.params.kv_overrides = self.kv_overrides.as_ptr();
213
214        eprintln!("saved ptr: {:?}", self.params.kv_overrides);
215    }
216}
217
218impl LlamaModelParams {
219    /// Adds buffer type overides to move all mixture-of-experts layers to CPU.
220    pub fn add_cpu_moe_override(self: Pin<&mut Self>) {
221        self.add_cpu_buft_override(c"\\.ffn_(up|down|gate)_(ch|)exps");
222    }
223
224    /// Appends a buffer type override to the model parameters, to move layers matching pattern to CPU.
225    /// It must be pinned as this creates a self-referential struct.
226    pub fn add_cpu_buft_override(mut self: Pin<&mut Self>, key: &CStr) {
227        let buft_override = self
228            .buft_overrides
229            .get_mut(0)
230            .expect("buft_overrides did not have a next allocated");
231
232        assert!(
233            buft_override.pattern.is_null(),
234            "last buft_override was not empty"
235        );
236
237        // There should be some way to do this without iterating over everything.
238        for (_i, &c) in key.to_bytes_with_nul().iter().enumerate() {
239            c_char::try_from(c).expect("invalid character in key");
240        }
241
242        buft_override.pattern = key.as_ptr();
243        buft_override.buft = unsafe { llama_cpp_sys_2::ggml_backend_cpu_buffer_type() };
244
245        // set to null pointer for panic safety (as push may move the vector, invalidating the pointer)
246        self.params.tensor_buft_overrides = null();
247
248        // push the next one to ensure we maintain the iterator invariant of ending with a 0
249        self.buft_overrides
250            .push(llama_cpp_sys_2::llama_model_tensor_buft_override {
251                pattern: std::ptr::null(),
252                buft: std::ptr::null_mut(),
253            });
254
255        // set the pointer to the (potentially) new vector
256        self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
257    }
258}
259
260impl LlamaModelParams {
261    /// Get the number of layers to offload to the GPU.
262    #[must_use]
263    pub fn n_gpu_layers(&self) -> i32 {
264        self.params.n_gpu_layers
265    }
266
267    /// The GPU that is used for scratch and small tensors
268    #[must_use]
269    pub fn main_gpu(&self) -> i32 {
270        self.params.main_gpu
271    }
272
273    /// only load the vocabulary, no weights
274    #[must_use]
275    pub fn vocab_only(&self) -> bool {
276        self.params.vocab_only
277    }
278
279    /// use mmap if possible
280    #[must_use]
281    pub fn use_mmap(&self) -> bool {
282        self.params.use_mmap
283    }
284
285    /// force system to keep model in RAM
286    #[must_use]
287    pub fn use_mlock(&self) -> bool {
288        self.params.use_mlock
289    }
290
291    /// get the split mode
292    ///
293    /// # Errors
294    /// Returns `LlamaSplitModeParseError` if the unknown split mode is encountered.
295    pub fn split_mode(&self) -> Result<LlamaSplitMode, LlamaSplitModeParseError> {
296        LlamaSplitMode::try_from(self.params.split_mode)
297    }
298
299    /// get the devices
300    #[must_use]
301    pub fn devices(&self) -> Vec<usize> {
302        let mut backend_devices = Vec::new();
303        for i in 0..unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
304            let dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(i) };
305            backend_devices.push(dev);
306        }
307        let mut devices = Vec::new();
308        for &dev in self.devices.iter() {
309            if dev.is_null() {
310                break;
311            }
312            if let Some((index, _)) = backend_devices
313                .iter()
314                .enumerate()
315                .find(|&(_i, &d)| d == dev)
316            {
317                devices.push(index);
318            }
319        }
320        devices
321    }
322
323    /// sets the number of gpu layers to offload to the GPU.
324    /// ```
325    /// # use llama_cpp_2::model::params::LlamaModelParams;
326    /// let params = LlamaModelParams::default();
327    /// let params = params.with_n_gpu_layers(1);
328    /// assert_eq!(params.n_gpu_layers(), 1);
329    /// ```
330    #[must_use]
331    pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
332        // The only way this conversion can fail is if u32 overflows the i32 - in which case we set
333        // to MAX
334        let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
335        self.params.n_gpu_layers = n_gpu_layers;
336        self
337    }
338
339    /// sets the main GPU
340    ///
341    /// To enable this option, you must set `split_mode` to `LlamaSplitMode::None` to enable single GPU mode.
342    #[must_use]
343    pub fn with_main_gpu(mut self, main_gpu: i32) -> Self {
344        self.params.main_gpu = main_gpu;
345        self
346    }
347
348    /// sets `vocab_only`
349    #[must_use]
350    pub fn with_vocab_only(mut self, vocab_only: bool) -> Self {
351        self.params.vocab_only = vocab_only;
352        self
353    }
354
355    /// sets `use_mlock`
356    #[must_use]
357    pub fn with_use_mlock(mut self, use_mlock: bool) -> Self {
358        self.params.use_mlock = use_mlock;
359        self
360    }
361
362    /// sets `split_mode`
363    #[must_use]
364    pub fn with_split_mode(mut self, split_mode: LlamaSplitMode) -> Self {
365        self.params.split_mode = split_mode.into();
366        self
367    }
368
369    /// sets `devices`
370    ///
371    /// The devices are specified as indices that correspond to the ggml backend device indices.
372    ///
373    /// The maximum number of devices is 16.
374    ///
375    /// You don't need to specify CPU or ACCEL devices.
376    ///
377    /// # Errors
378    /// Returns `LLamaCppError::BackendDeviceNotFound` if any device index is invalid.
379    pub fn with_devices(mut self, devices: &[usize]) -> Result<Self, LLamaCppError> {
380        for dev in self.devices.iter_mut() {
381            *dev = std::ptr::null_mut();
382        }
383        // Check device count
384        let max_devices = crate::max_devices().min(LLAMA_CPP_MAX_DEVICES);
385        if devices.len() > max_devices {
386            return Err(LLamaCppError::MaxDevicesExceeded(max_devices));
387        }
388        for (i, &dev) in devices.iter().enumerate() {
389            if dev >= unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
390                return Err(LLamaCppError::BackendDeviceNotFound(dev));
391            }
392            let backend_dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(dev) };
393            self.devices[i] = backend_dev;
394        }
395        if self.devices.is_empty() {
396            self.params.devices = std::ptr::null_mut();
397        } else {
398            self.params.devices = self.devices.as_mut_ptr();
399        }
400        Ok(self)
401    }
402}
403
404/// Default parameters for `LlamaModel`. (as defined in llama.cpp by `llama_model_default_params`)
405/// ```
406/// # use llama_cpp_2::model::params::LlamaModelParams;
407/// use llama_cpp_2::model::params::LlamaSplitMode;
408/// let params = LlamaModelParams::default();
409/// assert_eq!(params.n_gpu_layers(), 999, "n_gpu_layers should be 999");
410/// assert_eq!(params.main_gpu(), 0, "main_gpu should be 0");
411/// assert_eq!(params.vocab_only(), false, "vocab_only should be false");
412/// assert_eq!(params.use_mmap(), true, "use_mmap should be true");
413/// assert_eq!(params.use_mlock(), false, "use_mlock should be false");
414/// assert_eq!(params.split_mode(), Ok(LlamaSplitMode::Layer), "split_mode should be LAYER");
415/// assert_eq!(params.devices().len(), 0, "devices should be empty");
416/// ```
417impl Default for LlamaModelParams {
418    fn default() -> Self {
419        let default_params = unsafe { llama_cpp_sys_2::llama_model_default_params() };
420        LlamaModelParams {
421            params: default_params,
422            // push the next one to ensure we maintain the iterator invariant of ending with a 0
423            kv_overrides: vec![llama_cpp_sys_2::llama_model_kv_override {
424                key: [0; 128],
425                tag: 0,
426                __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
427                    val_i64: 0,
428                },
429            }],
430            buft_overrides: vec![llama_cpp_sys_2::llama_model_tensor_buft_override {
431                pattern: std::ptr::null(),
432                buft: std::ptr::null_mut(),
433            }],
434            devices: Box::pin([std::ptr::null_mut(); 16]),
435        }
436    }
437}