use numpy::{PyArray1, PyReadonlyArray1, PyReadonlyArray2};
use pyo3::prelude::*;
use crate::dataset::{BinnedDataset, FeatureInfo, FeatureType};
#[pyclass(name = "FeatureType", eq)]
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct PyFeatureType {
inner: FeatureType,
}
#[pymethods]
impl PyFeatureType {
#[staticmethod]
fn numeric() -> Self {
Self {
inner: FeatureType::Numeric,
}
}
#[staticmethod]
fn categorical() -> Self {
Self {
inner: FeatureType::Categorical,
}
}
#[getter]
fn is_numeric(&self) -> bool {
matches!(self.inner, FeatureType::Numeric)
}
#[getter]
fn is_categorical(&self) -> bool {
matches!(self.inner, FeatureType::Categorical)
}
fn __repr__(&self) -> &'static str {
match self.inner {
FeatureType::Numeric => "FeatureType.numeric()",
FeatureType::Categorical => "FeatureType.categorical()",
}
}
}
impl From<FeatureType> for PyFeatureType {
fn from(ft: FeatureType) -> Self {
Self { inner: ft }
}
}
impl From<PyFeatureType> for FeatureType {
fn from(pft: PyFeatureType) -> Self {
pft.inner
}
}
#[pyclass(name = "FeatureInfo")]
#[derive(Clone)]
pub struct PyFeatureInfo {
inner: FeatureInfo,
}
#[pymethods]
impl PyFeatureInfo {
#[getter]
fn name(&self) -> &str {
&self.inner.name
}
#[getter]
fn feature_type(&self) -> PyFeatureType {
self.inner.feature_type.into()
}
#[getter]
fn num_bins(&self) -> u8 {
self.inner.num_bins
}
#[getter]
fn bin_boundaries<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f64>> {
PyArray1::from_slice(py, &self.inner.bin_boundaries)
}
fn __repr__(&self) -> String {
format!(
"FeatureInfo(name='{}', type={:?}, num_bins={})",
self.inner.name, self.inner.feature_type, self.inner.num_bins
)
}
}
impl From<FeatureInfo> for PyFeatureInfo {
fn from(fi: FeatureInfo) -> Self {
Self { inner: fi }
}
}
impl From<&FeatureInfo> for PyFeatureInfo {
fn from(fi: &FeatureInfo) -> Self {
Self { inner: fi.clone() }
}
}
impl From<PyFeatureInfo> for FeatureInfo {
fn from(pfi: PyFeatureInfo) -> Self {
pfi.inner
}
}
#[pyclass(name = "BinnedDataset")]
pub struct PyBinnedDataset {
pub(crate) inner: BinnedDataset,
}
#[pymethods]
impl PyBinnedDataset {
#[staticmethod]
#[pyo3(signature = (features, targets, feature_info))]
fn from_arrays<'py>(
features: PyReadonlyArray2<'py, u8>,
targets: PyReadonlyArray1<'py, f32>,
feature_info: Vec<PyFeatureInfo>,
) -> PyResult<Self> {
let features_arr = features.as_array();
let targets_arr = targets.as_array();
let num_features = features_arr.nrows();
let num_rows = features_arr.ncols();
if targets_arr.len() != num_rows {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"targets length {} doesn't match num_rows {}",
targets_arr.len(),
num_rows
)));
}
if feature_info.len() != num_features {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"feature_info length {} doesn't match num_features {}",
feature_info.len(),
num_features
)));
}
let mut features_flat = Vec::with_capacity(num_rows * num_features);
for f in 0..num_features {
for r in 0..num_rows {
features_flat.push(features_arr[[f, r]]);
}
}
let targets_vec: Vec<f32> = targets_arr.to_vec();
let info_vec: Vec<FeatureInfo> = feature_info.into_iter().map(|fi| fi.inner).collect();
Ok(Self {
inner: BinnedDataset::new(num_rows, features_flat, targets_vec, info_vec),
})
}
#[getter]
fn num_rows(&self) -> usize {
self.inner.num_rows()
}
#[getter]
fn num_features(&self) -> usize {
self.inner.num_features()
}
#[getter]
fn targets<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f32>> {
PyArray1::from_slice(py, self.inner.targets())
}
#[getter]
fn feature_info(&self) -> Vec<PyFeatureInfo> {
self.inner
.all_feature_info()
.iter()
.map(|fi| fi.into())
.collect()
}
fn feature_info_at(&self, feature_idx: usize) -> PyResult<PyFeatureInfo> {
if feature_idx >= self.inner.num_features() {
return Err(pyo3::exceptions::PyIndexError::new_err(format!(
"feature index {} out of range (num_features={})",
feature_idx,
self.inner.num_features()
)));
}
Ok(self.inner.feature_info(feature_idx).into())
}
fn feature_column<'py>(
&self,
py: Python<'py>,
feature_idx: usize,
) -> PyResult<Bound<'py, PyArray1<u8>>> {
if feature_idx >= self.inner.num_features() {
return Err(pyo3::exceptions::PyIndexError::new_err(format!(
"feature index {} out of range (num_features={})",
feature_idx,
self.inner.num_features()
)));
}
Ok(PyArray1::from_slice(
py,
self.inner.feature_column(feature_idx),
))
}
fn get_bin(&self, row_idx: usize, feature_idx: usize) -> PyResult<u8> {
if row_idx >= self.inner.num_rows() {
return Err(pyo3::exceptions::PyIndexError::new_err(format!(
"row index {} out of range (num_rows={})",
row_idx,
self.inner.num_rows()
)));
}
if feature_idx >= self.inner.num_features() {
return Err(pyo3::exceptions::PyIndexError::new_err(format!(
"feature index {} out of range (num_features={})",
feature_idx,
self.inner.num_features()
)));
}
Ok(self.inner.get_bin(row_idx, feature_idx))
}
fn is_sparse(&self, feature_idx: usize) -> PyResult<bool> {
if feature_idx >= self.inner.num_features() {
return Err(pyo3::exceptions::PyIndexError::new_err(format!(
"feature index {} out of range",
feature_idx
)));
}
Ok(self.inner.is_sparse(feature_idx))
}
#[getter]
fn num_sparse_features(&self) -> usize {
self.inner.num_sparse_features()
}
#[getter]
fn max_bins(&self) -> u8 {
self.inner.max_bins()
}
#[getter]
fn supports_4bit(&self) -> bool {
self.inner.supports_4bit()
}
#[getter]
fn has_eras(&self) -> bool {
self.inner.has_eras()
}
#[getter]
fn num_eras(&self) -> usize {
self.inner.num_eras()
}
fn __repr__(&self) -> String {
format!(
"BinnedDataset(num_rows={}, num_features={}, max_bins={}, sparse_features={})",
self.inner.num_rows(),
self.inner.num_features(),
self.inner.max_bins(),
self.inner.num_sparse_features()
)
}
fn __len__(&self) -> usize {
self.inner.num_rows()
}
}
impl From<BinnedDataset> for PyBinnedDataset {
fn from(ds: BinnedDataset) -> Self {
Self { inner: ds }
}
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyFeatureType>()?;
m.add_class::<PyFeatureInfo>()?;
m.add_class::<PyBinnedDataset>()?;
Ok(())
}