Skip to main content

llama_cpp_bindings/model/params/
kv_overrides.rs

1//! Key-value overrides for a model.
2
3use crate::model::params::LlamaModelParams;
4use crate::model::params::param_override_value::ParamOverrideValue;
5use std::ffi::{CStr, CString};
6use std::fmt::Debug;
7
8/// A struct implementing [`IntoIterator`] over the key-value overrides for a model.
9#[derive(Debug)]
10pub struct KvOverrides<'model_params> {
11    model_params: &'model_params LlamaModelParams,
12}
13
14impl KvOverrides<'_> {
15    /// Creates a new `KvOverrides` view over the given model parameters.
16    #[must_use]
17    pub const fn new(model_params: &LlamaModelParams) -> KvOverrides<'_> {
18        KvOverrides { model_params }
19    }
20}
21
22impl<'model_params> IntoIterator for KvOverrides<'model_params> {
23    // I'm fairly certain this could be written returning by reference, but I'm not sure how to do it safely. I do not
24    // expect this to be a performance bottleneck so the copy should be fine. (let me know if it's not fine!)
25    type Item = (CString, ParamOverrideValue);
26    type IntoIter = KvOverrideValueIterator<'model_params>;
27
28    fn into_iter(self) -> Self::IntoIter {
29        KvOverrideValueIterator {
30            model_params: self.model_params,
31            current: 0,
32        }
33    }
34}
35
36/// An iterator over the key-value overrides for a model.
37#[derive(Debug)]
38pub struct KvOverrideValueIterator<'model_params> {
39    model_params: &'model_params LlamaModelParams,
40    current: usize,
41}
42
43impl Iterator for KvOverrideValueIterator<'_> {
44    type Item = (CString, ParamOverrideValue);
45
46    fn next(&mut self) -> Option<Self::Item> {
47        let overrides = self.model_params.params.kv_overrides;
48
49        if overrides.is_null() {
50            return None;
51        }
52
53        loop {
54            // SAFETY: llama.cpp guarantees the last element contains an empty key.
55            // We've checked the previous one in the last iteration, the next one
56            // should be valid or 0 (and thus safe to deref).
57            let current = unsafe { *overrides.add(self.current) };
58
59            if current.key[0] == 0 {
60                return None;
61            }
62
63            self.current += 1;
64
65            if let Ok(value) = ParamOverrideValue::try_from(&current) {
66                let key = unsafe { CStr::from_ptr(current.key.as_ptr()).to_owned() };
67
68                return Some((key, value));
69            }
70        }
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use std::ffi::CString;
77    use std::pin::pin;
78
79    use crate::model::params::LlamaModelParams;
80    use crate::model::params::param_override_value::ParamOverrideValue;
81
82    #[test]
83    fn kv_overrides_empty_by_default() {
84        let params = LlamaModelParams::default();
85        let overrides = params.kv_overrides();
86        let count = overrides.into_iter().count();
87
88        assert_eq!(count, 0);
89    }
90
91    #[test]
92    fn kv_overrides_iterates_single_entry() {
93        let mut params = pin!(LlamaModelParams::default());
94        let key = CString::new("test_key").unwrap();
95
96        params
97            .as_mut()
98            .append_kv_override(&key, ParamOverrideValue::Int(42))
99            .unwrap();
100
101        let entries: Vec<_> = params.kv_overrides().into_iter().collect();
102
103        assert_eq!(entries.len(), 1);
104        let (entry_key, entry_value) = &entries[0];
105        assert_eq!(entry_key.to_bytes(), b"test_key");
106        assert_eq!(*entry_value, ParamOverrideValue::Int(42));
107    }
108
109    #[test]
110    fn kv_overrides_new_creates_view() {
111        let params = LlamaModelParams::default();
112        let overrides = super::KvOverrides::new(&params);
113        let count = overrides.into_iter().count();
114
115        assert_eq!(count, 0);
116    }
117
118    #[test]
119    fn kv_overrides_skips_entry_with_unknown_tag() {
120        let mut params = pin!(LlamaModelParams::default());
121        let key = CString::new("valid_key").unwrap();
122
123        params
124            .as_mut()
125            .append_kv_override(&key, ParamOverrideValue::Int(99))
126            .unwrap();
127
128        // Corrupt the tag of the first entry to an unknown value
129        params.kv_overrides[0].tag = 9999;
130
131        assert_eq!(params.kv_overrides().into_iter().count(), 0);
132    }
133}