Skip to main content

llama_cpp_bindings/model/
params.rs

1//! A safe wrapper around `llama_model_params`.
2
3use crate::LlamaCppError;
4use crate::model::params::kv_overrides::KvOverrides;
5use std::ffi::{CStr, c_char};
6use std::fmt::{Debug, Formatter};
7use std::pin::Pin;
8use std::ptr::null;
9
10pub mod kv_overrides;
11
12const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_NONE as i8;
13const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_LAYER as i8;
14const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_bindings_sys::LLAMA_SPLIT_MODE_ROW as i8;
15
16/// A rusty wrapper around `llama_split_mode`.
17#[repr(i8)]
18#[derive(Copy, Clone, Debug, PartialEq, Eq)]
19pub enum LlamaSplitMode {
20    /// Single GPU
21    None = LLAMA_SPLIT_MODE_NONE,
22    /// Split layers and KV across GPUs
23    Layer = LLAMA_SPLIT_MODE_LAYER,
24    /// Split layers and KV across GPUs, use tensor parallelism if supported
25    Row = LLAMA_SPLIT_MODE_ROW,
26}
27
28/// An error that occurs when unknown split mode is encountered.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub struct LlamaSplitModeParseError(pub i32);
31
32/// Create a `LlamaSplitMode` from a `i32`.
33///
34/// # Errors
35/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
36impl TryFrom<i32> for LlamaSplitMode {
37    type Error = LlamaSplitModeParseError;
38
39    fn try_from(value: i32) -> Result<Self, Self::Error> {
40        let i8_value = value
41            .try_into()
42            .map_err(|_| LlamaSplitModeParseError(value))?;
43        match i8_value {
44            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
45            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
46            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
47            _ => Err(LlamaSplitModeParseError(value)),
48        }
49    }
50}
51
52/// Create a `LlamaSplitMode` from a `u32`.
53///
54/// # Errors
55/// Returns `LlamaSplitModeParseError` if the value does not correspond to a valid `LlamaSplitMode`.
56impl TryFrom<u32> for LlamaSplitMode {
57    type Error = LlamaSplitModeParseError;
58
59    fn try_from(value: u32) -> Result<Self, Self::Error> {
60        let i8_value = value
61            .try_into()
62            .map_err(|_| LlamaSplitModeParseError(value.try_into().unwrap_or(i32::MAX)))?;
63        match i8_value {
64            LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
65            LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
66            LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
67            _ => Err(LlamaSplitModeParseError(
68                value.try_into().unwrap_or(i32::MAX),
69            )),
70        }
71    }
72}
73
74/// Create a `i32` from a `LlamaSplitMode`.
75impl From<LlamaSplitMode> for i32 {
76    fn from(value: LlamaSplitMode) -> Self {
77        match value {
78            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
79            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
80            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
81        }
82    }
83}
84
85/// Create a `u32` from a `LlamaSplitMode`.
86impl From<LlamaSplitMode> for u32 {
87    fn from(value: LlamaSplitMode) -> Self {
88        match value {
89            LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as u32,
90            LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as u32,
91            LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as u32,
92        }
93    }
94}
95
96/// The default split mode is `Layer` in llama.cpp.
97impl Default for LlamaSplitMode {
98    fn default() -> Self {
99        LlamaSplitMode::Layer
100    }
101}
102
103/// The maximum number of devices supported.
104///
105/// The real maximum number of devices is the lesser one of this value and the value returned by
106/// `llama_cpp_bindings::max_devices()`.
107pub const LLAMA_CPP_MAX_DEVICES: usize = 16;
108
109/// A safe wrapper around `llama_model_params`.
110pub struct LlamaModelParams {
111    /// The underlying `llama_model_params` from the C API.
112    pub params: llama_cpp_bindings_sys::llama_model_params,
113    kv_overrides: Vec<llama_cpp_bindings_sys::llama_model_kv_override>,
114    buft_overrides: Vec<llama_cpp_bindings_sys::llama_model_tensor_buft_override>,
115    devices: Pin<Box<[llama_cpp_bindings_sys::ggml_backend_dev_t; LLAMA_CPP_MAX_DEVICES]>>,
116}
117
118impl Debug for LlamaModelParams {
119    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
120        f.debug_struct("LlamaModelParams")
121            .field("n_gpu_layers", &self.params.n_gpu_layers)
122            .field("main_gpu", &self.params.main_gpu)
123            .field("vocab_only", &self.params.vocab_only)
124            .field("use_mmap", &self.params.use_mmap)
125            .field("use_mlock", &self.params.use_mlock)
126            .field("split_mode", &self.split_mode())
127            .field("devices", &self.devices)
128            .field("kv_overrides", &"vec of kv_overrides")
129            .finish()
130    }
131}
132
133impl LlamaModelParams {
134    /// See [`KvOverrides`]
135    ///
136    /// # Examples
137    ///
138    /// ```rust
139    /// # use llama_cpp_bindings::model::params::LlamaModelParams;
140    /// let params = Box::pin(LlamaModelParams::default());
141    /// let kv_overrides = params.kv_overrides();
142    /// let count = kv_overrides.into_iter().count();
143    /// assert_eq!(count, 0);
144    /// ```
145    #[must_use]
146    pub fn kv_overrides(&self) -> KvOverrides<'_> {
147        KvOverrides::new(self)
148    }
149
150    /// Appends a key-value override to the model parameters. It must be pinned as this creates a self-referential struct.
151    ///
152    /// # Examples
153    ///
154    /// ```rust
155    /// # use std::ffi::{CStr, CString};
156    /// use std::pin::pin;
157    /// # use llama_cpp_bindings::model::params::LlamaModelParams;
158    /// # use llama_cpp_bindings::model::params::kv_overrides::ParamOverrideValue;
159    /// let mut params = pin!(LlamaModelParams::default());
160    /// let key = CString::new("key").expect("CString::new failed");
161    /// params.as_mut().append_kv_override(&key, ParamOverrideValue::Int(50));
162    ///
163    /// let kv_overrides = params.kv_overrides().into_iter().collect::<Vec<_>>();
164    /// assert_eq!(kv_overrides.len(), 1);
165    ///
166    /// let (k, v) = &kv_overrides[0];
167    /// assert_eq!(v, &ParamOverrideValue::Int(50));
168    ///
169    /// assert_eq!(k.to_bytes(), b"key", "expected key to be 'key', was {:?}", k);
170    /// ```
171    pub fn append_kv_override(
172        mut self: Pin<&mut Self>,
173        key: &CStr,
174        value: kv_overrides::ParamOverrideValue,
175    ) {
176        let kv_override = self
177            .kv_overrides
178            .get_mut(0)
179            .expect("kv_overrides did not have a next allocated");
180
181        assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
182
183        // There should be some way to do this without iterating over everything.
184        for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
185            kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
186        }
187
188        kv_override.tag = value.tag();
189        kv_override.__bindgen_anon_1 = value.value();
190
191        // set to null pointer for panic safety (as push may move the vector, invalidating the pointer)
192        self.params.kv_overrides = null();
193
194        // push the next one to ensure we maintain the iterator invariant of ending with a 0
195        self.kv_overrides
196            .push(llama_cpp_bindings_sys::llama_model_kv_override {
197                key: [0; 128],
198                tag: 0,
199                __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
200                    val_i64: 0,
201                },
202            });
203
204        // set the pointer to the (potentially) new vector
205        self.params.kv_overrides = self.kv_overrides.as_ptr();
206
207        eprintln!("saved ptr: {:?}", self.params.kv_overrides);
208    }
209}
210
211impl LlamaModelParams {
212    /// Adds buffer type overides to move all mixture-of-experts layers to CPU.
213    pub fn add_cpu_moe_override(self: Pin<&mut Self>) {
214        self.add_cpu_buft_override(c"\\.ffn_(up|down|gate)_(ch|)exps");
215    }
216
217    /// Appends a buffer type override to the model parameters, to move layers matching pattern to CPU.
218    /// It must be pinned as this creates a self-referential struct.
219    ///
220    /// # Panics
221    /// Panics if the internal buffer type overrides vector is empty or the last entry is not empty.
222    pub fn add_cpu_buft_override(mut self: Pin<&mut Self>, key: &CStr) {
223        let buft_override = self
224            .buft_overrides
225            .get_mut(0)
226            .expect("buft_overrides did not have a next allocated");
227
228        assert!(
229            buft_override.pattern.is_null(),
230            "last buft_override was not empty"
231        );
232
233        // There should be some way to do this without iterating over everything.
234        for &c in key.to_bytes_with_nul() {
235            c_char::try_from(c).expect("invalid character in key");
236        }
237
238        buft_override.pattern = key.as_ptr();
239        buft_override.buft = unsafe { llama_cpp_bindings_sys::ggml_backend_cpu_buffer_type() };
240
241        // set to null pointer for panic safety (as push may move the vector, invalidating the pointer)
242        self.params.tensor_buft_overrides = null();
243
244        // push the next one to ensure we maintain the iterator invariant of ending with a 0
245        self.buft_overrides
246            .push(llama_cpp_bindings_sys::llama_model_tensor_buft_override {
247                pattern: std::ptr::null(),
248                buft: std::ptr::null_mut(),
249            });
250
251        // set the pointer to the (potentially) new vector
252        self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
253    }
254}
255
256impl LlamaModelParams {
257    /// Get the number of layers to offload to the GPU.
258    #[must_use]
259    pub fn n_gpu_layers(&self) -> i32 {
260        self.params.n_gpu_layers
261    }
262
263    /// The GPU that is used for scratch and small tensors
264    #[must_use]
265    pub fn main_gpu(&self) -> i32 {
266        self.params.main_gpu
267    }
268
269    /// only load the vocabulary, no weights
270    #[must_use]
271    pub fn vocab_only(&self) -> bool {
272        self.params.vocab_only
273    }
274
275    /// use mmap if possible
276    #[must_use]
277    pub fn use_mmap(&self) -> bool {
278        self.params.use_mmap
279    }
280
281    /// force system to keep model in RAM
282    #[must_use]
283    pub fn use_mlock(&self) -> bool {
284        self.params.use_mlock
285    }
286
287    /// get the split mode
288    ///
289    /// # Errors
290    /// Returns `LlamaSplitModeParseError` if the unknown split mode is encountered.
291    pub fn split_mode(&self) -> Result<LlamaSplitMode, LlamaSplitModeParseError> {
292        LlamaSplitMode::try_from(self.params.split_mode)
293    }
294
295    /// get the devices
296    #[must_use]
297    pub fn devices(&self) -> Vec<usize> {
298        let mut backend_devices = Vec::new();
299        for i in 0..unsafe { llama_cpp_bindings_sys::ggml_backend_dev_count() } {
300            let dev = unsafe { llama_cpp_bindings_sys::ggml_backend_dev_get(i) };
301            backend_devices.push(dev);
302        }
303        let mut devices = Vec::new();
304        for &dev in self.devices.iter() {
305            if dev.is_null() {
306                break;
307            }
308            if let Some((index, _)) = backend_devices
309                .iter()
310                .enumerate()
311                .find(|&(_i, &d)| d == dev)
312            {
313                devices.push(index);
314            }
315        }
316        devices
317    }
318
319    /// sets the number of gpu layers to offload to the GPU.
320    /// ```
321    /// # use llama_cpp_bindings::model::params::LlamaModelParams;
322    /// let params = LlamaModelParams::default();
323    /// let params = params.with_n_gpu_layers(1);
324    /// assert_eq!(params.n_gpu_layers(), 1);
325    /// ```
326    #[must_use]
327    pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
328        // The only way this conversion can fail is if u32 overflows the i32 - in which case we set
329        // to MAX
330        let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
331        self.params.n_gpu_layers = n_gpu_layers;
332        self
333    }
334
335    /// sets the main GPU
336    ///
337    /// To enable this option, you must set `split_mode` to `LlamaSplitMode::None` to enable single GPU mode.
338    #[must_use]
339    pub fn with_main_gpu(mut self, main_gpu: i32) -> Self {
340        self.params.main_gpu = main_gpu;
341        self
342    }
343
344    /// sets `vocab_only`
345    #[must_use]
346    pub fn with_vocab_only(mut self, vocab_only: bool) -> Self {
347        self.params.vocab_only = vocab_only;
348        self
349    }
350
351    /// sets `use_mlock`
352    #[must_use]
353    pub fn with_use_mlock(mut self, use_mlock: bool) -> Self {
354        self.params.use_mlock = use_mlock;
355        self
356    }
357
358    /// sets `split_mode`
359    #[must_use]
360    pub fn with_split_mode(mut self, split_mode: LlamaSplitMode) -> Self {
361        self.params.split_mode = split_mode.into();
362        self
363    }
364
365    /// sets `devices`
366    ///
367    /// The devices are specified as indices that correspond to the ggml backend device indices.
368    ///
369    /// The maximum number of devices is 16.
370    ///
371    /// You don't need to specify CPU or ACCEL devices.
372    ///
373    /// # Errors
374    /// Returns `LlamaCppError::BackendDeviceNotFound` if any device index is invalid.
375    pub fn with_devices(mut self, devices: &[usize]) -> Result<Self, LlamaCppError> {
376        for dev in self.devices.iter_mut() {
377            *dev = std::ptr::null_mut();
378        }
379        // Check device count
380        let max_devices = crate::max_devices().min(LLAMA_CPP_MAX_DEVICES);
381        if devices.len() > max_devices {
382            return Err(LlamaCppError::MaxDevicesExceeded(max_devices));
383        }
384        for (i, &dev) in devices.iter().enumerate() {
385            if dev >= unsafe { llama_cpp_bindings_sys::ggml_backend_dev_count() } {
386                return Err(LlamaCppError::BackendDeviceNotFound(dev));
387            }
388            let backend_dev = unsafe { llama_cpp_bindings_sys::ggml_backend_dev_get(dev) };
389            self.devices[i] = backend_dev;
390        }
391        if self.devices.is_empty() {
392            self.params.devices = std::ptr::null_mut();
393        } else {
394            self.params.devices = self.devices.as_mut_ptr();
395        }
396        Ok(self)
397    }
398}
399
400/// Default parameters for `LlamaModel`. (as defined in llama.cpp by `llama_model_default_params`)
401/// ```
402/// # use llama_cpp_bindings::model::params::LlamaModelParams;
403/// use llama_cpp_bindings::model::params::LlamaSplitMode;
404/// let params = LlamaModelParams::default();
405/// assert_eq!(params.n_gpu_layers(), -1, "n_gpu_layers should be -1");
406/// assert_eq!(params.main_gpu(), 0, "main_gpu should be 0");
407/// assert_eq!(params.vocab_only(), false, "vocab_only should be false");
408/// assert_eq!(params.use_mmap(), true, "use_mmap should be true");
409/// assert_eq!(params.use_mlock(), false, "use_mlock should be false");
410/// assert_eq!(params.split_mode(), Ok(LlamaSplitMode::Layer), "split_mode should be LAYER");
411/// assert_eq!(params.devices().len(), 0, "devices should be empty");
412/// ```
413impl Default for LlamaModelParams {
414    fn default() -> Self {
415        let default_params = unsafe { llama_cpp_bindings_sys::llama_model_default_params() };
416        LlamaModelParams {
417            params: default_params,
418            // push the next one to ensure we maintain the iterator invariant of ending with a 0
419            kv_overrides: vec![llama_cpp_bindings_sys::llama_model_kv_override {
420                key: [0; 128],
421                tag: 0,
422                __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
423                    val_i64: 0,
424                },
425            }],
426            buft_overrides: vec![llama_cpp_bindings_sys::llama_model_tensor_buft_override {
427                pattern: std::ptr::null(),
428                buft: std::ptr::null_mut(),
429            }],
430            devices: Box::pin([std::ptr::null_mut(); 16]),
431        }
432    }
433}