llama_cpp_bindings/model/params/
param_override_value.rs1#[derive(Debug, Clone, Copy, PartialEq)]
3pub enum ParamOverrideValue {
4 Bool(bool),
6 Float(f64),
8 Int(i64),
10 Str([std::os::raw::c_char; 128]),
12}
13
14impl ParamOverrideValue {
15 #[must_use]
17 pub fn tag(&self) -> llama_cpp_bindings_sys::llama_model_kv_override_type {
18 match self {
19 ParamOverrideValue::Bool(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
20 ParamOverrideValue::Float(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
21 ParamOverrideValue::Int(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
22 ParamOverrideValue::Str(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
23 }
24 }
25
26 #[must_use]
28 pub fn value(&self) -> llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
29 match self {
30 ParamOverrideValue::Bool(value) => {
31 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_bool: *value }
32 }
33 ParamOverrideValue::Float(value) => {
34 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_f64: *value }
35 }
36 ParamOverrideValue::Int(value) => {
37 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_i64: *value }
38 }
39 ParamOverrideValue::Str(c_string) => {
40 llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string }
41 }
42 }
43 }
44}
45
46impl From<&llama_cpp_bindings_sys::llama_model_kv_override> for ParamOverrideValue {
47 fn from(
48 llama_cpp_bindings_sys::llama_model_kv_override {
49 key: _,
50 tag,
51 __bindgen_anon_1,
52 }: &llama_cpp_bindings_sys::llama_model_kv_override,
53 ) -> Self {
54 match *tag {
55 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT => {
56 ParamOverrideValue::Int(unsafe { __bindgen_anon_1.val_i64 })
57 }
58 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
59 ParamOverrideValue::Float(unsafe { __bindgen_anon_1.val_f64 })
60 }
61 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
62 ParamOverrideValue::Bool(unsafe { __bindgen_anon_1.val_bool })
63 }
64 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR => {
65 ParamOverrideValue::Str(unsafe { __bindgen_anon_1.val_str })
66 }
67 _ => unreachable!("Unknown tag of {tag}"),
68 }
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::ParamOverrideValue;
75
76 #[test]
77 fn tag_bool() {
78 let value = ParamOverrideValue::Bool(true);
79
80 assert_eq!(
81 value.tag(),
82 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL
83 );
84 }
85
86 #[test]
87 fn tag_float() {
88 let value = ParamOverrideValue::Float(1.23);
89
90 assert_eq!(
91 value.tag(),
92 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT
93 );
94 }
95
96 #[test]
97 fn tag_int() {
98 let value = ParamOverrideValue::Int(42);
99
100 assert_eq!(
101 value.tag(),
102 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT
103 );
104 }
105
106 #[test]
107 fn tag_str() {
108 let value = ParamOverrideValue::Str([0; 128]);
109
110 assert_eq!(
111 value.tag(),
112 llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR
113 );
114 }
115
116 #[test]
117 fn value_bool_roundtrip() {
118 let value = ParamOverrideValue::Bool(true);
119 let ffi_value = value.value();
120 let result = unsafe { ffi_value.val_bool };
121
122 assert!(result);
123 }
124
125 #[test]
126 fn value_float_roundtrip() {
127 let value = ParamOverrideValue::Float(1.23);
128 let ffi_value = value.value();
129 let result = unsafe { ffi_value.val_f64 };
130
131 assert!((result - 1.23).abs() < f64::EPSILON);
132 }
133
134 #[test]
135 fn value_int_roundtrip() {
136 let value = ParamOverrideValue::Int(99);
137 let ffi_value = value.value();
138 let result = unsafe { ffi_value.val_i64 };
139
140 assert_eq!(result, 99);
141 }
142
143 #[test]
144 fn from_ffi_override_int() {
145 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
146 key: [0; 128],
147 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
148 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
149 val_i64: 123,
150 },
151 };
152
153 let value = ParamOverrideValue::from(&ffi_override);
154
155 assert_eq!(value, ParamOverrideValue::Int(123));
156 }
157
158 #[test]
159 fn from_ffi_override_float() {
160 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
161 key: [0; 128],
162 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
163 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
164 val_f64: 1.5,
165 },
166 };
167
168 let value = ParamOverrideValue::from(&ffi_override);
169
170 assert_eq!(value, ParamOverrideValue::Float(1.5));
171 }
172
173 #[test]
174 fn from_ffi_override_bool() {
175 let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
176 key: [0; 128],
177 tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
178 __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
179 val_bool: false,
180 },
181 };
182
183 let value = ParamOverrideValue::from(&ffi_override);
184
185 assert_eq!(value, ParamOverrideValue::Bool(false));
186 }
187}