use super::module::PyModule;
use crate::{error::PyResult, py_result, tensor::PyTensor};
use pyo3::prelude::*;
use pyo3::types::PyAny;
use std::collections::HashMap;
use torsh_tensor::Tensor;
#[pyclass(name = "Conv2d", extends = PyModule)]
pub struct PyConv2d {
weight: Tensor<f32>,
bias: Option<Tensor<f32>>,
in_channels: usize,
out_channels: usize,
kernel_size: (usize, usize),
stride: (usize, usize),
padding: (usize, usize),
dilation: (usize, usize),
groups: usize,
has_bias: bool,
training: bool,
}
#[pymethods]
impl PyConv2d {
#[new]
fn new(
in_channels: usize,
out_channels: usize,
kernel_size: Py<PyAny>,
stride: Option<Py<PyAny>>,
padding: Option<Py<PyAny>>,
dilation: Option<Py<PyAny>>,
groups: Option<usize>,
bias: Option<bool>,
) -> PyResult<(Self, PyModule)> {
let has_bias = bias.unwrap_or(true);
let groups = groups.unwrap_or(1);
let kernel_size = Python::attach(|py| -> PyResult<(usize, usize)> {
if let Ok(size) = kernel_size.extract::<usize>(py) {
Ok((size, size))
} else if let Ok(tuple) = kernel_size.extract::<(usize, usize)>(py) {
Ok(tuple)
} else {
Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"kernel_size must be an integer or tuple of integers",
))
}
})?;
let stride = if let Some(stride_obj) = stride {
Python::attach(|py| -> PyResult<(usize, usize)> {
if let Ok(stride) = stride_obj.extract::<usize>(py) {
Ok((stride, stride))
} else if let Ok(tuple) = stride_obj.extract::<(usize, usize)>(py) {
Ok(tuple)
} else {
Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"stride must be an integer or tuple of integers",
))
}
})?
} else {
(1, 1)
};
let padding = if let Some(padding_obj) = padding {
Python::attach(|py| -> PyResult<(usize, usize)> {
if let Ok(padding) = padding_obj.extract::<usize>(py) {
Ok((padding, padding))
} else if let Ok(tuple) = padding_obj.extract::<(usize, usize)>(py) {
Ok(tuple)
} else {
Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"padding must be an integer or tuple of integers",
))
}
})?
} else {
(0, 0)
};
let dilation = if let Some(dilation_obj) = dilation {
Python::attach(|py| -> PyResult<(usize, usize)> {
if let Ok(dilation) = dilation_obj.extract::<usize>(py) {
Ok((dilation, dilation))
} else if let Ok(tuple) = dilation_obj.extract::<(usize, usize)>(py) {
Ok(tuple)
} else {
Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"dilation must be an integer or tuple of integers",
))
}
})?
} else {
(1, 1)
};
let weight_shape = vec![
out_channels,
in_channels / groups,
kernel_size.0,
kernel_size.1,
];
let weight = py_result!(torsh_tensor::creation::randn(&weight_shape))?.requires_grad_(true);
let bias = if has_bias {
let bias_shape = vec![out_channels];
Some(py_result!(torsh_tensor::creation::zeros(&bias_shape))?.requires_grad_(true))
} else {
None
};
Ok((
Self {
weight,
bias,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
has_bias,
training: true,
},
PyModule::new(),
))
}
fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.conv2d(
&self.weight,
self.bias.as_ref(),
self.stride,
self.padding,
self.dilation,
self.groups
))?;
Ok(PyTensor { tensor: result })
}
fn parameters(&self) -> PyResult<Vec<PyTensor>> {
let mut params = vec![PyTensor {
tensor: self.weight.clone(),
}];
if let Some(ref bias) = self.bias {
params.push(PyTensor {
tensor: bias.clone(),
});
}
Ok(params)
}
fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
let mut params = HashMap::new();
params.insert(
"weight".to_string(),
PyTensor {
tensor: self.weight.clone(),
},
);
if let Some(ref bias) = self.bias {
params.insert(
"bias".to_string(),
PyTensor {
tensor: bias.clone(),
},
);
}
Ok(params)
}
fn train(&mut self, mode: Option<bool>) -> PyResult<()> {
self.training = mode.unwrap_or(true);
Ok(())
}
fn eval(&mut self) -> PyResult<()> {
self.training = false;
Ok(())
}
fn extra_repr(&self) -> String {
let bias_str = if self.has_bias {
"bias=True"
} else {
"bias=False"
};
format!(
"{}, {}, kernel_size={:?}, stride={:?}, padding={:?}, dilation={:?}, groups={}, {}",
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
self.groups,
bias_str
)
}
fn __repr__(&self) -> String {
format!("Conv2d({})", self.extra_repr())
}
}
#[pyclass(name = "Conv1d", extends = PyModule)]
pub struct PyConv1d {
weight: Tensor<f32>,
bias: Option<Tensor<f32>>,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
groups: usize,
has_bias: bool,
training: bool,
}
#[pymethods]
impl PyConv1d {
#[new]
fn new(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: Option<usize>,
padding: Option<usize>,
dilation: Option<usize>,
groups: Option<usize>,
bias: Option<bool>,
) -> PyResult<(Self, PyModule)> {
let has_bias = bias.unwrap_or(true);
let stride = stride.unwrap_or(1);
let padding = padding.unwrap_or(0);
let dilation = dilation.unwrap_or(1);
let groups = groups.unwrap_or(1);
let weight_shape = vec![out_channels, in_channels / groups, kernel_size];
let weight = py_result!(torsh_tensor::creation::randn(&weight_shape))?.requires_grad_(true);
let bias = if has_bias {
let bias_shape = vec![out_channels];
Some(py_result!(torsh_tensor::creation::zeros(&bias_shape))?.requires_grad_(true))
} else {
None
};
Ok((
Self {
weight,
bias,
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
has_bias,
training: true,
},
PyModule::new(),
))
}
fn forward(&mut self, input: &PyTensor) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.conv1d(
&self.weight,
self.bias.as_ref(),
self.stride,
self.padding,
self.dilation,
self.groups
))?;
Ok(PyTensor { tensor: result })
}
fn parameters(&self) -> PyResult<Vec<PyTensor>> {
let mut params = vec![PyTensor {
tensor: self.weight.clone(),
}];
if let Some(ref bias) = self.bias {
params.push(PyTensor {
tensor: bias.clone(),
});
}
Ok(params)
}
fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
let mut params = HashMap::new();
params.insert(
"weight".to_string(),
PyTensor {
tensor: self.weight.clone(),
},
);
if let Some(ref bias) = self.bias {
params.insert(
"bias".to_string(),
PyTensor {
tensor: bias.clone(),
},
);
}
Ok(params)
}
fn train(&mut self, mode: Option<bool>) -> PyResult<()> {
self.training = mode.unwrap_or(true);
Ok(())
}
fn eval(&mut self) -> PyResult<()> {
self.training = false;
Ok(())
}
fn __repr__(&self) -> String {
let bias_str = if self.has_bias {
"bias=True"
} else {
"bias=False"
};
format!(
"Conv1d({}, {}, kernel_size={}, stride={}, padding={}, dilation={}, groups={}, {})",
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.padding,
self.dilation,
self.groups,
bias_str
)
}
}