use crate::{
Error, Parameters, Result,
mat::AsMat,
to_result,
utils::{get_strings, path_to_cstring, to_cstring},
};
use lgbm_sys::{
C_API_DTYPE_FLOAT32, C_API_DTYPE_FLOAT64, C_API_DTYPE_INT32, C_API_DTYPE_INT64, DatasetHandle,
LGBM_DatasetCreateFromFile, LGBM_DatasetCreateFromMat, LGBM_DatasetCreateFromMats,
LGBM_DatasetDumpText, LGBM_DatasetFree, LGBM_DatasetGetFeatureNames, LGBM_DatasetGetField,
LGBM_DatasetGetNumData, LGBM_DatasetGetNumFeature, LGBM_DatasetSetFeatureNames,
LGBM_DatasetSetField,
};
use std::{
marker::PhantomData,
os::raw::{c_char, c_int, c_void},
path::Path,
ptr::{null, null_mut},
slice,
};
pub trait Data: Sized {
const DATA_TYPE: c_int;
fn as_data_ptr(data: *const Self) -> *const c_void;
}
impl Data for f32 {
const DATA_TYPE: c_int = C_API_DTYPE_FLOAT32 as c_int;
fn as_data_ptr(data: *const Self) -> *const c_void {
data as *const c_void
}
}
impl Data for f64 {
const DATA_TYPE: c_int = C_API_DTYPE_FLOAT64 as c_int;
fn as_data_ptr(data: *const Self) -> *const c_void {
data as *const c_void
}
}
impl Data for i32 {
const DATA_TYPE: c_int = C_API_DTYPE_INT32 as c_int;
fn as_data_ptr(data: *const Self) -> *const c_void {
data as *const c_void
}
}
impl Data for i64 {
const DATA_TYPE: c_int = C_API_DTYPE_INT64 as c_int;
fn as_data_ptr(data: *const Self) -> *const c_void {
data as *const c_void
}
}
pub trait FeatureData: Data {}
impl FeatureData for f32 {}
impl FeatureData for f64 {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Field<T> {
name: &'static [u8],
_type: PhantomData<T>,
}
impl<T> Field<T> {
const fn new(name: &'static [u8]) -> Self {
assert!(matches!(name.last(), Some(&0)));
Self {
name,
_type: PhantomData,
}
}
fn name_ptr(&self) -> *const c_char {
self.name.as_ptr() as *const c_char
}
}
impl Field<f32> {
pub const LABEL: Self = Self::new(b"label\0");
pub const WEIGHT: Self = Self::new(b"weight\0");
}
impl Field<f64> {
pub const INIT_SCORE: Self = Self::new(b"init_score\0");
}
impl Field<i32> {
pub const GROUP: Self = Self::new(b"group\0");
}
pub struct Dataset(pub(crate) DatasetHandle);
impl Dataset {
#[doc(alias = "LGBM_DatasetCreateFromFile")]
pub fn from_file(
filename: &Path,
reference: Option<&Dataset>,
parameters: &Parameters,
) -> Result<Self> {
let mut handle = null_mut();
unsafe {
to_result(LGBM_DatasetCreateFromFile(
path_to_cstring(filename)?.as_ptr(),
parameters.to_cstring()?.as_ptr(),
to_dataset_handle(reference),
&mut handle,
))?;
}
Ok(Self(handle))
}
#[doc(alias = "LGBM_DatasetCreateFromMat")]
pub fn from_mat<T: FeatureData>(
mat: impl AsMat<T>,
reference: Option<&Dataset>,
parameters: &Parameters,
) -> Result<Self> {
let mat = mat.as_mat();
let mut handle = null_mut();
unsafe {
to_result(LGBM_DatasetCreateFromMat(
mat.as_data_ptr(),
T::DATA_TYPE,
mat.nrow().try_into()?,
mat.ncol().try_into()?,
mat.is_row_major(),
parameters.to_cstring()?.as_ptr(),
to_dataset_handle(reference),
&mut handle,
))?;
}
Ok(Self(handle))
}
#[doc(alias = "LGBM_DatasetCreateFromMats")]
pub fn from_mats<M: AsMat<T>, T: FeatureData>(
mats: impl IntoIterator<Item = M>,
reference: Option<&Dataset>,
parameters: &Parameters,
) -> Result<Self> {
let as_mats = mats.into_iter().collect::<Vec<_>>();
let mats = as_mats.iter().map(|x| x.as_mat()).collect::<Vec<_>>();
if mats.is_empty() {
return Err(Error::from_message("mats must not be empty"));
}
let ncol = mats[0].ncol();
let mut is_row_major = Vec::new();
let mut nrows: Vec<i32> = Vec::with_capacity(mats.len());
let mut mat_ptrs = Vec::with_capacity(mats.len());
for mat in &mats {
if mat.ncol() != ncol {
return Err(Error::from_message(
"mats must have the same number of columns",
));
}
is_row_major.push(mat.is_row_major());
nrows.push(mat.nrow().try_into()?);
mat_ptrs.push(mat.as_data_ptr());
}
let mut handle = null_mut();
unsafe {
to_result(LGBM_DatasetCreateFromMats(
mats.len().try_into()?,
mat_ptrs.as_mut_ptr(),
T::DATA_TYPE,
nrows.as_mut_ptr(),
ncol.try_into()?,
is_row_major.as_mut_ptr(),
parameters.to_cstring()?.as_ptr(),
to_dataset_handle(reference),
&mut handle,
))?;
}
Ok(Self(handle))
}
#[doc(alias = "LGBM_DatasetSetField")]
pub fn set_field<T: Data>(&mut self, field: Field<T>, data: &[T]) -> Result<()> {
unsafe {
to_result(LGBM_DatasetSetField(
self.0,
field.name_ptr(),
data.as_ptr() as *const c_void,
data.len().try_into()?,
T::DATA_TYPE,
))
}
}
#[doc(alias = "LGBM_DatasetGetField")]
pub fn get_field<T: Data>(&self, field: Field<T>) -> Result<&[T]> {
unsafe {
let mut out_len = 0;
let mut out_ptr = null();
let mut out_type = 0;
to_result(LGBM_DatasetGetField(
self.0,
field.name_ptr(),
&mut out_len,
&mut out_ptr,
&mut out_type,
))?;
if out_type != T::DATA_TYPE {
return Err(Error::from_message("element type mismatch"));
}
Ok(slice::from_raw_parts(out_ptr as *const T, out_len as usize))
}
}
#[doc(alias = "LGBM_DatasetGetNumFeature")]
pub fn get_num_feature(&self) -> Result<usize> {
let mut num_feature = 0;
unsafe {
to_result(LGBM_DatasetGetNumFeature(self.0, &mut num_feature))?;
}
Ok(num_feature as usize)
}
#[doc(alias = "LGBM_DatasetGetNumData")]
pub fn get_num_data(&self) -> Result<usize> {
let mut num_data = 0;
unsafe {
to_result(LGBM_DatasetGetNumData(self.0, &mut num_data))?;
}
Ok(num_data as usize)
}
#[doc(alias = "LGBM_DatasetDumpText")]
pub fn dump_text(&self, path: &Path) -> Result<()> {
unsafe {
to_result(LGBM_DatasetDumpText(
self.0,
path_to_cstring(path)?.as_ptr(),
))
}
}
#[doc(alias = "LGBM_DatasetSetFeatureNames")]
pub fn set_feature_names<T: AsRef<str>>(
&mut self,
names: impl IntoIterator<Item = T>,
) -> Result<()> {
let mut cstr_names = Vec::new();
for name in names {
cstr_names.push(to_cstring(name.as_ref())?);
}
let mut pcstr_names = cstr_names.iter().map(|x| x.as_ptr()).collect::<Vec<_>>();
unsafe {
to_result(LGBM_DatasetSetFeatureNames(
self.0,
pcstr_names.as_mut_ptr(),
pcstr_names.len().try_into()?,
))
}
}
#[doc(alias = "LGBM_DatasetGetFeatureNames")]
pub fn get_feature_names(&self) -> Result<Vec<String>> {
get_strings(
|len, out_len, buffer_len, out_buffer_len, out_strs| unsafe {
LGBM_DatasetGetFeatureNames(
self.0,
len,
out_len,
buffer_len,
out_buffer_len,
out_strs,
)
},
)
}
}
impl Drop for Dataset {
fn drop(&mut self) {
unsafe {
to_result(LGBM_DatasetFree(self.0)).unwrap();
}
}
}
unsafe impl Send for Dataset {}
unsafe impl Sync for Dataset {}
fn to_dataset_handle(dataset: Option<&Dataset>) -> DatasetHandle {
if let Some(dataset) = dataset {
dataset.0
} else {
null_mut()
}
}