use crate::model::params::LlamaModelParams;
use std::ffi::{CStr, CString};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ParamOverrideValue {
Bool(bool),
Float(f64),
Int(i64),
}
impl ParamOverrideValue {
pub(crate) fn tag(&self) -> llama_cpp_sys_2::llama_model_kv_override_type {
match self {
ParamOverrideValue::Bool(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL,
ParamOverrideValue::Float(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
ParamOverrideValue::Int(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT,
}
}
pub(crate) fn value(&self) -> llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
match self {
ParamOverrideValue::Bool(value) => {
llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { bool_value: *value }
}
ParamOverrideValue::Float(value) => {
llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
float_value: *value,
}
}
ParamOverrideValue::Int(value) => {
llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { int_value: *value }
}
}
}
}
impl From<&llama_cpp_sys_2::llama_model_kv_override> for ParamOverrideValue {
fn from(
llama_cpp_sys_2::llama_model_kv_override {
key: _,
tag,
__bindgen_anon_1,
}: &llama_cpp_sys_2::llama_model_kv_override,
) -> Self {
match *tag {
llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT => {
ParamOverrideValue::Int(unsafe { __bindgen_anon_1.int_value })
}
llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
ParamOverrideValue::Float(unsafe { __bindgen_anon_1.float_value })
}
llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
ParamOverrideValue::Bool(unsafe { __bindgen_anon_1.bool_value })
}
_ => unreachable!("Unknown tag of {tag}"),
}
}
}
#[derive(Debug)]
pub struct KvOverrides<'a> {
model_params: &'a LlamaModelParams,
}
impl KvOverrides<'_> {
pub(super) fn new(model_params: &LlamaModelParams) -> KvOverrides {
KvOverrides { model_params }
}
}
impl<'a> IntoIterator for KvOverrides<'a> {
type Item = (CString, ParamOverrideValue);
type IntoIter = KvOverrideValueIterator<'a>;
fn into_iter(self) -> Self::IntoIter {
KvOverrideValueIterator {
model_params: self.model_params,
current: 0,
}
}
}
#[derive(Debug)]
pub struct KvOverrideValueIterator<'a> {
model_params: &'a LlamaModelParams,
current: usize,
}
impl<'a> Iterator for KvOverrideValueIterator<'a> {
type Item = (CString, ParamOverrideValue);
fn next(&mut self) -> Option<Self::Item> {
let overrides = self.model_params.params.kv_overrides;
if overrides.is_null() {
return None;
}
let current = unsafe { *overrides.add(self.current) };
if current.key[0] == 0 {
return None;
}
let value = ParamOverrideValue::from(¤t);
let key = unsafe { CStr::from_ptr(current.key.as_ptr()).to_owned() };
self.current += 1;
Some((key, value))
}
}