use crate::error::{XGBoostError, XGBoostResult};
use crate::sys;
use std::ffi::CString;
use std::path::Path;
use std::ptr;
pub struct Booster {
handle: sys::BoosterHandle,
}
#[cfg(xgboost_thread_safe)]
unsafe impl Send for Booster {}
#[cfg(xgboost_thread_safe)]
unsafe impl Sync for Booster {}
impl Booster {
pub fn load<P: AsRef<Path>>(path: P) -> XGBoostResult<Self> {
let path_str = path.as_ref().to_str().ok_or_else(|| XGBoostError {
description: "Path contains invalid UTF-8 characters".to_string(),
})?;
let path_c_str = CString::new(path_str).map_err(|e| XGBoostError {
description: format!("Path contains NUL byte: {}", e),
})?;
let mut handle: sys::BoosterHandle = ptr::null_mut();
XGBoostError::check_return_value(unsafe {
sys::XGBoosterCreate(ptr::null(), 0, &mut handle)
})?;
let result = XGBoostError::check_return_value(unsafe {
sys::XGBoosterLoadModel(handle, path_c_str.as_ptr())
});
if let Err(e) = result {
unsafe {
sys::XGBoosterFree(handle);
}
return Err(e);
}
Ok(Booster { handle })
}
pub fn load_from_buffer(buffer: &[u8]) -> XGBoostResult<Self> {
let mut handle: sys::BoosterHandle = ptr::null_mut();
XGBoostError::check_return_value(unsafe {
sys::XGBoosterCreate(ptr::null(), 0, &mut handle)
})?;
let result = XGBoostError::check_return_value(unsafe {
sys::XGBoosterLoadModelFromBuffer(
handle,
buffer.as_ptr() as *const std::os::raw::c_void,
buffer.len() as u64,
)
});
if let Err(e) = result {
unsafe {
sys::XGBoosterFree(handle);
}
return Err(e);
}
Ok(Booster { handle })
}
pub fn predict(
&self,
data: &[f32],
num_rows: usize,
num_features: usize,
option_mask: u32,
training: bool,
) -> XGBoostResult<Vec<f32>> {
let expected_len = num_rows
.checked_mul(num_features)
.ok_or_else(|| XGBoostError {
description: format!(
"Integer overflow: num_rows ({}) * num_features ({}) exceeds usize::MAX",
num_rows, num_features
),
})?;
if data.len() != expected_len {
return Err(XGBoostError {
description: format!(
"Data length mismatch: expected {} elements ({}×{}), got {}",
expected_len,
num_rows,
num_features,
data.len()
),
});
}
let mut dmatrix_handle: sys::DMatrixHandle = ptr::null_mut();
XGBoostError::check_return_value(unsafe {
sys::XGDMatrixCreateFromMat(
data.as_ptr(),
num_rows as u64,
num_features as u64,
f32::NAN,
&mut dmatrix_handle,
)
})?;
struct DMatrixGuard(sys::DMatrixHandle);
impl Drop for DMatrixGuard {
fn drop(&mut self) {
unsafe {
sys::XGDMatrixFree(self.0);
}
}
}
let _guard = DMatrixGuard(dmatrix_handle);
let mut out_len: u64 = 0;
let mut out_result: *const f32 = ptr::null();
XGBoostError::check_return_value(unsafe {
sys::XGBoosterPredict(
self.handle,
dmatrix_handle,
option_mask as i32,
0, training as i32,
&mut out_len,
&mut out_result,
)
})?;
if out_result.is_null() || out_len == 0 {
return Err(XGBoostError {
description: "XGBoost returned null or empty prediction result".to_string(),
});
}
let results = unsafe { std::slice::from_raw_parts(out_result, out_len as usize).to_vec() };
Ok(results)
}
pub fn num_features(&self) -> XGBoostResult<usize> {
let mut out_num_features: u64 = 0;
XGBoostError::check_return_value(unsafe {
sys::XGBoosterGetNumFeature(self.handle, &mut out_num_features)
})?;
Ok(out_num_features as usize)
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> XGBoostResult<()> {
let path_str = path.as_ref().to_str().ok_or_else(|| XGBoostError {
description: "Path contains invalid UTF-8 characters".to_string(),
})?;
let path_c_str = CString::new(path_str).map_err(|e| XGBoostError {
description: format!("Path contains NUL byte: {}", e),
})?;
XGBoostError::check_return_value(unsafe {
sys::XGBoosterSaveModel(self.handle, path_c_str.as_ptr())
})
}
}
impl Drop for Booster {
fn drop(&mut self) {
unsafe {
sys::XGBoosterFree(self.handle);
}
}
}