llama_cpp_2/model/params.rs
1//! A safe wrapper around `llama_model_params`.
2
3use crate::model::params::kv_overrides::KvOverrides;
4use std::ffi::{c_char, CStr};
5use std::fmt::{Debug, Formatter};
6use std::pin::Pin;
7use std::ptr::null;
8
9pub mod kv_overrides;
10
11/// A safe wrapper around `llama_model_params`.
12#[allow(clippy::module_name_repetitions)]
13pub struct LlamaModelParams {
14 pub(crate) params: llama_cpp_sys_2::llama_model_params,
15 kv_overrides: Vec<llama_cpp_sys_2::llama_model_kv_override>,
16}
17
18impl Debug for LlamaModelParams {
19 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("LlamaModelParams")
21 .field("n_gpu_layers", &self.params.n_gpu_layers)
22 .field("main_gpu", &self.params.main_gpu)
23 .field("vocab_only", &self.params.vocab_only)
24 .field("use_mmap", &self.params.use_mmap)
25 .field("use_mlock", &self.params.use_mlock)
26 .field("kv_overrides", &"vec of kv_overrides")
27 .finish()
28 }
29}
30
31impl LlamaModelParams {
32 /// See [`KvOverrides`]
33 ///
34 /// # Examples
35 ///
36 /// ```rust
37 /// # use llama_cpp_2::model::params::LlamaModelParams;
38 /// let params = Box::pin(LlamaModelParams::default());
39 /// let kv_overrides = params.kv_overrides();
40 /// let count = kv_overrides.into_iter().count();
41 /// assert_eq!(count, 0);
42 /// ```
43 #[must_use]
44 pub fn kv_overrides(&self) -> KvOverrides {
45 KvOverrides::new(self)
46 }
47
48 /// Appends a key-value override to the model parameters. It must be pinned as this creates a self-referential struct.
49 ///
50 /// # Examples
51 ///
52 /// ```rust
53 /// # use std::ffi::{CStr, CString};
54 /// use std::pin::pin;
55 /// # use llama_cpp_2::model::params::LlamaModelParams;
56 /// # use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue;
57 /// let mut params = pin!(LlamaModelParams::default());
58 /// let key = CString::new("key").expect("CString::new failed");
59 /// params.as_mut().append_kv_override(&key, ParamOverrideValue::Int(50));
60 ///
61 /// let kv_overrides = params.kv_overrides().into_iter().collect::<Vec<_>>();
62 /// assert_eq!(kv_overrides.len(), 1);
63 ///
64 /// let (k, v) = &kv_overrides[0];
65 /// assert_eq!(v, &ParamOverrideValue::Int(50));
66 ///
67 /// assert_eq!(k.to_bytes(), b"key", "expected key to be 'key', was {:?}", k);
68 /// ```
69 #[allow(clippy::missing_panics_doc)] // panics are just to enforce internal invariants, not user errors
70 pub fn append_kv_override(
71 mut self: Pin<&mut Self>,
72 key: &CStr,
73 value: kv_overrides::ParamOverrideValue,
74 ) {
75 let kv_override = self
76 .kv_overrides
77 .get_mut(0)
78 .expect("kv_overrides did not have a next allocated");
79
80 assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
81
82 // There should be some way to do this without iterating over everything.
83 for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
84 kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
85 }
86
87 kv_override.tag = value.tag();
88 kv_override.__bindgen_anon_1 = value.value();
89
90 // set to null pointer for panic safety (as push may move the vector, invalidating the pointer)
91 self.params.kv_overrides = null();
92
93 // push the next one to ensure we maintain the iterator invariant of ending with a 0
94 self.kv_overrides
95 .push(llama_cpp_sys_2::llama_model_kv_override {
96 key: [0; 128],
97 tag: 0,
98 __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
99 val_i64: 0,
100 },
101 });
102
103 // set the pointer to the (potentially) new vector
104 self.params.kv_overrides = self.kv_overrides.as_ptr();
105
106 eprintln!("saved ptr: {:?}", self.params.kv_overrides);
107 }
108}
109
110impl LlamaModelParams {
111 /// Get the number of layers to offload to the GPU.
112 #[must_use]
113 pub fn n_gpu_layers(&self) -> i32 {
114 self.params.n_gpu_layers
115 }
116
117 /// The GPU that is used for scratch and small tensors
118 #[must_use]
119 pub fn main_gpu(&self) -> i32 {
120 self.params.main_gpu
121 }
122
123 /// only load the vocabulary, no weights
124 #[must_use]
125 pub fn vocab_only(&self) -> bool {
126 self.params.vocab_only
127 }
128
129 /// use mmap if possible
130 #[must_use]
131 pub fn use_mmap(&self) -> bool {
132 self.params.use_mmap
133 }
134
135 /// force system to keep model in RAM
136 #[must_use]
137 pub fn use_mlock(&self) -> bool {
138 self.params.use_mlock
139 }
140
141 /// sets the number of gpu layers to offload to the GPU.
142 /// ```
143 /// # use llama_cpp_2::model::params::LlamaModelParams;
144 /// let params = LlamaModelParams::default();
145 /// let params = params.with_n_gpu_layers(1);
146 /// assert_eq!(params.n_gpu_layers(), 1);
147 /// ```
148 #[must_use]
149 pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
150 // The only way this conversion can fail is if u32 overflows the i32 - in which case we set
151 // to MAX
152 let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
153 self.params.n_gpu_layers = n_gpu_layers;
154 self
155 }
156
157 /// sets the main GPU
158 #[must_use]
159 pub fn with_main_gpu(mut self, main_gpu: i32) -> Self {
160 self.params.main_gpu = main_gpu;
161 self
162 }
163
164 /// sets `vocab_only`
165 #[must_use]
166 pub fn with_vocab_only(mut self, vocab_only: bool) -> Self {
167 self.params.vocab_only = vocab_only;
168 self
169 }
170
171 /// sets `use_mlock`
172 #[must_use]
173 pub fn with_use_mlock(mut self, use_mlock: bool) -> Self {
174 self.params.use_mlock = use_mlock;
175 self
176 }
177}
178
179/// Default parameters for `LlamaModel`. (as defined in llama.cpp by `llama_model_default_params`)
180/// ```
181/// # use llama_cpp_2::model::params::LlamaModelParams;
182/// let params = LlamaModelParams::default();
183/// #[cfg(not(target_os = "macos"))]
184/// assert_eq!(params.n_gpu_layers(), 0, "n_gpu_layers should be 0");
185/// #[cfg(target_os = "macos")]
186/// assert_eq!(params.n_gpu_layers(), 999, "n_gpu_layers should be 999");
187/// assert_eq!(params.main_gpu(), 0, "main_gpu should be 0");
188/// assert_eq!(params.vocab_only(), false, "vocab_only should be false");
189/// assert_eq!(params.use_mmap(), true, "use_mmap should be true");
190/// assert_eq!(params.use_mlock(), false, "use_mlock should be false");
191/// ```
192impl Default for LlamaModelParams {
193 fn default() -> Self {
194 let default_params = unsafe { llama_cpp_sys_2::llama_model_default_params() };
195 LlamaModelParams {
196 params: default_params,
197 // push the next one to ensure we maintain the iterator invariant of ending with a 0
198 kv_overrides: vec![llama_cpp_sys_2::llama_model_kv_override {
199 key: [0; 128],
200 tag: 0,
201 __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
202 val_i64: 0,
203 },
204 }],
205 }
206 }
207}