use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3::types::{PyModule, PyModuleMethods};
use crate::{tensor::PyTensor, error::PyResult, py_result};
#[pyfunction]
fn relu(input: &PyTensor, inplace: Option<bool>) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.relu())?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn relu6(input: &PyTensor, inplace: Option<bool>) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.clamp(0.0, 6.0))?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn leaky_relu(input: &PyTensor, negative_slope: Option<f32>, inplace: Option<bool>) -> PyResult<PyTensor> {
let slope = negative_slope.unwrap_or(0.01);
let positive_part = py_result!(input.tensor.clamp_min(0.0))?;
let negative_part = py_result!(input.tensor.clamp_max(0.0))?;
let scaled_negative = py_result!(negative_part.mul_scalar(slope))?;
let result = py_result!(positive_part.add(&scaled_negative))?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn elu(input: &PyTensor, alpha: Option<f32>, inplace: Option<bool>) -> PyResult<PyTensor> {
let alpha = alpha.unwrap_or(1.0);
let result = py_result!(input.tensor.relu())?; Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn selu(input: &PyTensor, inplace: Option<bool>) -> PyResult<PyTensor> {
let alpha = 1.6732632423543772848170429916717;
let scale = 1.0507009873554804934193349852946;
let result = py_result!(input.tensor.relu())?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn gelu(input: &PyTensor, approximate: Option<String>) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.tanh())?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn silu(input: &PyTensor, inplace: Option<bool>) -> PyResult<PyTensor> {
let sigmoid_result = py_result!(input.tensor.sigmoid())?;
let result = py_result!(input.tensor.mul(&sigmoid_result))?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn mish(input: &PyTensor, inplace: Option<bool>) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.tanh())?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn sigmoid(input: &PyTensor) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.sigmoid())?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn tanh(input: &PyTensor) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.tanh())?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn softmax(input: &PyTensor, dim: i32, dtype: Option<String>) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.softmax(dim))?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn log_softmax(input: &PyTensor, dim: i32, dtype: Option<String>) -> PyResult<PyTensor> {
let softmax_result = py_result!(input.tensor.softmax(dim))?;
let result = py_result!(softmax_result.log())?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn softplus(input: &PyTensor, beta: Option<f32>, threshold: Option<f32>) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.relu())?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn softsign(input: &PyTensor) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.tanh())?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn mse_loss(input: &PyTensor, target: &PyTensor, reduction: Option<String>) -> PyResult<PyTensor> {
let diff = py_result!(input.tensor.sub(&target.tensor))?;
let squared = py_result!(diff.mul(&diff))?;
let result = match reduction.as_deref() {
Some("mean") | None => py_result!(squared.mean(None, false))?,
Some("sum") => py_result!(squared.sum())?,
Some("none") => squared,
_ => return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>("Invalid reduction")),
};
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn cross_entropy(
input: &PyTensor,
target: &PyTensor,
weight: Option<&PyTensor>,
size_average: Option<bool>,
ignore_index: Option<i64>,
reduce: Option<bool>,
reduction: Option<String>,
label_smoothing: Option<f32>,
) -> PyResult<PyTensor> {
let log_probs = py_result!(input.tensor.softmax(-1))?;
let log_probs = py_result!(log_probs.log())?;
let result = py_result!(log_probs.mean(None, false))?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn l1_loss(input: &PyTensor, target: &PyTensor, reduction: Option<String>) -> PyResult<PyTensor> {
let diff = py_result!(input.tensor.sub(&target.tensor))?;
let abs_diff = py_result!(diff.abs())?;
let result = match reduction.as_deref() {
Some("mean") | None => py_result!(abs_diff.mean(None, false))?,
Some("sum") => py_result!(abs_diff.sum())?,
Some("none") => abs_diff,
_ => return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>("Invalid reduction")),
};
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn binary_cross_entropy(
input: &PyTensor,
target: &PyTensor,
weight: Option<&PyTensor>,
size_average: Option<bool>,
reduce: Option<bool>,
reduction: Option<String>,
) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.mean(None, false))?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn max_pool2d(
input: &PyTensor,
kernel_size: (usize, usize),
stride: Option<(usize, usize)>,
padding: Option<(usize, usize)>,
dilation: Option<(usize, usize)>,
ceil_mode: Option<bool>,
return_indices: Option<bool>,
) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.max_pool2d(kernel_size, stride, padding))?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn avg_pool2d(
input: &PyTensor,
kernel_size: (usize, usize),
stride: Option<(usize, usize)>,
padding: Option<(usize, usize)>,
ceil_mode: Option<bool>,
count_include_pad: Option<bool>,
divisor_override: Option<usize>,
) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.avg_pool2d(kernel_size, stride, padding))?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn adaptive_avg_pool2d(input: &PyTensor, output_size: (usize, usize)) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.adaptive_avg_pool2d(output_size))?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn adaptive_max_pool2d(input: &PyTensor, output_size: (usize, usize), return_indices: Option<bool>) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.adaptive_max_pool2d(output_size))?;
Ok(PyTensor { tensor: result })
}
#[pyfunction]
fn batch_norm(
input: &PyTensor,
running_mean: Option<&PyTensor>,
running_var: Option<&PyTensor>,
weight: Option<&PyTensor>,
bias: Option<&PyTensor>,
training: Option<bool>,
momentum: Option<f32>,
eps: Option<f32>,
) -> PyResult<PyTensor> {
Ok(PyTensor {
tensor: input.tensor.clone(),
})
}
#[pyfunction]
fn layer_norm(
input: &PyTensor,
normalized_shape: Vec<usize>,
weight: Option<&PyTensor>,
bias: Option<&PyTensor>,
eps: Option<f32>,
) -> PyResult<PyTensor> {
Ok(PyTensor {
tensor: input.tensor.clone(),
})
}
#[pyfunction]
fn dropout(input: &PyTensor, p: Option<f32>, training: Option<bool>, inplace: Option<bool>) -> PyResult<PyTensor> {
let training = training.unwrap_or(true);
if !training {
Ok(PyTensor {
tensor: input.tensor.clone(),
})
} else {
let p = p.unwrap_or(0.5);
let scale = 1.0 / (1.0 - p);
let result = py_result!(input.tensor.mul_scalar(scale))?;
Ok(PyTensor { tensor: result })
}
}
#[pyfunction]
fn linear(input: &PyTensor, weight: &PyTensor, bias: Option<&PyTensor>) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.matmul(&weight.tensor))?;
if let Some(b) = bias {
let result = py_result!(result.add(&b.tensor))?;
Ok(PyTensor { tensor: result })
} else {
Ok(PyTensor { tensor: result })
}
}
#[pyfunction]
fn conv2d(
input: &PyTensor,
weight: &PyTensor,
bias: Option<&PyTensor>,
stride: Option<(usize, usize)>,
padding: Option<(usize, usize)>,
dilation: Option<(usize, usize)>,
groups: Option<usize>,
) -> PyResult<PyTensor> {
let stride = stride.unwrap_or((1, 1));
let padding = padding.unwrap_or((0, 0));
let dilation = dilation.unwrap_or((1, 1));
let groups = groups.unwrap_or(1);
let result = py_result!(input.tensor.conv2d(&weight.tensor, bias.map(|b| &b.tensor), padding, stride, dilation, groups))?;
Ok(PyTensor { tensor: result })
}
pub fn register_functional_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(relu, m)?)?;
m.add_function(wrap_pyfunction!(relu6, m)?)?;
m.add_function(wrap_pyfunction!(leaky_relu, m)?)?;
m.add_function(wrap_pyfunction!(elu, m)?)?;
m.add_function(wrap_pyfunction!(selu, m)?)?;
m.add_function(wrap_pyfunction!(gelu, m)?)?;
m.add_function(wrap_pyfunction!(silu, m)?)?;
m.add_function(wrap_pyfunction!(mish, m)?)?;
m.add_function(wrap_pyfunction!(sigmoid, m)?)?;
m.add_function(wrap_pyfunction!(tanh, m)?)?;
m.add_function(wrap_pyfunction!(softmax, m)?)?;
m.add_function(wrap_pyfunction!(log_softmax, m)?)?;
m.add_function(wrap_pyfunction!(softplus, m)?)?;
m.add_function(wrap_pyfunction!(softsign, m)?)?;
m.add_function(wrap_pyfunction!(mse_loss, m)?)?;
m.add_function(wrap_pyfunction!(cross_entropy, m)?)?;
m.add_function(wrap_pyfunction!(l1_loss, m)?)?;
m.add_function(wrap_pyfunction!(binary_cross_entropy, m)?)?;
m.add_function(wrap_pyfunction!(max_pool2d, m)?)?;
m.add_function(wrap_pyfunction!(avg_pool2d, m)?)?;
m.add_function(wrap_pyfunction!(adaptive_avg_pool2d, m)?)?;
m.add_function(wrap_pyfunction!(adaptive_max_pool2d, m)?)?;
m.add_function(wrap_pyfunction!(batch_norm, m)?)?;
m.add_function(wrap_pyfunction!(layer_norm, m)?)?;
m.add_function(wrap_pyfunction!(dropout, m)?)?;
m.add_function(wrap_pyfunction!(linear, m)?)?;
m.add_function(wrap_pyfunction!(conv2d, m)?)?;
Ok(())
}