use numpy::{PyArray1, PyReadonlyArray1};
use pyo3::prelude::*;
use crate::loss::{
sigmoid, softmax, BinaryLogLoss, LossFunction, MseLoss, MultiClassLogLoss, PseudoHuberLoss,
};
#[pyclass(name = "MseLoss")]
#[derive(Clone)]
pub struct PyMseLoss {
inner: MseLoss,
}
#[pymethods]
impl PyMseLoss {
#[new]
fn new() -> Self {
Self {
inner: MseLoss::new(),
}
}
fn loss(&self, target: f32, prediction: f32) -> f32 {
self.inner.loss(target, prediction)
}
fn gradient(&self, target: f32, prediction: f32) -> f32 {
self.inner.gradient(target, prediction)
}
fn hessian(&self, target: f32, prediction: f32) -> f32 {
self.inner.hessian(target, prediction)
}
fn gradient_hessian(&self, target: f32, prediction: f32) -> (f32, f32) {
self.inner.gradient_hessian(target, prediction)
}
#[getter]
fn name(&self) -> &'static str {
self.inner.name()
}
fn __repr__(&self) -> &'static str {
"MseLoss()"
}
}
#[pyclass(name = "PseudoHuberLoss")]
#[derive(Clone)]
pub struct PyPseudoHuberLoss {
inner: PseudoHuberLoss,
}
#[pymethods]
impl PyPseudoHuberLoss {
#[new]
#[pyo3(signature = (delta=1.0))]
fn new(delta: f32) -> PyResult<Self> {
if delta <= 0.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"delta must be positive",
));
}
Ok(Self {
inner: PseudoHuberLoss::new(delta),
})
}
#[getter]
fn delta(&self) -> f32 {
self.inner.delta()
}
fn loss(&self, target: f32, prediction: f32) -> f32 {
self.inner.loss(target, prediction)
}
fn gradient(&self, target: f32, prediction: f32) -> f32 {
self.inner.gradient(target, prediction)
}
fn hessian(&self, target: f32, prediction: f32) -> f32 {
self.inner.hessian(target, prediction)
}
fn gradient_hessian(&self, target: f32, prediction: f32) -> (f32, f32) {
self.inner.gradient_hessian(target, prediction)
}
#[getter]
fn name(&self) -> &'static str {
self.inner.name()
}
fn __repr__(&self) -> String {
format!("PseudoHuberLoss(delta={})", self.inner.delta())
}
}
#[pyclass(name = "BinaryLogLoss")]
#[derive(Clone)]
pub struct PyBinaryLogLoss {
inner: BinaryLogLoss,
}
#[pymethods]
impl PyBinaryLogLoss {
#[new]
fn new() -> Self {
Self {
inner: BinaryLogLoss::new(),
}
}
#[staticmethod]
fn with_eps(eps: f32) -> PyResult<Self> {
if eps <= 0.0 {
return Err(pyo3::exceptions::PyValueError::new_err(
"eps must be positive",
));
}
Ok(Self {
inner: BinaryLogLoss::with_eps(eps),
})
}
fn loss(&self, target: f32, prediction: f32) -> f32 {
self.inner.loss(target, prediction)
}
fn gradient(&self, target: f32, prediction: f32) -> f32 {
self.inner.gradient(target, prediction)
}
fn hessian(&self, target: f32, prediction: f32) -> f32 {
self.inner.hessian(target, prediction)
}
fn gradient_hessian(&self, target: f32, prediction: f32) -> (f32, f32) {
self.inner.gradient_hessian(target, prediction)
}
fn to_probability(&self, raw: f32) -> f32 {
self.inner.to_probability(raw)
}
#[pyo3(signature = (prob, threshold=0.5))]
fn to_class(&self, prob: f32, threshold: f32) -> u32 {
self.inner.to_class(prob, threshold)
}
#[getter]
fn name(&self) -> &'static str {
self.inner.name()
}
fn __repr__(&self) -> &'static str {
"BinaryLogLoss()"
}
}
#[pyclass(name = "MultiClassLogLoss")]
#[derive(Clone)]
pub struct PyMultiClassLogLoss {
inner: MultiClassLogLoss,
}
#[pymethods]
impl PyMultiClassLogLoss {
#[new]
fn new(num_classes: usize) -> PyResult<Self> {
if num_classes < 2 {
return Err(pyo3::exceptions::PyValueError::new_err(
"num_classes must be >= 2",
));
}
Ok(Self {
inner: MultiClassLogLoss::new(num_classes),
})
}
#[getter]
fn num_classes(&self) -> usize {
self.inner.num_classes()
}
fn gradient_hessian_for_class<'py>(
&self,
target_class: usize,
class_idx: usize,
raw_preds: PyReadonlyArray1<'py, f32>,
) -> PyResult<(f32, f32)> {
let preds = raw_preds.as_array();
if preds.len() != self.inner.num_classes() {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"raw_preds length {} doesn't match num_classes {}",
preds.len(),
self.inner.num_classes()
)));
}
if target_class >= self.inner.num_classes() {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"target_class {} >= num_classes {}",
target_class,
self.inner.num_classes()
)));
}
if class_idx >= self.inner.num_classes() {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"class_idx {} >= num_classes {}",
class_idx,
self.inner.num_classes()
)));
}
let preds_vec: Vec<f32> = preds.to_vec();
Ok(self
.inner
.gradient_hessian_for_class(target_class, class_idx, &preds_vec))
}
fn gradient_hessian_all_classes<'py>(
&self,
target_class: usize,
raw_preds: PyReadonlyArray1<'py, f32>,
) -> PyResult<(Vec<f32>, Vec<f32>)> {
let preds = raw_preds.as_array();
if preds.len() != self.inner.num_classes() {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"raw_preds length {} doesn't match num_classes {}",
preds.len(),
self.inner.num_classes()
)));
}
if target_class >= self.inner.num_classes() {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"target_class {} >= num_classes {}",
target_class,
self.inner.num_classes()
)));
}
let preds_vec: Vec<f32> = preds.to_vec();
Ok(self
.inner
.gradient_hessian_all_classes(target_class, &preds_vec))
}
#[getter]
fn name(&self) -> &'static str {
"multi_class_log_loss"
}
fn __repr__(&self) -> String {
format!(
"MultiClassLogLoss(num_classes={})",
self.inner.num_classes()
)
}
}
#[pyfunction]
fn py_sigmoid(x: f32) -> f32 {
sigmoid(x)
}
#[pyfunction]
fn sigmoid_batch<'py>(py: Python<'py>, x: PyReadonlyArray1<'py, f32>) -> Bound<'py, PyArray1<f32>> {
let x_arr = x.as_array();
let result: Vec<f32> = x_arr.iter().map(|&v| sigmoid(v)).collect();
PyArray1::from_vec(py, result)
}
#[pyfunction]
fn py_softmax<'py>(
py: Python<'py>,
raw_scores: PyReadonlyArray1<'py, f32>,
) -> Bound<'py, PyArray1<f32>> {
let scores = raw_scores.as_array();
let scores_vec: Vec<f32> = scores.to_vec();
let result = softmax(&scores_vec);
PyArray1::from_vec(py, result)
}
pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMseLoss>()?;
m.add_class::<PyPseudoHuberLoss>()?;
m.add_class::<PyBinaryLogLoss>()?;
m.add_class::<PyMultiClassLogLoss>()?;
m.add_function(wrap_pyfunction!(py_sigmoid, m)?)?;
m.add_function(wrap_pyfunction!(sigmoid_batch, m)?)?;
m.add_function(wrap_pyfunction!(py_softmax, m)?)?;
Ok(())
}