use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
#[cfg(feature = "std")]
use std::{collections::HashMap, vec::Vec};
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub struct ParameterList {
base: ModuleBase,
parameters: Vec<Parameter>,
}
impl ParameterList {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
parameters: Vec::new(),
}
}
pub fn len(&self) -> usize {
self.parameters.len()
}
pub fn is_empty(&self) -> bool {
self.parameters.is_empty()
}
pub fn push(&mut self, parameter: Parameter) {
self.parameters.push(parameter);
}
pub fn extend<I>(&mut self, parameters: I)
where
I: IntoIterator<Item = Parameter>,
{
self.parameters.extend(parameters);
}
pub fn get(&self, index: usize) -> Option<&Parameter> {
self.parameters.get(index)
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut Parameter> {
self.parameters.get_mut(index)
}
pub fn iter(&self) -> impl Iterator<Item = &Parameter> {
self.parameters.iter()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut Parameter> {
self.parameters.iter_mut()
}
}
impl Default for ParameterList {
fn default() -> Self {
Self::new()
}
}
impl Module for ParameterList {
fn forward(&self, _input: &Tensor) -> Result<Tensor> {
Err(TorshError::InvalidArgument(
"ParameterList doesn't define forward pass".to_string(),
))
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (i, param) in self.parameters.iter().enumerate() {
params.insert(i.to_string(), param.clone());
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
for (i, param) in self.parameters.iter().enumerate() {
params.insert(i.to_string(), param.clone());
}
params
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn training(&self) -> bool {
self.base.training()
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)?;
for param in &mut self.parameters {
let tensor_guard = param.tensor();
let tensor = tensor_guard.read().clone();
let moved_tensor = tensor.to(device)?;
*tensor_guard.write() = moved_tensor;
}
Ok(())
}
}
pub struct ParameterDict {
base: ModuleBase,
parameters: HashMap<String, Parameter>,
}
impl ParameterDict {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
parameters: HashMap::new(),
}
}
pub fn len(&self) -> usize {
self.parameters.len()
}
pub fn is_empty(&self) -> bool {
self.parameters.is_empty()
}
pub fn insert(&mut self, key: String, parameter: Parameter) {
self.parameters.insert(key, parameter);
}
pub fn get(&self, key: &str) -> Option<&Parameter> {
self.parameters.get(key)
}
pub fn get_mut(&mut self, key: &str) -> Option<&mut Parameter> {
self.parameters.get_mut(key)
}
pub fn keys(&self) -> impl Iterator<Item = &String> {
self.parameters.keys()
}
pub fn values(&self) -> impl Iterator<Item = &Parameter> {
self.parameters.values()
}
pub fn values_mut(&mut self) -> impl Iterator<Item = &mut Parameter> {
self.parameters.values_mut()
}
pub fn iter(&self) -> impl Iterator<Item = (&String, &Parameter)> {
self.parameters.iter()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = (&String, &mut Parameter)> {
self.parameters.iter_mut()
}
}
impl Default for ParameterDict {
fn default() -> Self {
Self::new()
}
}
impl Module for ParameterDict {
fn forward(&self, _input: &Tensor) -> Result<Tensor> {
Err(TorshError::InvalidArgument(
"ParameterDict doesn't define forward pass".to_string(),
))
}
fn parameters(&self) -> HashMap<String, Parameter> {
self.parameters.clone()
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
self.parameters.clone()
}
fn train(&mut self) {
self.base.set_training(true);
}
fn eval(&mut self) {
self.base.set_training(false);
}
fn training(&self) -> bool {
self.base.training()
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)?;
for param in self.parameters.values_mut() {
let tensor_guard = param.tensor();
let tensor = tensor_guard.read().clone();
let moved_tensor = tensor.to(device)?;
*tensor_guard.write() = moved_tensor;
}
Ok(())
}
}