use std::ffi::CString;
use std::path::Path;
use std::ptr::{null, null_mut};
use thiserror::Error;
use crate::context::params::LlamaContextParams;
use crate::llama_backend::LlamaBackend;
use crate::model::params::LlamaModelParams;
use crate::{max_devices, max_tensor_buft_overrides};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeviceMemoryEntry {
pub total: i64,
pub free: i64,
pub model: usize,
pub context: usize,
pub compute: usize,
}
impl DeviceMemoryEntry {
#[must_use]
pub fn used(&self) -> usize {
self.model + self.context + self.compute
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DeviceMemoryHyperParams {
pub n_gpu_layers: u32,
pub n_ctx_train: u32,
pub n_expert: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeviceMemoryReport {
pub entries: Vec<DeviceMemoryEntry>,
pub hyperparams: DeviceMemoryHyperParams,
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum DeviceMemoryError {
#[error("invalid model path")]
InvalidPath,
#[error("device memory query failed")]
QueryFailed,
#[error("device memory entry buffer overflow")]
BufferOverflow,
}
pub fn get_device_memory_data(
path_model: &Path,
mparams: &LlamaModelParams,
cparams: &LlamaContextParams,
log_level: llama_cpp_sys_4::ggml_log_level,
) -> Result<DeviceMemoryReport, DeviceMemoryError> {
let path = CString::new(path_model.to_string_lossy().as_ref())
.map_err(|_| DeviceMemoryError::InvalidPath)?;
let mparams = mparams.params;
let cparams = cparams.context_params;
let mut capacity = 8usize;
loop {
let mut raw = vec![
llama_cpp_sys_4::common_device_memory_flat_entry {
total: 0,
free: 0,
model: 0,
context: 0,
compute: 0,
};
capacity
];
let mut hp_ngl = 0u32;
let mut hp_nct = 0u32;
let mut hp_nex = 0u32;
let n = unsafe {
llama_cpp_sys_4::common_device_memory_collect(
path.as_ptr(),
&raw const mparams,
&raw const cparams,
log_level,
raw.as_mut_ptr(),
capacity,
&raw mut hp_ngl,
&raw mut hp_nct,
&raw mut hp_nex,
)
};
if n == usize::MAX {
return Err(DeviceMemoryError::QueryFailed);
}
if n < capacity {
let entries = raw
.into_iter()
.take(n)
.map(|e| DeviceMemoryEntry {
total: e.total,
free: e.free,
model: e.model,
context: e.context,
compute: e.compute,
})
.collect();
return Ok(DeviceMemoryReport {
entries,
hyperparams: DeviceMemoryHyperParams {
n_gpu_layers: hp_ngl,
n_ctx_train: hp_nct,
n_expert: hp_nex,
},
});
}
capacity = capacity.saturating_mul(2);
if capacity > 256 {
return Err(DeviceMemoryError::BufferOverflow);
}
}
}
const DEFAULT_MARGIN_BYTES: usize = 1024 * 1024 * 1024;
#[derive(Debug)]
pub struct FitParams {
pub model_params: LlamaModelParams,
pub context_params: LlamaContextParams,
pub margins: Vec<usize>,
pub n_ctx_min: u32,
pub log_level: llama_cpp_sys_4::ggml_log_level,
}
impl Default for FitParams {
fn default() -> Self {
let nd = max_devices();
Self {
model_params: LlamaModelParams::default(),
context_params: LlamaContextParams::default().with_n_ctx(None),
margins: vec![DEFAULT_MARGIN_BYTES; nd],
n_ctx_min: 4096,
log_level: llama_cpp_sys_4::GGML_LOG_LEVEL_ERROR,
}
}
}
impl FitParams {
#[must_use]
pub fn with_model_params(mut self, model_params: LlamaModelParams) -> Self {
self.model_params = model_params;
self
}
#[must_use]
pub fn with_context_params(mut self, context_params: LlamaContextParams) -> Self {
self.context_params = context_params;
self
}
#[must_use]
pub fn with_margins(mut self, margins: Vec<usize>) -> Self {
self.margins = margins;
self
}
#[must_use]
pub fn with_n_ctx_min(mut self, n_ctx_min: u32) -> Self {
self.n_ctx_min = n_ctx_min;
self
}
#[must_use]
pub fn with_log_level(mut self, log_level: llama_cpp_sys_4::ggml_log_level) -> Self {
self.log_level = log_level;
self
}
}
#[derive(Debug)]
pub struct FitParamsResult {
pub model_params: LlamaModelParams,
pub context_params: LlamaContextParams,
pub tensor_split: Vec<f32>,
#[allow(dead_code)]
pub(crate) tensor_buft_overrides: Vec<llama_cpp_sys_4::llama_model_tensor_buft_override>,
pub margins: Vec<usize>,
}
impl FitParamsResult {
#[must_use]
pub fn active_tensor_split(&self) -> &[f32] {
let mut nd = self.tensor_split.len();
while nd > 1 && self.tensor_split[nd - 1] == 0.0 {
nd -= 1;
}
&self.tensor_split[..nd]
}
}
#[derive(Debug, Error, PartialEq, Eq)]
pub enum FitParamsError {
#[error("invalid model path")]
InvalidPath,
#[error("could not fit parameters to available device memory")]
CouldNotFit,
#[error("parameter fitting failed")]
Failed,
}
pub fn fit_params(
_backend: &LlamaBackend,
path_model: &Path,
options: FitParams,
) -> Result<FitParamsResult, FitParamsError> {
let path = CString::new(path_model.to_string_lossy().as_ref())
.map_err(|_| FitParamsError::InvalidPath)?;
let nd = max_devices();
let mut tensor_split = vec![0.0_f32; nd];
let ntbo = max_tensor_buft_overrides();
let mut tensor_buft_overrides = vec![
llama_cpp_sys_4::llama_model_tensor_buft_override {
pattern: null(),
buft: null_mut(),
};
ntbo + 1
];
let mut margins = options.margins;
if margins.len() < nd {
margins.resize(nd, DEFAULT_MARGIN_BYTES);
}
let mut model_params = options.model_params;
model_params.params.tensor_split = tensor_split.as_mut_ptr();
model_params.params.tensor_buft_overrides = tensor_buft_overrides.as_mut_ptr();
let mut context_params = options.context_params;
let status = unsafe {
llama_cpp_sys_4::common_fit_params(
path.as_ptr(),
&raw mut model_params.params,
&raw mut context_params.context_params,
tensor_split.as_mut_ptr(),
tensor_buft_overrides.as_mut_ptr(),
margins.as_mut_ptr(),
options.n_ctx_min,
options.log_level,
)
};
match status {
llama_cpp_sys_4::COMMON_PARAMS_FIT_STATUS_SUCCESS => {
model_params.params.tensor_split = tensor_split.as_mut_ptr();
model_params.params.tensor_buft_overrides = tensor_buft_overrides.as_mut_ptr();
Ok(FitParamsResult {
model_params,
context_params,
tensor_split,
tensor_buft_overrides,
margins,
})
}
llama_cpp_sys_4::COMMON_PARAMS_FIT_STATUS_FAILURE => Err(FitParamsError::CouldNotFit),
_ => Err(FitParamsError::Failed),
}
}