llama_cpp_bindings/model/params/
param_override_value.rs1use crate::model::params::unknown_kv_override_tag::UnknownKvOverrideTag;
2
3#[derive(Debug, Clone, Copy, PartialEq)]
5pub enum ParamOverrideValue {
6 Bool(bool),
8 Float(f64),
10 Int(i64),
12 Str([std::os::raw::c_char; 128]),
14}
15
16impl ParamOverrideValue {
17 #[must_use]
19 pub const fn tag(&self) -> llama_cpp_bindings_sys::llama_model_kv_override_type {
20 match self {
21 Self::Bool(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
22 Self::Float(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
23 Self::Int(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
24 Self::Str(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
25 }
26 }
27
28 #[must_use]
30 pub const fn value(&self) -> llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
31 match self {
32 Self::Bool(value) => {
33 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_bool: *value }
34 }
35 Self::Float(value) => {
36 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_f64: *value }
37 }
38 Self::Int(value) => {
39 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_i64: *value }
40 }
41 Self::Str(c_string) => {
42 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string }
43 }
44 }
45 }
46}
47
48impl TryFrom<&llama_cpp_bindings_sys::llama_model_kv_override> for ParamOverrideValue {
49 type Error = UnknownKvOverrideTag;
50
51 fn try_from(
52 llama_cpp_bindings_sys::llama_model_kv_override {
53 key: _,
54 tag,
55 __bindgen_anon_1,
56 }: &llama_cpp_bindings_sys::llama_model_kv_override,
57 ) -> Result<Self, Self::Error> {
58 match *tag {
59 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT => {
60 Ok(Self::Int(unsafe { __bindgen_anon_1.val_i64 }))
61 }
62 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
63 Ok(Self::Float(unsafe { __bindgen_anon_1.val_f64 }))
64 }
65 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
66 Ok(Self::Bool(unsafe { __bindgen_anon_1.val_bool }))
67 }
68 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR => {
69 Ok(Self::Str(unsafe { __bindgen_anon_1.val_str }))
70 }
71 unknown_tag => Err(UnknownKvOverrideTag(unknown_tag)),
72 }
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::ParamOverrideValue;
79
80 #[test]
81 fn tag_bool() {
82 let value = ParamOverrideValue::Bool(true);
83
84 assert_eq!(
85 value.tag(),
86 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL
87 );
88 }
89
90 #[test]
91 fn tag_float() {
92 let value = ParamOverrideValue::Float(1.23);
93
94 assert_eq!(
95 value.tag(),
96 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT
97 );
98 }
99
100 #[test]
101 fn tag_int() {
102 let value = ParamOverrideValue::Int(42);
103
104 assert_eq!(
105 value.tag(),
106 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT
107 );
108 }
109
110 #[test]
111 fn tag_str() {
112 let value = ParamOverrideValue::Str([0; 128]);
113
114 assert_eq!(
115 value.tag(),
116 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR
117 );
118 }
119
120 #[test]
121 fn value_bool_roundtrip() {
122 let value = ParamOverrideValue::Bool(true);
123 let ffi_value = value.value();
124 let result = unsafe { ffi_value.val_bool };
125
126 assert!(result);
127 }
128
129 #[test]
130 fn value_float_roundtrip() {
131 let value = ParamOverrideValue::Float(1.23);
132 let ffi_value = value.value();
133 let result = unsafe { ffi_value.val_f64 };
134
135 assert!((result - 1.23).abs() < f64::EPSILON);
136 }
137
138 #[test]
139 fn value_int_roundtrip() {
140 let value = ParamOverrideValue::Int(99);
141 let ffi_value = value.value();
142 let result = unsafe { ffi_value.val_i64 };
143
144 assert_eq!(result, 99);
145 }
146
147 #[test]
148 fn from_ffi_override_int() {
149 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
150 key: [0; 128],
151 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
152 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
153 val_i64: 123,
154 },
155 };
156
157 let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
158
159 assert_eq!(value, ParamOverrideValue::Int(123));
160 }
161
162 #[test]
163 fn from_ffi_override_float() {
164 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
165 key: [0; 128],
166 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
167 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
168 val_f64: 1.5,
169 },
170 };
171
172 let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
173
174 assert_eq!(value, ParamOverrideValue::Float(1.5));
175 }
176
177 #[test]
178 fn from_ffi_override_bool() {
179 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
180 key: [0; 128],
181 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
182 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
183 val_bool: false,
184 },
185 };
186
187 let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
188
189 assert_eq!(value, ParamOverrideValue::Bool(false));
190 }
191
192 #[test]
193 fn value_str_roundtrip() {
194 let mut str_data = [0i8; 128];
195 str_data[0] = b'h'.cast_signed();
196 str_data[1] = b'i'.cast_signed();
197
198 let value = ParamOverrideValue::Str(str_data);
199 let ffi_value = value.value();
200 let result = unsafe { ffi_value.val_str };
201
202 assert_eq!(result[0], b'h'.cast_signed());
203 assert_eq!(result[1], b'i'.cast_signed());
204 }
205
206 #[test]
207 fn from_ffi_override_str() {
208 let mut str_data = [0i8; 128];
209 str_data[0] = b'a'.cast_signed();
210 str_data[1] = b'b'.cast_signed();
211
212 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
213 key: [0; 128],
214 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
215 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
216 val_str: str_data,
217 },
218 };
219
220 let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
221
222 assert_eq!(value, ParamOverrideValue::Str(str_data));
223 }
224
225 #[test]
226 fn unknown_tag_returns_error() {
227 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
228 key: [0; 128],
229 tag: 9999,
230 __bindgen_anon_1: unsafe { std::mem::zeroed() },
231 };
232
233 let result = ParamOverrideValue::try_from(&ffi_override);
234
235 assert!(result.is_err());
236 }
237}