1#![doc = include_str!("../README.md")]
2#![deny(
3 missing_docs,
4 missing_debug_implementations,
5 rustdoc::broken_intra_doc_links,
6 rustdoc::bare_urls,
7 macro_use_extern_crate,
8 non_ascii_idents,
9 elided_lifetimes_in_paths
10)]
11
12use std::{
13 convert::TryInto,
14 ffi::{CStr, CString},
15 fmt,
16 path::Path,
17};
18
19#[cfg(unix)]
20use std::os::unix::ffi::OsStrExt;
21#[cfg(windows)]
22use std::os::windows::ffi::OsStrExt;
23
24#[doc(inline)]
26pub use catboost_portable_sys as sys;
27
28pub type Result<T> = std::result::Result<T, Error>;
30
31#[derive(Debug, Eq, PartialEq)]
37#[repr(transparent)]
38pub struct Error {
39 description: String,
40}
41
42impl Error {
43 pub fn call<T>(ret_val: bool, val: T) -> Result<T> {
49 if ret_val {
50 Ok(val)
51 } else {
52 Err(Error::fetch_catboost_error())
53 }
54 }
55
56 pub fn fetch_catboost_error() -> Self {
59 let c_str = unsafe { CStr::from_ptr(sys::GetErrorString()) };
60 let str_slice = c_str
61 .to_str()
62 .expect("non-utf8 error message returned from catboost");
63 Error {
64 description: str_slice.to_owned(),
65 }
66 }
67}
68
69impl fmt::Display for Error {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 write!(f, "{}", self.description)
72 }
73}
74
75impl std::error::Error for Error {}
76
77#[derive(Debug)]
82pub struct Model {
83 handle: *mut sys::ModelCalcerHandle,
84}
85
86impl Model {
87 fn new() -> Self {
88 let model_handle = unsafe { sys::ModelCalcerCreate() };
89 Model {
90 handle: model_handle,
91 }
92 }
93
94 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
96 let model = Model::new();
97
98 #[cfg(unix)]
99 let path_c_str = CString::new(path.as_ref().as_os_str().as_bytes()).unwrap();
100 #[cfg(windows)]
101 let path_c_str =
102 CString::new(path.as_ref().as_os_str().to_string_lossy().as_bytes()).unwrap();
103
104 Error::call(
105 unsafe { sys::LoadFullModelFromFile(model.handle, path_c_str.as_ptr()) },
106 model,
107 )
108 }
109
110 pub fn load_buffer<P: AsRef<Vec<u8>>>(buffer: P) -> Result<Self> {
112 let model = Model::new();
113 Error::call(
114 unsafe {
115 sys::LoadFullModelFromBuffer(
116 model.handle,
117 buffer.as_ref().as_ptr() as *const std::os::raw::c_void,
118 buffer.as_ref().len().try_into().unwrap(),
119 )
120 },
121 model,
122 )
123 }
124
125 pub fn calc_model_prediction(
127 &self,
128 float_features: Vec<Vec<f32>>,
129 cat_features: Vec<Vec<String>>,
130 ) -> Result<Vec<f64>> {
131 let mut float_features_ptr = float_features
132 .iter()
133 .map(|x| x.as_ptr())
134 .collect::<Vec<_>>();
135
136 let hashed_cat_features = cat_features
137 .iter()
138 .map(|doc_cat_features| {
139 doc_cat_features
140 .iter()
141 .map(|cat_feature| unsafe {
142 sys::GetStringCatFeatureHash(
143 cat_feature.as_ptr() as *const std::os::raw::c_char,
144 cat_feature.len().try_into().unwrap(),
145 )
146 })
147 .collect::<Vec<_>>()
148 })
149 .collect::<Vec<_>>();
150
151 let mut hashed_cat_features_ptr = hashed_cat_features
152 .iter()
153 .map(|x| x.as_ptr())
154 .collect::<Vec<_>>();
155
156 let mut prediction = vec![0.0; float_features.len()];
157 Error::call(
158 unsafe {
159 sys::CalcModelPredictionWithHashedCatFeatures(
160 self.handle,
161 float_features.len().try_into().unwrap(),
162 float_features_ptr.as_mut_ptr(),
163 float_features[0].len().try_into().unwrap(),
164 hashed_cat_features_ptr.as_mut_ptr(),
165 cat_features[0].len().try_into().unwrap(),
166 prediction.as_mut_ptr(),
167 prediction.len().try_into().unwrap(),
168 )
169 },
170 prediction,
171 )
172 }
173
174 pub fn get_float_features_count(&self) -> u64 {
176 unsafe { sys::GetFloatFeaturesCount(self.handle) }
177 }
178
179 pub fn get_cat_features_count(&self) -> u64 {
181 unsafe { sys::GetCatFeaturesCount(self.handle) }
182 }
183
184 pub fn get_tree_count(&self) -> u64 {
186 unsafe { sys::GetTreeCount(self.handle) }
187 }
188
189 pub fn get_dimensions_count(&self) -> u64 {
191 unsafe { sys::GetDimensionsCount(self.handle) }
192 }
193}
194
195impl Drop for Model {
196 fn drop(&mut self) {
197 unsafe { sys::ModelCalcerDelete(self.handle) };
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use std::fs::read;
205
206 #[test]
207 fn load_model() {
208 let model = Model::load("model.bin");
209 assert!(model.is_ok());
210 }
211
212 #[test]
213 fn load_model_buffer() {
214 let buffer: Vec<u8> = read("model.bin").unwrap();
215 let model = Model::load_buffer(buffer);
216 assert!(model.is_ok());
217 }
218
219 #[test]
220 fn calc_prediction() {
221 let model = Model::load("model.bin").unwrap();
222 let prediction = model
223 .calc_model_prediction(
224 vec![
225 vec![-10.0, 5.0, 753.0],
226 vec![30.0, 1.0, 760.0],
227 vec![40.0, 0.1, 705.0],
228 ],
229 vec![
230 vec![String::from("north")],
231 vec![String::from("south")],
232 vec![String::from("south")],
233 ],
234 )
235 .unwrap();
236
237 assert_eq!(prediction[0], 0.9980003729960197);
238 assert_eq!(prediction[1], 0.00249414628534181);
239 assert_eq!(prediction[2], -0.0013677527881450977);
240 }
241
242 #[test]
243 fn get_model_stats() {
244 let model = Model::load("model.bin").unwrap();
245
246 assert_eq!(model.get_cat_features_count(), 1);
247 assert_eq!(model.get_float_features_count(), 3);
248 assert_eq!(model.get_tree_count(), 1000);
249 assert_eq!(model.get_dimensions_count(), 1);
250 }
251}