#![doc = include_str!("../README.md")]
#![deny(
missing_docs,
missing_debug_implementations,
rustdoc::broken_intra_doc_links,
rustdoc::bare_urls,
macro_use_extern_crate,
non_ascii_idents,
elided_lifetimes_in_paths
)]
use std::{
convert::TryInto,
ffi::{CStr, CString},
fmt,
path::Path,
};
#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
#[cfg(windows)]
use std::os::windows::ffi::OsStrExt;
#[doc(inline)]
pub use catboost2_sys as sys;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, Eq, PartialEq)]
#[repr(transparent)]
pub struct Error {
description: String,
}
impl Error {
pub fn call<T>(ret_val: bool, val: T) -> Result<T> {
if ret_val {
Ok(val)
} else {
Err(Error::fetch_catboost_error())
}
}
pub fn fetch_catboost_error() -> Self {
let c_str = unsafe { CStr::from_ptr(sys::GetErrorString()) };
let str_slice = c_str
.to_str()
.expect("non-utf8 error message returned from catboost");
Error {
description: str_slice.to_owned(),
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.description)
}
}
impl std::error::Error for Error {}
#[derive(Debug)]
pub struct Model {
handle: *mut sys::ModelCalcerHandle,
}
impl Model {
fn new() -> Self {
let model_handle = unsafe { sys::ModelCalcerCreate() };
Model {
handle: model_handle,
}
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let model = Model::new();
#[cfg(unix)]
let path_c_str = CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
#[cfg(windows)]
let path_c_str =
CString::new(path.as_ref().as_os_str().to_string_lossy().as_bytes()).unwrap();
Error::call(
unsafe { sys::LoadFullModelFromFile(model.handle, path_c_str.as_ptr()) },
model,
)
}
pub fn load_buffer<P: AsRef<Vec<u8>>>(buffer: P) -> Result<Self> {
let model = Model::new();
Error::call(
unsafe {
sys::LoadFullModelFromBuffer(
model.handle,
buffer.as_ref().as_ptr() as *const std::os::raw::c_void,
buffer.as_ref().len().try_into().unwrap(),
)
},
model,
)
}
pub fn calc_model_prediction(
&self,
float_features: Vec<Vec<f32>>,
cat_features: Vec<Vec<String>>,
) -> Result<Vec<f64>> {
let mut float_features_ptr = float_features
.iter()
.map(|x| x.as_ptr())
.collect::<Vec<_>>();
let hashed_cat_features = cat_features
.iter()
.map(|doc_cat_features| {
doc_cat_features
.iter()
.map(|cat_feature| unsafe {
sys::GetStringCatFeatureHash(
cat_feature.as_ptr() as *const std::os::raw::c_char,
cat_feature.len().try_into().unwrap(),
)
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let mut hashed_cat_features_ptr = hashed_cat_features
.iter()
.map(|x| x.as_ptr())
.collect::<Vec<_>>();
let mut prediction = vec![0.0; float_features.len()];
Error::call(
unsafe {
sys::CalcModelPredictionWithHashedCatFeatures(
self.handle,
float_features.len().try_into().unwrap(),
float_features_ptr.as_mut_ptr(),
float_features[0].len().try_into().unwrap(),
hashed_cat_features_ptr.as_mut_ptr(),
cat_features[0].len().try_into().unwrap(),
prediction.as_mut_ptr(),
prediction.len().try_into().unwrap(),
)
},
prediction,
)
}
pub fn get_float_features_count(&self) -> u64 {
unsafe { sys::GetFloatFeaturesCount(self.handle) }
}
pub fn get_cat_features_count(&self) -> u64 {
unsafe { sys::GetCatFeaturesCount(self.handle) }
}
pub fn get_tree_count(&self) -> u64 {
unsafe { sys::GetTreeCount(self.handle) }
}
pub fn get_dimensions_count(&self) -> u64 {
unsafe { sys::GetDimensionsCount(self.handle) }
}
}
impl Drop for Model {
fn drop(&mut self) {
unsafe { sys::ModelCalcerDelete(self.handle) };
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::read;
#[test]
fn load_model() {
let model = Model::load("model.bin");
assert!(model.is_ok());
}
#[test]
fn load_model_buffer() {
let buffer: Vec<u8> = read("model.bin").unwrap();
let model = Model::load_buffer(buffer);
assert!(model.is_ok());
}
#[test]
fn calc_prediction() {
let model = Model::load("model.bin").unwrap();
let prediction = model
.calc_model_prediction(
vec![
vec![-10.0, 5.0, 753.0],
vec![30.0, 1.0, 760.0],
vec![40.0, 0.1, 705.0],
],
vec![
vec![String::from("north")],
vec![String::from("south")],
vec![String::from("south")],
],
)
.unwrap();
assert_eq!(prediction[0], 0.9980003729960197);
assert_eq!(prediction[1], 0.00249414628534181);
assert_eq!(prediction[2], -0.0013677527881450977);
}
#[test]
fn get_model_stats() {
let model = Model::load("model.bin").unwrap();
assert_eq!(model.get_cat_features_count(), 1);
assert_eq!(model.get_float_features_count(), 3);
assert_eq!(model.get_tree_count(), 1000);
assert_eq!(model.get_dimensions_count(), 1);
}
}