llama-cpp-bindings 0.4.2

llama.cpp bindings for Rust
Documentation
/// An override value for a model parameter.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ParamOverrideValue {
    /// A boolean value
    Bool(bool),
    /// A float value
    Float(f64),
    /// A integer value
    Int(i64),
    /// A string value
    Str([std::os::raw::c_char; 128]),
}

impl ParamOverrideValue {
    /// Returns the FFI tag corresponding to this override value variant.
    #[must_use]
    pub const fn tag(&self) -> llama_cpp_bindings_sys::llama_model_kv_override_type {
        match self {
            Self::Bool(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
            Self::Float(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
            Self::Int(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
            Self::Str(_) => llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
        }
    }

    /// Returns the FFI union value for this override.
    #[must_use]
    pub const fn value(&self) -> llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
        match self {
            Self::Bool(value) => {
                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_bool: *value }
            }
            Self::Float(value) => {
                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_f64: *value }
            }
            Self::Int(value) => {
                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_i64: *value }
            }
            Self::Str(c_string) => {
                llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string }
            }
        }
    }
}

/// Unknown KV override tag from the FFI layer.
#[derive(Debug, thiserror::Error)]
#[error("unknown KV override tag: {0}")]
pub struct UnknownKvOverrideTag(pub llama_cpp_bindings_sys::llama_model_kv_override_type);

impl TryFrom<&llama_cpp_bindings_sys::llama_model_kv_override> for ParamOverrideValue {
    type Error = UnknownKvOverrideTag;

    fn try_from(
        llama_cpp_bindings_sys::llama_model_kv_override {
            key: _,
            tag,
            __bindgen_anon_1,
        }: &llama_cpp_bindings_sys::llama_model_kv_override,
    ) -> Result<Self, Self::Error> {
        match *tag {
            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT => {
                Ok(Self::Int(unsafe { __bindgen_anon_1.val_i64 }))
            }
            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT => {
                Ok(Self::Float(unsafe { __bindgen_anon_1.val_f64 }))
            }
            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL => {
                Ok(Self::Bool(unsafe { __bindgen_anon_1.val_bool }))
            }
            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR => {
                Ok(Self::Str(unsafe { __bindgen_anon_1.val_str }))
            }
            unknown_tag => Err(UnknownKvOverrideTag(unknown_tag)),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::ParamOverrideValue;

    #[test]
    fn tag_bool() {
        let value = ParamOverrideValue::Bool(true);

        assert_eq!(
            value.tag(),
            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL
        );
    }

    #[test]
    fn tag_float() {
        let value = ParamOverrideValue::Float(1.23);

        assert_eq!(
            value.tag(),
            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT
        );
    }

    #[test]
    fn tag_int() {
        let value = ParamOverrideValue::Int(42);

        assert_eq!(
            value.tag(),
            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT
        );
    }

    #[test]
    fn tag_str() {
        let value = ParamOverrideValue::Str([0; 128]);

        assert_eq!(
            value.tag(),
            llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR
        );
    }

    #[test]
    fn value_bool_roundtrip() {
        let value = ParamOverrideValue::Bool(true);
        let ffi_value = value.value();
        let result = unsafe { ffi_value.val_bool };

        assert!(result);
    }

    #[test]
    fn value_float_roundtrip() {
        let value = ParamOverrideValue::Float(1.23);
        let ffi_value = value.value();
        let result = unsafe { ffi_value.val_f64 };

        assert!((result - 1.23).abs() < f64::EPSILON);
    }

    #[test]
    fn value_int_roundtrip() {
        let value = ParamOverrideValue::Int(99);
        let ffi_value = value.value();
        let result = unsafe { ffi_value.val_i64 };

        assert_eq!(result, 99);
    }

    #[test]
    fn from_ffi_override_int() {
        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
            key: [0; 128],
            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_INT,
            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
                val_i64: 123,
            },
        };

        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();

        assert_eq!(value, ParamOverrideValue::Int(123));
    }

    #[test]
    fn from_ffi_override_float() {
        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
            key: [0; 128],
            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_FLOAT,
            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
                val_f64: 1.5,
            },
        };

        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();

        assert_eq!(value, ParamOverrideValue::Float(1.5));
    }

    #[test]
    fn from_ffi_override_bool() {
        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
            key: [0; 128],
            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_BOOL,
            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
                val_bool: false,
            },
        };

        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();

        assert_eq!(value, ParamOverrideValue::Bool(false));
    }

    #[test]
    fn value_str_roundtrip() {
        let mut str_data = [0i8; 128];
        str_data[0] = b'h'.cast_signed();
        str_data[1] = b'i'.cast_signed();

        let value = ParamOverrideValue::Str(str_data);
        let ffi_value = value.value();
        let result = unsafe { ffi_value.val_str };

        assert_eq!(result[0], b'h'.cast_signed());
        assert_eq!(result[1], b'i'.cast_signed());
    }

    #[test]
    fn from_ffi_override_str() {
        let mut str_data = [0i8; 128];
        str_data[0] = b'a'.cast_signed();
        str_data[1] = b'b'.cast_signed();

        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
            key: [0; 128],
            tag: llama_cpp_bindings_sys::LLAMA_KV_OVERRIDE_TYPE_STR,
            __bindgen_anon_1: llama_cpp_bindings_sys::llama_model_kv_override__bindgen_ty_1 {
                val_str: str_data,
            },
        };

        let value = ParamOverrideValue::try_from(&ffi_override).unwrap();

        assert_eq!(value, ParamOverrideValue::Str(str_data));
    }

    #[test]
    fn unknown_tag_returns_error() {
        let ffi_override = llama_cpp_bindings_sys::llama_model_kv_override {
            key: [0; 128],
            tag: 9999,
            __bindgen_anon_1: unsafe { std::mem::zeroed() },
        };

        let result = ParamOverrideValue::try_from(&ffi_override);

        assert!(result.is_err());
    }
}