use pyo3::prelude::*;
use crate::tuner::{
EvalStrategy, GridStrategy, ModelFormat, OptimizationMetric, TaskType, TuningMode,
};
#[pyclass(name = "TuningMode", eq)]
#[derive(Clone, PartialEq)]
pub struct PyTuningMode {
pub(crate) inner: TuningMode,
}
#[pymethods]
impl PyTuningMode {
#[staticmethod]
fn optimistic() -> Self {
Self {
inner: TuningMode::Optimistic,
}
}
#[staticmethod]
fn realistic() -> Self {
Self {
inner: TuningMode::Realistic,
}
}
#[getter]
fn is_optimistic(&self) -> bool {
self.inner.is_optimistic()
}
#[getter]
fn is_realistic(&self) -> bool {
self.inner.is_realistic()
}
fn __repr__(&self) -> &'static str {
match self.inner {
TuningMode::Optimistic => "TuningMode.optimistic()",
TuningMode::Realistic => "TuningMode.realistic()",
}
}
}
impl From<TuningMode> for PyTuningMode {
fn from(mode: TuningMode) -> Self {
Self { inner: mode }
}
}
#[pyclass(name = "OptimizationMetric", eq)]
#[derive(Clone, PartialEq)]
pub struct PyOptimizationMetric {
pub(crate) inner: OptimizationMetric,
}
#[pymethods]
impl PyOptimizationMetric {
#[staticmethod]
fn val_loss() -> Self {
Self {
inner: OptimizationMetric::ValidationLoss,
}
}
#[staticmethod]
fn f1_score() -> Self {
Self {
inner: OptimizationMetric::F1Score,
}
}
#[staticmethod]
fn roc_auc() -> Self {
Self {
inner: OptimizationMetric::RocAuc,
}
}
#[getter]
fn higher_is_better(&self) -> bool {
self.inner.higher_is_better()
}
#[getter]
fn name(&self) -> &'static str {
self.inner.name()
}
fn __repr__(&self) -> &'static str {
match self.inner {
OptimizationMetric::ValidationLoss => "OptimizationMetric.val_loss()",
OptimizationMetric::F1Score => "OptimizationMetric.f1_score()",
OptimizationMetric::RocAuc => "OptimizationMetric.roc_auc()",
}
}
}
impl From<OptimizationMetric> for PyOptimizationMetric {
fn from(metric: OptimizationMetric) -> Self {
Self { inner: metric }
}
}
#[pyclass(name = "TaskType", eq)]
#[derive(Clone, PartialEq)]
pub struct PyTaskType {
pub(crate) inner: TaskType,
}
#[pymethods]
impl PyTaskType {
#[staticmethod]
fn regression() -> Self {
Self {
inner: TaskType::Regression,
}
}
#[staticmethod]
fn binary_classification() -> Self {
Self {
inner: TaskType::BinaryClassification,
}
}
#[staticmethod]
fn multi_class_classification() -> Self {
Self {
inner: TaskType::MultiClassClassification,
}
}
#[getter]
fn is_classification(&self) -> bool {
self.inner.is_classification()
}
#[getter]
fn is_binary(&self) -> bool {
self.inner.is_binary()
}
#[getter]
fn is_regression(&self) -> bool {
self.inner.is_regression()
}
fn __repr__(&self) -> &'static str {
match self.inner {
TaskType::Regression => "TaskType.regression()",
TaskType::BinaryClassification => "TaskType.binary_classification()",
TaskType::MultiClassClassification => "TaskType.multi_class_classification()",
}
}
}
impl From<TaskType> for PyTaskType {
fn from(task_type: TaskType) -> Self {
Self { inner: task_type }
}
}
#[pyclass(name = "ModelFormat", eq)]
#[derive(Clone, PartialEq)]
pub struct PyModelFormat {
pub(crate) inner: ModelFormat,
}
#[pymethods]
impl PyModelFormat {
#[staticmethod]
fn rkyv() -> Self {
Self {
inner: ModelFormat::Rkyv,
}
}
#[staticmethod]
fn bincode() -> Self {
Self {
inner: ModelFormat::Bincode,
}
}
#[getter]
fn extension(&self) -> &'static str {
self.inner.extension()
}
#[getter]
fn filename(&self) -> &'static str {
self.inner.filename()
}
fn __repr__(&self) -> &'static str {
match self.inner {
ModelFormat::Rkyv => "ModelFormat.rkyv()",
ModelFormat::Bincode => "ModelFormat.bincode()",
}
}
}
impl From<ModelFormat> for PyModelFormat {
fn from(format: ModelFormat) -> Self {
Self { inner: format }
}
}
#[pyclass(name = "GridStrategy", eq)]
#[derive(Clone, PartialEq)]
pub struct PyGridStrategy {
pub(crate) inner: GridStrategy,
}
#[pymethods]
impl PyGridStrategy {
#[staticmethod]
#[pyo3(signature = (points_per_dim=3))]
fn cartesian(points_per_dim: usize) -> PyResult<Self> {
if points_per_dim < 2 {
return Err(pyo3::exceptions::PyValueError::new_err(
"points_per_dim must be >= 2",
));
}
Ok(Self {
inner: GridStrategy::cartesian(points_per_dim),
})
}
#[staticmethod]
fn lhs(n_samples: usize) -> PyResult<Self> {
if n_samples < 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"n_samples must be >= 1",
));
}
Ok(Self {
inner: GridStrategy::lhs(n_samples),
})
}
#[staticmethod]
fn random(n_samples: usize) -> PyResult<Self> {
if n_samples < 1 {
return Err(pyo3::exceptions::PyValueError::new_err(
"n_samples must be >= 1",
));
}
Ok(Self {
inner: GridStrategy::random(n_samples),
})
}
fn num_candidates(&self, num_params: usize) -> usize {
self.inner.num_candidates(num_params)
}
fn __repr__(&self) -> String {
match self.inner {
GridStrategy::Cartesian { points_per_dim } => {
format!("GridStrategy.cartesian({})", points_per_dim)
}
GridStrategy::LatinHypercube { n_samples } => {
format!("GridStrategy.lhs({})", n_samples)
}
GridStrategy::Random { n_samples } => {
format!("GridStrategy.random({})", n_samples)
}
}
}
}
impl From<GridStrategy> for PyGridStrategy {
fn from(strategy: GridStrategy) -> Self {
Self { inner: strategy }
}
}
#[pyclass(name = "EvalStrategy", eq)]
#[derive(Clone, PartialEq)]
pub struct PyEvalStrategy {
pub(crate) inner: EvalStrategy,
}
#[pymethods]
impl PyEvalStrategy {
#[staticmethod]
#[pyo3(signature = (validation_ratio=0.2))]
fn holdout(validation_ratio: f32) -> PyResult<Self> {
if validation_ratio <= 0.0 || validation_ratio >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"validation_ratio must be in (0, 1)",
));
}
Ok(Self {
inner: EvalStrategy::holdout(validation_ratio),
})
}
#[staticmethod]
#[pyo3(signature = (calibration_ratio=0.2, coverage=0.9))]
fn conformal(calibration_ratio: f32, coverage: f32) -> PyResult<Self> {
if calibration_ratio <= 0.0 || calibration_ratio >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"calibration_ratio must be in (0, 1)",
));
}
if coverage <= 0.0 || coverage >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"coverage must be in (0, 1)",
));
}
Ok(Self {
inner: EvalStrategy::conformal(calibration_ratio, coverage),
})
}
#[staticmethod]
fn auto(num_samples: usize) -> Self {
Self {
inner: EvalStrategy::auto(num_samples),
}
}
fn with_folds(&self, folds: usize) -> PyResult<Self> {
if folds == 0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"folds must be >= 1",
));
}
Ok(Self {
inner: self.inner.with_folds(folds),
})
}
#[getter]
fn folds(&self) -> usize {
self.inner.folds()
}
fn __repr__(&self) -> String {
match self.inner {
EvalStrategy::Holdout {
validation_ratio,
folds,
} => {
if folds == 1 {
format!("EvalStrategy.holdout({})", validation_ratio)
} else {
format!(
"EvalStrategy.holdout({}).with_folds({})",
validation_ratio, folds
)
}
}
EvalStrategy::Conformal {
calibration_ratio,
quantile,
folds,
} => {
if folds == 1 {
format!(
"EvalStrategy.conformal({}, {})",
calibration_ratio, quantile
)
} else {
format!(
"EvalStrategy.conformal({}, {}).with_folds({})",
calibration_ratio, quantile, folds
)
}
}
}
}
}
impl From<EvalStrategy> for PyEvalStrategy {
fn from(strategy: EvalStrategy) -> Self {
Self { inner: strategy }
}
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyTuningMode>()?;
m.add_class::<PyOptimizationMetric>()?;
m.add_class::<PyTaskType>()?;
m.add_class::<PyModelFormat>()?;
m.add_class::<PyGridStrategy>()?;
m.add_class::<PyEvalStrategy>()?;
Ok(())
}