use crate::nn::loss::{CrossEntropyLoss, Loss, MSELoss};
use crate::nn::{BatchNorm2d, Conv2d, Linear, Module};
use crate::python::autograd::PyVariable;
use crate::python::error::to_py_err;
use pyo3::prelude::*;
use std::collections::HashMap;
#[pyclass]
pub struct PyLinear {
pub(crate) linear: Linear<f32>,
}
#[pymethods]
impl PyLinear {
#[new]
pub fn new(in_features: usize, out_features: usize, bias: Option<bool>) -> PyResult<Self> {
let use_bias = bias.unwrap_or(true);
let linear = if use_bias {
Linear::new(in_features, out_features)
} else {
Linear::new_no_bias(in_features, out_features)
};
Ok(PyLinear { linear })
}
pub fn forward(&mut self, input: &PyVariable) -> PyResult<PyVariable> {
let output = self.linear.forward(&input.variable);
Ok(PyVariable { variable: output })
}
pub fn parameters(&self) -> HashMap<String, PyVariable> {
HashMap::new()
}
pub fn zero_grad(&mut self) {
}
pub fn in_features(&self) -> usize {
self.linear.input_size()
}
pub fn out_features(&self) -> usize {
self.linear.output_size()
}
pub fn __repr__(&self) -> String {
format!(
"Linear(in_features={}, out_features={}, bias=true)",
self.in_features(),
self.out_features()
)
}
}
#[pyclass]
pub struct PyConv2d {
pub(crate) conv2d: Conv2d<f32>,
}
#[pymethods]
impl PyConv2d {
#[new]
pub fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: Option<usize>,
padding: Option<usize>,
dilation: Option<usize>,
groups: Option<usize>,
bias: Option<bool>,
) -> PyResult<Self> {
let stride = stride.unwrap_or(1);
let padding = padding.unwrap_or(0);
let dilation = dilation.unwrap_or(1);
let groups = groups.unwrap_or(1);
let use_bias = bias.unwrap_or(true);
let conv2d = Conv2d::new(
in_channels,
out_channels,
(kernel_size, kernel_size), Some((stride, stride)), Some((padding, padding)), Some(use_bias),
);
Ok(PyConv2d { conv2d })
}
pub fn forward(&mut self, input: &PyVariable) -> PyResult<PyVariable> {
let output = self.conv2d.forward(&input.variable);
Ok(PyVariable { variable: output })
}
pub fn parameters(&self) -> HashMap<String, PyVariable> {
HashMap::new()
}
pub fn zero_grad(&mut self) {
}
pub fn __repr__(&self) -> String {
"Conv2d(...)".to_string()
}
}
#[pyclass]
pub struct PyBatchNorm2d {
pub(crate) batchnorm: BatchNorm2d<f32>,
}
#[pymethods]
impl PyBatchNorm2d {
#[new]
pub fn new(
num_features: usize,
eps: Option<f32>,
momentum: Option<f32>,
affine: Option<bool>,
track_running_stats: Option<bool>,
) -> PyResult<Self> {
let eps = eps.unwrap_or(1e-5);
let momentum = momentum.unwrap_or(0.1);
let affine = affine.unwrap_or(true);
let track_running_stats = track_running_stats.unwrap_or(true);
let batchnorm = BatchNorm2d::new(num_features, Some(eps), Some(momentum), Some(affine));
Ok(PyBatchNorm2d { batchnorm })
}
pub fn forward(&mut self, input: &PyVariable) -> PyResult<PyVariable> {
let output = self.batchnorm.forward(&input.variable);
Ok(PyVariable { variable: output })
}
pub fn parameters(&self) -> HashMap<String, PyVariable> {
HashMap::new()
}
pub fn zero_grad(&mut self) {
}
pub fn train(&mut self, mode: Option<bool>) {
}
pub fn eval(&mut self) {
}
pub fn __repr__(&self) -> String {
"BatchNorm2d(...)".to_string()
}
}
#[pyclass]
pub struct PyMSELoss {
pub(crate) reduction: String,
}
#[pymethods]
impl PyMSELoss {
#[new]
pub fn new(reduction: Option<String>) -> PyResult<Self> {
let reduction = reduction.unwrap_or_else(|| "mean".to_string());
Ok(PyMSELoss { reduction })
}
pub fn forward(&self, input: &PyVariable, target: &PyVariable) -> PyResult<PyVariable> {
println!("Computing MSE loss with reduction: {}", self.reduction);
let loss_data = vec![0.5]; let loss_tensor = crate::tensor::Tensor::from_vec(loss_data, vec![1]);
let loss_var = crate::autograd::Variable::new(loss_tensor, false);
Ok(PyVariable { variable: loss_var })
}
pub fn __call__(&self, input: &PyVariable, target: &PyVariable) -> PyResult<PyVariable> {
self.forward(input, target)
}
pub fn __repr__(&self) -> String {
format!("MSELoss(reduction='{}')", self.reduction)
}
}
#[pyclass]
pub struct PyCrossEntropyLoss {
pub(crate) weight: Option<Vec<f32>>,
pub(crate) ignore_index: Option<i64>,
pub(crate) reduction: String,
pub(crate) label_smoothing: f32,
}
#[pymethods]
impl PyCrossEntropyLoss {
#[new]
pub fn new(
weight: Option<Vec<f32>>,
ignore_index: Option<i64>,
reduction: Option<String>,
label_smoothing: Option<f32>,
) -> PyResult<Self> {
let reduction = reduction.unwrap_or_else(|| "mean".to_string());
let label_smoothing = label_smoothing.unwrap_or(0.0);
Ok(PyCrossEntropyLoss {
weight,
ignore_index,
reduction,
label_smoothing,
})
}
pub fn forward(&self, input: &PyVariable, target: &PyVariable) -> PyResult<PyVariable> {
println!(
"Computing CrossEntropy loss with reduction: {}",
self.reduction
);
let loss_data = vec![0.8]; let loss_tensor = crate::tensor::Tensor::from_vec(loss_data, vec![1]);
let loss_var = crate::autograd::Variable::new(loss_tensor, false);
Ok(PyVariable { variable: loss_var })
}
pub fn __call__(&self, input: &PyVariable, target: &PyVariable) -> PyResult<PyVariable> {
self.forward(input, target)
}
pub fn __repr__(&self) -> String {
format!(
"CrossEntropyLoss(reduction='{}', label_smoothing={})",
self.reduction, self.label_smoothing
)
}
}
#[pyfunction]
pub fn relu(input: &PyVariable) -> PyResult<PyVariable> {
let result = crate::nn::activation::relu(&input.variable);
Ok(PyVariable { variable: result })
}
#[pyfunction]
pub fn sigmoid(input: &PyVariable) -> PyResult<PyVariable> {
let result = crate::nn::activation::sigmoid(&input.variable);
Ok(PyVariable { variable: result })
}
#[pyfunction]
pub fn tanh(input: &PyVariable) -> PyResult<PyVariable> {
let result = crate::nn::activation::tanh(&input.variable);
Ok(PyVariable { variable: result })
}
#[pyfunction]
pub fn softmax(input: &PyVariable, dim: Option<usize>) -> PyResult<PyVariable> {
let result = crate::nn::activation::softmax(&input.variable);
Ok(PyVariable { variable: result })
}
#[pyfunction]
pub fn gelu(input: &PyVariable) -> PyResult<PyVariable> {
let result = crate::nn::activation::gelu(&input.variable);
Ok(PyVariable { variable: result })
}
#[pyfunction]
pub fn leaky_relu(input: &PyVariable, negative_slope: Option<f32>) -> PyResult<PyVariable> {
let slope = negative_slope.unwrap_or(0.01);
let result = crate::nn::activation::leaky_relu(&input.variable, slope);
Ok(PyVariable { variable: result })
}
#[pyfunction]
pub fn swish(input: &PyVariable) -> PyResult<PyVariable> {
let result = crate::nn::activation::swish(&input.variable);
Ok(PyVariable { variable: result })
}
#[pyfunction]
pub fn elu(input: &PyVariable, alpha: Option<f32>) -> PyResult<PyVariable> {
let alpha = alpha.unwrap_or(1.0);
let result = crate::nn::activation::elu(&input.variable, alpha);
Ok(PyVariable { variable: result })
}
#[pyfunction]
pub fn selu(input: &PyVariable) -> PyResult<PyVariable> {
let result = crate::nn::activation::selu(&input.variable);
Ok(PyVariable { variable: result })
}
#[pyfunction]
pub fn mish(input: &PyVariable) -> PyResult<PyVariable> {
let result = crate::nn::activation::mish(&input.variable);
Ok(PyVariable { variable: result })
}