use super::module::PyModule;
use crate::{error::PyResult, py_result, tensor::PyTensor};
use pyo3::prelude::*;
#[pyclass(name = "MSELoss", extends = PyModule)]
pub struct PyMSELoss {
reduction: String,
training: bool,
}
#[pymethods]
impl PyMSELoss {
#[new]
fn new(reduction: Option<String>) -> (Self, PyModule) {
let reduction = reduction.unwrap_or_else(|| "mean".to_string());
(
Self {
reduction,
training: true,
},
PyModule::new(),
)
}
fn forward(&self, input: &PyTensor, target: &PyTensor) -> PyResult<PyTensor> {
let diff = py_result!(input.tensor.sub(&target.tensor))?;
let squared = py_result!(diff.pow(2.0))?;
let result = py_result!(squared.mean(None, false))?;
Ok(PyTensor { tensor: result })
}
fn __repr__(&self) -> String {
format!("MSELoss(reduction='{}')", self.reduction)
}
fn train(&mut self, mode: Option<bool>) {
self.training = mode.unwrap_or(true);
}
fn eval(&mut self) {
self.training = false;
}
fn training(&self) -> bool {
self.training
}
}
#[pyclass(name = "CrossEntropyLoss", extends = PyModule)]
pub struct PyCrossEntropyLoss {
reduction: String,
training: bool,
}
#[pymethods]
impl PyCrossEntropyLoss {
#[new]
fn new(reduction: Option<String>) -> (Self, PyModule) {
let reduction = reduction.unwrap_or_else(|| "mean".to_string());
(
Self {
reduction,
training: true,
},
PyModule::new(),
)
}
fn forward(&self, input: &PyTensor, target: &PyTensor) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.sub(&target.tensor))?;
Ok(PyTensor { tensor: result })
}
fn __repr__(&self) -> String {
format!("CrossEntropyLoss(reduction='{}')", self.reduction)
}
fn train(&mut self, mode: Option<bool>) {
self.training = mode.unwrap_or(true);
}
fn eval(&mut self) {
self.training = false;
}
fn training(&self) -> bool {
self.training
}
}
#[pyclass(name = "BCELoss", extends = PyModule)]
pub struct PyBCELoss {
reduction: String,
training: bool,
}
#[pymethods]
impl PyBCELoss {
#[new]
fn new(reduction: Option<String>) -> (Self, PyModule) {
let reduction = reduction.unwrap_or_else(|| "mean".to_string());
(
Self {
reduction,
training: true,
},
PyModule::new(),
)
}
fn forward(&self, input: &PyTensor, target: &PyTensor) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.sub(&target.tensor))?;
Ok(PyTensor { tensor: result })
}
fn __repr__(&self) -> String {
format!("BCELoss(reduction='{}')", self.reduction)
}
fn train(&mut self, mode: Option<bool>) {
self.training = mode.unwrap_or(true);
}
fn eval(&mut self) {
self.training = false;
}
fn training(&self) -> bool {
self.training
}
}