Skip to main content

llama_cpp_bindings/model/params/
param_override_value.rs

1use crate::model::params::unknown_kv_override_tag::UnknownKvOverrideTag;
2
3/// An override value for a model parameter.
4#[derive(Debug, Clone, Copy, PartialEq)]
5pub enum ParamOverrideValue {
6    /// A boolean value
7    Bool(bool),
8    /// A float value
9    Float(f64),
10    /// A integer value
11    Int(i64),
12    /// A string value
13    Str([std::os::raw::c_char; 128]),
14}
15
16impl ParamOverrideValue {
17    /// Returns the FFI tag corresponding to this override value variant.
18    #[must_use]
19    pub const fn tag(&self) -> llama_cpp_bindings_sys::llama_model_kv_override_type {
20        match self {
21            Self::Bool(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
22            Self::Float(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
23            Self::Int(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
24            Self::Str(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
25        }
26    }
27
28    /// Returns the FFI union value for this override.
29    #[must_use]
30    pub const fn value(&self) -> llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
31        match self {
32            Self::Bool(value) => {
33                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_bool: *value }
34            }
35            Self::Float(value) => {
36                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_f64: *value }
37            }
38            Self::Int(value) => {
39                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_i64: *value }
40            }
41            Self::Str(c_string) => {
42                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string }
43            }
44        }
45    }
46}
47
48impl TryFrom<&llama_cpp_bindings_sys::llama_model_kv_override> for ParamOverrideValue {
49    type Error = UnknownKvOverrideTag;
50
51    fn try_from(
52        llama_cpp_bindings_sys::llama_model_kv_override {
53            key: _,
54            tag,
55            __bindgen_anon_1,
56        }: &llama_cpp_bindings_sys::llama_model_kv_override,
57    ) -> Result<Self, Self::Error> {
58        match *tag {
59            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT => {
60                Ok(Self::Int(unsafe { __bindgen_anon_1.val_i64 }))
61            }
62            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
63                Ok(Self::Float(unsafe { __bindgen_anon_1.val_f64 }))
64            }
65            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
66                Ok(Self::Bool(unsafe { __bindgen_anon_1.val_bool }))
67            }
68            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR => {
69                Ok(Self::Str(unsafe { __bindgen_anon_1.val_str }))
70            }
71            unknown_tag => Err(UnknownKvOverrideTag(unknown_tag)),
72        }
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::ParamOverrideValue;
79
80    #[test]
81    fn tag_bool() {
82        let value = ParamOverrideValue::Bool(true);
83
84        assert_eq!(
85            value.tag(),
86            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL
87        );
88    }
89
90    #[test]
91    fn tag_float() {
92        let value = ParamOverrideValue::Float(1.23);
93
94        assert_eq!(
95            value.tag(),
96            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT
97        );
98    }
99
100    #[test]
101    fn tag_int() {
102        let value = ParamOverrideValue::Int(42);
103
104        assert_eq!(
105            value.tag(),
106            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT
107        );
108    }
109
110    #[test]
111    fn tag_str() {
112        let value = ParamOverrideValue::Str([0; 128]);
113
114        assert_eq!(
115            value.tag(),
116            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR
117        );
118    }
119
120    #[test]
121    fn value_bool_roundtrip() {
122        let value = ParamOverrideValue::Bool(true);
123        let ffi_value = value.value();
124        let result = unsafe { ffi_value.val_bool };
125
126        assert!(result);
127    }
128
129    #[test]
130    fn value_float_roundtrip() {
131        let value = ParamOverrideValue::Float(1.23);
132        let ffi_value = value.value();
133        let result = unsafe { ffi_value.val_f64 };
134
135        assert!((result - 1.23).abs() < f64::EPSILON);
136    }
137
138    #[test]
139    fn value_int_roundtrip() {
140        let value = ParamOverrideValue::Int(99);
141        let ffi_value = value.value();
142        let result = unsafe { ffi_value.val_i64 };
143
144        assert_eq!(result, 99);
145    }
146
147    #[test]
148    fn from_ffi_override_int() {
149        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
150            key: [0; 128],
151            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
152            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
153                val_i64: 123,
154            },
155        };
156
157        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
158
159        assert_eq!(value, ParamOverrideValue::Int(123));
160    }
161
162    #[test]
163    fn from_ffi_override_float() {
164        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
165            key: [0; 128],
166            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
167            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
168                val_f64: 1.5,
169            },
170        };
171
172        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
173
174        assert_eq!(value, ParamOverrideValue::Float(1.5));
175    }
176
177    #[test]
178    fn from_ffi_override_bool() {
179        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
180            key: [0; 128],
181            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
182            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
183                val_bool: false,
184            },
185        };
186
187        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
188
189        assert_eq!(value, ParamOverrideValue::Bool(false));
190    }
191
192    #[test]
193    fn value_str_roundtrip() {
194        let mut str_data = [0i8; 128];
195        str_data[0] = b'h'.cast_signed();
196        str_data[1] = b'i'.cast_signed();
197
198        let value = ParamOverrideValue::Str(str_data);
199        let ffi_value = value.value();
200        let result = unsafe { ffi_value.val_str };
201
202        assert_eq!(result[0], b'h'.cast_signed());
203        assert_eq!(result[1], b'i'.cast_signed());
204    }
205
206    #[test]
207    fn from_ffi_override_str() {
208        let mut str_data = [0i8; 128];
209        str_data[0] = b'a'.cast_signed();
210        str_data[1] = b'b'.cast_signed();
211
212        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
213            key: [0; 128],
214            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
215            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
216                val_str: str_data,
217            },
218        };
219
220        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
221
222        assert_eq!(value, ParamOverrideValue::Str(str_data));
223    }
224
225    #[test]
226    fn unknown_tag_returns_error() {
227        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
228            key: [0; 128],
229            tag: 9999,
230            __bindgen_anon_1: unsafe { std::mem::zeroed() },
231        };
232
233        let result = ParamOverrideValue::try_from(&ffi_override);
234
235        assert!(result.is_err());
236    }
237}