use super::module::PyModule;
use crate::{device::PyDevice, error::PyResult, py_result, tensor::PyTensor};
use pyo3::prelude::*;
use std::collections::HashMap;
use torsh_tensor::Tensor;
#[pyclass(name = "Linear", extends = PyModule)]
pub struct PyLinear {
weight: Tensor<f32>,
bias: Option<Tensor<f32>>,
in_features: usize,
out_features: usize,
has_bias: bool,
training: bool,
}
#[pymethods]
impl PyLinear {
#[new]
fn new(
in_features: usize,
out_features: usize,
bias: Option<bool>,
) -> PyResult<(Self, PyModule)> {
let has_bias = bias.unwrap_or(true);
let weight_shape = vec![out_features, in_features];
let weight = py_result!(torsh_tensor::creation::randn(&weight_shape))?.requires_grad_(true);
let bias = if has_bias {
let bias_shape = vec![out_features];
Some(py_result!(torsh_tensor::creation::zeros(&bias_shape))?.requires_grad_(true))
} else {
None
};
Ok((
Self {
weight,
bias,
in_features,
out_features,
has_bias,
training: true,
},
PyModule::new(),
))
}
fn forward(&self, input: &PyTensor) -> PyResult<PyTensor> {
let result = py_result!(input.tensor.matmul(&self.weight))?;
let result = if let Some(ref bias) = self.bias {
py_result!(result.add(bias))?
} else {
result
};
Ok(PyTensor { tensor: result })
}
fn parameters(&self) -> PyResult<Vec<PyTensor>> {
let mut params = Vec::new();
params.push(PyTensor {
tensor: self.weight.clone(),
});
if let Some(ref bias) = self.bias {
params.push(PyTensor {
tensor: bias.clone(),
});
}
Ok(params)
}
fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
let mut named_params = HashMap::new();
named_params.insert(
"weight".to_string(),
PyTensor {
tensor: self.weight.clone(),
},
);
if let Some(ref bias) = self.bias {
named_params.insert(
"bias".to_string(),
PyTensor {
tensor: bias.clone(),
},
);
}
Ok(named_params)
}
fn train(&mut self, mode: Option<bool>) {
self.training = mode.unwrap_or(true);
}
fn eval(&mut self) {
self.training = false;
}
fn to(&mut self, device: PyDevice) -> PyResult<()> {
self.weight = py_result!(self.weight.clone().to(device.device))?;
if let Some(ref bias) = self.bias {
self.bias = Some(py_result!(bias.clone().to(device.device))?);
}
Ok(())
}
fn zero_grad(&mut self) {
let _ = self.weight.zero_grad();
if let Some(ref mut bias) = self.bias {
let _ = bias.zero_grad();
}
}
fn __repr__(&self) -> String {
format!(
"Linear(in_features={}, out_features={}, bias={})",
self.in_features, self.out_features, self.has_bias
)
}
#[getter]
fn in_features(&self) -> usize {
self.in_features
}
#[getter]
fn out_features(&self) -> usize {
self.out_features
}
#[getter]
fn bias(&self) -> bool {
self.has_bias
}
fn training(&self) -> bool {
self.training
}
#[getter]
fn weight(&self) -> PyResult<PyTensor> {
Ok(PyTensor {
tensor: self.weight.clone(),
})
}
fn load_state_dict(&mut self, state_dict: HashMap<String, PyTensor>) -> PyResult<()> {
if let Some(weight_tensor) = state_dict.get("weight") {
self.weight = weight_tensor.tensor.clone();
}
if self.has_bias {
if let Some(bias_tensor) = state_dict.get("bias") {
self.bias = Some(bias_tensor.tensor.clone());
}
}
Ok(())
}
}