use crate::Module;
use torsh_core::device::DeviceType;
use torsh_core::error::Result;
use torsh_tensor::Tensor;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
pub trait ModuleExt: Module {
fn and_then<F>(&self, input: &Tensor, f: F) -> Result<Tensor>
where
F: FnOnce(Tensor) -> Result<Tensor>,
{
let output = self.forward(input)?;
f(output)
}
fn map<F>(&self, input: &Tensor, f: F) -> Result<Tensor>
where
F: FnOnce(Tensor) -> Tensor,
{
let output = self.forward(input)?;
Ok(f(output))
}
fn with_input<F>(&self, input: &Tensor, f: F) -> Result<Tensor>
where
F: FnOnce(&Tensor) -> Result<Tensor>,
{
let transformed = f(input)?;
self.forward(&transformed)
}
fn summary(&self) -> String {
let info = self.module_info();
format!(
"Module: {}\n\
Training: {}\n\
Parameters: {} ({} trainable)\n\
Memory: {:.2} MB\n\
Children: {}",
info.name,
info.training,
info.parameter_count,
info.trainable_parameter_count,
info.memory_usage_bytes as f64 / (1024.0 * 1024.0),
info.children_count
)
}
fn print_summary(&self) {
println!("{}", self.summary());
}
fn parameter_stats(&self) -> ParameterStats {
let params = self.all_parameters();
let mut total_params = 0;
let mut trainable_params = 0;
let mut frozen_params = 0;
let mut total_memory = 0;
for param in params.values() {
let numel = param.numel().unwrap_or(0);
total_params += numel;
total_memory += numel * 4;
if param.requires_grad() {
trainable_params += numel;
} else {
frozen_params += numel;
}
}
ParameterStats {
total_parameters: total_params,
trainable_parameters: trainable_params,
frozen_parameters: frozen_params,
total_memory_bytes: total_memory,
parameter_count: params.len(),
}
}
fn has_finite_parameters(&self) -> bool {
self.all_parameters()
.values()
.all(|p| p.is_finite().unwrap_or(false))
}
fn parameter_names(&self) -> Vec<String> {
let mut names: Vec<String> = self.all_named_parameters().keys().cloned().collect();
names.sort();
names
}
fn get_parameter(&self, name: &str) -> Option<crate::Parameter> {
self.all_named_parameters().get(name).cloned()
}
fn freeze_matching(&mut self, pattern: &str) -> usize {
let mut count = 0;
for (name, _param) in self.all_named_parameters() {
if name.contains(pattern) {
count += 1;
}
}
count
}
fn unfreeze_matching(&mut self, pattern: &str) -> usize {
let mut count = 0;
for (name, _param) in self.all_named_parameters() {
if name.contains(pattern) {
count += 1;
}
}
count
}
fn frozen_parameters(&self) -> Vec<String> {
self.all_named_parameters()
.into_iter()
.filter(|(_, p)| !p.requires_grad())
.map(|(name, _)| name)
.collect()
}
fn trainable_parameters(&self) -> Vec<String> {
self.all_named_parameters()
.into_iter()
.filter(|(_, p)| p.requires_grad())
.map(|(name, _)| name)
.collect()
}
fn clone_state_dict(&self) -> HashMap<String, Tensor> {
self.state_dict()
}
fn apply_to_parameters<F>(&self, mut f: F)
where
F: FnMut(&str, &crate::Parameter),
{
for (name, param) in self.all_named_parameters() {
f(&name, ¶m);
}
}
fn parameters_by_type(&self) -> HashMap<String, usize> {
let mut counts = HashMap::new();
for (name, param) in self.all_named_parameters() {
let layer_type = name.split('.').next().unwrap_or("unknown").to_string();
let numel = param.numel().unwrap_or(0);
*counts.entry(layer_type).or_insert(0) += numel;
}
counts
}
fn validate(&self) -> Result<ValidationReport> {
let mut report = ValidationReport::default();
if !self.has_parameters() {
report.warnings.push("Module has no parameters".to_string());
}
if !self.has_finite_parameters() {
report
.errors
.push("Module has non-finite parameters (NaN or Inf)".to_string());
}
let memory_mb = self.memory_usage_mb();
if memory_mb > 1024.0 {
report
.warnings
.push(format!("Large memory usage: {:.2} GB", memory_mb / 1024.0));
}
let param_count = self.num_parameters();
if param_count > 100_000_000 {
report
.warnings
.push(format!("Very large model: {} parameters", param_count));
}
report.is_valid = report.errors.is_empty();
Ok(report)
}
fn device(&self) -> Option<DeviceType> {
if self.has_parameters() {
Some(DeviceType::Cpu)
} else {
None
}
}
fn is_cpu(&self) -> bool {
self.device() == Some(DeviceType::Cpu)
}
fn is_cuda(&self) -> bool {
matches!(self.device(), Some(DeviceType::Cuda(_)))
}
}
impl<T: Module + ?Sized> ModuleExt for T {}
#[derive(Debug, Clone)]
pub struct ParameterStats {
pub total_parameters: usize,
pub trainable_parameters: usize,
pub frozen_parameters: usize,
pub total_memory_bytes: usize,
pub parameter_count: usize,
}
impl ParameterStats {
pub fn memory_mb(&self) -> f64 {
self.total_memory_bytes as f64 / (1024.0 * 1024.0)
}
pub fn memory_gb(&self) -> f64 {
self.memory_mb() / 1024.0
}
pub fn trainable_percentage(&self) -> f64 {
if self.total_parameters == 0 {
0.0
} else {
(self.trainable_parameters as f64 / self.total_parameters as f64) * 100.0
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ValidationReport {
pub is_valid: bool,
pub errors: Vec<String>,
pub warnings: Vec<String>,
}
impl ValidationReport {
pub fn passed(&self) -> bool {
self.is_valid && self.errors.is_empty()
}
pub fn issue_count(&self) -> usize {
self.errors.len() + self.warnings.len()
}
pub fn format(&self) -> String {
let mut result = String::new();
result.push_str(&format!(
"Validation: {}\n",
if self.is_valid { "PASSED" } else { "FAILED" }
));
if !self.errors.is_empty() {
result.push_str("\nErrors:\n");
for error in &self.errors {
result.push_str(&format!(" - {}\n", error));
}
}
if !self.warnings.is_empty() {
result.push_str("\nWarnings:\n");
for warning in &self.warnings {
result.push_str(&format!(" - {}\n", warning));
}
}
result
}
}
#[cfg(test)]
mod tests {
}