llama_cpp_2/model/params/
kv_overrides.rs1use crate::model::params::LlamaModelParams;
4use std::ffi::{CStr, CString};
5use std::fmt::Debug;
6
7#[derive(Debug, Clone, Copy, PartialEq)]
9pub enum ParamOverrideValue {
10 Bool(bool),
12 Float(f64),
14 Int(i64),
16 Str([std::os::raw::c_char; 128]),
18}
19
20impl ParamOverrideValue {
21 pub(crate) fn tag(&self) -> llama_cpp_sys_2::llama_model_kv_override_type {
22 match self {
23 ParamOverrideValue::Bool(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL,
24 ParamOverrideValue::Float(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
25 ParamOverrideValue::Int(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT,
26 ParamOverrideValue::Str(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_STR,
27 }
28 }
29
30 pub(crate) fn value(&self) -> llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 {
31 match self {
32 ParamOverrideValue::Bool(value) => {
33 llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_bool: *value }
34 }
35 ParamOverrideValue::Float(value) => {
36 llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_f64: *value }
37 }
38 ParamOverrideValue::Int(value) => {
39 llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_i64: *value }
40 }
41 ParamOverrideValue::Str(c_string) => {
42 llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string }
43 }
44 }
45 }
46}
47
48impl From<&llama_cpp_sys_2::llama_model_kv_override> for ParamOverrideValue {
49 fn from(
50 llama_cpp_sys_2::llama_model_kv_override {
51 key: _,
52 tag,
53 __bindgen_anon_1,
54 }: &llama_cpp_sys_2::llama_model_kv_override,
55 ) -> Self {
56 match *tag {
57 llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT => {
58 ParamOverrideValue::Int(unsafe { __bindgen_anon_1.val_i64 })
59 }
60 llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
61 ParamOverrideValue::Float(unsafe { __bindgen_anon_1.val_f64 })
62 }
63 llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
64 ParamOverrideValue::Bool(unsafe { __bindgen_anon_1.val_bool })
65 }
66 llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_STR => {
67 ParamOverrideValue::Str(unsafe { __bindgen_anon_1.val_str })
68 }
69 _ => unreachable!("Unknown tag of {tag}"),
70 }
71 }
72}
73
74#[derive(Debug)]
76pub struct KvOverrides<'a> {
77 model_params: &'a LlamaModelParams,
78}
79
80impl KvOverrides<'_> {
81 pub(super) fn new(model_params: &LlamaModelParams) -> KvOverrides {
82 KvOverrides { model_params }
83 }
84}
85
86impl<'a> IntoIterator for KvOverrides<'a> {
87 type Item = (CString, ParamOverrideValue);
90 type IntoIter = KvOverrideValueIterator<'a>;
91
92 fn into_iter(self) -> Self::IntoIter {
93 KvOverrideValueIterator {
94 model_params: self.model_params,
95 current: 0,
96 }
97 }
98}
99
100#[derive(Debug)]
102pub struct KvOverrideValueIterator<'a> {
103 model_params: &'a LlamaModelParams,
104 current: usize,
105}
106
107impl Iterator for KvOverrideValueIterator<'_> {
108 type Item = (CString, ParamOverrideValue);
109
110 fn next(&mut self) -> Option<Self::Item> {
111 let overrides = self.model_params.params.kv_overrides;
112 if overrides.is_null() {
113 return None;
114 }
115
116 let current = unsafe { *overrides.add(self.current) };
119
120 if current.key[0] == 0 {
121 return None;
122 }
123
124 let value = ParamOverrideValue::from(¤t);
125
126 let key = unsafe { CStr::from_ptr(current.key.as_ptr()).to_owned() };
127
128 self.current += 1;
129 Some((key, value))
130 }
131}