llama_cpp_bindings/model/params/
kv_overrides.rs1use crate::model::params::LlamaModelParams;
4use crate::model::params::param_override_value::ParamOverrideValue;
5use std::ffi::{CStr, CString};
6use std::fmt::Debug;
7
8#[derive(Debug)]
10pub struct KvOverrides<'model_params> {
11 model_params: &'model_params LlamaModelParams,
12}
13
14impl KvOverrides<'_> {
15 #[must_use]
17 pub const fn new(model_params: &LlamaModelParams) -> KvOverrides<'_> {
18 KvOverrides { model_params }
19 }
20}
21
22impl<'model_params> IntoIterator for KvOverrides<'model_params> {
23 type Item = (CString, ParamOverrideValue);
26 type IntoIter = KvOverrideValueIterator<'model_params>;
27
28 fn into_iter(self) -> Self::IntoIter {
29 KvOverrideValueIterator {
30 model_params: self.model_params,
31 current: 0,
32 }
33 }
34}
35
36#[derive(Debug)]
38pub struct KvOverrideValueIterator<'model_params> {
39 model_params: &'model_params LlamaModelParams,
40 current: usize,
41}
42
43impl Iterator for KvOverrideValueIterator<'_> {
44 type Item = (CString, ParamOverrideValue);
45
46 fn next(&mut self) -> Option<Self::Item> {
47 let overrides = self.model_params.params.kv_overrides;
48
49 if overrides.is_null() {
50 return None;
51 }
52
53 loop {
54 let current = unsafe { *overrides.add(self.current) };
58
59 if current.key[0] == 0 {
60 return None;
61 }
62
63 self.current += 1;
64
65 if let Ok(value) = ParamOverrideValue::try_from(¤t) {
66 let key = unsafe { CStr::from_ptr(current.key.as_ptr()).to_owned() };
67
68 return Some((key, value));
69 }
70 }
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use std::ffi::CString;
77 use std::pin::pin;
78
79 use crate::model::params::LlamaModelParams;
80 use crate::model::params::param_override_value::ParamOverrideValue;
81
82 #[test]
83 fn kv_overrides_empty_by_default() {
84 let params = LlamaModelParams::default();
85 let overrides = params.kv_overrides();
86 let count = overrides.into_iter().count();
87
88 assert_eq!(count, 0);
89 }
90
91 #[test]
92 fn kv_overrides_iterates_single_entry() {
93 let mut params = pin!(LlamaModelParams::default());
94 let key = CString::new("test_key").unwrap();
95
96 params
97 .as_mut()
98 .append_kv_override(&key, ParamOverrideValue::Int(42))
99 .unwrap();
100
101 let entries: Vec<_> = params.kv_overrides().into_iter().collect();
102
103 assert_eq!(entries.len(), 1);
104 let (entry_key, entry_value) = &entries[0];
105 assert_eq!(entry_key.to_bytes(), b"test_key");
106 assert_eq!(*entry_value, ParamOverrideValue::Int(42));
107 }
108
109 #[test]
110 fn kv_overrides_new_creates_view() {
111 let params = LlamaModelParams::default();
112 let overrides = super::KvOverrides::new(¶ms);
113 let count = overrides.into_iter().count();
114
115 assert_eq!(count, 0);
116 }
117
118 #[test]
119 fn kv_overrides_skips_entry_with_unknown_tag() {
120 let mut params = pin!(LlamaModelParams::default());
121 let key = CString::new("valid_key").unwrap();
122
123 params
124 .as_mut()
125 .append_kv_override(&key, ParamOverrideValue::Int(99))
126 .unwrap();
127
128 params.kv_overrides[0].tag = 9999;
130
131 assert_eq!(params.kv_overrides().into_iter().count(), 0);
132 }
133}