use pyo3::prelude::*;
use crate::dataset::{DataPipeline, PipelineConfig, PipelineState};
use super::types::{PyBinnedDataset, PyFeatureInfo};
#[pyclass(name = "PipelineConfig")]
#[derive(Clone)]
pub struct PyPipelineConfig {
pub(crate) inner: PipelineConfig,
}
#[pymethods]
impl PyPipelineConfig {
#[new]
fn new() -> Self {
Self {
inner: PipelineConfig::default(),
}
}
fn with_num_bins(&self, num_bins: usize) -> PyResult<Self> {
if num_bins == 0 || num_bins > 255 {
return Err(pyo3::exceptions::PyValueError::new_err(
"num_bins must be between 1 and 255",
));
}
Ok(Self {
inner: self.inner.clone().with_num_bins(num_bins),
})
}
#[pyo3(signature = (eps=0.001, confidence=0.99, min_count=5))]
fn with_cms_params(&self, eps: f64, confidence: f64, min_count: u64) -> PyResult<Self> {
if eps <= 0.0 || eps >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"eps must be between 0 and 1",
));
}
if confidence <= 0.0 || confidence >= 1.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"confidence must be between 0 and 1",
));
}
Ok(Self {
inner: self
.inner
.clone()
.with_cms_params(eps, confidence, min_count),
})
}
fn with_smoothing(&self, smoothing: f64) -> PyResult<Self> {
if smoothing < 0.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"smoothing must be non-negative",
));
}
Ok(Self {
inner: self.inner.clone().with_smoothing(smoothing),
})
}
#[getter]
fn num_bins(&self) -> usize {
self.inner.num_bins
}
#[getter]
fn cms_eps(&self) -> f64 {
self.inner.cms_eps
}
#[getter]
fn cms_confidence(&self) -> f64 {
self.inner.cms_confidence
}
#[getter]
fn min_category_count(&self) -> u64 {
self.inner.min_category_count
}
#[getter]
fn target_encoding_smoothing(&self) -> f64 {
self.inner.target_encoding_smoothing
}
fn __repr__(&self) -> String {
format!(
"PipelineConfig(num_bins={}, cms_eps={}, min_count={}, smoothing={})",
self.inner.num_bins,
self.inner.cms_eps,
self.inner.min_category_count,
self.inner.target_encoding_smoothing
)
}
}
#[pyclass(name = "PipelineState")]
#[derive(Clone)]
pub struct PyPipelineState {
pub(crate) inner: PipelineState,
}
#[pymethods]
impl PyPipelineState {
#[getter]
fn feature_info(&self) -> Vec<PyFeatureInfo> {
self.inner.feature_info.iter().map(|fi| fi.into()).collect()
}
#[getter]
fn column_order(&self) -> Vec<String> {
self.inner.column_order.clone()
}
#[getter]
fn categorical_indices(&self) -> Vec<usize> {
self.inner.categorical_indices.clone()
}
#[getter]
fn num_categorical(&self) -> usize {
self.inner.categorical_encodings.len()
}
fn __len__(&self) -> usize {
self.inner.column_order.len()
}
fn __repr__(&self) -> String {
format!(
"PipelineState(columns={}, categorical={})",
self.inner.column_order.len(),
self.inner.categorical_encodings.len()
)
}
}
impl From<PipelineState> for PyPipelineState {
fn from(state: PipelineState) -> Self {
Self { inner: state }
}
}
#[pyclass(name = "DataPipeline")]
pub struct PyDataPipeline {
inner: DataPipeline,
}
#[pymethods]
impl PyDataPipeline {
#[new]
#[pyo3(signature = (config=None))]
fn new(config: Option<PyPipelineConfig>) -> Self {
let cfg = config.map(|c| c.inner).unwrap_or_default();
Self {
inner: DataPipeline::new(cfg),
}
}
#[staticmethod]
fn default() -> Self {
Self {
inner: DataPipeline::default_config(),
}
}
#[pyo3(signature = (path, target, categoricals=None))]
fn load_csv_for_training(
&self,
py: Python<'_>,
path: &str,
target: &str,
categoricals: Option<Vec<String>>,
) -> PyResult<(PyBinnedDataset, PyPipelineState)> {
let cat_refs: Option<Vec<&str>> = categoricals
.as_ref()
.map(|v| v.iter().map(|s| s.as_str()).collect());
let result = py.allow_threads(|| {
self.inner
.load_csv_for_training(path, target, cat_refs.as_deref())
});
match result {
Ok((dataset, state)) => {
Ok((PyBinnedDataset::from(dataset), PyPipelineState::from(state)))
}
Err(e) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
"Failed to process CSV for training: {}",
e
))),
}
}
#[pyo3(signature = (path, target, categoricals=None))]
fn load_parquet_for_training(
&self,
py: Python<'_>,
path: &str,
target: &str,
categoricals: Option<Vec<String>>,
) -> PyResult<(PyBinnedDataset, PyPipelineState)> {
let cat_refs: Option<Vec<&str>> = categoricals
.as_ref()
.map(|v| v.iter().map(|s| s.as_str()).collect());
let result = py.allow_threads(|| {
self.inner
.load_parquet_for_training(path, target, cat_refs.as_deref())
});
match result {
Ok((dataset, state)) => {
Ok((PyBinnedDataset::from(dataset), PyPipelineState::from(state)))
}
Err(e) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
"Failed to process Parquet for training: {}",
e
))),
}
}
fn load_csv_for_inference(
&self,
py: Python<'_>,
path: &str,
state: &PyPipelineState,
) -> PyResult<PyBinnedDataset> {
let result = py.allow_threads(|| self.inner.load_csv_for_inference(path, &state.inner));
match result {
Ok(dataset) => Ok(PyBinnedDataset::from(dataset)),
Err(e) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
"Failed to process CSV for inference: {}",
e
))),
}
}
fn load_parquet_for_inference(
&self,
py: Python<'_>,
path: &str,
state: &PyPipelineState,
) -> PyResult<PyBinnedDataset> {
let result = py.allow_threads(|| self.inner.load_parquet_for_inference(path, &state.inner));
match result {
Ok(dataset) => Ok(PyBinnedDataset::from(dataset)),
Err(e) => Err(pyo3::exceptions::PyRuntimeError::new_err(format!(
"Failed to process Parquet for inference: {}",
e
))),
}
}
fn __repr__(&self) -> &'static str {
"DataPipeline(...)"
}
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyPipelineConfig>()?;
m.add_class::<PyPipelineState>()?;
m.add_class::<PyDataPipeline>()?;
Ok(())
}