Skip to main content

llama_cpp_bindings/model/params/
param_override_value.rs

1/// An override value for a model parameter.
2#[derive(Debug, Clone, Copy, PartialEq)]
3pub enum ParamOverrideValue {
4    /// A boolean value
5    Bool(bool),
6    /// A float value
7    Float(f64),
8    /// A integer value
9    Int(i64),
10    /// A string value
11    Str([std::os::raw::c_char; 128]),
12}
13
14impl ParamOverrideValue {
15    /// Returns the FFI tag corresponding to this override value variant.
16    #[must_use]
17    pub const fn tag(&self) -> llama_cpp_bindings_sys::llama_model_kv_override_type {
18        match self {
19            Self::Bool(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
20            Self::Float(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
21            Self::Int(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
22            Self::Str(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
23        }
24    }
25
26    /// Returns the FFI union value for this override.
27    #[must_use]
28    pub const fn value(&self) -> llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
29        match self {
30            Self::Bool(value) => {
31                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_bool: *value }
32            }
33            Self::Float(value) => {
34                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_f64: *value }
35            }
36            Self::Int(value) => {
37                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_i64: *value }
38            }
39            Self::Str(c_string) => {
40                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string }
41            }
42        }
43    }
44}
45
46/// Unknown KV override tag from the FFI layer.
47#[derive(Debug, thiserror::Error)]
48#[error("unknown KV override tag: {0}")]
49pub struct UnknownKvOverrideTag(pub llama_cpp_bindings_sys::llama_model_kv_override_type);
50
51impl TryFrom<&llama_cpp_bindings_sys::llama_model_kv_override> for ParamOverrideValue {
52    type Error = UnknownKvOverrideTag;
53
54    fn try_from(
55        llama_cpp_bindings_sys::llama_model_kv_override {
56            key: _,
57            tag,
58            __bindgen_anon_1,
59        }: &llama_cpp_bindings_sys::llama_model_kv_override,
60    ) -> Result<Self, Self::Error> {
61        match *tag {
62            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT => {
63                Ok(Self::Int(unsafe { __bindgen_anon_1.val_i64 }))
64            }
65            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
66                Ok(Self::Float(unsafe { __bindgen_anon_1.val_f64 }))
67            }
68            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
69                Ok(Self::Bool(unsafe { __bindgen_anon_1.val_bool }))
70            }
71            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR => {
72                Ok(Self::Str(unsafe { __bindgen_anon_1.val_str }))
73            }
74            unknown_tag => Err(UnknownKvOverrideTag(unknown_tag)),
75        }
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::ParamOverrideValue;
82
83    #[test]
84    fn tag_bool() {
85        let value = ParamOverrideValue::Bool(true);
86
87        assert_eq!(
88            value.tag(),
89            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL
90        );
91    }
92
93    #[test]
94    fn tag_float() {
95        let value = ParamOverrideValue::Float(1.23);
96
97        assert_eq!(
98            value.tag(),
99            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT
100        );
101    }
102
103    #[test]
104    fn tag_int() {
105        let value = ParamOverrideValue::Int(42);
106
107        assert_eq!(
108            value.tag(),
109            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT
110        );
111    }
112
113    #[test]
114    fn tag_str() {
115        let value = ParamOverrideValue::Str([0; 128]);
116
117        assert_eq!(
118            value.tag(),
119            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR
120        );
121    }
122
123    #[test]
124    fn value_bool_roundtrip() {
125        let value = ParamOverrideValue::Bool(true);
126        let ffi_value = value.value();
127        let result = unsafe { ffi_value.val_bool };
128
129        assert!(result);
130    }
131
132    #[test]
133    fn value_float_roundtrip() {
134        let value = ParamOverrideValue::Float(1.23);
135        let ffi_value = value.value();
136        let result = unsafe { ffi_value.val_f64 };
137
138        assert!((result - 1.23).abs() < f64::EPSILON);
139    }
140
141    #[test]
142    fn value_int_roundtrip() {
143        let value = ParamOverrideValue::Int(99);
144        let ffi_value = value.value();
145        let result = unsafe { ffi_value.val_i64 };
146
147        assert_eq!(result, 99);
148    }
149
150    #[test]
151    fn from_ffi_override_int() {
152        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
153            key: [0; 128],
154            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
155            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
156                val_i64: 123,
157            },
158        };
159
160        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
161
162        assert_eq!(value, ParamOverrideValue::Int(123));
163    }
164
165    #[test]
166    fn from_ffi_override_float() {
167        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
168            key: [0; 128],
169            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
170            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
171                val_f64: 1.5,
172            },
173        };
174
175        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
176
177        assert_eq!(value, ParamOverrideValue::Float(1.5));
178    }
179
180    #[test]
181    fn from_ffi_override_bool() {
182        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
183            key: [0; 128],
184            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
185            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
186                val_bool: false,
187            },
188        };
189
190        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
191
192        assert_eq!(value, ParamOverrideValue::Bool(false));
193    }
194
195    #[test]
196    fn value_str_roundtrip() {
197        let mut str_data = [0i8; 128];
198        str_data[0] = b'h'.cast_signed();
199        str_data[1] = b'i'.cast_signed();
200
201        let value = ParamOverrideValue::Str(str_data);
202        let ffi_value = value.value();
203        let result = unsafe { ffi_value.val_str };
204
205        assert_eq!(result[0], b'h'.cast_signed());
206        assert_eq!(result[1], b'i'.cast_signed());
207    }
208
209    #[test]
210    fn from_ffi_override_str() {
211        let mut str_data = [0i8; 128];
212        str_data[0] = b'a'.cast_signed();
213        str_data[1] = b'b'.cast_signed();
214
215        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
216            key: [0; 128],
217            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
218            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
219                val_str: str_data,
220            },
221        };
222
223        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
224
225        assert_eq!(value, ParamOverrideValue::Str(str_data));
226    }
227
228    #[test]
229    fn unknown_tag_returns_error() {
230        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
231            key: [0; 128],
232            tag: 9999,
233            __bindgen_anon_1: unsafe { std::mem::zeroed() },
234        };
235
236        let result = ParamOverrideValue::try_from(&ffi_override);
237
238        assert!(result.is_err());
239    }
240}