use crate::model::params::kv_overrides::KvOverrides;
use crate::LlamaCppError;
use std::ffi::{c_char, CStr};
use std::fmt::{Debug, Formatter};
use std::pin::Pin;
use std::ptr::null;
pub mod kv_overrides;
#[allow(clippy::cast_possible_wrap)]
#[allow(clippy::cast_possible_truncation)]
const LLAMA_SPLIT_MODE_NONE: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_NONE as i8;
#[allow(clippy::cast_possible_wrap)]
#[allow(clippy::cast_possible_truncation)]
const LLAMA_SPLIT_MODE_LAYER: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_LAYER as i8;
#[allow(clippy::cast_possible_wrap)]
#[allow(clippy::cast_possible_truncation)]
const LLAMA_SPLIT_MODE_ROW: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_ROW as i8;
#[allow(clippy::cast_possible_wrap)]
#[allow(clippy::cast_possible_truncation)]
const LLAMA_SPLIT_MODE_TENSOR: i8 = llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR as i8;
#[repr(i8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum LlamaSplitMode {
None = LLAMA_SPLIT_MODE_NONE,
Layer = LLAMA_SPLIT_MODE_LAYER,
Row = LLAMA_SPLIT_MODE_ROW,
Tensor = LLAMA_SPLIT_MODE_TENSOR,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LlamaSplitModeParseError(pub i32);
impl TryFrom<i32> for LlamaSplitMode {
type Error = LlamaSplitModeParseError;
fn try_from(value: i32) -> Result<Self, Self::Error> {
let i8_value = value
.try_into()
.map_err(|_| LlamaSplitModeParseError(value))?;
match i8_value {
LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
_ => Err(LlamaSplitModeParseError(value)),
}
}
}
impl TryFrom<u32> for LlamaSplitMode {
type Error = LlamaSplitModeParseError;
fn try_from(value: u32) -> Result<Self, Self::Error> {
let i8_value = value
.try_into()
.map_err(|_| LlamaSplitModeParseError(value.try_into().unwrap_or(i32::MAX)))?;
match i8_value {
LLAMA_SPLIT_MODE_NONE => Ok(Self::None),
LLAMA_SPLIT_MODE_LAYER => Ok(Self::Layer),
LLAMA_SPLIT_MODE_ROW => Ok(Self::Row),
LLAMA_SPLIT_MODE_TENSOR => Ok(Self::Tensor),
_ => Err(LlamaSplitModeParseError(
value.try_into().unwrap_or(i32::MAX),
)),
}
}
}
impl From<LlamaSplitMode> for i32 {
fn from(value: LlamaSplitMode) -> Self {
match value {
LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE.into(),
LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER.into(),
LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW.into(),
LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR.into(),
}
}
}
impl From<LlamaSplitMode> for u32 {
fn from(value: LlamaSplitMode) -> Self {
match value {
LlamaSplitMode::None => LLAMA_SPLIT_MODE_NONE as u32,
LlamaSplitMode::Layer => LLAMA_SPLIT_MODE_LAYER as u32,
LlamaSplitMode::Row => LLAMA_SPLIT_MODE_ROW as u32,
LlamaSplitMode::Tensor => LLAMA_SPLIT_MODE_TENSOR as u32,
}
}
}
impl Default for LlamaSplitMode {
fn default() -> Self {
LlamaSplitMode::Layer
}
}
pub const LLAMA_CPP_MAX_DEVICES: usize = 16;
#[allow(clippy::module_name_repetitions)]
pub struct LlamaModelParams {
pub(crate) params: llama_cpp_sys_2::llama_model_params,
kv_overrides: Vec<llama_cpp_sys_2::llama_model_kv_override>,
buft_overrides: Vec<llama_cpp_sys_2::llama_model_tensor_buft_override>,
devices: Pin<Box<[llama_cpp_sys_2::ggml_backend_dev_t; LLAMA_CPP_MAX_DEVICES]>>,
}
impl Debug for LlamaModelParams {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlamaModelParams")
.field("n_gpu_layers", &self.params.n_gpu_layers)
.field("main_gpu", &self.params.main_gpu)
.field("vocab_only", &self.params.vocab_only)
.field("use_mmap", &self.params.use_mmap)
.field("use_mlock", &self.params.use_mlock)
.field("split_mode", &self.split_mode())
.field("devices", &self.devices)
.field("kv_overrides", &"vec of kv_overrides")
.finish()
}
}
impl LlamaModelParams {
#[must_use]
pub fn kv_overrides<'a>(&'a self) -> KvOverrides<'a> {
KvOverrides::new(self)
}
#[allow(clippy::missing_panics_doc)] pub fn append_kv_override(
mut self: Pin<&mut Self>,
key: &CStr,
value: kv_overrides::ParamOverrideValue,
) {
let kv_override = self
.kv_overrides
.get_mut(0)
.expect("kv_overrides did not have a next allocated");
assert_eq!(kv_override.key[0], 0, "last kv_override was not empty");
for (i, &c) in key.to_bytes_with_nul().iter().enumerate() {
kv_override.key[i] = c_char::try_from(c).expect("invalid character in key");
}
kv_override.tag = value.tag();
kv_override.__bindgen_anon_1 = value.value();
self.params.kv_overrides = null();
self.kv_overrides
.push(llama_cpp_sys_2::llama_model_kv_override {
key: [0; 128],
tag: 0,
__bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
val_i64: 0,
},
});
self.params.kv_overrides = self.kv_overrides.as_ptr();
eprintln!("saved ptr: {:?}", self.params.kv_overrides);
}
}
impl LlamaModelParams {
pub fn add_cpu_moe_override(self: Pin<&mut Self>) {
self.add_cpu_buft_override(c"\\.ffn_(up|down|gate)_(ch|)exps");
}
pub fn add_cpu_buft_override(mut self: Pin<&mut Self>, key: &CStr) {
let buft_override = self
.buft_overrides
.get_mut(0)
.expect("buft_overrides did not have a next allocated");
assert!(
buft_override.pattern.is_null(),
"last buft_override was not empty"
);
for &c in key.to_bytes_with_nul().iter() {
c_char::try_from(c).expect("invalid character in key");
}
buft_override.pattern = key.as_ptr();
buft_override.buft = unsafe { llama_cpp_sys_2::ggml_backend_cpu_buffer_type() };
self.params.tensor_buft_overrides = null();
self.buft_overrides
.push(llama_cpp_sys_2::llama_model_tensor_buft_override {
pattern: std::ptr::null(),
buft: std::ptr::null_mut(),
});
self.params.tensor_buft_overrides = self.buft_overrides.as_ptr();
}
}
impl LlamaModelParams {
#[must_use]
pub fn n_gpu_layers(&self) -> i32 {
self.params.n_gpu_layers
}
#[must_use]
pub fn main_gpu(&self) -> i32 {
self.params.main_gpu
}
#[must_use]
pub fn vocab_only(&self) -> bool {
self.params.vocab_only
}
#[must_use]
pub fn use_mmap(&self) -> bool {
self.params.use_mmap
}
#[must_use]
pub fn use_mlock(&self) -> bool {
self.params.use_mlock
}
pub fn split_mode(&self) -> Result<LlamaSplitMode, LlamaSplitModeParseError> {
LlamaSplitMode::try_from(self.params.split_mode)
}
#[must_use]
pub fn devices(&self) -> Vec<usize> {
let mut backend_devices = Vec::new();
for i in 0..unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
let dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(i) };
backend_devices.push(dev);
}
let mut devices = Vec::new();
for &dev in self.devices.iter() {
if dev.is_null() {
break;
}
if let Some((index, _)) = backend_devices
.iter()
.enumerate()
.find(|&(_i, &d)| d == dev)
{
devices.push(index);
}
}
devices
}
#[must_use]
pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self {
let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX);
self.params.n_gpu_layers = n_gpu_layers;
self
}
#[must_use]
pub fn with_main_gpu(mut self, main_gpu: i32) -> Self {
self.params.main_gpu = main_gpu;
self
}
#[must_use]
pub fn with_vocab_only(mut self, vocab_only: bool) -> Self {
self.params.vocab_only = vocab_only;
self
}
#[must_use]
pub fn with_use_mmap(mut self, use_mmap: bool) -> Self {
self.params.use_mmap = use_mmap;
self
}
#[must_use]
pub fn with_use_mlock(mut self, use_mlock: bool) -> Self {
self.params.use_mlock = use_mlock;
self
}
#[must_use]
pub fn with_split_mode(mut self, split_mode: LlamaSplitMode) -> Self {
self.params.split_mode = split_mode.into();
self
}
pub fn with_devices(mut self, devices: &[usize]) -> Result<Self, LlamaCppError> {
for dev in self.devices.iter_mut() {
*dev = std::ptr::null_mut();
}
let max_devices = crate::max_devices().min(LLAMA_CPP_MAX_DEVICES);
if devices.len() > max_devices {
return Err(LlamaCppError::MaxDevicesExceeded(max_devices));
}
for (i, &dev) in devices.iter().enumerate() {
if dev >= unsafe { llama_cpp_sys_2::ggml_backend_dev_count() } {
return Err(LlamaCppError::BackendDeviceNotFound(dev));
}
let backend_dev = unsafe { llama_cpp_sys_2::ggml_backend_dev_get(dev) };
self.devices[i] = backend_dev;
}
if self.devices.is_empty() {
self.params.devices = std::ptr::null_mut();
} else {
self.params.devices = self.devices.as_mut_ptr();
}
Ok(self)
}
#[must_use]
pub fn with_no_alloc(mut self, no_alloc: bool) -> Self {
self.params.no_alloc = no_alloc;
if no_alloc {
self = self.with_use_mmap(false);
}
self
}
#[must_use]
pub fn no_alloc(&self) -> bool {
self.params.no_alloc
}
}
impl Default for LlamaModelParams {
fn default() -> Self {
let default_params = unsafe { llama_cpp_sys_2::llama_model_default_params() };
LlamaModelParams {
params: default_params,
kv_overrides: vec![llama_cpp_sys_2::llama_model_kv_override {
key: [0; 128],
tag: 0,
__bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
val_i64: 0,
},
}],
buft_overrides: vec![llama_cpp_sys_2::llama_model_tensor_buft_override {
pattern: std::ptr::null(),
buft: std::ptr::null_mut(),
}],
devices: Box::pin([std::ptr::null_mut(); 16]),
}
}
}
#[cfg(test)]
mod tests {
use super::LlamaSplitMode;
#[test]
fn tensor_split_mode_round_trips() {
assert_eq!(
LlamaSplitMode::try_from(llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR),
Ok(LlamaSplitMode::Tensor)
);
assert_eq!(
u32::from(LlamaSplitMode::Tensor),
llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR as u32
);
assert_eq!(
i32::from(LlamaSplitMode::Tensor),
llama_cpp_sys_2::LLAMA_SPLIT_MODE_TENSOR as i32
);
}
}