Skip to main content

llama_cpp_bindings/model/params/
param_override_value.rs

1use crate::model::params::unknown_kv_override_tag::UnknownKvOverrideTag;
2
3#[derive(Debug, Clone, Copy, PartialEq)]
4pub enum ParamOverrideValue {
5    Bool(bool),
6    Float(f64),
7    Int(i64),
8    Str([std::os::raw::c_char; 128]),
9}
10
11impl ParamOverrideValue {
12    #[must_use]
13    pub const fn tag(&self) -> llama_cpp_bindings_sys::llama_model_kv_override_type {
14        match self {
15            Self::Bool(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
16            Self::Float(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
17            Self::Int(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
18            Self::Str(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
19        }
20    }
21
22    #[must_use]
23    pub const fn value(&self) -> llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
24        match self {
25            Self::Bool(value) => {
26                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_bool: *value }
27            }
28            Self::Float(value) => {
29                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_f64: *value }
30            }
31            Self::Int(value) => {
32                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_i64: *value }
33            }
34            Self::Str(c_string) => {
35                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string }
36            }
37        }
38    }
39}
40
41impl TryFrom<&llama_cpp_bindings_sys::llama_model_kv_override> for ParamOverrideValue {
42    type Error = UnknownKvOverrideTag;
43
44    fn try_from(
45        llama_cpp_bindings_sys::llama_model_kv_override {
46            key: _,
47            tag,
48            __bindgen_anon_1,
49        }: &llama_cpp_bindings_sys::llama_model_kv_override,
50    ) -> Result<Self, Self::Error> {
51        match *tag {
52            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT => {
53                Ok(Self::Int(unsafe { __bindgen_anon_1.val_i64 }))
54            }
55            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
56                Ok(Self::Float(unsafe { __bindgen_anon_1.val_f64 }))
57            }
58            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
59                Ok(Self::Bool(unsafe { __bindgen_anon_1.val_bool }))
60            }
61            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR => {
62                Ok(Self::Str(unsafe { __bindgen_anon_1.val_str }))
63            }
64            unknown_tag => Err(UnknownKvOverrideTag(unknown_tag)),
65        }
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::ParamOverrideValue;
72
73    #[test]
74    fn tag_bool() {
75        let value = ParamOverrideValue::Bool(true);
76
77        assert_eq!(
78            value.tag(),
79            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL
80        );
81    }
82
83    #[test]
84    fn tag_float() {
85        let value = ParamOverrideValue::Float(1.23);
86
87        assert_eq!(
88            value.tag(),
89            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT
90        );
91    }
92
93    #[test]
94    fn tag_int() {
95        let value = ParamOverrideValue::Int(42);
96
97        assert_eq!(
98            value.tag(),
99            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT
100        );
101    }
102
103    #[test]
104    fn tag_str() {
105        let value = ParamOverrideValue::Str([0; 128]);
106
107        assert_eq!(
108            value.tag(),
109            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR
110        );
111    }
112
113    #[test]
114    fn value_bool_roundtrip() {
115        let value = ParamOverrideValue::Bool(true);
116        let ffi_value = value.value();
117        let result = unsafe { ffi_value.val_bool };
118
119        assert!(result);
120    }
121
122    #[test]
123    fn value_float_roundtrip() {
124        let value = ParamOverrideValue::Float(1.23);
125        let ffi_value = value.value();
126        let result = unsafe { ffi_value.val_f64 };
127
128        assert!((result - 1.23).abs() < f64::EPSILON);
129    }
130
131    #[test]
132    fn value_int_roundtrip() {
133        let value = ParamOverrideValue::Int(99);
134        let ffi_value = value.value();
135        let result = unsafe { ffi_value.val_i64 };
136
137        assert_eq!(result, 99);
138    }
139
140    #[test]
141    fn from_ffi_override_int() {
142        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
143            key: [0; 128],
144            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
145            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
146                val_i64: 123,
147            },
148        };
149
150        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
151
152        assert_eq!(value, ParamOverrideValue::Int(123));
153    }
154
155    #[test]
156    fn from_ffi_override_float() {
157        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
158            key: [0; 128],
159            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
160            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
161                val_f64: 1.5,
162            },
163        };
164
165        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
166
167        assert_eq!(value, ParamOverrideValue::Float(1.5));
168    }
169
170    #[test]
171    fn from_ffi_override_bool() {
172        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
173            key: [0; 128],
174            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
175            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
176                val_bool: false,
177            },
178        };
179
180        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
181
182        assert_eq!(value, ParamOverrideValue::Bool(false));
183    }
184
185    #[test]
186    fn value_str_roundtrip() {
187        let mut str_data = [0i8; 128];
188        str_data[0] = b'h'.cast_signed();
189        str_data[1] = b'i'.cast_signed();
190
191        let value = ParamOverrideValue::Str(str_data);
192        let ffi_value = value.value();
193        let result = unsafe { ffi_value.val_str };
194
195        assert_eq!(result[0], b'h'.cast_signed());
196        assert_eq!(result[1], b'i'.cast_signed());
197    }
198
199    #[test]
200    fn from_ffi_override_str() {
201        let mut str_data = [0i8; 128];
202        str_data[0] = b'a'.cast_signed();
203        str_data[1] = b'b'.cast_signed();
204
205        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
206            key: [0; 128],
207            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
208            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
209                val_str: str_data,
210            },
211        };
212
213        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();
214
215        assert_eq!(value, ParamOverrideValue::Str(str_data));
216    }
217
218    #[test]
219    fn unknown_tag_returns_error() {
220        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
221            key: [0; 128],
222            tag: 9999,
223            __bindgen_anon_1: unsafe { std::mem::zeroed() },
224        };
225
226        let result = ParamOverrideValue::try_from(&ffi_override);
227
228        assert!(result.is_err());
229    }
230}