use std;
use std::ffi::CStr;
use std::fmt::{self, Display};
use std::error::Error;
use xgboost_sys;
pub type XGBResult<T> = std::result::Result<T, XGBError>;
#[derive(Debug, Eq, PartialEq)]
pub struct XGBError {
desc: String,
}
impl XGBError {
pub(crate) fn new<S: Into<String>>(desc: S) -> Self {
XGBError { desc: desc.into() }
}
pub(crate) fn check_return_value(ret_val: i32) -> XGBResult<()> {
match ret_val {
0 => Ok(()),
-1 => Err(XGBError::from_xgboost()),
_ => panic!(format!("unexpected return value '{}', expected 0 or -1", ret_val)),
}
}
fn from_xgboost() -> Self {
let c_str = unsafe { CStr::from_ptr(xgboost_sys::XGBGetLastError()) };
let str_slice = c_str.to_str().unwrap();
XGBError { desc: str_slice.to_owned() }
}
}
impl Error for XGBError {}
impl Display for XGBError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "XGBoost error: {}", &self.desc)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn return_value_handling() {
let result = XGBError::check_return_value(0);
assert_eq!(result, Ok(()));
let result = XGBError::check_return_value(-1);
assert_eq!(result, Err(XGBError { desc: "".to_owned() }));
}
}