llama_cpp_bindings/model/params/
param_override_value.rs1use crate::model::params::unknown_kv_override_tag::UnknownKvOverrideTag;
2
3#[derive(Debug, Clone, Copy, PartialEq)]
4pub enum ParamOverrideValue {
5 Bool(bool),
6 Float(f64),
7 Int(i64),
8 Str([std::os::raw::c_char; 128]),
9}
10
11impl ParamOverrideValue {
12 #[must_use]
13 pub const fn tag(&self) -> llama_cpp_bindings_sys::llama_model_kv_override_type {
14 match self {
15 Self::Bool(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
16 Self::Float(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
17 Self::Int(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
18 Self::Str(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
19 }
20 }
21
22 #[must_use]
23 pub const fn value(&self) -> llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
24 match self {
25 Self::Bool(value) => {
26 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_bool: *value }
27 }
28 Self::Float(value) => {
29 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_f64: *value }
30 }
31 Self::Int(value) => {
32 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_i64: *value }
33 }
34 Self::Str(c_string) => {
35 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string }
36 }
37 }
38 }
39}
40
41impl TryFrom<&llama_cpp_bindings_sys::llama_model_kv_override> for ParamOverrideValue {
42 type Error = UnknownKvOverrideTag;
43
44 fn try_from(
45 llama_cpp_bindings_sys::llama_model_kv_override {
46 key: _,
47 tag,
48 __bindgen_anon_1,
49 }: &llama_cpp_bindings_sys::llama_model_kv_override,
50 ) -> Result<Self, Self::Error> {
51 match *tag {
52 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT => {
53 Ok(Self::Int(unsafe { __bindgen_anon_1.val_i64 }))
54 }
55 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
56 Ok(Self::Float(unsafe { __bindgen_anon_1.val_f64 }))
57 }
58 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
59 Ok(Self::Bool(unsafe { __bindgen_anon_1.val_bool }))
60 }
61 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR => {
62 Ok(Self::Str(unsafe { __bindgen_anon_1.val_str }))
63 }
64 unknown_tag => Err(UnknownKvOverrideTag(unknown_tag)),
65 }
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::ParamOverrideValue;
72
73 #[test]
74 fn tag_bool() {
75 let value = ParamOverrideValue::Bool(true);
76
77 assert_eq!(
78 value.tag(),
79 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL
80 );
81 }
82
83 #[test]
84 fn tag_float() {
85 let value = ParamOverrideValue::Float(1.23);
86
87 assert_eq!(
88 value.tag(),
89 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT
90 );
91 }
92
93 #[test]
94 fn tag_int() {
95 let value = ParamOverrideValue::Int(42);
96
97 assert_eq!(
98 value.tag(),
99 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT
100 );
101 }
102
103 #[test]
104 fn tag_str() {
105 let value = ParamOverrideValue::Str([0; 128]);
106
107 assert_eq!(
108 value.tag(),
109 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR
110 );
111 }
112
113 #[test]
114 fn value_bool_roundtrip() {
115 let value = ParamOverrideValue::Bool(true);
116 let ffi_value = value.value();
117 let result = unsafe { ffi_value.val_bool };
118
119 assert!(result);
120 }
121
122 #[test]
123 fn value_float_roundtrip() {
124 let value = ParamOverrideValue::Float(1.23);
125 let ffi_value = value.value();
126 let result = unsafe { ffi_value.val_f64 };
127
128 assert!((result - 1.23).abs() < f64::EPSILON);
129 }
130
131 #[test]
132 fn value_int_roundtrip() {
133 let value = ParamOverrideValue::Int(99);
134 let ffi_value = value.value();
135 let result = unsafe { ffi_value.val_i64 };
136
137 assert_eq!(result, 99);
138 }
139
140 #[test]
141 fn from_ffi_override_int() {
142 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
143 key: [0; 128],
144 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
145 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
146 val_i64: 123,
147 },
148 };
149
150 let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
151
152 assert_eq!(value, ParamOverrideValue::Int(123));
153 }
154
155 #[test]
156 fn from_ffi_override_float() {
157 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
158 key: [0; 128],
159 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
160 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
161 val_f64: 1.5,
162 },
163 };
164
165 let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
166
167 assert_eq!(value, ParamOverrideValue::Float(1.5));
168 }
169
170 #[test]
171 fn from_ffi_override_bool() {
172 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
173 key: [0; 128],
174 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
175 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
176 val_bool: false,
177 },
178 };
179
180 let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
181
182 assert_eq!(value, ParamOverrideValue::Bool(false));
183 }
184
185 #[test]
186 fn value_str_roundtrip() {
187 let mut str_data = [0i8; 128];
188 str_data[0] = b'h'.cast_signed();
189 str_data[1] = b'i'.cast_signed();
190
191 let value = ParamOverrideValue::Str(str_data);
192 let ffi_value = value.value();
193 let result = unsafe { ffi_value.val_str };
194
195 assert_eq!(result[0], b'h'.cast_signed());
196 assert_eq!(result[1], b'i'.cast_signed());
197 }
198
199 #[test]
200 fn from_ffi_override_str() {
201 let mut str_data = [0i8; 128];
202 str_data[0] = b'a'.cast_signed();
203 str_data[1] = b'b'.cast_signed();
204
205 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
206 key: [0; 128],
207 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
208 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
209 val_str: str_data,
210 },
211 };
212
213 let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
214
215 assert_eq!(value, ParamOverrideValue::Str(str_data));
216 }
217
218 #[test]
219 fn unknown_tag_returns_error() {
220 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
221 key: [0; 128],
222 tag: 9999,
223 __bindgen_anon_1: unsafe { std::mem::zeroed() },
224 };
225
226 let result = ParamOverrideValue::try_from(&ffi_override);
227
228 assert!(result.is_err());
229 }
230}