Skip to main content

llama_cpp_bindings/model/params/
kv_overrides.rs

1//! Key-value overrides for a model.
2
3use crate::model::params::LlamaModelParams;
4use crate::model::params::param_override_value::ParamOverrideValue;
5use std::ffi::{CStr, CString};
6use std::fmt::Debug;
7
8/// A struct implementing [`IntoIterator`] over the key-value overrides for a model.
9#[derive(Debug)]
10pub struct KvOverrides<'model_params> {
11    model_params: &'model_params LlamaModelParams,
12}
13
14impl KvOverrides<'_> {
15    /// Creates a new `KvOverrides` view over the given model parameters.
16    #[must_use]
17    pub const fn new(model_params: &LlamaModelParams) -> KvOverrides<'_> {
18        KvOverrides { model_params }
19    }
20}
21
22impl<'model_params> IntoIterator for KvOverrides<'model_params> {
23    type Item = (CString, ParamOverrideValue);
24    type IntoIter = KvOverrideValueIterator<'model_params>;
25
26    fn into_iter(self) -> Self::IntoIter {
27        KvOverrideValueIterator {
28            model_params: self.model_params,
29            current: 0,
30        }
31    }
32}
33
34/// An iterator over the key-value overrides for a model.
35#[derive(Debug)]
36pub struct KvOverrideValueIterator<'model_params> {
37    model_params: &'model_params LlamaModelParams,
38    current: usize,
39}
40
41impl Iterator for KvOverrideValueIterator<'_> {
42    type Item = (CString, ParamOverrideValue);
43
44    fn next(&mut self) -> Option<Self::Item> {
45        let overrides = self.model_params.params.kv_overrides;
46
47        if overrides.is_null() {
48            return None;
49        }
50
51        loop {
52            // SAFETY: llama.cpp guarantees the last element contains an empty key.
53            // We've checked the previous one in the last iteration, the next one
54            // should be valid or 0 (and thus safe to deref).
55            let current = unsafe { *overrides.add(self.current) };
56
57            if current.key[0] == 0 {
58                return None;
59            }
60
61            self.current += 1;
62
63            if let Ok(value) = ParamOverrideValue::try_from(&current) {
64                let key = unsafe { CStr::from_ptr(current.key.as_ptr()).to_owned() };
65
66                return Some((key, value));
67            }
68        }
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use std::ffi::CString;
75    use std::pin::pin;
76
77    use crate::model::params::LlamaModelParams;
78    use crate::model::params::param_override_value::ParamOverrideValue;
79
80    #[test]
81    fn kv_overrides_empty_by_default() {
82        let params = LlamaModelParams::default();
83        let overrides = params.kv_overrides();
84        let count = overrides.into_iter().count();
85
86        assert_eq!(count, 0);
87    }
88
89    #[test]
90    fn kv_overrides_iterates_single_entry() {
91        let mut params = pin!(LlamaModelParams::default());
92        let key = CString::new("test_key").unwrap();
93
94        params
95            .as_mut()
96            .append_kv_override(&key, ParamOverrideValue::Int(42))
97            .unwrap();
98
99        let entries: Vec<_> = params.kv_overrides().into_iter().collect();
100
101        assert_eq!(entries.len(), 1);
102        let (entry_key, entry_value) = &entries[0];
103        assert_eq!(entry_key.to_bytes(), b"test_key");
104        assert_eq!(*entry_value, ParamOverrideValue::Int(42));
105    }
106
107    #[test]
108    fn kv_overrides_new_creates_view() {
109        let params = LlamaModelParams::default();
110        let overrides = super::KvOverrides::new(&params);
111        let count = overrides.into_iter().count();
112
113        assert_eq!(count, 0);
114    }
115
116    #[test]
117    fn kv_overrides_skips_entry_with_unknown_tag() {
118        let mut params = pin!(LlamaModelParams::default());
119        let key = CString::new("valid_key").unwrap();
120
121        params
122            .as_mut()
123            .append_kv_override(&key, ParamOverrideValue::Int(99))
124            .unwrap();
125
126        params.kv_overrides[0].tag = 9999;
127
128        assert_eq!(params.kv_overrides().into_iter().count(), 0);
129    }
130}