llama_cpp_2/model/params/
kv_overrides.rs

1//! Key-value overrides for a model.
2
3use crate::model::params::LlamaModelParams;
4use std::ffi::{CStr, CString};
5use std::fmt::Debug;
6
7/// An override value for a model parameter.
8#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum ParamOverrideValue {
10    /// A string value
11    Bool(bool),
12    /// A float value
13    Float(f64),
14    /// A integer value
15    Int(i64),
16    /// A string value
17    Str([std::os::raw::c_char; 128]),
18}
19
20impl ParamOverrideValue {
21    pub(crate) fn tag(&self) -> llama_cpp_sys_2::llama_model_kv_override_type {
22        match self {
23            ParamOverrideValue::Bool(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL,
24            ParamOverrideValue::Float(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
25            ParamOverrideValue::Int(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT,
26            ParamOverrideValue::Str(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_STR,
27        }
28    }
29
30    pub(crate) fn value(&self) -> llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
31        match self {
32            ParamOverrideValue::Bool(value) => {
33                llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_bool: *value }
34            }
35            ParamOverrideValue::Float(value) => {
36                llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_f64: *value }
37            }
38            ParamOverrideValue::Int(value) => {
39                llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_i64: *value }
40            }
41            ParamOverrideValue::Str(c_string) => {
42                llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string }
43            }
44        }
45    }
46}
47
48impl From<&llama_cpp_sys_2::llama_model_kv_override> for ParamOverrideValue {
49    fn from(
50        llama_cpp_sys_2::llama_model_kv_override {
51            key: _,
52            tag,
53            __bindgen_anon_1,
54        }: &llama_cpp_sys_2::llama_model_kv_override,
55    ) -> Self {
56        match *tag {
57            llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT => {
58                ParamOverrideValue::Int(unsafe { __bindgen_anon_1.val_i64 })
59            }
60            llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
61                ParamOverrideValue::Float(unsafe { __bindgen_anon_1.val_f64 })
62            }
63            llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
64                ParamOverrideValue::Bool(unsafe { __bindgen_anon_1.val_bool })
65            }
66            llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_STR => {
67                ParamOverrideValue::Str(unsafe { __bindgen_anon_1.val_str })
68            }
69            _ => unreachable!("Unknown tag of {tag}"),
70        }
71    }
72}
73
74/// A struct implementing [`IntoIterator`] over the key-value overrides for a model.
75#[derive(Debug)]
76pub struct KvOverrides<'a> {
77    model_params: &'a LlamaModelParams,
78}
79
80impl KvOverrides<'_> {
81    pub(super) fn new(model_params: &LlamaModelParams) -> KvOverrides {
82        KvOverrides { model_params }
83    }
84}
85
86impl<'a> IntoIterator for KvOverrides<'a> {
87    // I'm fairly certain this could be written returning by reference, but I'm not sure how to do it safely. I do not
88    // expect this to be a performance bottleneck so the copy should be fine. (let me know if it's not fine!)
89    type Item = (CString, ParamOverrideValue);
90    type IntoIter = KvOverrideValueIterator<'a>;
91
92    fn into_iter(self) -> Self::IntoIter {
93        KvOverrideValueIterator {
94            model_params: self.model_params,
95            current: 0,
96        }
97    }
98}
99
100/// An iterator over the key-value overrides for a model.
101#[derive(Debug)]
102pub struct KvOverrideValueIterator<'a> {
103    model_params: &'a LlamaModelParams,
104    current: usize,
105}
106
107impl Iterator for KvOverrideValueIterator<'_> {
108    type Item = (CString, ParamOverrideValue);
109
110    fn next(&mut self) -> Option<Self::Item> {
111        let overrides = self.model_params.params.kv_overrides;
112        if overrides.is_null() {
113            return None;
114        }
115
116        // SAFETY: llama.cpp seems to guarantee that the last element contains an empty key or is valid. We've checked
117        // the prev one in the last iteration, the next one should be valid or 0 (and thus safe to deref)
118        let current = unsafe { *overrides.add(self.current) };
119
120        if current.key[0] == 0 {
121            return None;
122        }
123
124        let value = ParamOverrideValue::from(&current);
125
126        let key = unsafe { CStr::from_ptr(current.key.as_ptr()).to_owned() };
127
128        self.current += 1;
129        Some((key, value))
130    }
131}