Skip to main content

llama_cpp_4/model/
params.rs

1//! A safe wrapper around `llama_model_params`.
2
3use crate::model::params::kv_overrides::KvOverrides;
4use std::ffi::{c_char, CStr};
5use std::fmt::{Debug, Formatter};
6use std::pin::Pin;
7use std::ptr::null;
8
9pub mod kv_overrides;
10
11/// A safe wrapper around `llama_model_params`.
12#[allow(clippy::module_name_repetitions)]
13pub struct LlamaModelParams {
14    pub(crate) params: llama_cpp_sys_4::llama_model_params,
15    kv_overrides: Vec<llama_cpp_sys_4::llama_model_kv_override>,
16    #[cfg(feature = "mtp")]
17    override_arch: Option<std::ffi::CString>,
18}
19
20impl Debug for LlamaModelParams {
21    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
22        f.debug_struct("LlamaModelParams")
23            .field("n_gpu_layers", &self.params.n_gpu_layers)
24            .field("main_gpu", &self.params.main_gpu)
25            .field("vocab_only", &self.params.vocab_only)
26            .field("use_mmap", &self.params.use_mmap)
27            .field("use_mlock", &self.params.use_mlock)
28            .field("kv_overrides", &"vec of kv_overrides")
29            .finish()
30    }
31}
32
33impl LlamaModelParams {
34    /// See [`KvOverrides`]
35    ///
36    /// # Examples
37    ///
38    /// ```rust
39    /// # use llama_cpp_4::model::params::LlamaModelParams;
40    /// let params = Box::pin(LlamaModelParams::default());
41    /// let kv_overrides = params.kv_overrides();
42    /// let count = kv_overrides.into_iter().count();
43    /// assert_eq!(count, 0);
44    /// ```
45    #[must_use]
46    pub fn kv_overrides(&self) -> KvOverrides<'_> {
47        KvOverrides::new(self)
48    }
49
50    /// Appends a key-value override to the model parameters. It must be pinned as this creates a self-referential struct.
51    ///
52    /// # Examples
53    ///
54    /// ```rust
55    /// # use std::ffi::{CStr, CString};
56    /// use std::pin::pin;
57    /// # use llama_cpp_4::model::params::LlamaModelParams;
58    /// # use llama_cpp_4::model::params::kv_overrides::ParamOverrideValue;
59    /// let mut params = pin!(LlamaModelParams::default());
60    /// let key = CString::new("key").expect("CString::new failed");
61    /// params.as_mut().append_kv_override(&key, ParamOverrideValue::Int(50));
62    ///
63    /// let kv_overrides = params.kv_overrides().into_iter().collect::<Vec<_>>();
64    /// assert_eq!(kv_overrides.len(), 1);
65    ///
66    /// let (k, v) = &kv_overrides[0];
67    /// assert_eq!(v, &ParamOverrideValue::Int(50));
68    ///
69    /// assert_eq!(k.to_bytes(), b"key", "expected key to be 'key', was {:?}", k);
70    /// ```
71    #[allow(clippy::missing_panics_doc)] // panics are just to enforce internal invariants, not user errors
72    pub fn append_kv_override(
73        mut self: Pin<&mut Self>,
74        key: &CStr,
75        value: kv_overrides::ParamOverrideValue,
76    ) {
77        let kv_override = self
78            .kv_overrides
79            .get_mut(0)
80            .expect("kv_overrides did not have a next allocated");
81
82        assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
83
84        // There should be some way to do this without iterating over everything.
85        for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
86            kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
87        }
88
89        kv_override.tag = value.tag();
90        kv_override.__bindgen_anon_1 = value.value();
91
92        // set to null pointer for panic safety (as push may move the vector, invalidating the pointer)
93        self.params.kv_overrides = null();
94
95        // push the next one to ensure we maintain the iterator invariant of ending with a 0
96        self.kv_overrides
97            .push(llama_cpp_sys_4::llama_model_kv_override {
98                key: [0; 128],
99                tag: 0,
100                __bindgen_anon_1: llama_cpp_sys_4::llama_model_kv_override__bindgen_ty_1 {
101                    val_i64: 0,
102                },
103            });
104
105        // set the pointer to the (potentially) new vector
106        self.params.kv_overrides = self.kv_overrides.as_ptr();
107
108        eprintln!("saved ptr: {:?}", self.params.kv_overrides);
109    }
110}
111
112impl LlamaModelParams {
113    /// Get the number of layers to offload to the GPU.
114    #[must_use]
115    pub fn n_gpu_layers(&self) -> i32 {
116        self.params.n_gpu_layers
117    }
118
119    /// The GPU that is used for scratch and small tensors
120    #[must_use]
121    pub fn main_gpu(&self) -> i32 {
122        self.params.main_gpu
123    }
124
125    /// only load the vocabulary, no weights
126    #[must_use]
127    pub fn vocab_only(&self) -> bool {
128        self.params.vocab_only
129    }
130
131    /// use mmap if possible
132    #[must_use]
133    pub fn use_mmap(&self) -> bool {
134        self.params.use_mmap
135    }
136
137    /// force system to keep model in RAM
138    #[must_use]
139    pub fn use_mlock(&self) -> bool {
140        self.params.use_mlock
141    }
142
143    /// sets the number of gpu layers to offload to the GPU.
144    /// ```
145    /// # use llama_cpp_4::model::params::LlamaModelParams;
146    /// let params = LlamaModelParams::default();
147    /// let params = params.with_n_gpu_layers(1);
148    /// assert_eq!(params.n_gpu_layers(), 1);
149    /// ```
150    #[must_use]
151    pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
152        // The only way this conversion can fail is if u32 overflows the i32 - in which case we set
153        // to MAX
154        let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
155        self.params.n_gpu_layers = n_gpu_layers;
156        self
157    }
158
159    /// sets the main GPU
160    #[must_use]
161    pub fn with_main_gpu(mut self, main_gpu: i32) -> Self {
162        self.params.main_gpu = main_gpu;
163        self
164    }
165
166    /// sets `vocab_only`
167    #[must_use]
168    pub fn with_vocab_only(mut self, vocab_only: bool) -> Self {
169        self.params.vocab_only = vocab_only;
170        self
171    }
172
173    /// sets `use_mlock`
174    #[must_use]
175    pub fn with_use_mlock(mut self, use_mlock: bool) -> Self {
176        self.params.use_mlock = use_mlock;
177        self
178    }
179
180    /// Override model architecture string used when loading.
181    ///
182    /// This is primarily used by MTP to load the draft head architecture
183    /// from the same GGUF (for example `qwen35_mtp` / `qwen35moe_mtp`).
184    ///
185    /// This API is only available when built with the `mtp` feature.
186    #[cfg(feature = "mtp")]
187    #[must_use]
188    pub fn with_override_arch(mut self, override_arch: Option<&str>) -> Self {
189        self.override_arch = override_arch
190            .map(|value| std::ffi::CString::new(value).expect("override_arch contains null bytes"));
191        self.params.override_arch = self
192            .override_arch
193            .as_ref()
194            .map_or(std::ptr::null(), |value| value.as_ptr());
195        self
196    }
197
198    /// Get the currently configured model architecture override.
199    ///
200    /// This API is only available when built with the `mtp` feature.
201    #[cfg(feature = "mtp")]
202    #[must_use]
203    pub fn override_arch(&self) -> Option<&str> {
204        self.override_arch
205            .as_ref()
206            .and_then(|value| value.to_str().ok())
207    }
208}
209
210/// Default parameters for `LlamaModel`. (as defined in llama.cpp by `llama_model_default_params`)
211/// ```
212/// # use llama_cpp_4::model::params::LlamaModelParams;
213/// let params = LlamaModelParams::default();
214/// #[cfg(not(target_os = "macos"))]
215/// assert_eq!(params.n_gpu_layers(), 0, "n_gpu_layers should be 0");
216/// #[cfg(target_os = "macos")]
217/// assert_eq!(params.n_gpu_layers(), -1, "n_gpu_layers should be -1 (all layers)");
218/// assert_eq!(params.main_gpu(), 0, "main_gpu should be 0");
219/// assert_eq!(params.vocab_only(), false, "vocab_only should be false");
220/// assert_eq!(params.use_mmap(), true, "use_mmap should be true");
221/// assert_eq!(params.use_mlock(), false, "use_mlock should be false");
222/// ```
223impl Default for LlamaModelParams {
224    fn default() -> Self {
225        let default_params = unsafe { llama_cpp_sys_4::llama_model_default_params() };
226        LlamaModelParams {
227            params: default_params,
228            // push the next one to ensure we maintain the iterator invariant of ending with a 0
229            kv_overrides: vec![llama_cpp_sys_4::llama_model_kv_override {
230                key: [0; 128],
231                tag: 0,
232                __bindgen_anon_1: llama_cpp_sys_4::llama_model_kv_override__bindgen_ty_1 {
233                    val_i64: 0,
234                },
235            }],
236            #[cfg(feature = "mtp")]
237            override_arch: None,
238        }
239    }
240}