llama_cpp_4/fit.rs
1//! Memory estimation and parameter fitting from llama.cpp `common/fit`.
2//!
3//! - [`get_device_memory_data`] — project per-device memory for a parameter set
4//! without keeping a context alive.
5//! - [`fit_params`] — adjust [`LlamaModelParams`] / [`LlamaContextParams`] to fit
6//! available device memory (upstream `common_fit_params`).
7//!
8//! # Example — memory estimate
9//!
10//! ```no_run
11//! use llama_cpp_4::prelude::*;
12//! use std::path::Path;
13//!
14//! fn main() {
15//! let _backend = LlamaBackend::init().unwrap();
16//! let report = get_device_memory_data(
17//! Path::new("model.gguf"),
18//! &LlamaModelParams::default().with_n_gpu_layers(99),
19//! &LlamaContextParams::default(),
20//! llama_cpp_sys_4::GGML_LOG_LEVEL_ERROR,
21//! )
22//! .unwrap();
23//!
24//! println!("training ctx: {}", report.hyperparams.n_ctx_train);
25//! for (i, entry) in report.entries.iter().enumerate() {
26//! println!(
27//! "device {i}: {} bytes free / {} total (projected {})",
28//! entry.free,
29//! entry.total,
30//! entry.used(),
31//! );
32//! }
33//! }
34//! ```
35//!
36//! # Example — auto-fit parameters
37//!
38//! ```no_run
39//! use llama_cpp_4::fit::{fit_params, FitParams};
40//! use llama_cpp_4::prelude::*;
41//! use std::path::Path;
42//!
43//! fn main() {
44//! let backend = LlamaBackend::init().unwrap();
45//! let result = fit_params(
46//! &backend,
47//! Path::new("model.gguf"),
48//! FitParams::default().with_n_ctx_min(512),
49//! )
50//! .unwrap();
51//!
52//! use std::num::NonZeroU32;
53//!
54//! println!("n_ctx: {}", result.context_params.n_ctx().map_or(0, NonZeroU32::get));
55//! println!("n_gpu_layers: {}", result.model_params.n_gpu_layers());
56//! }
57//! ```
58
59use std::ffi::CString;
60use std::path::Path;
61use std::ptr::{null, null_mut};
62
63use thiserror::Error;
64
65use crate::context::params::LlamaContextParams;
66use crate::llama_backend::LlamaBackend;
67use crate::model::params::LlamaModelParams;
68use crate::{max_devices, max_tensor_buft_overrides};
69
70/// Per-device memory projection from [`get_device_memory_data`].
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct DeviceMemoryEntry {
73 /// Total device memory in bytes.
74 pub total: i64,
75 /// Free device memory in bytes at query time.
76 pub free: i64,
77 /// Projected model weight bytes on this device.
78 pub model: usize,
79 /// Projected KV / recurrent cache bytes.
80 pub context: usize,
81 /// Projected temporary compute buffer bytes.
82 pub compute: usize,
83}
84
85impl DeviceMemoryEntry {
86 /// Sum of model, context, and compute bytes.
87 #[must_use]
88 pub fn used(&self) -> usize {
89 self.model + self.context + self.compute
90 }
91}
92
93/// Hyper-parameters discovered while estimating device memory.
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub struct DeviceMemoryHyperParams {
96 /// Model `n_gpu_layers` hyper-parameter used for the estimate.
97 pub n_gpu_layers: u32,
98 /// Model training context length.
99 pub n_ctx_train: u32,
100 /// Number of `MoE` experts (`0` when dense).
101 pub n_expert: u32,
102}
103
104/// Result of [`get_device_memory_data`].
105#[derive(Debug, Clone, PartialEq, Eq)]
106pub struct DeviceMemoryReport {
107 /// Per-device memory breakdown (one entry per backend device).
108 pub entries: Vec<DeviceMemoryEntry>,
109 /// Hyper-parameters read from the checkpoint during estimation.
110 pub hyperparams: DeviceMemoryHyperParams,
111}
112
113/// Errors from [`get_device_memory_data`].
114#[derive(Debug, Error, PartialEq, Eq)]
115pub enum DeviceMemoryError {
116 /// The model path could not be encoded as a C string.
117 #[error("invalid model path")]
118 InvalidPath,
119 /// The underlying C++ helper failed (model missing, incompatible params, …).
120 #[error("device memory query failed")]
121 QueryFailed,
122 /// More devices were reported than the internal buffer allows.
123 #[error("device memory entry buffer overflow")]
124 BufferOverflow,
125}
126
127/// Estimate per-device memory for a model path and parameter set.
128///
129/// This wraps `common_get_device_memory_data` through `ext_shim`. The model is
130/// loaded with `no_alloc` and freed before returning; no context is kept alive.
131///
132/// # Errors
133///
134/// Returns [`DeviceMemoryError`] when the path is invalid or llama.cpp cannot
135/// produce an estimate.
136pub fn get_device_memory_data(
137 path_model: &Path,
138 mparams: &LlamaModelParams,
139 cparams: &LlamaContextParams,
140 log_level: llama_cpp_sys_4::ggml_log_level,
141) -> Result<DeviceMemoryReport, DeviceMemoryError> {
142 let path = CString::new(path_model.to_string_lossy().as_ref())
143 .map_err(|_| DeviceMemoryError::InvalidPath)?;
144
145 let mparams = mparams.params;
146 let cparams = cparams.context_params;
147
148 let mut capacity = 8usize;
149 loop {
150 let mut raw = vec![
151 llama_cpp_sys_4::common_device_memory_flat_entry {
152 total: 0,
153 free: 0,
154 model: 0,
155 context: 0,
156 compute: 0,
157 };
158 capacity
159 ];
160 let mut hp_ngl = 0u32;
161 let mut hp_nct = 0u32;
162 let mut hp_nex = 0u32;
163
164 let n = unsafe {
165 llama_cpp_sys_4::common_device_memory_collect(
166 path.as_ptr(),
167 &raw const mparams,
168 &raw const cparams,
169 log_level,
170 raw.as_mut_ptr(),
171 capacity,
172 &raw mut hp_ngl,
173 &raw mut hp_nct,
174 &raw mut hp_nex,
175 )
176 };
177
178 if n == usize::MAX {
179 return Err(DeviceMemoryError::QueryFailed);
180 }
181
182 if n < capacity {
183 let entries = raw
184 .into_iter()
185 .take(n)
186 .map(|e| DeviceMemoryEntry {
187 total: e.total,
188 free: e.free,
189 model: e.model,
190 context: e.context,
191 compute: e.compute,
192 })
193 .collect();
194 return Ok(DeviceMemoryReport {
195 entries,
196 hyperparams: DeviceMemoryHyperParams {
197 n_gpu_layers: hp_ngl,
198 n_ctx_train: hp_nct,
199 n_expert: hp_nex,
200 },
201 });
202 }
203
204 capacity = capacity.saturating_mul(2);
205 if capacity > 256 {
206 return Err(DeviceMemoryError::BufferOverflow);
207 }
208 }
209}
210
211const DEFAULT_MARGIN_BYTES: usize = 1024 * 1024 * 1024;
212
213/// Input to [`fit_params`].
214///
215/// Defaults mirror upstream `common_params`: unset `n_ctx` (`0`) so context size
216/// can be reduced, default model params so `n_gpu_layers` may be adjusted, and
217/// 1 GiB per-device memory margins.
218#[derive(Debug)]
219pub struct FitParams {
220 /// Starting model parameters. Only fields still at their defaults are modified.
221 pub model_params: LlamaModelParams,
222 /// Starting context parameters. Set `n_ctx` to `0` via
223 /// [`LlamaContextParams::with_n_ctx`]`(None)` to let fitting pick a context size.
224 pub context_params: LlamaContextParams,
225 /// Minimum free memory to leave on each device, in bytes (one entry per device).
226 pub margins: Vec<usize>,
227 /// Minimum context size when fitting must reduce `n_ctx`.
228 pub n_ctx_min: u32,
229 /// Minimum log level printed during fitting.
230 pub log_level: llama_cpp_sys_4::ggml_log_level,
231}
232
233impl Default for FitParams {
234 fn default() -> Self {
235 let nd = max_devices();
236 Self {
237 model_params: LlamaModelParams::default(),
238 context_params: LlamaContextParams::default().with_n_ctx(None),
239 margins: vec![DEFAULT_MARGIN_BYTES; nd],
240 n_ctx_min: 4096,
241 log_level: llama_cpp_sys_4::GGML_LOG_LEVEL_ERROR,
242 }
243 }
244}
245
246impl FitParams {
247 /// Override starting model parameters.
248 #[must_use]
249 pub fn with_model_params(mut self, model_params: LlamaModelParams) -> Self {
250 self.model_params = model_params;
251 self
252 }
253
254 /// Override starting context parameters.
255 #[must_use]
256 pub fn with_context_params(mut self, context_params: LlamaContextParams) -> Self {
257 self.context_params = context_params;
258 self
259 }
260
261 /// Per-device memory margins in bytes (length must be at least [`max_devices`]`()`).
262 #[must_use]
263 pub fn with_margins(mut self, margins: Vec<usize>) -> Self {
264 self.margins = margins;
265 self
266 }
267
268 /// Minimum context size when fitting reduces memory by shrinking `n_ctx`.
269 #[must_use]
270 pub fn with_n_ctx_min(mut self, n_ctx_min: u32) -> Self {
271 self.n_ctx_min = n_ctx_min;
272 self
273 }
274
275 /// Minimum log level emitted while fitting.
276 #[must_use]
277 pub fn with_log_level(mut self, log_level: llama_cpp_sys_4::ggml_log_level) -> Self {
278 self.log_level = log_level;
279 self
280 }
281}
282
283/// Fitted model/context parameters plus auxiliary buffers.
284///
285/// [`LlamaModelParams`] inside this struct may point at [`Self::tensor_split`] and
286/// its internal tensor buffer-type override storage; keep the whole `FitParamsResult` alive while
287/// loading a model with these parameters.
288#[derive(Debug)]
289pub struct FitParamsResult {
290 /// Model parameters after fitting (`n_gpu_layers`, tensor split, …).
291 pub model_params: LlamaModelParams,
292 /// Context parameters after fitting (`n_ctx`, …).
293 pub context_params: LlamaContextParams,
294 /// Layer split ratios per device (writable buffer passed to llama.cpp).
295 pub tensor_split: Vec<f32>,
296 /// Tensor buffer-type overrides written by fitting (keeps pointers valid).
297 #[allow(dead_code)]
298 pub(crate) tensor_buft_overrides: Vec<llama_cpp_sys_4::llama_model_tensor_buft_override>,
299 /// Per-device memory margins used during fitting (bytes).
300 pub margins: Vec<usize>,
301}
302
303impl FitParamsResult {
304 /// Tensor split values for active devices, trimming trailing zeros.
305 #[must_use]
306 pub fn active_tensor_split(&self) -> &[f32] {
307 let mut nd = self.tensor_split.len();
308 while nd > 1 && self.tensor_split[nd - 1] == 0.0 {
309 nd -= 1;
310 }
311 &self.tensor_split[..nd]
312 }
313}
314
315/// Errors from [`fit_params`].
316#[derive(Debug, Error, PartialEq, Eq)]
317pub enum FitParamsError {
318 /// The model path could not be encoded as a C string.
319 #[error("invalid model path")]
320 InvalidPath,
321 /// Fitting could not find allocations that fit device memory.
322 #[error("could not fit parameters to available device memory")]
323 CouldNotFit,
324 /// A hard error occurred (e.g. model file missing).
325 #[error("parameter fitting failed")]
326 Failed,
327}
328
329/// Adjust model and context parameters to fit available device memory.
330///
331/// Wraps `common_fit_params`. Requires an initialized [`LlamaBackend`]. The model
332/// is probed with `no_alloc` internally; nothing is kept loaded on return.
333///
334/// Only model fields still equal to [`LlamaModelParams::default`] are modified
335/// (except `n_gpu_layers` on macOS where the default is `-1`). Context `n_ctx`
336/// is adjusted only when it is `0` — use [`LlamaContextParams::with_n_ctx`] with `None`.
337///
338/// # Errors
339///
340/// Returns [`FitParamsError::InvalidPath`] for bad paths,
341/// [`FitParamsError::CouldNotFit`] when no allocation fits, and
342/// [`FitParamsError::Failed`] on hard errors (missing model, incompatible params, …).
343pub fn fit_params(
344 _backend: &LlamaBackend,
345 path_model: &Path,
346 options: FitParams,
347) -> Result<FitParamsResult, FitParamsError> {
348 let path = CString::new(path_model.to_string_lossy().as_ref())
349 .map_err(|_| FitParamsError::InvalidPath)?;
350
351 let nd = max_devices();
352 let mut tensor_split = vec![0.0_f32; nd];
353
354 let ntbo = max_tensor_buft_overrides();
355 let mut tensor_buft_overrides = vec![
356 llama_cpp_sys_4::llama_model_tensor_buft_override {
357 pattern: null(),
358 buft: null_mut(),
359 };
360 ntbo + 1
361 ];
362
363 let mut margins = options.margins;
364 if margins.len() < nd {
365 margins.resize(nd, DEFAULT_MARGIN_BYTES);
366 }
367
368 let mut model_params = options.model_params;
369 model_params.params.tensor_split = tensor_split.as_mut_ptr();
370 model_params.params.tensor_buft_overrides = tensor_buft_overrides.as_mut_ptr();
371
372 let mut context_params = options.context_params;
373
374 let status = unsafe {
375 llama_cpp_sys_4::common_fit_params(
376 path.as_ptr(),
377 &raw mut model_params.params,
378 &raw mut context_params.context_params,
379 tensor_split.as_mut_ptr(),
380 tensor_buft_overrides.as_mut_ptr(),
381 margins.as_mut_ptr(),
382 options.n_ctx_min,
383 options.log_level,
384 )
385 };
386
387 match status {
388 llama_cpp_sys_4::COMMON_PARAMS_FIT_STATUS_SUCCESS => {
389 model_params.params.tensor_split = tensor_split.as_mut_ptr();
390 model_params.params.tensor_buft_overrides = tensor_buft_overrides.as_mut_ptr();
391 Ok(FitParamsResult {
392 model_params,
393 context_params,
394 tensor_split,
395 tensor_buft_overrides,
396 margins,
397 })
398 }
399 llama_cpp_sys_4::COMMON_PARAMS_FIT_STATUS_FAILURE => Err(FitParamsError::CouldNotFit),
400 _ => Err(FitParamsError::Failed),
401 }
402}