Skip to main content

llama_cpp_4/
fit.rs

1//! Memory estimation and parameter fitting from llama.cpp `common/fit`.
2//!
3//! - [`get_device_memory_data`] — project per-device memory for a parameter set
4//!   without keeping a context alive.
5//! - [`fit_params`] — adjust [`LlamaModelParams`] / [`LlamaContextParams`] to fit
6//!   available device memory (upstream `common_fit_params`).
7//!
8//! # Example — memory estimate
9//!
10//! ```no_run
11//! use llama_cpp_4::prelude::*;
12//! use std::path::Path;
13//!
14//! fn main() {
15//!     let _backend = LlamaBackend::init().unwrap();
16//!     let report = get_device_memory_data(
17//!         Path::new("model.gguf"),
18//!         &LlamaModelParams::default().with_n_gpu_layers(99),
19//!         &LlamaContextParams::default(),
20//!         llama_cpp_sys_4::GGML_LOG_LEVEL_ERROR,
21//!     )
22//!     .unwrap();
23//!
24//!     println!("training ctx: {}", report.hyperparams.n_ctx_train);
25//!     for (i, entry) in report.entries.iter().enumerate() {
26//!         println!(
27//!             "device {i}: {} bytes free / {} total (projected {})",
28//!             entry.free,
29//!             entry.total,
30//!             entry.used(),
31//!         );
32//!     }
33//! }
34//! ```
35//!
36//! # Example — auto-fit parameters
37//!
38//! ```no_run
39//! use llama_cpp_4::fit::{fit_params, FitParams};
40//! use llama_cpp_4::prelude::*;
41//! use std::path::Path;
42//!
43//! fn main() {
44//!     let backend = LlamaBackend::init().unwrap();
45//!     let result = fit_params(
46//!         &backend,
47//!         Path::new("model.gguf"),
48//!         FitParams::default().with_n_ctx_min(512),
49//!     )
50//!     .unwrap();
51//!
52//!     use std::num::NonZeroU32;
53//!
54//!     println!("n_ctx: {}", result.context_params.n_ctx().map_or(0, NonZeroU32::get));
55//!     println!("n_gpu_layers: {}", result.model_params.n_gpu_layers());
56//! }
57//! ```
58
59use std::ffi::CString;
60use std::path::Path;
61use std::ptr::{null, null_mut};
62
63use thiserror::Error;
64
65use crate::context::params::LlamaContextParams;
66use crate::llama_backend::LlamaBackend;
67use crate::model::params::LlamaModelParams;
68use crate::{max_devices, max_tensor_buft_overrides};
69
70/// Per-device memory projection from [`get_device_memory_data`].
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct DeviceMemoryEntry {
73    /// Total device memory in bytes.
74    pub total: i64,
75    /// Free device memory in bytes at query time.
76    pub free: i64,
77    /// Projected model weight bytes on this device.
78    pub model: usize,
79    /// Projected KV / recurrent cache bytes.
80    pub context: usize,
81    /// Projected temporary compute buffer bytes.
82    pub compute: usize,
83}
84
85impl DeviceMemoryEntry {
86    /// Sum of model, context, and compute bytes.
87    #[must_use]
88    pub fn used(&self) -> usize {
89        self.model + self.context + self.compute
90    }
91}
92
93/// Hyper-parameters discovered while estimating device memory.
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
95pub struct DeviceMemoryHyperParams {
96    /// Model `n_gpu_layers` hyper-parameter used for the estimate.
97    pub n_gpu_layers: u32,
98    /// Model training context length.
99    pub n_ctx_train: u32,
100    /// Number of `MoE` experts (`0` when dense).
101    pub n_expert: u32,
102}
103
104/// Result of [`get_device_memory_data`].
105#[derive(Debug, Clone, PartialEq, Eq)]
106pub struct DeviceMemoryReport {
107    /// Per-device memory breakdown (one entry per backend device).
108    pub entries: Vec<DeviceMemoryEntry>,
109    /// Hyper-parameters read from the checkpoint during estimation.
110    pub hyperparams: DeviceMemoryHyperParams,
111}
112
113/// Errors from [`get_device_memory_data`].
114#[derive(Debug, Error, PartialEq, Eq)]
115pub enum DeviceMemoryError {
116    /// The model path could not be encoded as a C string.
117    #[error("invalid model path")]
118    InvalidPath,
119    /// The underlying C++ helper failed (model missing, incompatible params, …).
120    #[error("device memory query failed")]
121    QueryFailed,
122    /// More devices were reported than the internal buffer allows.
123    #[error("device memory entry buffer overflow")]
124    BufferOverflow,
125}
126
127/// Estimate per-device memory for a model path and parameter set.
128///
129/// This wraps `common_get_device_memory_data` through `ext_shim`. The model is
130/// loaded with `no_alloc` and freed before returning; no context is kept alive.
131///
132/// # Errors
133///
134/// Returns [`DeviceMemoryError`] when the path is invalid or llama.cpp cannot
135/// produce an estimate.
136pub fn get_device_memory_data(
137    path_model: &Path,
138    mparams: &LlamaModelParams,
139    cparams: &LlamaContextParams,
140    log_level: llama_cpp_sys_4::ggml_log_level,
141) -> Result<DeviceMemoryReport, DeviceMemoryError> {
142    let path = CString::new(path_model.to_string_lossy().as_ref())
143        .map_err(|_| DeviceMemoryError::InvalidPath)?;
144
145    let mparams = mparams.params;
146    let cparams = cparams.context_params;
147
148    let mut capacity = 8usize;
149    loop {
150        let mut raw = vec![
151            llama_cpp_sys_4::common_device_memory_flat_entry {
152                total: 0,
153                free: 0,
154                model: 0,
155                context: 0,
156                compute: 0,
157            };
158            capacity
159        ];
160        let mut hp_ngl = 0u32;
161        let mut hp_nct = 0u32;
162        let mut hp_nex = 0u32;
163
164        let n = unsafe {
165            llama_cpp_sys_4::common_device_memory_collect(
166                path.as_ptr(),
167                &raw const mparams,
168                &raw const cparams,
169                log_level,
170                raw.as_mut_ptr(),
171                capacity,
172                &raw mut hp_ngl,
173                &raw mut hp_nct,
174                &raw mut hp_nex,
175            )
176        };
177
178        if n == usize::MAX {
179            return Err(DeviceMemoryError::QueryFailed);
180        }
181
182        if n < capacity {
183            let entries = raw
184                .into_iter()
185                .take(n)
186                .map(|e| DeviceMemoryEntry {
187                    total: e.total,
188                    free: e.free,
189                    model: e.model,
190                    context: e.context,
191                    compute: e.compute,
192                })
193                .collect();
194            return Ok(DeviceMemoryReport {
195                entries,
196                hyperparams: DeviceMemoryHyperParams {
197                    n_gpu_layers: hp_ngl,
198                    n_ctx_train: hp_nct,
199                    n_expert: hp_nex,
200                },
201            });
202        }
203
204        capacity = capacity.saturating_mul(2);
205        if capacity > 256 {
206            return Err(DeviceMemoryError::BufferOverflow);
207        }
208    }
209}
210
211const DEFAULT_MARGIN_BYTES: usize = 1024 * 1024 * 1024;
212
213/// Input to [`fit_params`].
214///
215/// Defaults mirror upstream `common_params`: unset `n_ctx` (`0`) so context size
216/// can be reduced, default model params so `n_gpu_layers` may be adjusted, and
217/// 1 GiB per-device memory margins.
218#[derive(Debug)]
219pub struct FitParams {
220    /// Starting model parameters. Only fields still at their defaults are modified.
221    pub model_params: LlamaModelParams,
222    /// Starting context parameters. Set `n_ctx` to `0` via
223    /// [`LlamaContextParams::with_n_ctx`]`(None)` to let fitting pick a context size.
224    pub context_params: LlamaContextParams,
225    /// Minimum free memory to leave on each device, in bytes (one entry per device).
226    pub margins: Vec<usize>,
227    /// Minimum context size when fitting must reduce `n_ctx`.
228    pub n_ctx_min: u32,
229    /// Minimum log level printed during fitting.
230    pub log_level: llama_cpp_sys_4::ggml_log_level,
231}
232
233impl Default for FitParams {
234    fn default() -> Self {
235        let nd = max_devices();
236        Self {
237            model_params: LlamaModelParams::default(),
238            context_params: LlamaContextParams::default().with_n_ctx(None),
239            margins: vec![DEFAULT_MARGIN_BYTES; nd],
240            n_ctx_min: 4096,
241            log_level: llama_cpp_sys_4::GGML_LOG_LEVEL_ERROR,
242        }
243    }
244}
245
246impl FitParams {
247    /// Override starting model parameters.
248    #[must_use]
249    pub fn with_model_params(mut self, model_params: LlamaModelParams) -> Self {
250        self.model_params = model_params;
251        self
252    }
253
254    /// Override starting context parameters.
255    #[must_use]
256    pub fn with_context_params(mut self, context_params: LlamaContextParams) -> Self {
257        self.context_params = context_params;
258        self
259    }
260
261    /// Per-device memory margins in bytes (length must be at least [`max_devices`]`()`).
262    #[must_use]
263    pub fn with_margins(mut self, margins: Vec<usize>) -> Self {
264        self.margins = margins;
265        self
266    }
267
268    /// Minimum context size when fitting reduces memory by shrinking `n_ctx`.
269    #[must_use]
270    pub fn with_n_ctx_min(mut self, n_ctx_min: u32) -> Self {
271        self.n_ctx_min = n_ctx_min;
272        self
273    }
274
275    /// Minimum log level emitted while fitting.
276    #[must_use]
277    pub fn with_log_level(mut self, log_level: llama_cpp_sys_4::ggml_log_level) -> Self {
278        self.log_level = log_level;
279        self
280    }
281}
282
283/// Fitted model/context parameters plus auxiliary buffers.
284///
285/// [`LlamaModelParams`] inside this struct may point at [`Self::tensor_split`] and
286/// its internal tensor buffer-type override storage; keep the whole `FitParamsResult` alive while
287/// loading a model with these parameters.
288#[derive(Debug)]
289pub struct FitParamsResult {
290    /// Model parameters after fitting (`n_gpu_layers`, tensor split, …).
291    pub model_params: LlamaModelParams,
292    /// Context parameters after fitting (`n_ctx`, …).
293    pub context_params: LlamaContextParams,
294    /// Layer split ratios per device (writable buffer passed to llama.cpp).
295    pub tensor_split: Vec<f32>,
296    /// Tensor buffer-type overrides written by fitting (keeps pointers valid).
297    #[allow(dead_code)]
298    pub(crate) tensor_buft_overrides: Vec<llama_cpp_sys_4::llama_model_tensor_buft_override>,
299    /// Per-device memory margins used during fitting (bytes).
300    pub margins: Vec<usize>,
301}
302
303impl FitParamsResult {
304    /// Tensor split values for active devices, trimming trailing zeros.
305    #[must_use]
306    pub fn active_tensor_split(&self) -> &[f32] {
307        let mut nd = self.tensor_split.len();
308        while nd > 1 && self.tensor_split[nd - 1] == 0.0 {
309            nd -= 1;
310        }
311        &self.tensor_split[..nd]
312    }
313}
314
315/// Errors from [`fit_params`].
316#[derive(Debug, Error, PartialEq, Eq)]
317pub enum FitParamsError {
318    /// The model path could not be encoded as a C string.
319    #[error("invalid model path")]
320    InvalidPath,
321    /// Fitting could not find allocations that fit device memory.
322    #[error("could not fit parameters to available device memory")]
323    CouldNotFit,
324    /// A hard error occurred (e.g. model file missing).
325    #[error("parameter fitting failed")]
326    Failed,
327}
328
329/// Adjust model and context parameters to fit available device memory.
330///
331/// Wraps `common_fit_params`. Requires an initialized [`LlamaBackend`]. The model
332/// is probed with `no_alloc` internally; nothing is kept loaded on return.
333///
334/// Only model fields still equal to [`LlamaModelParams::default`] are modified
335/// (except `n_gpu_layers` on macOS where the default is `-1`). Context `n_ctx`
336/// is adjusted only when it is `0` — use [`LlamaContextParams::with_n_ctx`] with `None`.
337///
338/// # Errors
339///
340/// Returns [`FitParamsError::InvalidPath`] for bad paths,
341/// [`FitParamsError::CouldNotFit`] when no allocation fits, and
342/// [`FitParamsError::Failed`] on hard errors (missing model, incompatible params, …).
343pub fn fit_params(
344    _backend: &LlamaBackend,
345    path_model: &Path,
346    options: FitParams,
347) -> Result<FitParamsResult, FitParamsError> {
348    let path = CString::new(path_model.to_string_lossy().as_ref())
349        .map_err(|_| FitParamsError::InvalidPath)?;
350
351    let nd = max_devices();
352    let mut tensor_split = vec![0.0_f32; nd];
353
354    let ntbo = max_tensor_buft_overrides();
355    let mut tensor_buft_overrides = vec![
356        llama_cpp_sys_4::llama_model_tensor_buft_override {
357            pattern: null(),
358            buft: null_mut(),
359        };
360        ntbo + 1
361    ];
362
363    let mut margins = options.margins;
364    if margins.len() < nd {
365        margins.resize(nd, DEFAULT_MARGIN_BYTES);
366    }
367
368    let mut model_params = options.model_params;
369    model_params.params.tensor_split = tensor_split.as_mut_ptr();
370    model_params.params.tensor_buft_overrides = tensor_buft_overrides.as_mut_ptr();
371
372    let mut context_params = options.context_params;
373
374    let status = unsafe {
375        llama_cpp_sys_4::common_fit_params(
376            path.as_ptr(),
377            &raw mut model_params.params,
378            &raw mut context_params.context_params,
379            tensor_split.as_mut_ptr(),
380            tensor_buft_overrides.as_mut_ptr(),
381            margins.as_mut_ptr(),
382            options.n_ctx_min,
383            options.log_level,
384        )
385    };
386
387    match status {
388        llama_cpp_sys_4::COMMON_PARAMS_FIT_STATUS_SUCCESS => {
389            model_params.params.tensor_split = tensor_split.as_mut_ptr();
390            model_params.params.tensor_buft_overrides = tensor_buft_overrides.as_mut_ptr();
391            Ok(FitParamsResult {
392                model_params,
393                context_params,
394                tensor_split,
395                tensor_buft_overrides,
396                margins,
397            })
398        }
399        llama_cpp_sys_4::COMMON_PARAMS_FIT_STATUS_FAILURE => Err(FitParamsError::CouldNotFit),
400        _ => Err(FitParamsError::Failed),
401    }
402}