pub mod module_ext;
pub use module_ext::{ModuleExt, ParameterStats, ValidationReport};
use torsh_core::device::DeviceType;
use torsh_core::error::Result;
use torsh_tensor::Tensor;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub trait Module: Send + Sync {
fn forward(&self, input: &Tensor) -> Result<Tensor>;
fn parameters(&self) -> HashMap<String, crate::Parameter> {
HashMap::new()
}
fn named_parameters(&self) -> HashMap<String, crate::Parameter> {
self.parameters()
}
fn all_parameters(&self) -> HashMap<String, crate::Parameter> {
let mut all_params = self.parameters();
for child in self.children() {
let child_params = child.all_parameters();
for (name, param) in child_params {
all_params.insert(name, param);
}
}
all_params
}
fn all_named_parameters(&self) -> HashMap<String, crate::Parameter> {
let mut all_params = HashMap::new();
for (name, param) in self.named_parameters() {
all_params.insert(name, param);
}
let children_named = self.named_children();
for (child_name, child) in children_named {
for (param_name, param) in child.all_named_parameters() {
let full_name = format!("{}.{}", child_name, param_name);
all_params.insert(full_name, param);
}
}
all_params
}
fn training(&self) -> bool {
true
}
fn train(&mut self) {
self.set_training(true);
}
fn eval(&mut self) {
self.set_training(false);
}
fn set_training(&mut self, _training: bool) {
}
fn to_device(&mut self, _device: DeviceType) -> Result<()> {
Ok(())
}
fn load_state_dict(
&mut self,
state_dict: &HashMap<String, Tensor>,
strict: bool,
) -> Result<()> {
let current_params = self.all_named_parameters();
let mut missing_keys = Vec::new();
let mut unexpected_keys = Vec::new();
for name in current_params.keys() {
if !state_dict.contains_key(name) {
missing_keys.push(name.clone());
}
}
for name in state_dict.keys() {
if !current_params.contains_key(name) {
unexpected_keys.push(name.clone());
}
}
if strict && (!missing_keys.is_empty() || !unexpected_keys.is_empty()) {
return Err(torsh_core::error::TorshError::Other(format!(
"State dict loading failed. Missing keys: {:?}, Unexpected keys: {:?}",
missing_keys, unexpected_keys
)));
}
for (name, param) in current_params {
if let Some(new_tensor) = state_dict.get(&name) {
let current_shape = param.shape()?;
let new_shape = new_tensor.shape().dims().to_vec();
if current_shape != new_shape {
return Err(torsh_core::error::TorshError::Other(format!(
"Shape mismatch for parameter '{}': expected {:?}, got {:?}",
name, current_shape, new_shape
)));
}
*param.tensor().write() = new_tensor.clone();
}
}
Ok(())
}
fn load_state_dict_strict(&mut self, state_dict: &HashMap<String, Tensor>) -> Result<()> {
self.load_state_dict(state_dict, true)
}
fn state_dict(&self) -> HashMap<String, Tensor> {
let mut state = HashMap::new();
for (name, param) in self.all_named_parameters() {
state.insert(name, param.clone_data());
}
state
}
fn name(&self) -> Option<&str> {
None
}
fn buffers(&self) -> Vec<std::sync::Arc<parking_lot::RwLock<Tensor>>> {
Vec::new()
}
fn named_buffers(&self) -> HashMap<String, std::sync::Arc<parking_lot::RwLock<Tensor>>> {
HashMap::new()
}
fn children(&self) -> Vec<&dyn Module> {
Vec::new()
}
fn named_children(&self) -> Vec<(String, &dyn Module)> {
Vec::new()
}
fn modules(&self) -> Vec<&dyn Module>
where
Self: Sized,
{
let mut modules: Vec<&dyn Module> = vec![self];
for child in self.children() {
modules.push(child);
}
modules
}
fn named_modules(&self) -> Vec<(String, &dyn Module)>
where
Self: Sized,
{
let mut modules: Vec<(String, &dyn Module)> = vec![(String::new(), self)];
for (child_name, child) in self.named_children() {
modules.push((child_name, child));
}
modules
}
fn zero_grad(&mut self) {
}
fn num_parameters(&self) -> usize {
self.all_parameters()
.values()
.map(|p| p.numel().unwrap_or(0))
.sum()
}
fn num_trainable_parameters(&self) -> usize {
self.all_parameters()
.values()
.filter(|p| p.requires_grad())
.map(|p| p.numel().unwrap_or(0))
.sum()
}
fn memory_usage(&self) -> usize {
self.all_parameters()
.values()
.map(|p| p.numel().unwrap_or(0) * 4) .sum()
}
fn freeze(&mut self) {
}
fn unfreeze(&mut self) {
}
fn extra_repr(&self) -> String {
String::new()
}
fn register_hook(
&mut self,
_hook_type: crate::HookType,
_callback: crate::HookCallback,
) -> Option<crate::HookHandle> {
None
}
fn remove_hook(&mut self, _hook_type: crate::HookType, _handle: crate::HookHandle) -> bool {
false
}
fn execute_hooks(
&self,
_hook_type: crate::HookType,
_input: &Tensor,
_output: Option<&Tensor>,
) -> Result<()> {
Ok(())
}
fn forward_with_hooks(&self, input: &Tensor) -> Result<Tensor> {
self.execute_hooks(crate::HookType::PreForward, input, None)?;
let output = self.forward(input)?;
self.execute_hooks(crate::HookType::PostForward, input, Some(&output))?;
Ok(output)
}
fn has_hooks(&self, _hook_type: crate::HookType) -> bool {
false
}
fn call(&self, input: &Tensor) -> Result<Tensor> {
self.forward(input)
}
fn apply(&self, input: &Tensor) -> Result<Tensor> {
self.forward(input)
}
fn has_parameters(&self) -> bool {
!self.parameters().is_empty()
}
fn has_children(&self) -> bool {
!self.children().is_empty()
}
fn parameter_count(&self) -> usize {
self.num_parameters()
}
fn trainable_parameter_count(&self) -> usize {
self.num_trainable_parameters()
}
fn memory_usage_mb(&self) -> f64 {
self.memory_usage() as f64 / (1024.0 * 1024.0)
}
fn toggle_training(&mut self) {
self.set_training(!self.training());
}
fn eval_mode(&self) -> bool {
!self.training()
}
fn sequential_forward(modules: &[&dyn Module], mut input: Tensor) -> Result<Tensor>
where
Self: Sized,
{
for module in modules {
input = module.forward(&input)?;
}
Ok(input)
}
fn batch_forward(&self, inputs: &[Tensor]) -> Result<Vec<Tensor>> {
inputs.iter().map(|input| self.forward(input)).collect()
}
fn conditional_forward(&self, input: &Tensor, condition: bool) -> Result<Tensor> {
if condition {
self.forward(input)
} else {
Ok(input.clone())
}
}
fn residual_forward(&self, input: &Tensor) -> Result<Tensor> {
let output = self.forward(input)?;
Ok(output)
}
fn module_info(&self) -> crate::ModuleInfo {
crate::ModuleInfo {
name: self.name().unwrap_or("Unknown").to_string(),
training: self.training(),
parameter_count: self.num_parameters(),
trainable_parameter_count: self.num_trainable_parameters(),
memory_usage_bytes: self.memory_usage(),
has_children: self.has_children(),
children_count: self.children().len(),
}
}
fn check_training_readiness(&self) -> Result<()> {
if !self.has_parameters() {
return Err(torsh_core::error::TorshError::Other(
"Module has no parameters - may not be trainable".to_string(),
));
}
if !self.training() {
return Err(torsh_core::error::TorshError::Other(
"Module is in evaluation mode - switch to training mode first".to_string(),
));
}
for param in self.parameters().values() {
if !param.is_finite().unwrap_or(false) {
return Err(torsh_core::error::TorshError::Other(
"Module contains non-finite parameters (NaN or infinity)".to_string(),
));
}
}
Ok(())
}
fn parameter_names_matching(&self, pattern: &str) -> Vec<String> {
self.all_named_parameters()
.keys()
.filter(|name| name.contains(pattern))
.cloned()
.collect()
}
fn parameters_by_type(&self, param_type: &str) -> HashMap<String, crate::Parameter> {
self.all_named_parameters()
.into_iter()
.filter(|(name, _)| name.contains(param_type))
.collect()
}
fn clone_parameters(&self) -> HashMap<String, Tensor> {
self.all_named_parameters()
.into_iter()
.map(|(name, param)| (name, param.clone_data()))
.collect()
}
fn diagnose(&self) -> crate::ModuleDiagnostics {
let mut issues = Vec::new();
let mut warnings = Vec::new();
for (name, param) in self.all_named_parameters() {
if let Ok(diag) = param.diagnose() {
if !diag.issues.is_empty() {
issues.extend(
diag.issues
.into_iter()
.map(|issue| format!("{}: {}", name, issue)),
);
}
if !diag.warnings.is_empty() {
warnings.extend(
diag.warnings
.into_iter()
.map(|warning| format!("{}: {}", name, warning)),
);
}
}
}
if let Err(e) = self.check_training_readiness() {
warnings.push(format!("Training readiness: {}", e));
}
crate::ModuleDiagnostics {
module_info: self.module_info(),
issues,
warnings,
parameter_diagnostics: self
.all_named_parameters()
.into_iter()
.filter_map(|(name, param)| param.diagnose().ok().map(|d| (name, d)))
.collect(),
}
}
}
impl Module for Box<dyn Module> {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
(**self).forward(x)
}
fn parameters(&self) -> HashMap<String, crate::Parameter> {
(**self).parameters()
}
fn train(&mut self) {
(**self).train()
}
fn eval(&mut self) {
(**self).eval()
}
fn training(&self) -> bool {
(**self).training()
}
fn children(&self) -> Vec<&dyn Module> {
(**self).children()
}
fn named_children(&self) -> Vec<(String, &dyn Module)> {
(**self).named_children()
}
fn set_training(&mut self, training: bool) {
(**self).set_training(training)
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
(**self).to_device(device)
}
}
impl Module for &mut Box<dyn Module> {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
(***self).forward(x)
}
fn parameters(&self) -> HashMap<String, crate::Parameter> {
(***self).parameters()
}
fn train(&mut self) {
(***self).train()
}
fn eval(&mut self) {
(***self).eval()
}
fn training(&self) -> bool {
(***self).training()
}
fn children(&self) -> Vec<&dyn Module> {
(***self).children()
}
fn named_children(&self) -> Vec<(String, &dyn Module)> {
(***self).named_children()
}
fn set_training(&mut self, training: bool) {
(***self).set_training(training)
}
fn to_device(&mut self, device: DeviceType) -> Result<()> {
(***self).to_device(device)
}
}