Skip to main content

llama_cpp_bindings/model/params/
param_override_value.rs

1/// An override value for a model parameter.
2#[derive(Debug, Clone, Copy, PartialEq)]
3pub enum ParamOverrideValue {
4    /// A boolean value
5    Bool(bool),
6    /// A float value
7    Float(f64),
8    /// A integer value
9    Int(i64),
10    /// A string value
11    Str([std::os::raw::c_char; 128]),
12}
13
14impl ParamOverrideValue {
15    /// Returns the FFI tag corresponding to this override value variant.
16    #[must_use]
17    pub fn tag(&self) -> llama_cpp_bindings_sys::llama_model_kv_override_type {
18        match self {
19            ParamOverrideValue::Bool(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
20            ParamOverrideValue::Float(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
21            ParamOverrideValue::Int(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
22            ParamOverrideValue::Str(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
23        }
24    }
25
26    /// Returns the FFI union value for this override.
27    #[must_use]
28    pub fn value(&self) -> llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
29        match self {
30            ParamOverrideValue::Bool(value) => {
31                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_bool: *value }
32            }
33            ParamOverrideValue::Float(value) => {
34                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_f64: *value }
35            }
36            ParamOverrideValue::Int(value) => {
37                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_i64: *value }
38            }
39            ParamOverrideValue::Str(c_string) => {
40                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string }
41            }
42        }
43    }
44}
45
46impl From<&llama_cpp_bindings_sys::llama_model_kv_override> for ParamOverrideValue {
47    fn from(
48        llama_cpp_bindings_sys::llama_model_kv_override {
49            key: _,
50            tag,
51            __bindgen_anon_1,
52        }: &llama_cpp_bindings_sys::llama_model_kv_override,
53    ) -> Self {
54        match *tag {
55            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT => {
56                ParamOverrideValue::Int(unsafe { __bindgen_anon_1.val_i64 })
57            }
58            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
59                ParamOverrideValue::Float(unsafe { __bindgen_anon_1.val_f64 })
60            }
61            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
62                ParamOverrideValue::Bool(unsafe { __bindgen_anon_1.val_bool })
63            }
64            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR => {
65                ParamOverrideValue::Str(unsafe { __bindgen_anon_1.val_str })
66            }
67            _ => unreachable!("Unknown tag of {tag}"),
68        }
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::ParamOverrideValue;
75
76    #[test]
77    fn tag_bool() {
78        let value = ParamOverrideValue::Bool(true);
79
80        assert_eq!(
81            value.tag(),
82            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL
83        );
84    }
85
86    #[test]
87    fn tag_float() {
88        let value = ParamOverrideValue::Float(1.23);
89
90        assert_eq!(
91            value.tag(),
92            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT
93        );
94    }
95
96    #[test]
97    fn tag_int() {
98        let value = ParamOverrideValue::Int(42);
99
100        assert_eq!(
101            value.tag(),
102            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT
103        );
104    }
105
106    #[test]
107    fn tag_str() {
108        let value = ParamOverrideValue::Str([0; 128]);
109
110        assert_eq!(
111            value.tag(),
112            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR
113        );
114    }
115
116    #[test]
117    fn value_bool_roundtrip() {
118        let value = ParamOverrideValue::Bool(true);
119        let ffi_value = value.value();
120        let result = unsafe { ffi_value.val_bool };
121
122        assert!(result);
123    }
124
125    #[test]
126    fn value_float_roundtrip() {
127        let value = ParamOverrideValue::Float(1.23);
128        let ffi_value = value.value();
129        let result = unsafe { ffi_value.val_f64 };
130
131        assert!((result - 1.23).abs() < f64::EPSILON);
132    }
133
134    #[test]
135    fn value_int_roundtrip() {
136        let value = ParamOverrideValue::Int(99);
137        let ffi_value = value.value();
138        let result = unsafe { ffi_value.val_i64 };
139
140        assert_eq!(result, 99);
141    }
142
143    #[test]
144    fn from_ffi_override_int() {
145        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
146            key: [0; 128],
147            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
148            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
149                val_i64: 123,
150            },
151        };
152
153        let value = ParamOverrideValue::from(&ffi_override);
154
155        assert_eq!(value, ParamOverrideValue::Int(123));
156    }
157
158    #[test]
159    fn from_ffi_override_float() {
160        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
161            key: [0; 128],
162            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
163            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
164                val_f64: 1.5,
165            },
166        };
167
168        let value = ParamOverrideValue::from(&ffi_override);
169
170        assert_eq!(value, ParamOverrideValue::Float(1.5));
171    }
172
173    #[test]
174    fn from_ffi_override_bool() {
175        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
176            key: [0; 128],
177            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
178            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
179                val_bool: false,
180            },
181        };
182
183        let value = ParamOverrideValue::from(&ffi_override);
184
185        assert_eq!(value, ParamOverrideValue::Bool(false));
186    }
187}