use numpy::PyArray1;
use pyo3::prelude::*;
use crate::dataset::{split_holdout, split_kfold, HoldoutSplit, KFoldSplit};
#[pyclass(name = "HoldoutSplit")]
pub struct PyHoldoutSplit {
inner: HoldoutSplit,
}
#[pymethods]
impl PyHoldoutSplit {
#[staticmethod]
#[pyo3(signature = (n_samples, val_ratio=0.2, calib_ratio=0.0, seed=42))]
fn create(n_samples: usize, val_ratio: f32, calib_ratio: f32, seed: u64) -> PyResult<Self> {
if n_samples == 0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"n_samples must be > 0",
));
}
if val_ratio < 0.0 || val_ratio >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"val_ratio must be in [0.0, 1.0)",
));
}
if calib_ratio < 0.0 || calib_ratio >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"calib_ratio must be in [0.0, 1.0)",
));
}
if val_ratio + calib_ratio >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"val_ratio + calib_ratio must be < 1.0",
));
}
Ok(Self {
inner: split_holdout(n_samples, val_ratio, calib_ratio, seed),
})
}
#[getter]
fn train_indices<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<usize>> {
PyArray1::from_slice(py, &self.inner.train)
}
#[getter]
fn validation_indices<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<usize>> {
PyArray1::from_slice(py, &self.inner.validation)
}
#[getter]
fn calibration_indices<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<usize>> {
PyArray1::from_slice(py, &self.inner.calibration)
}
#[getter]
fn train_len(&self) -> usize {
self.inner.train_len()
}
#[getter]
fn val_len(&self) -> usize {
self.inner.val_len()
}
#[getter]
fn calib_len(&self) -> usize {
self.inner.calib_len()
}
fn __repr__(&self) -> String {
format!(
"HoldoutSplit(train={}, val={}, calib={})",
self.inner.train_len(),
self.inner.val_len(),
self.inner.calib_len()
)
}
}
impl From<HoldoutSplit> for PyHoldoutSplit {
fn from(split: HoldoutSplit) -> Self {
Self { inner: split }
}
}
#[pyclass(name = "KFoldSplit")]
pub struct PyKFoldSplit {
inner: KFoldSplit,
}
#[pymethods]
impl PyKFoldSplit {
#[staticmethod]
#[pyo3(signature = (n_samples, k=5, seed=42))]
fn create(n_samples: usize, k: usize, seed: u64) -> PyResult<Self> {
if n_samples == 0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"n_samples must be > 0",
));
}
if k < 2 {
return Err(pyo3::exceptions::PyValueError::new_err(
"k must be >= 2 for cross-validation",
));
}
if k > n_samples {
return Err(pyo3::exceptions::PyValueError::new_err(
"k cannot exceed n_samples",
));
}
Ok(Self {
inner: split_kfold(n_samples, k, seed),
})
}
#[getter]
fn k(&self) -> usize {
self.inner.k()
}
fn get_fold<'py>(
&self,
py: Python<'py>,
fold_idx: usize,
) -> PyResult<(Bound<'py, PyArray1<usize>>, Bound<'py, PyArray1<usize>>)> {
if fold_idx >= self.inner.k() {
return Err(pyo3::exceptions::PyIndexError::new_err(format!(
"fold_idx {} out of range (k={})",
fold_idx,
self.inner.k()
)));
}
let (train, val) = self.inner.get_fold(fold_idx);
Ok((PyArray1::from_vec(py, train), PyArray1::from_vec(py, val)))
}
fn fold_sizes(&self) -> Vec<usize> {
self.inner.folds.iter().map(|f| f.len()).collect()
}
fn __repr__(&self) -> String {
let sizes = self.fold_sizes();
let min_size = sizes.iter().min().unwrap_or(&0);
let max_size = sizes.iter().max().unwrap_or(&0);
format!(
"KFoldSplit(k={}, fold_sizes={}..{})",
self.inner.k(),
min_size,
max_size
)
}
}
impl From<KFoldSplit> for PyKFoldSplit {
fn from(split: KFoldSplit) -> Self {
Self { inner: split }
}
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyHoldoutSplit>()?;
m.add_class::<PyKFoldSplit>()?;
Ok(())
}