use crate::{Module, ModuleBase, Parameter};
use torsh_core::device::DeviceType;
use torsh_core::error::{Result, TorshError};
use torsh_tensor::Tensor;
#[cfg(feature = "std")]
use std::{boxed::Box, collections::HashMap, sync::Arc, vec::Vec};
#[cfg(not(feature = "std"))]
use alloc::{boxed::Box, sync::Arc, vec::Vec};
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
use parking_lot::Mutex;
pub struct LazySequential {
base: ModuleBase,
module_factories: Vec<Box<dyn Fn(&[usize]) -> Result<Box<dyn Module>> + Send + Sync>>,
modules: Mutex<Option<Vec<Box<dyn Module>>>>,
initialized: Mutex<bool>,
}
impl LazySequential {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
module_factories: Vec::new(),
modules: Mutex::new(None),
initialized: Mutex::new(false),
}
}
pub fn add_factory<F>(mut self, factory: F) -> Self
where
F: Fn(&[usize]) -> Result<Box<dyn Module>> + Send + Sync + 'static,
{
self.module_factories.push(Box::new(factory));
self
}
pub fn add_module<M: Module + 'static>(mut self, module: M) -> Self {
let module = Arc::new(Mutex::new(Some(module)));
self.module_factories.push(Box::new(move |_shape| {
if let Some(m) = module.lock().take() {
Ok(Box::new(m) as Box<dyn Module>)
} else {
Err(TorshError::Other("Module already used".to_string()))
}
}));
self
}
fn initialize(&self, input_shape: &[usize]) -> Result<()> {
let mut modules_guard = self.modules.lock();
if modules_guard.is_some() {
return Ok(()); }
let mut modules = Vec::new();
let mut current_shape = input_shape.to_vec();
for factory in &self.module_factories {
let module = factory(¤t_shape)?;
let dummy_input = torsh_tensor::creation::zeros(¤t_shape)?;
let dummy_output = module.forward(&dummy_input)?;
current_shape = dummy_output.shape().dims().to_vec();
modules.push(module);
}
*modules_guard = Some(modules);
*self.initialized.lock() = true;
Ok(())
}
pub fn is_initialized(&self) -> bool {
*self.initialized.lock()
}
pub fn len(&self) -> usize {
self.module_factories.len()
}
pub fn is_empty(&self) -> bool {
self.module_factories.is_empty()
}
}
impl Default for LazySequential {
fn default() -> Self {
Self::new()
}
}
impl Module for LazySequential {
fn forward(&self, input: &Tensor) -> Result<Tensor> {
if !self.is_initialized() {
self.initialize(input.shape().dims())?;
}
let modules_guard = self.modules.lock();
let modules = modules_guard
.as_ref()
.expect("modules should be initialized before forward pass");
let mut output = input.clone();
for module in modules {
output = module.forward(&output)?;
}
Ok(output)
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
if let Some(modules) = self.modules.lock().as_ref() {
for (i, module) in modules.iter().enumerate() {
for (name, param) in module.parameters() {
params.insert(format!("{}.{}", i, name), param);
}
}
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
if let Some(modules) = self.modules.lock().as_ref() {
for (i, module) in modules.iter().enumerate() {
for (name, param) in module.named_parameters() {
params.insert(format!("{}.{}", i, name), param);
}
}
}
params
}
fn train(&mut self) {
self.base.set_training(true);
if let Some(modules) = self.modules.lock().as_mut() {
for module in modules {
module.train();
}
}
}
fn eval(&mut self) {
self.base.set_training(false);
if let Some(modules) = self.modules.lock().as_mut() {
for module in modules {
module.eval();
}
}
}
fn training(&self) -> bool {
self.base.training()
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
if let Some(modules) = self.modules.lock().as_mut() {
for module in modules {
module.set_training(training);
}
}
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)?;
if let Some(modules) = self.modules.lock().as_mut() {
for module in modules {
module.to_device(device)?;
}
}
Ok(())
}
fn children(&self) -> Vec<&dyn Module> {
Vec::new()
}
}
pub struct LazyModuleList {
base: ModuleBase,
module_factories: Vec<Box<dyn Fn(&[usize]) -> Result<Box<dyn Module>> + Send + Sync>>,
modules: Mutex<Option<Vec<Box<dyn Module>>>>,
initialized: Mutex<bool>,
}
impl LazyModuleList {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
module_factories: Vec::new(),
modules: Mutex::new(None),
initialized: Mutex::new(false),
}
}
pub fn push_factory<F>(&mut self, factory: F)
where
F: Fn(&[usize]) -> Result<Box<dyn Module>> + Send + Sync + 'static,
{
self.module_factories.push(Box::new(factory));
}
pub fn push_module<M: Module + 'static>(&mut self, module: M) {
let module = Arc::new(Mutex::new(Some(module)));
self.module_factories.push(Box::new(move |_shape| {
if let Some(m) = module.lock().take() {
Ok(Box::new(m) as Box<dyn Module>)
} else {
Err(TorshError::Other("Module already used".to_string()))
}
}));
}
pub fn initialize_lazy(&self, input_shape: &[usize]) -> Result<()> {
let mut modules_guard = self.modules.lock();
if modules_guard.is_some() {
return Ok(());
}
let mut modules = Vec::new();
for factory in &self.module_factories {
modules.push(factory(input_shape)?);
}
*modules_guard = Some(modules);
*self.initialized.lock() = true;
Ok(())
}
pub fn is_initialized(&self) -> bool {
*self.initialized.lock()
}
pub fn len(&self) -> usize {
self.module_factories.len()
}
pub fn is_empty(&self) -> bool {
self.module_factories.is_empty()
}
pub fn get(&self, _index: usize) -> Option<&dyn Module> {
None
}
pub fn apply_to_module<F, R>(&self, index: usize, f: F) -> Option<R>
where
F: FnOnce(&dyn Module) -> R,
{
if let Some(modules) = self.modules.lock().as_ref() {
modules.get(index).map(|m| f(m.as_ref()))
} else {
None
}
}
}
impl Default for LazyModuleList {
fn default() -> Self {
Self::new()
}
}
impl Module for LazyModuleList {
fn forward(&self, _input: &Tensor) -> Result<Tensor> {
Err(TorshError::InvalidArgument(
"LazyModuleList doesn't define forward pass - initialize and use modules individually"
.to_string(),
))
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
if let Some(modules) = self.modules.lock().as_ref() {
for (i, module) in modules.iter().enumerate() {
for (name, param) in module.parameters() {
params.insert(format!("{}.{}", i, name), param);
}
}
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
if let Some(modules) = self.modules.lock().as_ref() {
for (i, module) in modules.iter().enumerate() {
for (name, param) in module.named_parameters() {
params.insert(format!("{}.{}", i, name), param);
}
}
}
params
}
fn train(&mut self) {
self.base.set_training(true);
if let Some(modules) = self.modules.lock().as_mut() {
for module in modules {
module.train();
}
}
}
fn eval(&mut self) {
self.base.set_training(false);
if let Some(modules) = self.modules.lock().as_mut() {
for module in modules {
module.eval();
}
}
}
fn training(&self) -> bool {
self.base.training()
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
if let Some(modules) = self.modules.lock().as_mut() {
for module in modules {
module.set_training(training);
}
}
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)?;
if let Some(modules) = self.modules.lock().as_mut() {
for module in modules {
module.to_device(device)?;
}
}
Ok(())
}
fn children(&self) -> Vec<&dyn Module> {
Vec::new()
}
}
pub struct LazyModuleDict {
base: ModuleBase,
module_factories:
HashMap<String, Box<dyn Fn(&[usize]) -> Result<Box<dyn Module>> + Send + Sync>>,
modules: Mutex<Option<HashMap<String, Box<dyn Module>>>>,
initialized: Mutex<bool>,
}
impl LazyModuleDict {
pub fn new() -> Self {
Self {
base: ModuleBase::new(),
module_factories: HashMap::new(),
modules: Mutex::new(None),
initialized: Mutex::new(false),
}
}
pub fn insert_factory<F>(&mut self, key: String, factory: F)
where
F: Fn(&[usize]) -> Result<Box<dyn Module>> + Send + Sync + 'static,
{
self.module_factories.insert(key, Box::new(factory));
}
pub fn insert_module<M: Module + 'static>(&mut self, key: String, module: M) {
let module = Arc::new(Mutex::new(Some(module)));
self.module_factories.insert(
key,
Box::new(move |_shape| {
if let Some(m) = module.lock().take() {
Ok(Box::new(m) as Box<dyn Module>)
} else {
Err(TorshError::Other("Module already used".to_string()))
}
}),
);
}
pub fn initialize_lazy(&self, input_shape: &[usize]) -> Result<()> {
let mut modules_guard = self.modules.lock();
if modules_guard.is_some() {
return Ok(());
}
let mut modules = HashMap::new();
for (key, factory) in &self.module_factories {
modules.insert(key.clone(), factory(input_shape)?);
}
*modules_guard = Some(modules);
*self.initialized.lock() = true;
Ok(())
}
pub fn is_initialized(&self) -> bool {
*self.initialized.lock()
}
pub fn len(&self) -> usize {
self.module_factories.len()
}
pub fn is_empty(&self) -> bool {
self.module_factories.is_empty()
}
pub fn get(&self, _key: &str) -> Option<&dyn Module> {
None
}
pub fn apply_to_module<F, R>(&self, key: &str, f: F) -> Option<R>
where
F: FnOnce(&dyn Module) -> R,
{
if let Some(modules) = self.modules.lock().as_ref() {
modules.get(key).map(|m| f(m.as_ref()))
} else {
None
}
}
pub fn factory_keys(&self) -> impl Iterator<Item = &String> {
self.module_factories.keys()
}
pub fn module_keys(&self) -> Vec<String> {
if let Some(modules) = self.modules.lock().as_ref() {
modules.keys().cloned().collect()
} else {
Vec::new()
}
}
}
impl Default for LazyModuleDict {
fn default() -> Self {
Self::new()
}
}
impl Module for LazyModuleDict {
fn forward(&self, _input: &Tensor) -> Result<Tensor> {
Err(TorshError::InvalidArgument(
"LazyModuleDict doesn't define forward pass - initialize and use modules individually"
.to_string(),
))
}
fn parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
if let Some(modules) = self.modules.lock().as_ref() {
for (module_name, module) in modules {
for (param_name, param) in module.parameters() {
params.insert(format!("{}.{}", module_name, param_name), param);
}
}
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
if let Some(modules) = self.modules.lock().as_ref() {
for (module_name, module) in modules {
for (param_name, param) in module.named_parameters() {
params.insert(format!("{}.{}", module_name, param_name), param);
}
}
}
params
}
fn train(&mut self) {
self.base.set_training(true);
if let Some(modules) = self.modules.lock().as_mut() {
for module in modules.values_mut() {
module.train();
}
}
}
fn eval(&mut self) {
self.base.set_training(false);
if let Some(modules) = self.modules.lock().as_mut() {
for module in modules.values_mut() {
module.eval();
}
}
}
fn training(&self) -> bool {
self.base.training()
}
fn set_training(&mut self, training: bool) {
self.base.set_training(training);
if let Some(modules) = self.modules.lock().as_mut() {
for module in modules.values_mut() {
module.set_training(training);
}
}
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
self.base.to_device(device)?;
if let Some(modules) = self.modules.lock().as_mut() {
for module in modules.values_mut() {
module.to_device(device)?;
}
}
Ok(())
}
fn children(&self) -> Vec<&dyn Module> {
Vec::new()
}
}