catboost_rs/
model.rs

1use crate::error::{CatBoostError, CatBoostResult};
2use catboost_sys;
3use std::ffi::CString;
4use std::os::unix::ffi::OsStrExt;
5use std::path::Path;
6
7pub struct Model {
8    handle: *mut catboost_sys::ModelCalcerHandle,
9}
10
11impl Model {
12    fn new() -> Self {
13        let model_handle = unsafe { catboost_sys::ModelCalcerCreate() };
14        Model {
15            handle: model_handle,
16        }
17    }
18
19    /// Load a model from a file
20    pub fn load<P: AsRef<Path>>(path: P) -> CatBoostResult<Self> {
21        let model = Model::new();
22        let path_c_str = CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
23        CatBoostError::check_return_value(unsafe {
24            catboost_sys::LoadFullModelFromFile(model.handle, path_c_str.as_ptr())
25        })?;
26        Ok(model)
27    }
28
29    /// Load a model from a buffer
30    pub fn load_buffer<P: AsRef<Vec<u8>>>(buffer: P) -> CatBoostResult<Self> {
31        let model = Model::new();
32        CatBoostError::check_return_value(unsafe {
33            catboost_sys::LoadFullModelFromBuffer(
34                model.handle,
35                buffer.as_ref().as_ptr() as *const std::os::raw::c_void,
36                buffer.as_ref().len(),
37            )
38        })?;
39        Ok(model)
40    }
41
42    /// Calculate raw model predictions on float features and string categorical feature values
43    pub fn calc_model_prediction(
44        &self,
45        float_features: Vec<Vec<f32>>,
46        cat_features: Vec<Vec<String>>,
47    ) -> CatBoostResult<Vec<f64>> {
48        let mut float_features_ptr = float_features
49            .iter()
50            .map(|x| x.as_ptr())
51            .collect::<Vec<_>>();
52
53        let hashed_cat_features = cat_features
54            .iter()
55            .map(|doc_cat_features| {
56                doc_cat_features
57                    .iter()
58                    .map(|cat_feature| unsafe {
59                        catboost_sys::GetStringCatFeatureHash(
60                            cat_feature.as_ptr() as *const std::os::raw::c_char,
61                            cat_feature.len(),
62                        )
63                    })
64                    .collect::<Vec<_>>()
65            })
66            .collect::<Vec<_>>();
67
68        let mut hashed_cat_features_ptr = hashed_cat_features
69            .iter()
70            .map(|x| x.as_ptr())
71            .collect::<Vec<_>>();
72
73        let mut prediction = vec![0.0; float_features.len()];
74        CatBoostError::check_return_value(unsafe {
75            catboost_sys::CalcModelPredictionWithHashedCatFeatures(
76                self.handle,
77                float_features.len(),
78                float_features_ptr.as_mut_ptr(),
79                float_features[0].len(),
80                hashed_cat_features_ptr.as_mut_ptr(),
81                cat_features[0].len(),
82                prediction.as_mut_ptr(),
83                prediction.len(),
84            )
85        })?;
86        Ok(prediction)
87    }
88
89    /// Apply sigmoid to get predict probability
90    // https://catboost.ai/en/docs/concepts/output-data_model-value-output#classification
91    pub fn calc_predict_proba(
92        &self,
93        float_features: Vec<Vec<f32>>,
94        cat_features: Vec<Vec<String>>,
95    ) -> CatBoostResult<Vec<f64>> {
96        let raw_results = self.calc_model_prediction(float_features, cat_features)?;
97        let probabilities = raw_results.into_iter().map(sigmoid).collect();
98        Ok(probabilities)
99    }
100
101    /// Get expected float feature count for model
102    pub fn get_float_features_count(&self) -> usize {
103        unsafe { catboost_sys::GetFloatFeaturesCount(self.handle) }
104    }
105
106    /// Get expected categorical feature count for model
107    pub fn get_cat_features_count(&self) -> usize {
108        unsafe { catboost_sys::GetCatFeaturesCount(self.handle) }
109    }
110
111    /// Get number of trees in model
112    pub fn get_tree_count(&self) -> usize {
113        unsafe { catboost_sys::GetTreeCount(self.handle) }
114    }
115
116    /// Get number of dimensions in model
117    pub fn get_dimensions_count(&self) -> usize {
118        unsafe { catboost_sys::GetDimensionsCount(self.handle) }
119    }
120}
121
122impl Drop for Model {
123    fn drop(&mut self) {
124        unsafe { catboost_sys::ModelCalcerDelete(self.handle) };
125    }
126}
127
128// Should be thread safe as stated here: https://github.com/catboost/catboost/issues/272
129unsafe impl Send for Model {}
130
131unsafe impl Sync for Model {}
132
133fn sigmoid(x: f64) -> f64 {
134    1. / (1. + (-x).exp())
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn load_model() {
143        let model = Model::load("files/model.bin");
144        assert!(model.is_ok());
145    }
146
147    #[test]
148    fn load_model_buffer() {
149        let buffer: Vec<u8> = read_fast("files/model.bin").unwrap();
150        let model = Model::load_buffer(buffer);
151        assert!(model.is_ok());
152    }
153
154    #[test]
155    fn calc_prediction() {
156        let model = Model::load("files/model.bin").unwrap();
157        let prediction = model
158            .calc_model_prediction(
159                vec![
160                    vec![-10.0, 5.0, 753.0],
161                    vec![30.0, 1.0, 760.0],
162                    vec![40.0, 0.1, 705.0],
163                ],
164                vec![
165                    vec![String::from("north")],
166                    vec![String::from("south")],
167                    vec![String::from("south")],
168                ],
169            )
170            .unwrap();
171
172        assert_eq!(prediction[0], 0.9980003729960197);
173        assert_eq!(prediction[1], 0.00249414628534181);
174        assert_eq!(prediction[2], -0.0013677527881450977);
175    }
176
177    #[test]
178    fn get_model_stats() {
179        let model = Model::load("files/model.bin").unwrap();
180
181        assert_eq!(model.get_cat_features_count(), 1);
182        assert_eq!(model.get_float_features_count(), 3);
183        assert_eq!(model.get_tree_count(), 1000);
184        assert_eq!(model.get_dimensions_count(), 1);
185    }
186
187    use std::io::Read;
188    fn read_fast<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<Vec<u8>> {
189        let mut file = std::fs::File::open(path)?;
190        let meta = file.metadata()?;
191        let size = meta.len() as usize;
192        let mut data = Vec::with_capacity(size);
193        data.resize(size, 0);
194        file.read_exact(&mut data)?;
195        Ok(data)
196    }
197}