catboost_portable/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(
3    missing_docs,
4    missing_debug_implementations,
5    rustdoc::broken_intra_doc_links,
6    rustdoc::bare_urls,
7    macro_use_extern_crate,
8    non_ascii_idents,
9    elided_lifetimes_in_paths
10)]
11
12use std::{
13    convert::TryInto,
14    ffi::{CStr, CString},
15    fmt,
16    path::Path,
17};
18
19#[cfg(unix)]
20use std::os::unix::ffi::OsStrExt;
21#[cfg(windows)]
22use std::os::windows::ffi::OsStrExt;
23
24/// Raw bindings to the catboost library.
25#[doc(inline)]
26pub use catboost_portable_sys as sys;
27
28/// The result of a catboost operation.
29pub type Result<T> = std::result::Result<T, Error>;
30
31/// A catboost error.
32///
33/// This is just a wrapper around an error message that
34/// can only be constructed by reading the last error
35/// message using the [`Error::fetch_catboost_error`] method.
36#[derive(Debug, Eq, PartialEq)]
37#[repr(transparent)]
38pub struct Error {
39    description: String,
40}
41
42impl Error {
43    /// Shorthand for checking the result of a FFI call.
44    ///
45    /// If the `ret_val` is `true`, the call was successfull and
46    /// `Ok(val)` is returned. If the return value is `false`,
47    /// the error message is fetched and returned.
48    pub fn call<T>(ret_val: bool, val: T) -> Result<T> {
49        if ret_val {
50            Ok(val)
51        } else {
52            Err(Error::fetch_catboost_error())
53        }
54    }
55
56    /// Fetch current error message from CatBoost
57    /// using the [`sys::GetErrorString`] FFI call.
58    pub fn fetch_catboost_error() -> Self {
59        let c_str = unsafe { CStr::from_ptr(sys::GetErrorString()) };
60        let str_slice = c_str
61            .to_str()
62            .expect("non-utf8 error message returned from catboost");
63        Error {
64            description: str_slice.to_owned(),
65        }
66    }
67}
68
69impl fmt::Display for Error {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        write!(f, "{}", self.description)
72    }
73}
74
75impl std::error::Error for Error {}
76
77/// Safe wrapper around a [raw catboost model](sys::ModelCalcerHandle).
78///
79/// This is the central type of this library and is the entry
80/// point for any use of this library.
81#[derive(Debug)]
82pub struct Model {
83    handle: *mut sys::ModelCalcerHandle,
84}
85
86impl Model {
87    fn new() -> Self {
88        let model_handle = unsafe { sys::ModelCalcerCreate() };
89        Model {
90            handle: model_handle,
91        }
92    }
93
94    /// Load a model from a file.
95    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
96        let model = Model::new();
97
98        #[cfg(unix)]
99        let path_c_str = CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
100        #[cfg(windows)]
101        let path_c_str =
102            CString::new(path.as_ref().as_os_str().to_string_lossy().as_bytes()).unwrap();
103
104        Error::call(
105            unsafe { sys::LoadFullModelFromFile(model.handle, path_c_str.as_ptr()) },
106            model,
107        )
108    }
109
110    /// Load a model from a buffer.
111    pub fn load_buffer<P: AsRef<Vec<u8>>>(buffer: P) -> Result<Self> {
112        let model = Model::new();
113        Error::call(
114            unsafe {
115                sys::LoadFullModelFromBuffer(
116                    model.handle,
117                    buffer.as_ref().as_ptr() as *const std::os::raw::c_void,
118                    buffer.as_ref().len().try_into().unwrap(),
119                )
120            },
121            model,
122        )
123    }
124
125    /// Calculate raw model predictions on float features and string categorical feature values
126    pub fn calc_model_prediction(
127        &self,
128        float_features: Vec<Vec<f32>>,
129        cat_features: Vec<Vec<String>>,
130    ) -> Result<Vec<f64>> {
131        let mut float_features_ptr = float_features
132            .iter()
133            .map(|x| x.as_ptr())
134            .collect::<Vec<_>>();
135
136        let hashed_cat_features = cat_features
137            .iter()
138            .map(|doc_cat_features| {
139                doc_cat_features
140                    .iter()
141                    .map(|cat_feature| unsafe {
142                        sys::GetStringCatFeatureHash(
143                            cat_feature.as_ptr() as *const std::os::raw::c_char,
144                            cat_feature.len().try_into().unwrap(),
145                        )
146                    })
147                    .collect::<Vec<_>>()
148            })
149            .collect::<Vec<_>>();
150
151        let mut hashed_cat_features_ptr = hashed_cat_features
152            .iter()
153            .map(|x| x.as_ptr())
154            .collect::<Vec<_>>();
155
156        let mut prediction = vec![0.0; float_features.len()];
157        Error::call(
158            unsafe {
159                sys::CalcModelPredictionWithHashedCatFeatures(
160                    self.handle,
161                    float_features.len().try_into().unwrap(),
162                    float_features_ptr.as_mut_ptr(),
163                    float_features[0].len().try_into().unwrap(),
164                    hashed_cat_features_ptr.as_mut_ptr(),
165                    cat_features[0].len().try_into().unwrap(),
166                    prediction.as_mut_ptr(),
167                    prediction.len().try_into().unwrap(),
168                )
169            },
170            prediction,
171        )
172    }
173
174    /// Get expected float feature count for model
175    pub fn get_float_features_count(&self) -> u64 {
176        unsafe { sys::GetFloatFeaturesCount(self.handle) }
177    }
178
179    /// Get expected categorical feature count for model
180    pub fn get_cat_features_count(&self) -> u64 {
181        unsafe { sys::GetCatFeaturesCount(self.handle) }
182    }
183
184    /// Get number of trees in model
185    pub fn get_tree_count(&self) -> u64 {
186        unsafe { sys::GetTreeCount(self.handle) }
187    }
188
189    /// Get number of dimensions in model
190    pub fn get_dimensions_count(&self) -> u64 {
191        unsafe { sys::GetDimensionsCount(self.handle) }
192    }
193}
194
195impl Drop for Model {
196    fn drop(&mut self) {
197        unsafe { sys::ModelCalcerDelete(self.handle) };
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use std::fs::read;
205
206    #[test]
207    fn load_model() {
208        let model = Model::load("model.bin");
209        assert!(model.is_ok());
210    }
211
212    #[test]
213    fn load_model_buffer() {
214        let buffer: Vec<u8> = read("model.bin").unwrap();
215        let model = Model::load_buffer(buffer);
216        assert!(model.is_ok());
217    }
218
219    #[test]
220    fn calc_prediction() {
221        let model = Model::load("model.bin").unwrap();
222        let prediction = model
223            .calc_model_prediction(
224                vec![
225                    vec![-10.0, 5.0, 753.0],
226                    vec![30.0, 1.0, 760.0],
227                    vec![40.0, 0.1, 705.0],
228                ],
229                vec![
230                    vec![String::from("north")],
231                    vec![String::from("south")],
232                    vec![String::from("south")],
233                ],
234            )
235            .unwrap();
236
237        assert_eq!(prediction[0], 0.9980003729960197);
238        assert_eq!(prediction[1], 0.00249414628534181);
239        assert_eq!(prediction[2], -0.0013677527881450977);
240    }
241
242    #[test]
243    fn get_model_stats() {
244        let model = Model::load("model.bin").unwrap();
245
246        assert_eq!(model.get_cat_features_count(), 1);
247        assert_eq!(model.get_float_features_count(), 3);
248        assert_eq!(model.get_tree_count(), 1000);
249        assert_eq!(model.get_dimensions_count(), 1);
250    }
251}