llama_cpp_sys_4/common.rs
1//! Manual wrapper for values in llama.cpp/common/common.h
2
3use crate::{
4 ggml_numa_strategy, llama_attention_type, llama_pooling_type, llama_rope_scaling_type,
5 llama_split_mode, GGML_NUMA_STRATEGY_DISABLED, LLAMA_ATTENTION_TYPE_UNSPECIFIED,
6 LLAMA_DEFAULT_SEED, LLAMA_POOLING_TYPE_UNSPECIFIED, LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
7 LLAMA_SPLIT_MODE_LAYER,
8};
9
10pub const COMMON_SAMPLER_TYPE_NONE: common_sampler_type = 0;
11pub const COMMON_SAMPLER_TYPE_DRY: common_sampler_type = 1;
12pub const COMMON_SAMPLER_TYPE_TOP_K: common_sampler_type = 2;
13pub const COMMON_SAMPLER_TYPE_TOP_P: common_sampler_type = 3;
14pub const COMMON_SAMPLER_TYPE_MIN_P: common_sampler_type = 4;
15pub const COMMON_SAMPLER_TYPE_TFS_Z: common_sampler_type = 5;
16pub const COMMON_SAMPLER_TYPE_TYPICAL_P: common_sampler_type = 6;
17pub const COMMON_SAMPLER_TYPE_TEMPERATURE: common_sampler_type = 7;
18pub const COMMON_SAMPLER_TYPE_XTC: common_sampler_type = 8;
19pub const COMMON_SAMPLER_TYPE_INFILL: common_sampler_type = 9;
20pub type common_sampler_type = ::core::ffi::c_uint;
21
22/// common sampler params
23#[repr(C)]
24#[derive(Debug, PartialEq)]
25pub struct common_sampler_params {
26 /// the seed used to initialize `llama_sampler`
27 pub seed: u32,
28 /// number of previous tokens to remember
29 pub n_prev: i32,
30 /// if greater than 0, output the probabilities of top `n_probs` tokens.
31 pub n_probs: i32,
32 /// 0 = disabled, otherwise samplers should return at least `min_keep` tokens
33 pub min_keep: i32,
34 /// <= 0 to use vocab size
35 pub top_k: i32,
36 /// 1.0 = disabled
37 pub top_p: f32,
38 /// 0.0 = disabled
39 pub min_p: f32,
40 /// 0.0 = disabled
41 pub xtc_probability: f32,
42 /// > 0.5 disables XTC
43 pub xtc_threshold: f32,
44 /// 1.0 = disabled
45 pub tfs_z: f32,
46 /// typical_p, 1.0 = disabled
47 pub typ_p: f32,
48 /// <= 0.0 to sample greedily, 0.0 to not output probabilities
49 pub temp: f32,
50 /// 0.0 = disabled
51 pub dynatemp_range: f32,
52 /// controls how entropy maps to temperature in dynamic temperature sampler
53 pub dynatemp_exponent: f32,
54 /// last n tokens to penalize (0 = disable penalty, -1 = context size)
55 pub penalty_last_n: i32,
56 /// 1.0 = disabled
57 pub penalty_repeat: f32,
58 /// 0.0 = disabled
59 pub penalty_freq: f32,
60 /// 0.0 = disabled
61 pub penalty_present: f32,
62 /// 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
63 pub dry_multiplier: f32,
64 /// 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
65 pub dry_base: f32,
66 /// tokens extending repetitions beyond this receive penalty
67 pub dry_allowed_length: i32,
68 /// how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
69 pub dry_penalty_last_n: i32,
70 /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
71 pub mirostat: i32,
72 /// target entropy
73 pub mirostat_tau: f32,
74 /// learning rate
75 pub mirostat_eta: f32,
76 /// consider newlines as a repeatable token
77 pub penalize_nl: bool,
78 pub ignore_eos: bool,
79 /// disable performance metrics
80 pub no_perf: bool,
81 pub dry_sequence_breakers: Vec<String>,
82 pub samplers: Vec<common_sampler_type>,
83 pub grammar: Vec<String>,
84 pub logit_bias: Vec<(i32, f64)>,
85}
86
87impl Default for common_sampler_params {
88 fn default() -> Self {
89 Self {
90 seed: LLAMA_DEFAULT_SEED, // the seed used to initialize llama_sampler
91 n_prev: 64, // number of previous tokens to remember
92 n_probs: 0, // if greater than 0, output the probabilities of top n_probs tokens.
93 min_keep: 0, // 0 = disabled, otherwise samplers should return at least min_keep tokens
94 top_k: 40, // <= 0 to use vocab size
95 top_p: 0.95, // 1.0 = disabled
96 min_p: 0.05, // 0.0 = disabled
97 xtc_probability: 0.00, // 0.0 = disabled
98 xtc_threshold: 0.10, // > 0.5 disables XTC
99 tfs_z: 1.00, // 1.0 = disabled
100 typ_p: 1.00, // typical_p, 1.0 = disabled
101 temp: 0.80, // <= 0.0 to sample greedily, 0.0 to not output probabilities
102 dynatemp_range: 0.00, // 0.0 = disabled
103 dynatemp_exponent: 1.00, // controls how entropy maps to temperature in dynamic temperature sampler
104 penalty_last_n: 64, // last n tokens to penalize (0 = disable penalty, -1 = context size)
105 penalty_repeat: 1.00, // 1.0 = disabled
106 penalty_freq: 0.00, // 0.0 = disabled
107 penalty_present: 0.00, // 0.0 = disabled
108 dry_multiplier: 0.0, // 0.0 = disabled; DRY repetition penalty for tokens extending repetition:
109 dry_base: 1.75, // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length)
110 dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty
111 dry_penalty_last_n: -1, // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
112 mirostat: 0, // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
113 mirostat_tau: 5.00, // target entropy
114 mirostat_eta: 0.10, // learning rate
115 penalize_nl: false, // consider newlines as a repeatable token
116 ignore_eos: false,
117 no_perf: false, // disable performance metrics
118
119 dry_sequence_breakers: vec!["\n".into(), ":".into(), "\"".into(), "*".into()], // default sequence breakers for DRY
120
121 samplers: vec![
122 COMMON_SAMPLER_TYPE_DRY,
123 COMMON_SAMPLER_TYPE_TOP_K,
124 COMMON_SAMPLER_TYPE_TFS_Z,
125 COMMON_SAMPLER_TYPE_TYPICAL_P,
126 COMMON_SAMPLER_TYPE_TOP_P,
127 COMMON_SAMPLER_TYPE_MIN_P,
128 COMMON_SAMPLER_TYPE_XTC,
129 COMMON_SAMPLER_TYPE_TEMPERATURE,
130 ],
131
132 grammar: vec![], // optional BNF-like grammar to constrain sampling
133
134 logit_bias: vec![], // logit biases to apply
135 }
136 }
137}
138
139#[repr(C)]
140#[derive(Debug, PartialEq)]
141pub struct common_params {
142 /// new tokens to predict
143 pub n_predict: i32,
144 /// context size
145 pub n_ctx: i32,
146 /// logical batch size for prompt processing (must be >=32 to use BLAS)
147 pub n_batch: i32,
148 /// physical batch size for prompt processing (must be >=32 to use BLAS)
149 pub n_ubatch: i32,
150 /// number of tokens to keep from initial prompt
151 pub n_keep: i32,
152 /// number of tokens to draft during speculative decoding
153 pub n_draft: i32,
154 /// max number of chunks to process (-1 = unlimited)
155 pub n_chunks: i32,
156 /// number of parallel sequences to decode
157 pub n_parallel: i32,
158 /// number of sequences to decode
159 pub n_sequences: i32,
160 // speculative decoding split probability
161 pub p_split: f32,
162 /// number of layers to store in VRAM (-1 - use default)
163 pub n_gpu_layers: i32,
164 /// number of layers to store in VRAM for the draft model (-1 - use default)
165 pub n_gpu_layers_draft: i32,
166 /// the GPU that is used for scratch and small tensors
167 pub main_gpu: i32,
168 /// how split tensors should be distributed across GPUs
169 // pub tensor_split: [f32; 128usize],
170 /// group-attention factor
171 pub grp_attn_n: i32,
172 /// group-attention width
173 pub grp_attn_w: i32,
174 /// print token count every n tokens (-1 = disabled)
175 pub n_print: i32,
176 /// RoPE base frequency
177 pub rope_freq_base: f32,
178 /// RoPE frequency scaling factor
179 pub rope_freq_scale: f32,
180 /// YaRN extrapolation mix factor
181 pub yarn_ext_factor: f32,
182 /// YaRN magnitude scaling factor
183 pub yarn_attn_factor: f32,
184 /// YaRN low correction dim
185 pub yarn_beta_fast: f32,
186 /// YaRN high correction dim
187 pub yarn_beta_slow: f32,
188 /// YaRN original context length
189 pub yarn_orig_ctx: i32,
190 /// KV cache defragmentation threshold
191 pub defrag_thold: f32,
192 // pub cpuparams: cpu_params,
193 // pub cpuparams_batch: cpu_params,
194 // pub draft_cpuparams: cpu_params,
195 // pub draft_cpuparams_batch: cpu_params,
196 // pub cb_eval: ggml_backend_sched_eval_callback,
197 // pub cb_eval_user_data: *mut ::core::ffi::c_void,
198 pub numa: ggml_numa_strategy,
199 pub split_mode: llama_split_mode,
200 pub rope_scaling_type: llama_rope_scaling_type,
201 pub pooling_type: llama_pooling_type,
202 pub attention_type: llama_attention_type,
203 pub sparams: common_sampler_params,
204 // pub model: std___1_string,
205 // pub model_draft: std___1_string,
206 // pub model_alias: std___1_string,
207 // pub model_url: std___1_string,
208 // pub hf_token: std___1_string,
209 // pub hf_repo: std___1_string,
210 // pub hf_file: std___1_string,
211 // pub prompt: std___1_string,
212 // pub prompt_file: std___1_string,
213 // pub path_prompt_cache: std___1_string,
214 // pub input_prefix: std___1_string,
215 // pub input_suffix: std___1_string,
216 // pub logdir: std___1_string,
217 // pub lookup_cache_static: std___1_string,
218 // pub lookup_cache_dynamic: std___1_string,
219 // pub logits_file: std___1_string,
220 // pub rpc_servers: std___1_string,
221 // pub in_files: [u64; 3usize],
222 // pub antiprompt: [u64; 3usize],
223 // pub kv_overrides: [u64; 3usize],
224 // pub lora_init_without_apply: bool,
225 // pub lora_adapters: [u64; 3usize],
226 // pub control_vectors: [u64; 3usize],
227 // pub verbosity: i32,
228 // pub control_vector_layer_start: i32,
229 // pub control_vector_layer_end: i32,
230 // pub ppl_stride: i32,
231 // pub ppl_output_type: i32,
232 // pub hellaswag: bool,
233 // pub hellaswag_tasks: usize,
234 // pub winogrande: bool,
235 // pub winogrande_tasks: usize,
236 // pub multiple_choice: bool,
237 // pub multiple_choice_tasks: usize,
238 // pub kl_divergence: bool,
239 // pub usage: bool,
240 // pub use_color: bool,
241 // pub special: bool,
242 // pub interactive: bool,
243 // pub interactive_first: bool,
244 // pub conversation: bool,
245 // pub prompt_cache_all: bool,
246 // pub prompt_cache_ro: bool,
247 // pub escape: bool,
248 // pub multiline_input: bool,
249 // pub simple_io: bool,
250 // pub cont_batching: bool,
251 // pub flash_attn: bool,
252 // pub no_perf: bool,
253 // pub ctx_shift: bool,
254 // pub input_prefix_bos: bool,
255 // pub logits_all: bool,
256 // pub use_mmap: bool,
257 // pub use_mlock: bool,
258 // pub verbose_prompt: bool,
259 // pub display_prompt: bool,
260 // pub dump_kv_cache: bool,
261 // pub no_kv_offload: bool,
262 // pub warmup: bool,
263 // pub check_tensors: bool,
264 // pub cache_type_k: std___1_string,
265 // pub cache_type_v: std___1_string,
266 // pub mmproj: std___1_string,
267 // pub image: [u64; 3usize],
268 // pub embedding: bool,
269 // pub embd_normalize: i32,
270 // pub embd_out: std___1_string,
271 // pub embd_sep: std___1_string,
272 // pub reranking: bool,
273 // pub port: i32,
274 // pub timeout_read: i32,
275 // pub timeout_write: i32,
276 // pub n_threads_http: i32,
277 // pub n_cache_reuse: i32,
278 // pub hostname: std___1_string,
279 // pub public_path: std___1_string,
280 // pub chat_template: std___1_string,
281 // pub enable_chat_template: bool,
282 // pub api_keys: [u64; 3usize],
283 // pub ssl_file_key: std___1_string,
284 // pub ssl_file_cert: std___1_string,
285 // pub webui: bool,
286 // pub endpoint_slots: bool,
287 // pub endpoint_props: bool,
288 // pub endpoint_metrics: bool,
289 // pub log_json: bool,
290 // pub slot_save_path: std___1_string,
291 // pub slot_prompt_similarity: f32,
292 // pub is_pp_shared: bool,
293 // pub n_pp: [u64; 3usize],
294 // pub n_tg: [u64; 3usize],
295 // pub n_pl: [u64; 3usize],
296 // pub context_files: [u64; 3usize],
297 // pub chunk_size: i32,
298 // pub chunk_separator: std___1_string,
299 // pub n_junk: i32,
300 // pub i_pos: i32,
301 // pub out_file: std___1_string,
302 // pub n_out_freq: i32,
303 // pub n_save_freq: i32,
304 // pub i_chunk: i32,
305 // pub process_output: bool,
306 // pub compute_ppl: bool,
307 // pub n_pca_batch: ::core::ffi::c_int,
308 // pub n_pca_iterations: ::core::ffi::c_int,
309 // pub cvector_dimre_method: dimre_method,
310 // pub cvector_outfile: std___1_string,
311 // pub cvector_positive_file: std___1_string,
312 // pub cvector_negative_file: std___1_string,
313 // pub spm_infill: bool,
314 // pub lora_outfile: std___1_string,
315 // pub batched_bench_output_jsonl: bool,
316}
317
318impl Default for common_params {
319 fn default() -> Self {
320 Self {
321 n_predict: -1, // new tokens to predict
322 n_ctx: 0, // context size
323 n_batch: 2048, // logical batch size for prompt processing (must be >=32 to use BLAS)
324 n_ubatch: 512, // physical batch size for prompt processing (must be >=32 to use BLAS)
325 n_keep: 0, // number of tokens to keep from initial prompt
326 n_draft: 5, // number of tokens to draft during speculative decoding
327 n_chunks: -1, // max number of chunks to process (-1 = unlimited)
328 n_parallel: 1, // number of parallel sequences to decode
329 n_sequences: 1, // number of sequences to decode
330 p_split: 0.1, // speculative decoding split probability
331 n_gpu_layers: -1, // number of layers to store in VRAM (-1 - use default)
332 n_gpu_layers_draft: -1, // number of layers to store in VRAM for the draft model (-1 - use default)
333 main_gpu: 0, // the GPU that is used for scratch and small tensors
334 // tensor_split[128] : {0}, // how split tensors should be distributed across GPUs
335 grp_attn_n: 1, // group-attention factor
336 grp_attn_w: 512, // group-attention width
337 n_print: -1, // print token count every n tokens (-1 = disabled)
338 rope_freq_base: 0.0, // RoPE base frequency
339 rope_freq_scale: 0.0, // RoPE frequency scaling factor
340 yarn_ext_factor: -1.0, // YaRN extrapolation mix factor
341 yarn_attn_factor: 1.0, // YaRN magnitude scaling factor
342 yarn_beta_fast: 32.0, // YaRN low correction dim
343 yarn_beta_slow: 1.0, // YaRN high correction dim
344 yarn_orig_ctx: 0, // YaRN original context length
345 defrag_thold: -1.0, // KV cache defragmentation threshold
346
347 // struct cpu_params cpuparams;
348 // struct cpu_params cpuparams_batch;
349 // struct cpu_params draft_cpuparams;
350 // struct cpu_params draft_cpuparams_batch;
351
352 // ggml_backend_sched_eval_callback cb_eval = nullptr;
353 // void * cb_eval_user_data = nullptr;
354 numa: GGML_NUMA_STRATEGY_DISABLED,
355
356 split_mode: LLAMA_SPLIT_MODE_LAYER, // how to split the model across GPUs
357 rope_scaling_type: LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
358 pooling_type: LLAMA_POOLING_TYPE_UNSPECIFIED, // pooling type for embeddings
359 attention_type: LLAMA_ATTENTION_TYPE_UNSPECIFIED, // attention type for embeddings
360
361 sparams: common_sampler_params::default(),
362 }
363 }
364}