use llama_crab_sys as sys;
#[derive(Clone, Debug)]
pub enum ParamOverrideValue {
Bool(bool),
Float(f64),
Int(i64),
Str(String),
}
impl ParamOverrideValue {
pub(crate) fn as_c(&self) -> sys::llama_model_kv_override {
let mut out: sys::llama_model_kv_override = unsafe { std::mem::zeroed() };
match self {
Self::Bool(b) => {
out.tag = sys::llama_model_kv_override_type::LLAMA_KV_OVERRIDE_TYPE_BOOL;
out.__bindgen_anon_1.val_bool = *b;
}
Self::Float(f) => {
out.tag = sys::llama_model_kv_override_type::LLAMA_KV_OVERRIDE_TYPE_FLOAT;
out.__bindgen_anon_1.val_f64 = *f;
}
Self::Int(i) => {
out.tag = sys::llama_model_kv_override_type::LLAMA_KV_OVERRIDE_TYPE_INT;
out.__bindgen_anon_1.val_i64 = *i;
}
Self::Str(s) => {
out.tag = sys::llama_model_kv_override_type::LLAMA_KV_OVERRIDE_TYPE_STR;
let bytes = s.as_bytes();
let n = bytes.len().min(127);
let dst = unsafe { &mut out.__bindgen_anon_1.val_str };
unsafe {
std::ptr::copy_nonoverlapping(bytes.as_ptr(), dst.as_mut_ptr().cast(), n);
dst[n] = 0;
}
}
}
out
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bool_override() {
let c = ParamOverrideValue::Bool(true).as_c();
assert_eq!(
c.tag,
sys::llama_model_kv_override_type::LLAMA_KV_OVERRIDE_TYPE_BOOL
);
}
#[test]
fn string_override_truncates() {
let big = "x".repeat(200);
let c = ParamOverrideValue::Str(big).as_c();
assert_eq!(
c.tag,
sys::llama_model_kv_override_type::LLAMA_KV_OVERRIDE_TYPE_STR
);
}
}