use std::collections::HashMap;
use torsh_core::error::Result;
use crate::{Module, Parameter};
pub trait ModuleApply {
fn apply<F>(&mut self, f: &F) -> Result<()>
where
F: Fn(&mut dyn Module) -> Result<()>;
fn apply_to_parameters<F>(&mut self, f: &F) -> Result<()>
where
F: Fn(&mut Parameter) -> Result<()>;
fn apply_to_modules<F>(&mut self, f: &F) -> Result<()>
where
F: Fn(&mut dyn Module) -> Result<()>;
}
impl<T: Module> ModuleApply for T {
fn apply<F>(&mut self, f: &F) -> Result<()>
where
F: Fn(&mut dyn Module) -> Result<()>,
{
f(self)
}
fn apply_to_parameters<F>(&mut self, _f: &F) -> Result<()>
where
F: Fn(&mut Parameter) -> Result<()>,
{
Ok(())
}
fn apply_to_modules<F>(&mut self, _f: &F) -> Result<()>
where
F: Fn(&mut dyn Module) -> Result<()>,
{
Ok(())
}
}
pub mod analysis {
use super::*;
pub fn count_parameters(module: &dyn Module) -> usize {
module
.parameters()
.values()
.map(|param| param.tensor().read().shape().numel())
.sum()
}
pub fn count_trainable_parameters(module: &dyn Module) -> usize {
module
.parameters()
.values()
.filter(|param| param.requires_grad())
.map(|param| param.tensor().read().shape().numel())
.sum()
}
pub fn parameter_statistics(module: &dyn Module) -> ModuleParameterStats {
let parameters = module.parameters();
let total_params = parameters
.values()
.map(|param| param.tensor().read().shape().numel())
.sum();
let trainable_params = parameters
.values()
.filter(|param| param.requires_grad())
.map(|param| param.tensor().read().shape().numel())
.sum();
ModuleParameterStats {
total_parameters: total_params,
trainable_parameters: trainable_params,
frozen_parameters: total_params - trainable_params,
parameter_count_by_layer: parameters
.iter()
.map(|(name, param)| (name.clone(), param.tensor().read().shape().numel()))
.collect(),
}
}
pub fn is_training(module: &dyn Module) -> bool {
module.training()
}
pub fn parameter_names(module: &dyn Module) -> Vec<String> {
module.parameters().keys().cloned().collect()
}
pub fn find_parameters_by_pattern(
module: &dyn Module,
pattern: &str,
) -> HashMap<String, Parameter> {
module
.parameters()
.into_iter()
.filter(|(name, _)| name.contains(pattern))
.collect()
}
}
#[derive(Debug, Clone)]
pub struct ModuleParameterStats {
pub total_parameters: usize,
pub trainable_parameters: usize,
pub frozen_parameters: usize,
pub parameter_count_by_layer: HashMap<String, usize>,
}
impl ModuleParameterStats {
pub fn trainable_percentage(&self) -> f32 {
if self.total_parameters == 0 {
0.0
} else {
(self.trainable_parameters as f32 / self.total_parameters as f32) * 100.0
}
}
pub fn memory_usage_bytes(&self) -> usize {
self.total_parameters * 4 }
pub fn memory_usage_mb(&self) -> f32 {
self.memory_usage_bytes() as f32 / (1024.0 * 1024.0)
}
}
pub mod introspection {
use super::*;
pub fn print_parameter_summary(module: &dyn Module, module_name: &str) {
let stats = analysis::parameter_statistics(module);
println!("=== {} Parameter Summary ===", module_name);
println!("Total parameters: {}", stats.total_parameters);
println!(
"Trainable parameters: {} ({:.1}%)",
stats.trainable_parameters,
stats.trainable_percentage()
);
println!("Frozen parameters: {}", stats.frozen_parameters);
println!("Memory usage: {:.2} MB", stats.memory_usage_mb());
println!("Training mode: {}", analysis::is_training(module));
if !stats.parameter_count_by_layer.is_empty() {
println!("\nParameters by layer:");
let mut layers: Vec<_> = stats.parameter_count_by_layer.iter().collect();
layers.sort_by_key(|(name, _)| name.as_str());
for (layer, count) in layers {
println!(" {}: {}", layer, count);
}
}
println!();
}
pub fn health_check(module: &dyn Module) -> Vec<String> {
let mut issues = Vec::new();
let stats = analysis::parameter_statistics(module);
if stats.total_parameters == 0 {
issues.push("Module has no parameters".to_string());
}
if stats.total_parameters > 0 && stats.trainable_parameters == 0 {
issues.push("All parameters are frozen - module won't train".to_string());
}
if stats.memory_usage_bytes() > 1024 * 1024 * 1024 {
issues.push(format!(
"Large model detected: {:.1} GB",
stats.memory_usage_bytes() as f32 / (1024.0 * 1024.0 * 1024.0)
));
}
issues
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layers::linear::Linear;
#[test]
fn test_parameter_counting() {
let linear = Linear::new(10, 5, true); let count = analysis::count_parameters(&linear);
assert_eq!(count, 55);
}
#[test]
fn test_parameter_stats() {
let linear = Linear::new(4, 2, true); let stats = analysis::parameter_statistics(&linear);
assert_eq!(stats.total_parameters, 10);
assert_eq!(stats.trainable_parameters, 10);
assert_eq!(stats.frozen_parameters, 0);
assert_eq!(stats.trainable_percentage(), 100.0);
}
#[test]
fn test_memory_calculation() {
let stats = ModuleParameterStats {
total_parameters: 1000,
trainable_parameters: 800,
frozen_parameters: 200,
parameter_count_by_layer: HashMap::new(),
};
assert_eq!(stats.memory_usage_bytes(), 4000); assert_eq!(stats.memory_usage_mb(), 4000.0 / (1024.0 * 1024.0));
}
}