llama_cpp_bindings/model/params/
param_override_value.rs1#[derive(Debug, Clone, Copy, PartialEq)]
3pub enum ParamOverrideValue {
4 Bool(bool),
6 Float(f64),
8 Int(i64),
10 Str([std::os::raw::c_char; 128]),
12}
13
14impl ParamOverrideValue {
15 #[must_use]
17 pub const fn tag(&self) -> llama_cpp_bindings_sys::llama_model_kv_override_type {
18 match self {
19 Self::Bool(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
20 Self::Float(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
21 Self::Int(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
22 Self::Str(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
23 }
24 }
25
26 #[must_use]
28 pub const fn value(&self) -> llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
29 match self {
30 Self::Bool(value) => {
31 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_bool: *value }
32 }
33 Self::Float(value) => {
34 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_f64: *value }
35 }
36 Self::Int(value) => {
37 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_i64: *value }
38 }
39 Self::Str(c_string) => {
40 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string }
41 }
42 }
43 }
44}
45
46#[derive(Debug, thiserror::Error)]
48#[error("unknown KV override tag: {0}")]
49pub struct UnknownKvOverrideTag(pub llama_cpp_bindings_sys::llama_model_kv_override_type);
50
51impl TryFrom<&llama_cpp_bindings_sys::llama_model_kv_override> for ParamOverrideValue {
52 type Error = UnknownKvOverrideTag;
53
54 fn try_from(
55 llama_cpp_bindings_sys::llama_model_kv_override {
56 key: _,
57 tag,
58 __bindgen_anon_1,
59 }: &llama_cpp_bindings_sys::llama_model_kv_override,
60 ) -> Result<Self, Self::Error> {
61 match *tag {
62 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT => {
63 Ok(Self::Int(unsafe { __bindgen_anon_1.val_i64 }))
64 }
65 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
66 Ok(Self::Float(unsafe { __bindgen_anon_1.val_f64 }))
67 }
68 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
69 Ok(Self::Bool(unsafe { __bindgen_anon_1.val_bool }))
70 }
71 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR => {
72 Ok(Self::Str(unsafe { __bindgen_anon_1.val_str }))
73 }
74 unknown_tag => Err(UnknownKvOverrideTag(unknown_tag)),
75 }
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::ParamOverrideValue;
82
83 #[test]
84 fn tag_bool() {
85 let value = ParamOverrideValue::Bool(true);
86
87 assert_eq!(
88 value.tag(),
89 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL
90 );
91 }
92
93 #[test]
94 fn tag_float() {
95 let value = ParamOverrideValue::Float(1.23);
96
97 assert_eq!(
98 value.tag(),
99 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT
100 );
101 }
102
103 #[test]
104 fn tag_int() {
105 let value = ParamOverrideValue::Int(42);
106
107 assert_eq!(
108 value.tag(),
109 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT
110 );
111 }
112
113 #[test]
114 fn tag_str() {
115 let value = ParamOverrideValue::Str([0; 128]);
116
117 assert_eq!(
118 value.tag(),
119 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR
120 );
121 }
122
123 #[test]
124 fn value_bool_roundtrip() {
125 let value = ParamOverrideValue::Bool(true);
126 let ffi_value = value.value();
127 let result = unsafe { ffi_value.val_bool };
128
129 assert!(result);
130 }
131
132 #[test]
133 fn value_float_roundtrip() {
134 let value = ParamOverrideValue::Float(1.23);
135 let ffi_value = value.value();
136 let result = unsafe { ffi_value.val_f64 };
137
138 assert!((result - 1.23).abs() < f64::EPSILON);
139 }
140
141 #[test]
142 fn value_int_roundtrip() {
143 let value = ParamOverrideValue::Int(99);
144 let ffi_value = value.value();
145 let result = unsafe { ffi_value.val_i64 };
146
147 assert_eq!(result, 99);
148 }
149
150 #[test]
151 fn from_ffi_override_int() {
152 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
153 key: [0; 128],
154 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
155 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
156 val_i64: 123,
157 },
158 };
159
160 let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
161
162 assert_eq!(value, ParamOverrideValue::Int(123));
163 }
164
165 #[test]
166 fn from_ffi_override_float() {
167 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
168 key: [0; 128],
169 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
170 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
171 val_f64: 1.5,
172 },
173 };
174
175 let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
176
177 assert_eq!(value, ParamOverrideValue::Float(1.5));
178 }
179
180 #[test]
181 fn from_ffi_override_bool() {
182 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
183 key: [0; 128],
184 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
185 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
186 val_bool: false,
187 },
188 };
189
190 let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
191
192 assert_eq!(value, ParamOverrideValue::Bool(false));
193 }
194
195 #[test]
196 fn value_str_roundtrip() {
197 let mut str_data = [0i8; 128];
198 str_data[0] = b'h'.cast_signed();
199 str_data[1] = b'i'.cast_signed();
200
201 let value = ParamOverrideValue::Str(str_data);
202 let ffi_value = value.value();
203 let result = unsafe { ffi_value.val_str };
204
205 assert_eq!(result[0], b'h'.cast_signed());
206 assert_eq!(result[1], b'i'.cast_signed());
207 }
208
209 #[test]
210 fn from_ffi_override_str() {
211 let mut str_data = [0i8; 128];
212 str_data[0] = b'a'.cast_signed();
213 str_data[1] = b'b'.cast_signed();
214
215 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
216 key: [0; 128],
217 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
218 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
219 val_str: str_data,
220 },
221 };
222
223 let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
224
225 assert_eq!(value, ParamOverrideValue::Str(str_data));
226 }
227
228 #[test]
229 fn unknown_tag_returns_error() {
230 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
231 key: [0; 128],
232 tag: 9999,
233 __bindgen_anon_1: unsafe { std::mem::zeroed() },
234 };
235
236 let result = ParamOverrideValue::try_from(&ffi_override);
237
238 assert!(result.is_err());
239 }
240}