catboost_sys/
lib.rs

1#![allow(non_upper_case_globals)]
2#![allow(non_camel_case_types)]
3#![allow(non_snake_case)]
4
5include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
6
7#[cfg(test)]
8mod tests {
9    use super::*;
10    use std::ffi::CStr;
11
12    #[test]
13    fn it_works() {
14        let model_handle = unsafe { ModelCalcerCreate() };
15        let ret_val = unsafe {
16            LoadFullModelFromFile(
17                model_handle,
18                std::ffi::CString::new("files/model.bin").unwrap().as_ptr(),
19            )
20        };
21        if !ret_val {
22            let c_str = unsafe { CStr::from_ptr(GetErrorString()) };
23            let str_slice = c_str.to_str().unwrap();
24            panic!(str_slice);
25        }
26
27        let tree_count = unsafe { GetTreeCount(model_handle) };
28        assert_eq!(tree_count, 1000);
29
30        let float_features_count = unsafe { GetFloatFeaturesCount(model_handle) };
31        assert_eq!(float_features_count, 3);
32        let cat_features_count = unsafe { GetCatFeaturesCount(model_handle) };
33        assert_eq!(cat_features_count, 1);
34
35        unsafe { ModelCalcerDelete(model_handle) };
36    }
37
38    #[test]
39    fn it_works_buffer() {
40        let buffer = read_fast("files/model.bin").unwrap();
41        let model_handle = unsafe { ModelCalcerCreate() };
42        let ret_val = unsafe {
43            LoadFullModelFromBuffer(
44                model_handle,
45                buffer.as_ptr() as *const std::os::raw::c_void,
46                buffer.len(),
47            )
48        };
49        if !ret_val {
50            let c_str = unsafe { CStr::from_ptr(GetErrorString()) };
51            let str_slice = c_str.to_str().unwrap();
52            panic!(str_slice);
53        }
54
55        let tree_count = unsafe { GetTreeCount(model_handle) };
56        assert_eq!(tree_count, 1000);
57
58        let float_features_count = unsafe { GetFloatFeaturesCount(model_handle) };
59        assert_eq!(float_features_count, 3);
60        let cat_features_count = unsafe { GetCatFeaturesCount(model_handle) };
61        assert_eq!(cat_features_count, 1);
62
63        unsafe { ModelCalcerDelete(model_handle) };
64    }
65    use std::io::Read;
66    fn read_fast<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<Vec<u8>> {
67        let mut file = std::fs::File::open(path)?;
68        let meta = file.metadata()?;
69        let size = meta.len() as usize;
70        let mut data = Vec::with_capacity(size);
71        data.resize(size, 0);
72        file.read_exact(&mut data)?;
73        Ok(data)
74    }
75}