catboost2 0.1.1+catboost.1.0.5

Unofficial bindings for the catboost library.
Documentation
#![doc = include_str!("../README.md")]
#![deny(
    missing_docs,
    missing_debug_implementations,
    rustdoc::broken_intra_doc_links,
    rustdoc::bare_urls,
    macro_use_extern_crate,
    non_ascii_idents,
    elided_lifetimes_in_paths
)]

use std::{
    convert::TryInto,
    ffi::{CStr, CString},
    fmt,
    path::Path,
};

#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
#[cfg(windows)]
use std::os::windows::ffi::OsStrExt;

/// Raw bindings to the catboost library.
#[doc(inline)]
pub use catboost2_sys as sys;

/// The result of a catboost operation.
pub type Result<T> = std::result::Result<T, Error>;

/// A catboost error.
///
/// This is just a wrapper around an error message that
/// can only be constructed by reading the last error
/// message using the [`Error::fetch_catboost_error`] method.
#[derive(Debug, Eq, PartialEq)]
#[repr(transparent)]
pub struct Error {
    description: String,
}

impl Error {
    /// Shorthand for checking the result of a FFI call.
    ///
    /// If the `ret_val` is `true`, the call was successfull and
    /// `Ok(val)` is returned. If the return value is `false`,
    /// the error message is fetched and returned.
    pub fn call<T>(ret_val: bool, val: T) -> Result<T> {
        if ret_val {
            Ok(val)
        } else {
            Err(Error::fetch_catboost_error())
        }
    }

    /// Fetch current error message from CatBoost
    /// using the [`sys::GetErrorString`] FFI call.
    pub fn fetch_catboost_error() -> Self {
        let c_str = unsafe { CStr::from_ptr(sys::GetErrorString()) };
        let str_slice = c_str
            .to_str()
            .expect("non-utf8 error message returned from catboost");
        Error {
            description: str_slice.to_owned(),
        }
    }
}

impl fmt::Display for Error {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.description)
    }
}

impl std::error::Error for Error {}

/// Safe wrapper around a [raw catboost model](sys::ModelCalcerHandle).
///
/// This is the central type of this library and is the entry
/// point for any use of this library.
#[derive(Debug)]
pub struct Model {
    handle: *mut sys::ModelCalcerHandle,
}

impl Model {
    fn new() -> Self {
        let model_handle = unsafe { sys::ModelCalcerCreate() };
        Model {
            handle: model_handle,
        }
    }

    /// Load a model from a file.
    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
        let model = Model::new();

        #[cfg(unix)]
        let path_c_str = CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
        #[cfg(windows)]
        let path_c_str =
            CString::new(path.as_ref().as_os_str().to_string_lossy().as_bytes()).unwrap();

        Error::call(
            unsafe { sys::LoadFullModelFromFile(model.handle, path_c_str.as_ptr()) },
            model,
        )
    }

    /// Load a model from a buffer.
    pub fn load_buffer<P: AsRef<Vec<u8>>>(buffer: P) -> Result<Self> {
        let model = Model::new();
        Error::call(
            unsafe {
                sys::LoadFullModelFromBuffer(
                    model.handle,
                    buffer.as_ref().as_ptr() as *const std::os::raw::c_void,
                    buffer.as_ref().len().try_into().unwrap(),
                )
            },
            model,
        )
    }

    /// Calculate raw model predictions on float features and string categorical feature values
    pub fn calc_model_prediction(
        &self,
        float_features: Vec<Vec<f32>>,
        cat_features: Vec<Vec<String>>,
    ) -> Result<Vec<f64>> {
        let mut float_features_ptr = float_features
            .iter()
            .map(|x| x.as_ptr())
            .collect::<Vec<_>>();

        let hashed_cat_features = cat_features
            .iter()
            .map(|doc_cat_features| {
                doc_cat_features
                    .iter()
                    .map(|cat_feature| unsafe {
                        sys::GetStringCatFeatureHash(
                            cat_feature.as_ptr() as *const std::os::raw::c_char,
                            cat_feature.len().try_into().unwrap(),
                        )
                    })
                    .collect::<Vec<_>>()
            })
            .collect::<Vec<_>>();

        let mut hashed_cat_features_ptr = hashed_cat_features
            .iter()
            .map(|x| x.as_ptr())
            .collect::<Vec<_>>();

        let mut prediction = vec![0.0; float_features.len()];
        Error::call(
            unsafe {
                sys::CalcModelPredictionWithHashedCatFeatures(
                    self.handle,
                    float_features.len().try_into().unwrap(),
                    float_features_ptr.as_mut_ptr(),
                    float_features[0].len().try_into().unwrap(),
                    hashed_cat_features_ptr.as_mut_ptr(),
                    cat_features[0].len().try_into().unwrap(),
                    prediction.as_mut_ptr(),
                    prediction.len().try_into().unwrap(),
                )
            },
            prediction,
        )
    }

    /// Get expected float feature count for model
    pub fn get_float_features_count(&self) -> u64 {
        unsafe { sys::GetFloatFeaturesCount(self.handle) }
    }

    /// Get expected categorical feature count for model
    pub fn get_cat_features_count(&self) -> u64 {
        unsafe { sys::GetCatFeaturesCount(self.handle) }
    }

    /// Get number of trees in model
    pub fn get_tree_count(&self) -> u64 {
        unsafe { sys::GetTreeCount(self.handle) }
    }

    /// Get number of dimensions in model
    pub fn get_dimensions_count(&self) -> u64 {
        unsafe { sys::GetDimensionsCount(self.handle) }
    }
}

impl Drop for Model {
    fn drop(&mut self) {
        unsafe { sys::ModelCalcerDelete(self.handle) };
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::fs::read;

    #[test]
    fn load_model() {
        let model = Model::load("model.bin");
        assert!(model.is_ok());
    }

    #[test]
    fn load_model_buffer() {
        let buffer: Vec<u8> = read("model.bin").unwrap();
        let model = Model::load_buffer(buffer);
        assert!(model.is_ok());
    }

    #[test]
    fn calc_prediction() {
        let model = Model::load("model.bin").unwrap();
        let prediction = model
            .calc_model_prediction(
                vec![
                    vec![-10.0, 5.0, 753.0],
                    vec![30.0, 1.0, 760.0],
                    vec![40.0, 0.1, 705.0],
                ],
                vec![
                    vec![String::from("north")],
                    vec![String::from("south")],
                    vec![String::from("south")],
                ],
            )
            .unwrap();

        assert_eq!(prediction[0], 0.9980003729960197);
        assert_eq!(prediction[1], 0.00249414628534181);
        assert_eq!(prediction[2], -0.0013677527881450977);
    }

    #[test]
    fn get_model_stats() {
        let model = Model::load("model.bin").unwrap();

        assert_eq!(model.get_cat_features_count(), 1);
        assert_eq!(model.get_float_features_count(), 3);
        assert_eq!(model.get_tree_count(), 1000);
        assert_eq!(model.get_dimensions_count(), 1);
    }
}