use numpy::{PyArray1, PyReadonlyArray1, PyReadonlyArray2};
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;
use crate::backend::{BackendType, GpuMode};
use crate::booster::{GBDTConfig, GBDTModel, GbdtPreset, LossType};
use crate::serialize;
use crate::tree::MonotonicConstraint;
use crate::tuner::ModelFormat;
fn extract_features_array<'py>(features: &Bound<'py, PyAny>) -> PyResult<(usize, Vec<f64>)> {
if let Ok(arr) = features.extract::<PyReadonlyArray2<'py, f64>>() {
let arr = arr.as_array();
let num_rows = arr.nrows();
let num_cols = arr.ncols();
let mut raw = Vec::with_capacity(num_rows * num_cols);
for row in arr.rows() {
raw.extend(row.iter().copied());
}
Ok((num_cols, raw))
} else if let Ok(arr) = features.extract::<PyReadonlyArray2<'py, f32>>() {
let arr = arr.as_array();
let num_rows = arr.nrows();
let num_cols = arr.ncols();
let mut raw = Vec::with_capacity(num_rows * num_cols);
for row in arr.rows() {
raw.extend(row.iter().map(|&v| v as f64));
}
Ok((num_cols, raw))
} else {
Err(PyValueError::new_err(
"features must be a 2D numpy array of float32 or float64",
))
}
}
fn validate_feature_count(actual: usize, expected: usize) -> PyResult<()> {
if actual != expected {
return Err(PyValueError::new_err(format!(
"Expected {} features, got {}",
expected, actual
)));
}
Ok(())
}
fn parse_model_format(format: &str) -> PyResult<ModelFormat> {
match format.to_lowercase().as_str() {
"rkyv" => Ok(ModelFormat::Rkyv),
"bincode" | "bin" => Ok(ModelFormat::Bincode),
_ => Err(PyValueError::new_err("format must be 'rkyv' or 'bincode'")),
}
}
#[pyclass(name = "MonotonicConstraint", eq)]
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct PyMonotonicConstraint {
inner: MonotonicConstraint,
}
#[pymethods]
impl PyMonotonicConstraint {
#[staticmethod]
fn increasing() -> Self {
Self {
inner: MonotonicConstraint::Increasing,
}
}
#[staticmethod]
fn decreasing() -> Self {
Self {
inner: MonotonicConstraint::Decreasing,
}
}
#[staticmethod]
fn none() -> Self {
Self {
inner: MonotonicConstraint::None,
}
}
#[getter]
fn is_increasing(&self) -> bool {
matches!(self.inner, MonotonicConstraint::Increasing)
}
#[getter]
fn is_decreasing(&self) -> bool {
matches!(self.inner, MonotonicConstraint::Decreasing)
}
#[getter]
fn is_none(&self) -> bool {
matches!(self.inner, MonotonicConstraint::None)
}
fn to_int(&self) -> i32 {
match self.inner {
MonotonicConstraint::Increasing => 1,
MonotonicConstraint::Decreasing => -1,
MonotonicConstraint::None => 0,
}
}
fn __repr__(&self) -> &'static str {
match self.inner {
MonotonicConstraint::Increasing => "MonotonicConstraint.increasing()",
MonotonicConstraint::Decreasing => "MonotonicConstraint.decreasing()",
MonotonicConstraint::None => "MonotonicConstraint.none()",
}
}
}
impl From<MonotonicConstraint> for PyMonotonicConstraint {
fn from(mc: MonotonicConstraint) -> Self {
Self { inner: mc }
}
}
impl From<PyMonotonicConstraint> for MonotonicConstraint {
fn from(pmc: PyMonotonicConstraint) -> Self {
pmc.inner
}
}
#[pyclass(name = "GBDTConfig")]
#[derive(Clone)]
pub struct PyGBDTConfig {
inner: GBDTConfig,
}
#[pymethods]
impl PyGBDTConfig {
#[new]
fn new() -> Self {
Self {
inner: GBDTConfig::default(),
}
}
#[staticmethod]
fn preset(preset: &str) -> PyResult<Self> {
let preset = match preset.to_lowercase().as_str() {
"standard" => GbdtPreset::Standard,
"speed" => GbdtPreset::Speed,
"accuracy" => GbdtPreset::Accuracy,
"smalldata" | "small_data" => GbdtPreset::SmallData,
"largedata" | "large_data" => GbdtPreset::LargeData,
"conformal" => GbdtPreset::Conformal,
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"unknown preset (use: standard, speed, accuracy, smalldata, largedata, conformal)",
));
}
};
Ok(Self {
inner: GBDTConfig::default().with_preset(preset),
})
}
fn with_preset(&self, preset: &str) -> PyResult<Self> {
let mut inner = self.inner.clone();
inner = match preset.to_lowercase().as_str() {
"standard" => inner.with_preset(GbdtPreset::Standard),
"speed" => inner.with_preset(GbdtPreset::Speed),
"accuracy" => inner.with_preset(GbdtPreset::Accuracy),
"smalldata" | "small_data" => inner.with_preset(GbdtPreset::SmallData),
"largedata" | "large_data" => inner.with_preset(GbdtPreset::LargeData),
"conformal" => inner.with_preset(GbdtPreset::Conformal),
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"unknown preset (use: standard, speed, accuracy, smalldata, largedata, conformal)",
));
}
};
Ok(Self { inner })
}
#[getter]
fn num_rounds(&self) -> usize {
self.inner.num_rounds
}
#[setter]
fn set_num_rounds(&mut self, value: usize) {
self.inner.num_rounds = value;
}
#[getter]
fn learning_rate(&self) -> f32 {
self.inner.learning_rate
}
#[setter]
fn set_learning_rate(&mut self, value: f32) {
self.inner.learning_rate = value;
}
#[getter]
fn max_depth(&self) -> usize {
self.inner.max_depth
}
#[setter]
fn set_max_depth(&mut self, value: usize) {
self.inner.max_depth = value;
}
#[getter]
fn max_leaves(&self) -> usize {
self.inner.max_leaves
}
#[setter]
fn set_max_leaves(&mut self, value: usize) {
self.inner.max_leaves = value;
}
#[getter]
fn min_samples_leaf(&self) -> usize {
self.inner.min_samples_leaf
}
#[setter]
fn set_min_samples_leaf(&mut self, value: usize) {
self.inner.min_samples_leaf = value;
}
#[getter]
fn min_hessian_leaf(&self) -> f32 {
self.inner.min_hessian_leaf
}
#[setter]
fn set_min_hessian_leaf(&mut self, value: f32) {
self.inner.min_hessian_leaf = value;
}
#[getter]
fn min_gain(&self) -> f32 {
self.inner.min_gain
}
#[setter]
fn set_min_gain(&mut self, value: f32) {
self.inner.min_gain = value;
}
#[getter]
fn lambda_(&self) -> f32 {
self.inner.lambda
}
#[setter]
fn set_lambda(&mut self, value: f32) {
self.inner.lambda = value;
}
#[getter]
fn entropy_weight(&self) -> f32 {
self.inner.entropy_weight
}
#[setter]
fn set_entropy_weight(&mut self, value: f32) {
self.inner.entropy_weight = value;
}
fn use_mse_loss(&mut self) {
self.inner.loss_type = LossType::Mse;
}
fn use_pseudo_huber_loss(&mut self, delta: f32) {
self.inner.loss_type = LossType::PseudoHuber { delta };
}
fn use_binary_logloss(&mut self) {
self.inner.loss_type = LossType::BinaryLogLoss;
}
fn use_multiclass_logloss(&mut self, num_classes: usize) -> PyResult<()> {
if num_classes < 2 {
return Err(PyValueError::new_err("num_classes must be >= 2"));
}
self.inner.loss_type = LossType::MultiClassLogLoss { num_classes };
Ok(())
}
#[getter]
fn subsample(&self) -> f32 {
self.inner.subsample
}
#[setter]
fn set_subsample(&mut self, value: f32) {
self.inner.subsample = value;
}
#[getter]
fn colsample(&self) -> f32 {
self.inner.colsample
}
#[setter]
fn set_colsample(&mut self, value: f32) {
self.inner.colsample = value;
}
#[getter]
fn num_bins(&self) -> usize {
self.inner.num_bins
}
#[setter]
fn set_num_bins(&mut self, value: usize) {
self.inner.num_bins = value;
}
#[getter]
fn calibration_ratio(&self) -> f32 {
self.inner.calibration_ratio
}
#[setter]
fn set_calibration_ratio(&mut self, value: f32) {
self.inner.calibration_ratio = value;
}
#[getter]
fn conformal_quantile(&self) -> f32 {
self.inner.conformal_quantile
}
#[setter]
fn set_conformal_quantile(&mut self, value: f32) {
self.inner.conformal_quantile = value;
}
#[getter]
fn early_stopping_rounds(&self) -> usize {
self.inner.early_stopping_rounds
}
#[setter]
fn set_early_stopping_rounds(&mut self, value: usize) {
self.inner.early_stopping_rounds = value;
}
#[getter]
fn validation_ratio(&self) -> f32 {
self.inner.validation_ratio
}
#[setter]
fn set_validation_ratio(&mut self, value: f32) {
self.inner.validation_ratio = value;
}
#[getter]
fn parallel_prediction(&self) -> bool {
self.inner.parallel_prediction
}
#[setter]
fn set_parallel_prediction(&mut self, value: bool) {
self.inner.parallel_prediction = value;
}
#[getter]
fn column_reordering(&self) -> bool {
self.inner.column_reordering
}
#[setter]
fn set_column_reordering(&mut self, value: bool) {
self.inner.column_reordering = value;
}
#[getter]
fn packed_dataset(&self) -> bool {
self.inner.packed_dataset
}
#[setter]
fn set_packed_dataset(&mut self, value: bool) {
self.inner.packed_dataset = value;
}
#[getter]
fn parallel_gradient(&self) -> bool {
self.inner.parallel_gradient
}
#[setter]
fn set_parallel_gradient(&mut self, value: bool) {
self.inner.parallel_gradient = value;
}
#[getter]
fn use_gpu_subgroups(&self) -> bool {
self.inner.use_gpu_subgroups
}
#[setter]
fn set_use_gpu_subgroups(&mut self, value: bool) {
self.inner.use_gpu_subgroups = value;
}
#[getter]
fn backend(&self) -> &'static str {
match self.inner.backend_type {
BackendType::Auto => "auto",
BackendType::Scalar => "cpu",
BackendType::Wgpu => "wgpu",
BackendType::Avx512 => "avx512",
BackendType::Sve2 => "sve2",
BackendType::Cuda => "cuda",
BackendType::Rocm => "rocm",
BackendType::Metal => "metal",
}
}
#[setter]
fn set_backend(&mut self, value: &str) -> PyResult<()> {
self.inner.backend_type = match value.to_lowercase().as_str() {
"auto" | "gpu" => BackendType::Auto, "cpu" | "scalar" => BackendType::Scalar,
"wgpu" => BackendType::Wgpu,
"cuda" => BackendType::Cuda,
"rocm" => BackendType::Rocm,
"metal" => BackendType::Metal,
_ => return Err(PyValueError::new_err(
"backend must be one of: 'auto' (or 'gpu'), 'cpu' (or 'scalar'), 'wgpu', 'cuda', 'rocm', 'metal'"
)),
};
Ok(())
}
#[getter]
fn gpu_mode(&self) -> &'static str {
match self.inner.gpu_mode {
GpuMode::Auto => "auto",
GpuMode::Hybrid => "hybrid",
GpuMode::Full => "full",
}
}
#[setter]
fn set_gpu_mode(&mut self, value: &str) -> PyResult<()> {
self.inner.gpu_mode = match value.to_lowercase().as_str() {
"auto" => GpuMode::Auto,
"hybrid" => GpuMode::Hybrid,
"full" => GpuMode::Full,
_ => {
return Err(PyValueError::new_err(
"gpu_mode must be one of: 'auto', 'hybrid', 'full'",
))
}
};
Ok(())
}
fn set_monotonic_constraints(&mut self, constraints: Vec<i32>) -> PyResult<()> {
let parsed: Result<Vec<MonotonicConstraint>, _> = constraints
.into_iter()
.map(|c| match c {
1 => Ok(MonotonicConstraint::Increasing),
-1 => Ok(MonotonicConstraint::Decreasing),
0 => Ok(MonotonicConstraint::None),
_ => Err(PyValueError::new_err(
"Constraint must be 1 (increasing), -1 (decreasing), or 0 (none)",
)),
})
.collect();
self.inner.monotonic_constraints = parsed?;
Ok(())
}
#[getter]
fn era_splitting(&self) -> bool {
self.inner.era_splitting
}
#[setter]
fn set_era_splitting(&mut self, value: bool) {
self.inner.era_splitting = value;
}
fn set_interaction_groups(&mut self, groups: Vec<Vec<usize>>) {
self.inner.interaction_groups = groups;
}
fn __repr__(&self) -> String {
format!(
"GBDTConfig(num_rounds={}, learning_rate={}, max_depth={}, max_leaves={}, backend='{}', gpu_mode='{}')",
self.inner.num_rounds,
self.inner.learning_rate,
self.inner.max_depth,
self.inner.max_leaves,
self.backend(),
self.gpu_mode()
)
}
fn with_num_rounds(&self, value: usize) -> PyResult<Self> {
if value == 0 {
return Err(PyValueError::new_err("num_rounds must be >= 1"));
}
let mut new = self.clone();
new.inner.num_rounds = value;
Ok(new)
}
fn with_learning_rate(&self, value: f32) -> PyResult<Self> {
if value <= 0.0 || value > 1.0 {
return Err(PyValueError::new_err("learning_rate must be in (0.0, 1.0]"));
}
let mut new = self.clone();
new.inner.learning_rate = value;
Ok(new)
}
fn with_max_depth(&self, value: usize) -> PyResult<Self> {
if value == 0 {
return Err(PyValueError::new_err("max_depth must be >= 1"));
}
let mut new = self.clone();
new.inner.max_depth = value;
Ok(new)
}
fn with_max_leaves(&self, value: usize) -> PyResult<Self> {
if value < 2 {
return Err(PyValueError::new_err("max_leaves must be >= 2"));
}
let mut new = self.clone();
new.inner.max_leaves = value;
Ok(new)
}
fn with_min_samples_leaf(&self, value: usize) -> PyResult<Self> {
if value == 0 {
return Err(PyValueError::new_err("min_samples_leaf must be >= 1"));
}
let mut new = self.clone();
new.inner.min_samples_leaf = value;
Ok(new)
}
fn with_min_hessian_leaf(&self, value: f32) -> PyResult<Self> {
if value < 0.0 {
return Err(PyValueError::new_err("min_hessian_leaf must be >= 0.0"));
}
let mut new = self.clone();
new.inner.min_hessian_leaf = value;
Ok(new)
}
fn with_min_gain(&self, value: f32) -> PyResult<Self> {
if value < 0.0 {
return Err(PyValueError::new_err("min_gain must be >= 0.0"));
}
let mut new = self.clone();
new.inner.min_gain = value;
Ok(new)
}
fn with_lambda(&self, value: f32) -> PyResult<Self> {
if value < 0.0 {
return Err(PyValueError::new_err("lambda must be >= 0.0"));
}
let mut new = self.clone();
new.inner.lambda = value;
Ok(new)
}
fn with_entropy_weight(&self, value: f32) -> PyResult<Self> {
if value < 0.0 {
return Err(PyValueError::new_err("entropy_weight must be >= 0.0"));
}
let mut new = self.clone();
new.inner.entropy_weight = value;
Ok(new)
}
fn with_subsample(&self, value: f32) -> PyResult<Self> {
if value <= 0.0 || value > 1.0 {
return Err(PyValueError::new_err("subsample must be in (0.0, 1.0]"));
}
let mut new = self.clone();
new.inner.subsample = value;
Ok(new)
}
fn with_colsample(&self, value: f32) -> PyResult<Self> {
if value <= 0.0 || value > 1.0 {
return Err(PyValueError::new_err("colsample must be in (0.0, 1.0]"));
}
let mut new = self.clone();
new.inner.colsample = value;
Ok(new)
}
fn with_num_bins(&self, value: usize) -> PyResult<Self> {
if value < 2 || value > 255 {
return Err(PyValueError::new_err("num_bins must be in [2, 255]"));
}
let mut new = self.clone();
new.inner.num_bins = value;
Ok(new)
}
fn with_calibration_ratio(&self, value: f32) -> PyResult<Self> {
if value < 0.0 || value >= 1.0 {
return Err(PyValueError::new_err(
"calibration_ratio must be in [0.0, 1.0)",
));
}
let mut new = self.clone();
new.inner.calibration_ratio = value;
Ok(new)
}
fn with_conformal_quantile(&self, value: f32) -> PyResult<Self> {
if value <= 0.0 || value >= 1.0 {
return Err(PyValueError::new_err(
"conformal_quantile must be in (0.0, 1.0)",
));
}
let mut new = self.clone();
new.inner.conformal_quantile = value;
Ok(new)
}
fn with_early_stopping_rounds(&self, value: usize) -> PyResult<Self> {
let mut new = self.clone();
new.inner.early_stopping_rounds = value;
Ok(new)
}
fn with_validation_ratio(&self, value: f32) -> PyResult<Self> {
if value < 0.0 || value >= 1.0 {
return Err(PyValueError::new_err(
"validation_ratio must be in [0.0, 1.0)",
));
}
let mut new = self.clone();
new.inner.validation_ratio = value;
Ok(new)
}
fn with_parallel_prediction(&self, value: bool) -> PyResult<Self> {
let mut new = self.clone();
new.inner.parallel_prediction = value;
Ok(new)
}
fn with_column_reordering(&self, value: bool) -> PyResult<Self> {
let mut new = self.clone();
new.inner.column_reordering = value;
Ok(new)
}
fn with_packed_dataset(&self, value: bool) -> PyResult<Self> {
let mut new = self.clone();
new.inner.packed_dataset = value;
Ok(new)
}
fn with_parallel_gradient(&self, value: bool) -> PyResult<Self> {
let mut new = self.clone();
new.inner.parallel_gradient = value;
Ok(new)
}
fn with_gpu_subgroups(&self, value: bool) -> PyResult<Self> {
let mut new = self.clone();
new.inner.use_gpu_subgroups = value;
Ok(new)
}
fn with_backend(&self, value: &str) -> PyResult<Self> {
let mut new = self.clone();
new.inner.backend_type = match value.to_lowercase().as_str() {
"auto" | "gpu" => BackendType::Auto,
"cpu" | "scalar" => BackendType::Scalar,
"wgpu" => BackendType::Wgpu,
"cuda" => BackendType::Cuda,
"rocm" => BackendType::Rocm,
"metal" => BackendType::Metal,
_ => return Err(PyValueError::new_err(
"backend must be one of: 'auto' (or 'gpu'), 'cpu' (or 'scalar'), 'wgpu', 'cuda', 'rocm', 'metal'"
)),
};
Ok(new)
}
fn with_gpu_mode(&self, value: &str) -> PyResult<Self> {
let mut new = self.clone();
new.inner.gpu_mode = match value.to_lowercase().as_str() {
"auto" => GpuMode::Auto,
"hybrid" => GpuMode::Hybrid,
"full" => GpuMode::Full,
_ => {
return Err(PyValueError::new_err(
"gpu_mode must be one of: 'auto', 'hybrid', 'full'",
))
}
};
Ok(new)
}
fn with_era_splitting(&self, value: bool) -> PyResult<Self> {
let mut new = self.clone();
new.inner.era_splitting = value;
Ok(new)
}
fn with_mse_loss(&self) -> PyResult<Self> {
let mut new = self.clone();
new.inner.loss_type = LossType::Mse;
Ok(new)
}
fn with_pseudo_huber_loss(&self, delta: f32) -> PyResult<Self> {
if delta <= 0.0 {
return Err(PyValueError::new_err("delta must be > 0.0"));
}
let mut new = self.clone();
new.inner.loss_type = LossType::PseudoHuber { delta };
Ok(new)
}
fn with_binary_logloss(&self) -> PyResult<Self> {
let mut new = self.clone();
new.inner.loss_type = LossType::BinaryLogLoss;
Ok(new)
}
fn with_multiclass_logloss(&self, num_classes: usize) -> PyResult<Self> {
if num_classes < 2 {
return Err(PyValueError::new_err("num_classes must be >= 2"));
}
let mut new = self.clone();
new.inner.loss_type = LossType::MultiClassLogLoss { num_classes };
Ok(new)
}
fn with_monotonic_constraints(&self, constraints: Vec<i32>) -> PyResult<Self> {
let parsed: Result<Vec<MonotonicConstraint>, _> = constraints
.into_iter()
.map(|c| match c {
1 => Ok(MonotonicConstraint::Increasing),
-1 => Ok(MonotonicConstraint::Decreasing),
0 => Ok(MonotonicConstraint::None),
_ => Err(PyValueError::new_err(
"Constraint must be 1 (increasing), -1 (decreasing), or 0 (none)",
)),
})
.collect();
let mut new = self.clone();
new.inner.monotonic_constraints = parsed?;
Ok(new)
}
fn with_constraints(&self, constraints: Vec<PyMonotonicConstraint>) -> PyResult<Self> {
let parsed: Vec<MonotonicConstraint> = constraints.into_iter().map(|c| c.into()).collect();
let mut new = self.clone();
new.inner.monotonic_constraints = parsed;
Ok(new)
}
fn with_interaction_groups(&self, groups: Vec<Vec<usize>>) -> PyResult<Self> {
let mut new = self.clone();
new.inner.interaction_groups = groups;
Ok(new)
}
}
impl PyGBDTConfig {
pub fn inner(&self) -> &GBDTConfig {
&self.inner
}
pub fn from_inner(config: GBDTConfig) -> Self {
Self { inner: config }
}
}
#[pyclass(name = "GBDTModel")]
pub struct PyGBDTModel {
model: GBDTModel,
}
#[pymethods]
impl PyGBDTModel {
#[staticmethod]
#[pyo3(signature = (features, targets, config, feature_names=None, output_dir=None))]
fn train<'py>(
py: Python<'py>,
features: PyReadonlyArray2<'py, f32>,
targets: PyReadonlyArray1<'py, f32>,
config: &PyGBDTConfig,
feature_names: Option<Vec<String>>,
output_dir: Option<String>,
) -> PyResult<Self> {
let features_arr = features.as_array();
let targets_arr = targets.as_array();
let num_rows = features_arr.nrows();
let num_features = features_arr.ncols();
let mut features_flat: Vec<f32> = Vec::with_capacity(num_rows * num_features);
for row in features_arr.rows() {
features_flat.extend(row.iter().copied());
}
let targets_vec: Vec<f32> = targets_arr.to_vec();
let model = py
.allow_threads(|| {
GBDTModel::train(
&features_flat,
num_features,
&targets_vec,
config.inner.clone(),
feature_names,
)
})
.map_err(|e| PyValueError::new_err(e.to_string()))?;
if let Some(ref dir) = output_dir {
model
.save_to_directory(dir, &config.inner, &[ModelFormat::Rkyv])
.map_err(|e| {
PyIOError::new_err(format!("Failed to save to output directory: {}", e))
})?;
}
Ok(Self { model })
}
#[staticmethod]
#[pyo3(signature = (features, targets, era_indices, config, feature_names=None))]
fn train_with_eras<'py>(
py: Python<'py>,
features: PyReadonlyArray2<'py, f32>,
targets: PyReadonlyArray1<'py, f32>,
era_indices: PyReadonlyArray1<'py, u16>,
config: &PyGBDTConfig,
feature_names: Option<Vec<String>>,
) -> PyResult<Self> {
let features_arr = features.as_array();
let targets_arr = targets.as_array();
let era_indices_arr = era_indices.as_array();
let num_rows = features_arr.nrows();
let num_features = features_arr.ncols();
if era_indices_arr.len() != num_rows {
return Err(PyValueError::new_err(format!(
"era_indices length {} doesn't match number of rows {}",
era_indices_arr.len(),
num_rows
)));
}
if !config.inner.era_splitting {
return Err(PyValueError::new_err(
"era_splitting must be True in config when using train_with_eras",
));
}
let mut features_flat: Vec<f32> = Vec::with_capacity(num_rows * num_features);
for row in features_arr.rows() {
features_flat.extend(row.iter().copied());
}
let targets_vec: Vec<f32> = targets_arr.to_vec();
let era_indices_vec: Vec<u16> = era_indices_arr.to_vec();
let model = py
.allow_threads(|| {
GBDTModel::train_with_eras(
&features_flat,
num_features,
&targets_vec,
&era_indices_vec,
config.inner.clone(),
feature_names,
)
})
.map_err(|e| PyValueError::new_err(e.to_string()))?;
Ok(Self { model })
}
#[pyo3(signature = (features))]
fn predict<'py>(
&self,
py: Python<'py>,
features: &Bound<'py, PyAny>,
) -> PyResult<Bound<'py, PyArray1<f32>>> {
let (num_features, raw_features) = extract_features_array(features)?;
validate_feature_count(num_features, self.model.num_features())?;
let predictions = py.allow_threads(|| self.model.predict_raw(&raw_features));
Ok(PyArray1::from_vec(py, predictions))
}
#[pyo3(signature = (features))]
fn predict_with_intervals<'py>(
&self,
py: Python<'py>,
features: &Bound<'py, PyAny>,
) -> PyResult<(
Bound<'py, PyArray1<f32>>,
Bound<'py, PyArray1<f32>>,
Bound<'py, PyArray1<f32>>,
)> {
let (num_features, raw_features) = extract_features_array(features)?;
validate_feature_count(num_features, self.model.num_features())?;
let (preds, lower, upper) =
py.allow_threads(|| self.model.predict_raw_with_intervals(&raw_features));
Ok((
PyArray1::from_vec(py, preds),
PyArray1::from_vec(py, lower),
PyArray1::from_vec(py, upper),
))
}
#[pyo3(signature = (features))]
fn predict_proba<'py>(
&self,
py: Python<'py>,
features: &Bound<'py, PyAny>,
) -> PyResult<Bound<'py, PyArray1<f32>>> {
let (num_features, raw_features) = extract_features_array(features)?;
validate_feature_count(num_features, self.model.num_features())?;
if self.model.is_multiclass() {
return Err(PyValueError::new_err(
"predict_proba() is for binary classification only. \
For multi-class models, use predict_proba_multiclass() instead.",
));
}
let proba = py.allow_threads(|| self.model.predict_proba_raw(&raw_features));
Ok(PyArray1::from_vec(py, proba))
}
#[pyo3(signature = (features, threshold = 0.5))]
fn predict_class<'py>(
&self,
py: Python<'py>,
features: &Bound<'py, PyAny>,
threshold: f32,
) -> PyResult<Bound<'py, PyArray1<u32>>> {
let (num_features, raw_features) = extract_features_array(features)?;
validate_feature_count(num_features, self.model.num_features())?;
if self.model.is_multiclass() {
return Err(PyValueError::new_err(
"predict_class() is for binary classification only. \
For multi-class models, use predict_class_multiclass() instead.",
));
}
let classes = py.allow_threads(|| self.model.predict_class_raw(&raw_features, threshold));
Ok(PyArray1::from_vec(py, classes))
}
#[getter]
fn is_multiclass(&self) -> bool {
self.model.is_multiclass()
}
#[getter]
fn num_classes(&self) -> usize {
self.model.get_num_classes()
}
#[pyo3(signature = (features))]
fn predict_proba_multiclass<'py>(
&self,
py: Python<'py>,
features: &Bound<'py, PyAny>,
) -> PyResult<Bound<'py, numpy::PyArray2<f32>>> {
use numpy::PyArray2;
let (num_features, raw_features) = extract_features_array(features)?;
validate_feature_count(num_features, self.model.num_features())?;
if !self.model.is_multiclass() {
return Err(PyValueError::new_err(
"predict_proba_multiclass() is for multi-class models only. \
For binary classification, use predict_proba() instead.",
));
}
let proba = py.allow_threads(|| self.model.predict_proba_multiclass_raw(&raw_features));
if proba.is_empty() {
return Err(PyValueError::new_err("No predictions returned"));
}
let expected_cols = proba[0].len();
if !proba.iter().all(|row| row.len() == expected_cols) {
return Err(PyValueError::new_err(
"Internal error: jagged probability array returned",
));
}
PyArray2::from_vec2(py, &proba)
.map_err(|e| PyValueError::new_err(format!("Failed to create numpy array: {:?}", e)))
}
#[pyo3(signature = (features))]
fn predict_class_multiclass<'py>(
&self,
py: Python<'py>,
features: &Bound<'py, PyAny>,
) -> PyResult<Bound<'py, PyArray1<u32>>> {
let (num_features, raw_features) = extract_features_array(features)?;
validate_feature_count(num_features, self.model.num_features())?;
if !self.model.is_multiclass() {
return Err(PyValueError::new_err(
"predict_class_multiclass() is for multi-class models only. \
For binary classification, use predict_class() instead.",
));
}
let classes = py.allow_threads(|| self.model.predict_class_multiclass_raw(&raw_features));
Ok(PyArray1::from_vec(py, classes))
}
fn feature_importance<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f32>> {
let importances = self.model.feature_importance();
PyArray1::from_vec(py, importances)
}
#[pyo3(signature = (path, format="rkyv"))]
fn save(&self, path: &str, format: &str) -> PyResult<()> {
let model_format = parse_model_format(format)?;
match model_format {
ModelFormat::Rkyv => serialize::save_model(&self.model, path),
ModelFormat::Bincode => serialize::save_model_bincode(&self.model, path),
}
.map_err(|e| PyIOError::new_err(e.to_string()))
}
#[pyo3(signature = (output_dir, config, formats=None))]
fn save_to_directory(
&self,
output_dir: &str,
config: &PyGBDTConfig,
formats: Option<&Bound<'_, PyAny>>,
) -> PyResult<()> {
let model_formats = match formats {
None => vec![ModelFormat::Rkyv], Some(obj) => {
if let Ok(s) = obj.extract::<String>() {
vec![parse_model_format(&s)?]
} else if let Ok(list) = obj.extract::<Vec<String>>() {
if list.is_empty() {
return Err(PyValueError::new_err("formats list must not be empty"));
}
list.iter()
.map(|s| parse_model_format(s))
.collect::<PyResult<Vec<_>>>()?
} else {
return Err(PyValueError::new_err(
"formats must be a string ('rkyv' or 'bincode') or a list of strings",
));
}
}
};
self.model
.save_to_directory(output_dir, &config.inner, &model_formats)
.map_err(|e| PyIOError::new_err(e.to_string()))
}
#[staticmethod]
#[pyo3(signature = (path, format="rkyv"))]
fn load(path: &str, format: &str) -> PyResult<Self> {
let model_format = parse_model_format(format)?;
let model = match model_format {
ModelFormat::Rkyv => serialize::load_model(path),
ModelFormat::Bincode => serialize::load_model_bincode(path),
}
.map_err(|e| PyIOError::new_err(e.to_string()))?;
Ok(Self { model })
}
#[getter]
fn num_trees(&self) -> usize {
self.model.num_trees()
}
#[getter]
fn num_features(&self) -> usize {
self.model.num_features()
}
#[getter]
fn base_prediction(&self) -> f32 {
self.model.base_prediction()
}
#[getter]
fn conformal_quantile(&self) -> Option<f32> {
self.model.conformal_quantile()
}
#[getter]
fn feature_names(&self) -> Vec<String> {
self.model
.feature_info()
.iter()
.map(|info| info.name.clone())
.collect()
}
fn __repr__(&self) -> String {
format!(
"GBDTModel(num_trees={}, num_features={}, base_prediction={:.4})",
self.model.num_trees(),
self.model.num_features(),
self.model.base_prediction()
)
}
}
pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMonotonicConstraint>()?;
m.add_class::<PyGBDTConfig>()?;
m.add_class::<PyGBDTModel>()?;
Ok(())
}