use super::module::PyModule;
use crate::{device::PyDevice, error::PyResult, tensor::PyTensor};
use pyo3::prelude::*;
use pyo3::types::PyAny;
use std::collections::HashMap;
#[pyclass(name = "Sequential", extends = PyModule)]
pub struct PySequential {
modules: Vec<Py<PyAny>>,
training: bool,
}
#[pymethods]
impl PySequential {
#[new]
fn new(modules: Option<Vec<Py<PyAny>>>) -> (Self, PyModule) {
let modules = modules.unwrap_or_default();
(
Self {
modules,
training: true,
},
PyModule::new(),
)
}
fn add_module(&mut self, _name: &str, module: Py<PyAny>) {
self.modules.push(module);
}
fn forward(&self, mut input: PyTensor) -> PyResult<PyTensor> {
Python::attach(|py| {
for module in &self.modules {
if let Ok(forward_method) = module.getattr(py, "forward") {
let result = forward_method.call1(py, (input.clone(),))?;
input = result.extract::<PyTensor>(py)?;
} else {
let result = module.call1(py, (input.clone(),))?;
input = result.extract::<PyTensor>(py)?;
}
}
Ok(input)
})
}
fn parameters(&self) -> PyResult<Vec<PyTensor>> {
let mut all_params = Vec::new();
Python::attach(|py| {
for module in &self.modules {
if let Ok(params_method) = module.getattr(py, "parameters") {
let params_result = params_method.call0(py)?;
if let Ok(params) = params_result.extract::<Vec<PyTensor>>(py) {
all_params.extend(params);
}
}
}
Ok(all_params)
})
}
fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
let mut all_named_params = HashMap::new();
Python::attach(|py| {
for (i, module) in self.modules.iter().enumerate() {
if let Ok(named_params_method) = module.getattr(py, "named_parameters") {
let named_params_result = named_params_method.call0(py)?;
if let Ok(named_params) =
named_params_result.extract::<HashMap<String, PyTensor>>(py)
{
for (name, param) in named_params {
all_named_params.insert(format!("{}.{}", i, name), param);
}
}
}
}
Ok(all_named_params)
})
}
fn train(&mut self, mode: Option<bool>) {
let mode = mode.unwrap_or(true);
self.training = mode;
Python::attach(|py| {
for module in &self.modules {
if let Ok(train_method) = module.getattr(py, "train") {
let _ = train_method.call1(py, (mode,));
}
}
});
}
fn eval(&mut self) {
self.training = false;
Python::attach(|py| {
for module in &self.modules {
if let Ok(eval_method) = module.getattr(py, "eval") {
let _ = eval_method.call0(py);
}
}
});
}
fn to(&mut self, device: PyDevice) -> PyResult<()> {
Python::attach(|py| {
for module in &self.modules {
if let Ok(to_method) = module.getattr(py, "to") {
to_method.call1(py, (device.clone(),))?;
}
}
Ok(())
})
}
fn zero_grad(&mut self) {
Python::attach(|py| {
for module in &self.modules {
if let Ok(zero_grad_method) = module.getattr(py, "zero_grad") {
let _ = zero_grad_method.call0(py);
}
}
});
}
fn __repr__(&self) -> String {
format!("Sequential({} modules)", self.modules.len())
}
fn __len__(&self) -> usize {
self.modules.len()
}
fn __getitem__(&self, index: usize) -> PyResult<Py<PyAny>> {
Python::attach(|py| {
self.modules
.get(index)
.map(|obj| obj.clone_ref(py))
.ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyIndexError, _>("Index out of range")
})
})
}
fn training(&self) -> bool {
self.training
}
}
#[pyclass(name = "ModuleList", extends = PyModule)]
pub struct PyModuleList {
modules: Vec<Py<PyAny>>,
training: bool,
}
#[pymethods]
impl PyModuleList {
#[new]
fn new(modules: Option<Vec<Py<PyAny>>>) -> (Self, PyModule) {
let modules = modules.unwrap_or_default();
(
Self {
modules,
training: true,
},
PyModule::new(),
)
}
fn append(&mut self, module: Py<PyAny>) {
self.modules.push(module);
}
fn extend(&mut self, modules: Vec<Py<PyAny>>) {
self.modules.extend(modules);
}
fn insert(&mut self, index: usize, module: Py<PyAny>) {
if index <= self.modules.len() {
self.modules.insert(index, module);
}
}
fn parameters(&self) -> PyResult<Vec<PyTensor>> {
let mut all_params = Vec::new();
Python::attach(|py| {
for module in &self.modules {
if let Ok(params_method) = module.getattr(py, "parameters") {
let params_result = params_method.call0(py)?;
if let Ok(params) = params_result.extract::<Vec<PyTensor>>(py) {
all_params.extend(params);
}
}
}
Ok(all_params)
})
}
fn named_parameters(&self) -> PyResult<HashMap<String, PyTensor>> {
let mut all_named_params = HashMap::new();
Python::attach(|py| {
for (i, module) in self.modules.iter().enumerate() {
if let Ok(named_params_method) = module.getattr(py, "named_parameters") {
let named_params_result = named_params_method.call0(py)?;
if let Ok(named_params) =
named_params_result.extract::<HashMap<String, PyTensor>>(py)
{
for (name, param) in named_params {
all_named_params.insert(format!("{}.{}", i, name), param);
}
}
}
}
Ok(all_named_params)
})
}
fn train(&mut self, mode: Option<bool>) {
let mode = mode.unwrap_or(true);
self.training = mode;
Python::attach(|py| {
for module in &self.modules {
if let Ok(train_method) = module.getattr(py, "train") {
let _ = train_method.call1(py, (mode,));
}
}
});
}
fn eval(&mut self) {
self.training = false;
Python::attach(|py| {
for module in &self.modules {
if let Ok(eval_method) = module.getattr(py, "eval") {
let _ = eval_method.call0(py);
}
}
});
}
fn to(&mut self, device: PyDevice) -> PyResult<()> {
Python::attach(|py| {
for module in &self.modules {
if let Ok(to_method) = module.getattr(py, "to") {
to_method.call1(py, (device.clone(),))?;
}
}
Ok(())
})
}
fn zero_grad(&mut self) {
Python::attach(|py| {
for module in &self.modules {
if let Ok(zero_grad_method) = module.getattr(py, "zero_grad") {
let _ = zero_grad_method.call0(py);
}
}
});
}
fn __repr__(&self) -> String {
format!("ModuleList({} modules)", self.modules.len())
}
fn __len__(&self) -> usize {
self.modules.len()
}
fn __getitem__(&self, index: usize) -> PyResult<Py<PyAny>> {
Python::attach(|py| {
self.modules
.get(index)
.map(|obj| obj.clone_ref(py))
.ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyIndexError, _>("Index out of range")
})
})
}
fn __setitem__(&mut self, index: usize, module: Py<PyAny>) -> PyResult<()> {
if index < self.modules.len() {
self.modules[index] = module;
Ok(())
} else {
Err(PyErr::new::<pyo3::exceptions::PyIndexError, _>(
"Index out of range",
))
}
}
fn training(&self) -> bool {
self.training
}
}