1use crate::error::{CatBoostError, CatBoostResult};
2use catboost_sys;
3use std::ffi::CString;
4use std::os::unix::ffi::OsStrExt;
5use std::path::Path;
6
7pub struct Model {
8 handle: *mut catboost_sys::ModelCalcerHandle,
9}
10
11impl Model {
12 fn new() -> Self {
13 let model_handle = unsafe { catboost_sys::ModelCalcerCreate() };
14 Model {
15 handle: model_handle,
16 }
17 }
18
19 pub fn load<P: AsRef<Path>>(path: P) -> CatBoostResult<Self> {
21 let model = Model::new();
22 let path_c_str = CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
23 CatBoostError::check_return_value(unsafe {
24 catboost_sys::LoadFullModelFromFile(model.handle, path_c_str.as_ptr())
25 })?;
26 Ok(model)
27 }
28
29 pub fn load_buffer<P: AsRef<Vec<u8>>>(buffer: P) -> CatBoostResult<Self> {
31 let model = Model::new();
32 CatBoostError::check_return_value(unsafe {
33 catboost_sys::LoadFullModelFromBuffer(
34 model.handle,
35 buffer.as_ref().as_ptr() as *const std::os::raw::c_void,
36 buffer.as_ref().len(),
37 )
38 })?;
39 Ok(model)
40 }
41
42 pub fn calc_model_prediction(
44 &self,
45 float_features: Vec<Vec<f32>>,
46 cat_features: Vec<Vec<String>>,
47 ) -> CatBoostResult<Vec<f64>> {
48 let mut float_features_ptr = float_features
49 .iter()
50 .map(|x| x.as_ptr())
51 .collect::<Vec<_>>();
52
53 let hashed_cat_features = cat_features
54 .iter()
55 .map(|doc_cat_features| {
56 doc_cat_features
57 .iter()
58 .map(|cat_feature| unsafe {
59 catboost_sys::GetStringCatFeatureHash(
60 cat_feature.as_ptr() as *const std::os::raw::c_char,
61 cat_feature.len(),
62 )
63 })
64 .collect::<Vec<_>>()
65 })
66 .collect::<Vec<_>>();
67
68 let mut hashed_cat_features_ptr = hashed_cat_features
69 .iter()
70 .map(|x| x.as_ptr())
71 .collect::<Vec<_>>();
72
73 let mut prediction = vec![0.0; float_features.len()];
74 CatBoostError::check_return_value(unsafe {
75 catboost_sys::CalcModelPredictionWithHashedCatFeatures(
76 self.handle,
77 float_features.len(),
78 float_features_ptr.as_mut_ptr(),
79 float_features[0].len(),
80 hashed_cat_features_ptr.as_mut_ptr(),
81 cat_features[0].len(),
82 prediction.as_mut_ptr(),
83 prediction.len(),
84 )
85 })?;
86 Ok(prediction)
87 }
88
89 pub fn calc_predict_proba(
92 &self,
93 float_features: Vec<Vec<f32>>,
94 cat_features: Vec<Vec<String>>,
95 ) -> CatBoostResult<Vec<f64>> {
96 let raw_results = self.calc_model_prediction(float_features, cat_features)?;
97 let probabilities = raw_results.into_iter().map(sigmoid).collect();
98 Ok(probabilities)
99 }
100
101 pub fn get_float_features_count(&self) -> usize {
103 unsafe { catboost_sys::GetFloatFeaturesCount(self.handle) }
104 }
105
106 pub fn get_cat_features_count(&self) -> usize {
108 unsafe { catboost_sys::GetCatFeaturesCount(self.handle) }
109 }
110
111 pub fn get_tree_count(&self) -> usize {
113 unsafe { catboost_sys::GetTreeCount(self.handle) }
114 }
115
116 pub fn get_dimensions_count(&self) -> usize {
118 unsafe { catboost_sys::GetDimensionsCount(self.handle) }
119 }
120}
121
122impl Drop for Model {
123 fn drop(&mut self) {
124 unsafe { catboost_sys::ModelCalcerDelete(self.handle) };
125 }
126}
127
128unsafe impl Send for Model {}
130
131unsafe impl Sync for Model {}
132
133fn sigmoid(x: f64) -> f64 {
134 1. / (1. + (-x).exp())
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn load_model() {
143 let model = Model::load("files/model.bin");
144 assert!(model.is_ok());
145 }
146
147 #[test]
148 fn load_model_buffer() {
149 let buffer: Vec<u8> = read_fast("files/model.bin").unwrap();
150 let model = Model::load_buffer(buffer);
151 assert!(model.is_ok());
152 }
153
154 #[test]
155 fn calc_prediction() {
156 let model = Model::load("files/model.bin").unwrap();
157 let prediction = model
158 .calc_model_prediction(
159 vec![
160 vec![-10.0, 5.0, 753.0],
161 vec![30.0, 1.0, 760.0],
162 vec![40.0, 0.1, 705.0],
163 ],
164 vec![
165 vec![String::from("north")],
166 vec![String::from("south")],
167 vec![String::from("south")],
168 ],
169 )
170 .unwrap();
171
172 assert_eq!(prediction[0], 0.9980003729960197);
173 assert_eq!(prediction[1], 0.00249414628534181);
174 assert_eq!(prediction[2], -0.0013677527881450977);
175 }
176
177 #[test]
178 fn get_model_stats() {
179 let model = Model::load("files/model.bin").unwrap();
180
181 assert_eq!(model.get_cat_features_count(), 1);
182 assert_eq!(model.get_float_features_count(), 3);
183 assert_eq!(model.get_tree_count(), 1000);
184 assert_eq!(model.get_dimensions_count(), 1);
185 }
186
187 use std::io::Read;
188 fn read_fast<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<Vec<u8>> {
189 let mut file = std::fs::File::open(path)?;
190 let meta = file.metadata()?;
191 let size = meta.len() as usize;
192 let mut data = Vec::with_capacity(size);
193 data.resize(size, 0);
194 file.read_exact(&mut data)?;
195 Ok(data)
196 }
197}