use super::common::*;
use pyo3::types::PyDict;
use pyo3::Bound;
use sklears_core::traits::{Fit, Predict, Score, Trained};
use sklears_linear::{BayesianRidge, BayesianRidgeConfig};
#[derive(Debug, Clone)]
pub struct PyBayesianRidgeConfig {
pub max_iter: usize,
pub tol: f64,
pub alpha_init: Option<f64>,
pub lambda_init: Option<f64>,
pub fit_intercept: bool,
pub compute_score: bool,
pub copy_x: bool,
}
impl Default for PyBayesianRidgeConfig {
fn default() -> Self {
Self {
max_iter: 300,
tol: 1e-3,
alpha_init: Some(1.0),
lambda_init: Some(1.0),
fit_intercept: true,
compute_score: false,
copy_x: true,
}
}
}
impl From<PyBayesianRidgeConfig> for BayesianRidgeConfig {
fn from(py_config: PyBayesianRidgeConfig) -> Self {
BayesianRidgeConfig {
max_iter: py_config.max_iter,
tol: py_config.tol,
alpha_init: py_config
.alpha_init
.unwrap_or_else(|| BayesianRidgeConfig::default().alpha_init),
lambda_init: py_config
.lambda_init
.unwrap_or_else(|| BayesianRidgeConfig::default().lambda_init),
fit_intercept: py_config.fit_intercept,
compute_score: py_config.compute_score,
}
}
}
#[pyclass(name = "BayesianRidge")]
pub struct PyBayesianRidge {
py_config: PyBayesianRidgeConfig,
fitted_model: Option<BayesianRidge<Trained>>,
}
#[pymethods]
impl PyBayesianRidge {
#[new]
#[pyo3(signature = (max_iter=300, tol=1e-3, alpha_init=1.0, lambda_init=1.0, fit_intercept=true, compute_score=false, copy_x=true))]
fn new(
max_iter: usize,
tol: f64,
alpha_init: f64,
lambda_init: f64,
fit_intercept: bool,
compute_score: bool,
copy_x: bool,
) -> PyResult<Self> {
if max_iter == 0 {
return Err(PyValueError::new_err("max_iter must be greater than 0"));
}
if tol <= 0.0 {
return Err(PyValueError::new_err("tol must be positive"));
}
if alpha_init <= 0.0 {
return Err(PyValueError::new_err("alpha_init must be positive"));
}
if lambda_init <= 0.0 {
return Err(PyValueError::new_err("lambda_init must be positive"));
}
let py_config = PyBayesianRidgeConfig {
max_iter,
tol,
alpha_init: Some(alpha_init),
lambda_init: Some(lambda_init),
fit_intercept,
compute_score,
copy_x,
};
Ok(Self {
py_config,
fitted_model: None,
})
}
fn fit(&mut self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<()> {
let x_array = pyarray_to_core_array2(x)?;
let y_array = pyarray_to_core_array1(y)?;
validate_fit_arrays(&x_array, &y_array)?;
let model = BayesianRidge::new()
.max_iter(self.py_config.max_iter)
.tol(self.py_config.tol)
.fit_intercept(self.py_config.fit_intercept)
.compute_score(self.py_config.compute_score);
match model.fit(&x_array, &y_array) {
Ok(fitted_model) => {
self.fitted_model = Some(fitted_model);
Ok(())
}
Err(e) => Err(PyValueError::new_err(format!(
"Failed to fit Bayesian Ridge model: {:?}",
e
))),
}
}
fn predict(&self, py: Python<'_>, x: PyReadonlyArray2<f64>) -> PyResult<Py<PyArray1<f64>>> {
let fitted = self
.fitted_model
.as_ref()
.ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
let x_array = pyarray_to_core_array2(x)?;
validate_predict_array(&x_array)?;
match fitted.predict(&x_array) {
Ok(predictions) => Ok(core_array1_to_py(py, &predictions)),
Err(e) => Err(PyValueError::new_err(format!("Prediction failed: {:?}", e))),
}
}
#[getter]
fn coef_(&self, py: Python<'_>) -> PyResult<Py<PyArray1<f64>>> {
let fitted = self
.fitted_model
.as_ref()
.ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
let coef = fitted
.coef()
.map_err(|e| PyValueError::new_err(format!("Failed to get coefficients: {:?}", e)))?;
Ok(core_array1_to_py(py, coef))
}
#[getter]
fn intercept_(&self) -> PyResult<f64> {
let fitted = self
.fitted_model
.as_ref()
.ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
Ok(fitted.intercept().unwrap_or(0.0))
}
#[getter]
fn alpha_(&self) -> PyResult<f64> {
let fitted = self
.fitted_model
.as_ref()
.ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
fitted
.alpha()
.map_err(|e| PyValueError::new_err(format!("Failed to get alpha: {:?}", e)))
}
#[getter]
fn lambda_(&self) -> PyResult<f64> {
let fitted = self
.fitted_model
.as_ref()
.ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
fitted
.lambda()
.map_err(|e| PyValueError::new_err(format!("Failed to get lambda: {:?}", e)))
}
fn score(&self, x: PyReadonlyArray2<f64>, y: PyReadonlyArray1<f64>) -> PyResult<f64> {
let fitted = self
.fitted_model
.as_ref()
.ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
let x_array = pyarray_to_core_array2(x)?;
let y_array = pyarray_to_core_array1(y)?;
match fitted.score(&x_array, &y_array) {
Ok(score) => Ok(score),
Err(e) => Err(PyValueError::new_err(format!(
"Score calculation failed: {:?}",
e
))),
}
}
#[getter]
fn n_features_in_(&self) -> PyResult<usize> {
let fitted = self
.fitted_model
.as_ref()
.ok_or_else(|| PyValueError::new_err("Model not fitted. Call fit() first."))?;
let coef = fitted
.coef()
.map_err(|e| PyValueError::new_err(format!("Failed to get coefficients: {:?}", e)))?;
Ok(coef.len())
}
fn get_params(&self, py: Python<'_>, deep: Option<bool>) -> PyResult<Py<PyDict>> {
let _deep = deep.unwrap_or(true);
let dict = PyDict::new(py);
dict.set_item("max_iter", self.py_config.max_iter)?;
dict.set_item("tol", self.py_config.tol)?;
dict.set_item("alpha_init", self.py_config.alpha_init)?;
dict.set_item("lambda_init", self.py_config.lambda_init)?;
dict.set_item("fit_intercept", self.py_config.fit_intercept)?;
dict.set_item("compute_score", self.py_config.compute_score)?;
dict.set_item("copy_X", self.py_config.copy_x)?;
Ok(dict.into())
}
fn set_params(&mut self, kwargs: &Bound<'_, PyDict>) -> PyResult<()> {
if let Some(max_iter) = kwargs.get_item("max_iter")? {
let max_iter_val: usize = max_iter.extract()?;
if max_iter_val == 0 {
return Err(PyValueError::new_err("max_iter must be greater than 0"));
}
self.py_config.max_iter = max_iter_val;
}
if let Some(tol) = kwargs.get_item("tol")? {
let tol_val: f64 = tol.extract()?;
if tol_val <= 0.0 {
return Err(PyValueError::new_err("tol must be positive"));
}
self.py_config.tol = tol_val;
}
if let Some(alpha_init) = kwargs.get_item("alpha_init")? {
let alpha_init_val: f64 = alpha_init.extract()?;
if alpha_init_val <= 0.0 {
return Err(PyValueError::new_err("alpha_init must be positive"));
}
self.py_config.alpha_init = Some(alpha_init_val);
}
if let Some(lambda_init) = kwargs.get_item("lambda_init")? {
let lambda_init_val: f64 = lambda_init.extract()?;
if lambda_init_val <= 0.0 {
return Err(PyValueError::new_err("lambda_init must be positive"));
}
self.py_config.lambda_init = Some(lambda_init_val);
}
if let Some(fit_intercept) = kwargs.get_item("fit_intercept")? {
self.py_config.fit_intercept = fit_intercept.extract()?;
}
if let Some(compute_score) = kwargs.get_item("compute_score")? {
self.py_config.compute_score = compute_score.extract()?;
}
if let Some(copy_x) = kwargs.get_item("copy_X")? {
self.py_config.copy_x = copy_x.extract()?;
}
self.fitted_model = None;
Ok(())
}
fn __repr__(&self) -> String {
format!(
"BayesianRidge(max_iter={}, tol={}, alpha_init={:?}, lambda_init={:?}, fit_intercept={}, compute_score={}, copy_X={})",
self.py_config.max_iter,
self.py_config.tol,
self.py_config.alpha_init,
self.py_config.lambda_init,
self.py_config.fit_intercept,
self.py_config.compute_score,
self.py_config.copy_x
)
}
}