llama_cpp_bindings/model/params/
kv_overrides.rs1use crate::model::params::LlamaModelParams;
4use crate::model::params::param_override_value::ParamOverrideValue;
5use std::ffi::{CStr, CString};
6use std::fmt::Debug;
7
8#[derive(Debug)]
10pub struct KvOverrides<'model_params> {
11 model_params: &'model_params LlamaModelParams,
12}
13
14impl KvOverrides<'_> {
15 #[must_use]
17 pub const fn new(model_params: &LlamaModelParams) -> KvOverrides<'_> {
18 KvOverrides { model_params }
19 }
20}
21
22impl<'model_params> IntoIterator for KvOverrides<'model_params> {
23 type Item = (CString, ParamOverrideValue);
24 type IntoIter = KvOverrideValueIterator<'model_params>;
25
26 fn into_iter(self) -> Self::IntoIter {
27 KvOverrideValueIterator {
28 model_params: self.model_params,
29 current: 0,
30 }
31 }
32}
33
34#[derive(Debug)]
36pub struct KvOverrideValueIterator<'model_params> {
37 model_params: &'model_params LlamaModelParams,
38 current: usize,
39}
40
41impl Iterator for KvOverrideValueIterator<'_> {
42 type Item = (CString, ParamOverrideValue);
43
44 fn next(&mut self) -> Option<Self::Item> {
45 let overrides = self.model_params.params.kv_overrides;
46
47 if overrides.is_null() {
48 return None;
49 }
50
51 loop {
52 let current = unsafe { *overrides.add(self.current) };
56
57 if current.key[0] == 0 {
58 return None;
59 }
60
61 self.current += 1;
62
63 if let Ok(value) = ParamOverrideValue::try_from(¤t) {
64 let key = unsafe { CStr::from_ptr(current.key.as_ptr()).to_owned() };
65
66 return Some((key, value));
67 }
68 }
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use std::ffi::CString;
75 use std::pin::pin;
76
77 use crate::model::params::LlamaModelParams;
78 use crate::model::params::param_override_value::ParamOverrideValue;
79
80 #[test]
81 fn kv_overrides_empty_by_default() {
82 let params = LlamaModelParams::default();
83 let overrides = params.kv_overrides();
84 let count = overrides.into_iter().count();
85
86 assert_eq!(count, 0);
87 }
88
89 #[test]
90 fn kv_overrides_iterates_single_entry() {
91 let mut params = pin!(LlamaModelParams::default());
92 let key = CString::new("test_key").unwrap();
93
94 params
95 .as_mut()
96 .append_kv_override(&key, ParamOverrideValue::Int(42))
97 .unwrap();
98
99 let entries: Vec<_> = params.kv_overrides().into_iter().collect();
100
101 assert_eq!(entries.len(), 1);
102 let (entry_key, entry_value) = &entries[0];
103 assert_eq!(entry_key.to_bytes(), b"test_key");
104 assert_eq!(*entry_value, ParamOverrideValue::Int(42));
105 }
106
107 #[test]
108 fn kv_overrides_new_creates_view() {
109 let params = LlamaModelParams::default();
110 let overrides = super::KvOverrides::new(¶ms);
111 let count = overrides.into_iter().count();
112
113 assert_eq!(count, 0);
114 }
115
116 #[test]
117 fn kv_overrides_skips_entry_with_unknown_tag() {
118 let mut params = pin!(LlamaModelParams::default());
119 let key = CString::new("valid_key").unwrap();
120
121 params
122 .as_mut()
123 .append_kv_override(&key, ParamOverrideValue::Int(99))
124 .unwrap();
125
126 params.kv_overrides[0].tag = 9999;
127
128 assert_eq!(params.kv_overrides().into_iter().count(), 0);
129 }
130}