use std::collections::HashMap;
use pyo3::prelude::*;
use crate::tuner::{
ModelFormat, ParamBounds, ParameterSpace, SpacePreset, TunerConfig, TunerPreset,
};
use super::enums::{
PyEvalStrategy, PyGridStrategy, PyModelFormat, PyOptimizationMetric, PyTaskType, PyTuningMode,
};
#[pyclass(name = "ParamBounds")]
#[derive(Clone)]
pub struct PyParamBounds {
pub(crate) inner: ParamBounds,
}
#[pymethods]
impl PyParamBounds {
#[staticmethod]
fn continuous(min: f32, max: f32) -> PyResult<Self> {
if min >= max {
return Err(pyo3::exceptions::PyValueError::new_err(
"min must be less than max",
));
}
Ok(Self {
inner: ParamBounds::continuous(min, max),
})
}
#[staticmethod]
fn log_continuous(min: f32, max: f32) -> PyResult<Self> {
if min <= 0.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"min must be positive for log scaling",
));
}
if min >= max {
return Err(pyo3::exceptions::PyValueError::new_err(
"min must be less than max",
));
}
Ok(Self {
inner: ParamBounds::log_continuous(min, max),
})
}
#[staticmethod]
fn discrete(min: usize, max: usize) -> PyResult<Self> {
if min >= max {
return Err(pyo3::exceptions::PyValueError::new_err(
"min must be less than max",
));
}
Ok(Self {
inner: ParamBounds::discrete(min, max),
})
}
#[staticmethod]
fn discrete_step(min: usize, max: usize, step: usize) -> PyResult<Self> {
if min >= max {
return Err(pyo3::exceptions::PyValueError::new_err(
"min must be less than max",
));
}
if step == 0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"step must be positive",
));
}
Ok(Self {
inner: ParamBounds::discrete_step(min, max, step),
})
}
fn clamp(&self, value: f32) -> f32 {
self.inner.clamp(value)
}
fn contains(&self, value: f32) -> bool {
self.inner.contains(value)
}
#[getter]
fn min_value(&self) -> f32 {
self.inner.min_value()
}
#[getter]
fn max_value(&self) -> f32 {
self.inner.max_value()
}
#[getter]
fn is_log_scale(&self) -> bool {
self.inner.is_log_scale()
}
fn __repr__(&self) -> String {
match &self.inner {
ParamBounds::Continuous {
min,
max,
log_scale,
} => {
if *log_scale {
format!("ParamBounds.log_continuous({}, {})", min, max)
} else {
format!("ParamBounds.continuous({}, {})", min, max)
}
}
ParamBounds::Discrete { min, max, step } => {
if *step == 1 {
format!("ParamBounds.discrete({}, {})", min, max)
} else {
format!("ParamBounds.discrete_step({}, {}, {})", min, max, step)
}
}
}
}
}
impl From<ParamBounds> for PyParamBounds {
fn from(bounds: ParamBounds) -> Self {
Self { inner: bounds }
}
}
#[pyclass(name = "ParameterSpace")]
#[derive(Clone)]
pub struct PyParameterSpace {
pub(crate) inner: ParameterSpace,
}
#[pymethods]
impl PyParameterSpace {
#[new]
fn new() -> Self {
Self {
inner: ParameterSpace::new(),
}
}
#[staticmethod]
fn preset(preset: &str) -> PyResult<Self> {
let preset = match preset.to_lowercase().as_str() {
"minimal" => SpacePreset::Minimal,
"regression" => SpacePreset::Regression,
"classification" => SpacePreset::Classification,
"exhaustive" => SpacePreset::Exhaustive,
"universal" => SpacePreset::Universal,
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"unknown preset (use: minimal, regression, classification, exhaustive, universal)",
));
}
};
Ok(Self {
inner: ParameterSpace::with_preset(preset),
})
}
fn with_param(&self, name: &str, bounds: &PyParamBounds, center: f32) -> Self {
Self {
inner: self
.inner
.clone()
.with_param(name, bounds.inner.clone(), center),
}
}
fn add_continuous(&self, name: &str, min: f32, max: f32, center: f32) -> PyResult<Self> {
if min >= max {
return Err(pyo3::exceptions::PyValueError::new_err(
"min must be less than max",
));
}
Ok(Self {
inner: self
.inner
.clone()
.with_param(name, ParamBounds::continuous(min, max), center),
})
}
fn add_log_continuous(&self, name: &str, min: f32, max: f32, center: f32) -> PyResult<Self> {
if min <= 0.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"min must be positive for log scaling",
));
}
if min >= max {
return Err(pyo3::exceptions::PyValueError::new_err(
"min must be less than max",
));
}
Ok(Self {
inner: self.inner.clone().with_param(
name,
ParamBounds::log_continuous(min, max),
center,
),
})
}
fn add_discrete(&self, name: &str, min: usize, max: usize, center: f32) -> PyResult<Self> {
if min >= max {
return Err(pyo3::exceptions::PyValueError::new_err(
"min must be less than max",
));
}
Ok(Self {
inner: self
.inner
.clone()
.with_param(name, ParamBounds::discrete(min, max), center),
})
}
fn add_integer_range(
&self,
name: &str,
min: usize,
max: usize,
step: usize,
center: f32,
) -> PyResult<Self> {
if min >= max {
return Err(pyo3::exceptions::PyValueError::new_err(
"min must be less than max",
));
}
if step == 0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"step must be positive",
));
}
Ok(Self {
inner: self.inner.clone().with_param(
name,
ParamBounds::discrete_step(min, max, step),
center,
),
})
}
fn without_param(&self, name: &str) -> Self {
Self {
inner: self.inner.clone().without_param(name),
}
}
fn __len__(&self) -> usize {
self.inner.len()
}
#[getter]
fn is_empty(&self) -> bool {
self.inner.is_empty()
}
fn param_names(&self) -> Vec<String> {
self.inner.param_names()
}
fn centers(&self) -> HashMap<String, f32> {
self.inner.centers()
}
fn validate(&self) -> PyResult<()> {
self.inner
.validate()
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
}
fn __repr__(&self) -> String {
let names = self.inner.param_names();
format!("ParameterSpace([{}])", names.join(", "))
}
}
impl From<ParameterSpace> for PyParameterSpace {
fn from(space: ParameterSpace) -> Self {
Self { inner: space }
}
}
#[pyclass(name = "TunerConfig")]
#[derive(Clone)]
pub struct PyTunerConfig {
pub(crate) inner: TunerConfig,
}
#[pymethods]
impl PyTunerConfig {
#[new]
fn new() -> Self {
Self {
inner: TunerConfig::default(),
}
}
#[staticmethod]
fn preset(preset: &str) -> PyResult<Self> {
let preset = match preset.to_lowercase().as_str() {
"smoketest" | "smoke_test" => TunerPreset::SmokeTest,
"quick" => TunerPreset::Quick,
"balanced" => TunerPreset::Balanced,
"thorough" => TunerPreset::Thorough,
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"unknown preset (use: smoketest, quick, balanced, thorough)",
));
}
};
Ok(Self {
inner: TunerConfig::default().with_preset(preset),
})
}
fn with_preset(&self, preset: &str) -> PyResult<Self> {
let preset = match preset.to_lowercase().as_str() {
"smoketest" | "smoke_test" => TunerPreset::SmokeTest,
"quick" => TunerPreset::Quick,
"balanced" => TunerPreset::Balanced,
"thorough" => TunerPreset::Thorough,
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(
"unknown preset (use: smoketest, quick, balanced, thorough)",
));
}
};
Ok(Self {
inner: self.inner.clone().with_preset(preset),
})
}
fn with_space(&self, space: &PyParameterSpace) -> Self {
Self {
inner: self.inner.clone().with_space(space.inner.clone()),
}
}
fn with_iterations(&self, n: usize) -> PyResult<Self> {
if n == 0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"iterations must be > 0",
));
}
Ok(Self {
inner: self.inner.clone().with_iterations(n),
})
}
fn with_initial_spread(&self, spread: f32) -> PyResult<Self> {
if spread <= 0.0 || spread > 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"initial_spread must be in (0, 1]",
));
}
Ok(Self {
inner: self.inner.clone().with_initial_spread(spread),
})
}
fn with_zoom_factor(&self, factor: f32) -> PyResult<Self> {
if factor <= 0.0 || factor >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"zoom_factor must be in (0, 1)",
));
}
Ok(Self {
inner: self.inner.clone().with_zoom_factor(factor),
})
}
fn with_grid_strategy(&self, strategy: &PyGridStrategy) -> Self {
Self {
inner: self.inner.clone().with_grid_strategy(strategy.inner),
}
}
fn with_eval_strategy(&self, strategy: &PyEvalStrategy) -> Self {
Self {
inner: self.inner.clone().with_eval_strategy(strategy.inner),
}
}
fn with_parallel(&self, enabled: bool) -> Self {
Self {
inner: self.inner.clone().with_parallel(enabled),
}
}
fn with_n_parallel(&self, n: usize) -> Self {
Self {
inner: self.inner.clone().with_n_parallel(n),
}
}
fn with_num_rounds(&self, rounds: usize) -> PyResult<Self> {
if rounds == 0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"num_rounds must be > 0",
));
}
Ok(Self {
inner: self.inner.clone().with_num_rounds(rounds),
})
}
fn with_early_stopping(&self, rounds: usize, validation_ratio: f32) -> PyResult<Self> {
if validation_ratio <= 0.0 || validation_ratio >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"validation_ratio must be in (0, 1)",
));
}
Ok(Self {
inner: self
.inner
.clone()
.with_early_stopping(rounds, validation_ratio),
})
}
fn without_early_stopping(&self) -> Self {
Self {
inner: self.inner.clone().without_early_stopping(),
}
}
fn with_improvement_threshold(&self, threshold: f32) -> PyResult<Self> {
if threshold < 0.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"threshold must be non-negative",
));
}
Ok(Self {
inner: self.inner.clone().with_improvement_threshold(threshold),
})
}
fn with_min_f1_score(&self, min_f1: f32) -> PyResult<Self> {
if min_f1 < 0.0 || min_f1 > 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"min_f1 must be in [0, 1]",
));
}
Ok(Self {
inner: self.inner.clone().with_min_f1_score(min_f1),
})
}
fn with_tuning_mode(&self, mode: &PyTuningMode) -> Self {
Self {
inner: self.inner.clone().with_tuning_mode(mode.inner),
}
}
fn optimistic(&self) -> Self {
Self {
inner: self.inner.clone().optimistic(),
}
}
fn realistic(&self) -> Self {
Self {
inner: self.inner.clone().realistic(),
}
}
fn with_seed(&self, seed: u64) -> Self {
Self {
inner: self.inner.clone().with_seed(seed),
}
}
fn with_verbose(&self, verbose: bool) -> Self {
Self {
inner: self.inner.clone().with_verbose(verbose),
}
}
fn with_optimization_metric(&self, metric: &PyOptimizationMetric) -> Self {
Self {
inner: self.inner.clone().with_optimization_metric(metric.inner),
}
}
fn with_task_type(&self, task_type: &PyTaskType) -> Self {
Self {
inner: self.inner.clone().with_task_type(task_type.inner),
}
}
fn with_output_dir(&self, path: &str) -> Self {
Self {
inner: self.inner.clone().with_output_dir(path),
}
}
fn with_save_model_formats(&self, formats: Vec<PyModelFormat>) -> Self {
let rust_formats: Vec<ModelFormat> = formats.iter().map(|f| f.inner).collect();
Self {
inner: self.inner.clone().with_save_model_formats(rust_formats),
}
}
#[getter]
fn n_iterations(&self) -> usize {
self.inner.n_iterations
}
#[getter]
fn initial_spread(&self) -> f32 {
self.inner.initial_spread
}
#[getter]
fn zoom_factor(&self) -> f32 {
self.inner.zoom_factor
}
#[getter]
fn num_rounds(&self) -> usize {
self.inner.num_rounds
}
#[getter]
fn early_stopping_rounds(&self) -> usize {
self.inner.early_stopping_rounds
}
#[getter]
fn validation_ratio(&self) -> f32 {
self.inner.validation_ratio
}
#[getter]
fn improvement_threshold(&self) -> f32 {
self.inner.improvement_threshold
}
#[getter]
fn min_f1_score(&self) -> f32 {
self.inner.min_f1_score
}
#[getter]
fn parallel_trials(&self) -> bool {
self.inner.parallel_trials
}
#[getter]
fn seed(&self) -> u64 {
self.inner.seed
}
#[getter]
fn verbose(&self) -> bool {
self.inner.verbose
}
#[getter]
fn space(&self) -> PyParameterSpace {
self.inner.space.clone().into()
}
#[getter]
fn tuning_mode(&self) -> PyTuningMode {
self.inner.tuning_mode.into()
}
#[getter]
fn optimization_metric(&self) -> PyOptimizationMetric {
self.inner.optimization_metric.into()
}
#[getter]
fn task_type(&self) -> PyTaskType {
self.inner.task_type.into()
}
#[getter]
fn grid_strategy(&self) -> PyGridStrategy {
self.inner.grid_strategy.into()
}
#[getter]
fn eval_strategy(&self) -> PyEvalStrategy {
self.inner.eval_strategy.into()
}
#[getter]
fn output_dir(&self) -> Option<String> {
self.inner
.output_dir
.as_ref()
.map(|p| p.to_string_lossy().to_string())
}
fn validate(&self) -> PyResult<()> {
self.inner
.validate()
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
}
fn estimated_trials(&self) -> usize {
self.inner.estimated_trials()
}
fn spread_for_iteration(&self, iteration: usize) -> f32 {
self.inner.spread_for_iteration(iteration)
}
fn __repr__(&self) -> String {
format!(
"TunerConfig(iterations={}, rounds={}, grid={:?})",
self.inner.n_iterations, self.inner.num_rounds, self.inner.grid_strategy
)
}
}
impl From<TunerConfig> for PyTunerConfig {
fn from(config: TunerConfig) -> Self {
Self { inner: config }
}
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyParamBounds>()?;
m.add_class::<PyParameterSpace>()?;
m.add_class::<PyTunerConfig>()?;
Ok(())
}