use crate::error::{LightGBMError, LightGBMResult};
use crate::sys;
use std::ffi::CString;
use std::path::Path;
use std::ptr;
pub struct Booster {
handle: sys::BoosterHandle,
}
impl Booster {
pub fn load<P: AsRef<Path>>(path: P) -> LightGBMResult<Self> {
let path_str = path.as_ref().to_str()
.ok_or_else(|| LightGBMError {
description: "Path contains invalid UTF-8 characters".to_string(),
})?;
let path_c_str = CString::new(path_str)
.map_err(|e| LightGBMError {
description: format!("Path contains NUL byte: {}", e),
})?;
let mut handle: sys::BoosterHandle = ptr::null_mut();
let mut num_iterations = 0i32;
LightGBMError::check_return_value(unsafe {
sys::LGBM_BoosterCreateFromModelfile(
path_c_str.as_ptr(),
&mut num_iterations,
&mut handle,
)
})?;
Ok(Booster { handle })
}
pub fn load_from_string(model_str: &str) -> LightGBMResult<Self> {
let model_c_str = CString::new(model_str)
.map_err(|e| LightGBMError {
description: format!("Model string contains NUL byte: {}", e),
})?;
let mut handle: sys::BoosterHandle = ptr::null_mut();
let mut num_iterations = 0i32;
LightGBMError::check_return_value(unsafe {
sys::LGBM_BoosterLoadModelFromString(
model_c_str.as_ptr(),
&mut num_iterations,
&mut handle,
)
})?;
Ok(Booster { handle })
}
pub fn load_from_buffer(buffer: &[u8]) -> LightGBMResult<Self> {
let model_str = std::str::from_utf8(buffer)
.map_err(|e| LightGBMError {
description: format!("Invalid UTF-8 in model buffer: {}", e),
})?;
Self::load_from_string(model_str)
}
pub fn num_features(&self) -> LightGBMResult<i32> {
let mut num_features = 0i32;
LightGBMError::check_return_value(unsafe {
sys::LGBM_BoosterGetNumFeature(self.handle, &mut num_features)
})?;
Ok(num_features)
}
pub fn num_classes(&self) -> LightGBMResult<i32> {
let mut num_classes = 0i32;
LightGBMError::check_return_value(unsafe {
sys::LGBM_BoosterGetNumClasses(self.handle, &mut num_classes)
})?;
Ok(num_classes)
}
pub fn predict(
&self,
data: &[f64],
num_rows: i32,
num_cols: i32,
predict_type: i32,
) -> LightGBMResult<Vec<f64>> {
let expected_len = (num_rows as usize).checked_mul(num_cols as usize)
.ok_or_else(|| LightGBMError {
description: format!(
"Integer overflow when computing expected data size: num_rows ({}) * num_cols ({})",
num_rows, num_cols
),
})?;
if expected_len != data.len() {
return Err(LightGBMError {
description: format!(
"Input data size mismatch: expected {} elements ({}×{}), got {}",
expected_len, num_rows, num_cols, data.len()
),
});
}
let mut out_len = 0i64;
LightGBMError::check_return_value(unsafe {
sys::LGBM_BoosterPredictForMat(
self.handle,
data.as_ptr() as *const std::os::raw::c_void,
sys::C_API_DTYPE_FLOAT64 as i32,
num_rows,
num_cols,
1, predict_type,
0, -1, ptr::null(),
&mut out_len,
ptr::null_mut(),
)
})?;
let mut out_result = vec![0.0f64; out_len as usize];
LightGBMError::check_return_value(unsafe {
sys::LGBM_BoosterPredictForMat(
self.handle,
data.as_ptr() as *const std::os::raw::c_void,
sys::C_API_DTYPE_FLOAT64 as i32,
num_rows,
num_cols,
1, predict_type,
0, -1, ptr::null(),
&mut out_len,
out_result.as_mut_ptr(),
)
})?;
Ok(out_result)
}
pub fn predict_f32(
&self,
data: &[f32],
num_rows: i32,
num_cols: i32,
predict_type: i32,
) -> LightGBMResult<Vec<f64>> {
let expected_len = (num_rows as usize).checked_mul(num_cols as usize)
.ok_or_else(|| LightGBMError {
description: format!(
"Integer overflow when computing expected data size: num_rows ({}) * num_cols ({})",
num_rows, num_cols
),
})?;
if expected_len != data.len() {
return Err(LightGBMError {
description: format!(
"Input data size mismatch: expected {} elements ({}×{}), got {}",
expected_len, num_rows, num_cols, data.len()
),
});
}
let mut out_len = 0i64;
LightGBMError::check_return_value(unsafe {
sys::LGBM_BoosterPredictForMat(
self.handle,
data.as_ptr() as *const std::os::raw::c_void,
sys::C_API_DTYPE_FLOAT32 as i32,
num_rows,
num_cols,
1, predict_type,
0, -1, ptr::null(),
&mut out_len,
ptr::null_mut(),
)
})?;
let mut out_result = vec![0.0f64; out_len as usize];
LightGBMError::check_return_value(unsafe {
sys::LGBM_BoosterPredictForMat(
self.handle,
data.as_ptr() as *const std::os::raw::c_void,
sys::C_API_DTYPE_FLOAT32 as i32,
num_rows,
num_cols,
1, predict_type,
0, -1, ptr::null(),
&mut out_len,
out_result.as_mut_ptr(),
)
})?;
Ok(out_result)
}
}
impl Drop for Booster {
fn drop(&mut self) {
unsafe {
sys::LGBM_BoosterFree(self.handle);
}
}
}