use super::Parameter;
use torsh_core::error::{Result, TorshError};
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
use hashbrown::HashMap;
#[derive(Debug, Clone)]
pub struct ParameterGroup {
pub name: String,
pub parameters: Vec<Parameter>,
pub lr_multiplier: f32,
pub weight_decay: f32,
pub clip_gradients: bool,
pub max_grad_norm: f32,
}
impl ParameterGroup {
pub fn new(name: String, parameters: Vec<Parameter>) -> Self {
Self {
name,
parameters,
lr_multiplier: 1.0,
weight_decay: 0.0,
clip_gradients: false,
max_grad_norm: 1.0,
}
}
pub fn with_lr_multiplier(mut self, multiplier: f32) -> Self {
self.lr_multiplier = multiplier;
self
}
pub fn with_weight_decay(mut self, decay: f32) -> Self {
self.weight_decay = decay;
self
}
pub fn with_gradient_clipping(mut self, max_norm: f32) -> Self {
self.clip_gradients = true;
self.max_grad_norm = max_norm;
self
}
pub fn num_parameters(&self) -> usize {
self.parameters.iter().map(|p| p.numel().unwrap_or(0)).sum()
}
pub fn parameter_count(&self) -> usize {
self.parameters.len()
}
}
#[derive(Debug, Clone)]
pub enum ParameterConstraint {
ClampRange { min: f32, max: f32 },
NonNegative,
UnitNorm,
Probability,
Custom { name: String },
}
impl ParameterConstraint {
pub fn apply(&self, parameter: &Parameter) -> Result<()> {
let tensor = parameter.tensor();
let _data = tensor.write();
match self {
ParameterConstraint::ClampRange { min, max } => {
let _ = (min, max);
Ok(())
}
ParameterConstraint::NonNegative => {
Ok(())
}
ParameterConstraint::UnitNorm => {
Ok(())
}
ParameterConstraint::Probability => {
Ok(())
}
ParameterConstraint::Custom { name: _ } => {
Ok(())
}
}
}
pub fn name(&self) -> &str {
match self {
ParameterConstraint::ClampRange { .. } => "ClampRange",
ParameterConstraint::NonNegative => "NonNegative",
ParameterConstraint::UnitNorm => "UnitNorm",
ParameterConstraint::Probability => "Probability",
ParameterConstraint::Custom { name } => name,
}
}
}
#[derive(Debug, Clone)]
pub struct ParameterAnalysis {
pub mean: f32,
pub std: f32,
pub min: f32,
pub max: f32,
pub numel: usize,
pub sparsity: f32,
pub has_nan: bool,
pub has_inf: bool,
}
pub trait ParameterExt {
fn analyze(&self) -> Result<ParameterAnalysis>;
fn is_finite(&self) -> Result<bool>;
fn norm(&self) -> Result<f32>;
fn l1_norm(&self) -> Result<f32>;
fn grad_norm(&self) -> Result<f32>;
fn has_grad(&self) -> bool;
fn to_vec(&self) -> Result<Vec<f32>>;
fn dtype_name(&self) -> &str;
fn memory_bytes(&self) -> usize;
fn clone_with_grad(&self, requires_grad: bool) -> Parameter;
}
impl ParameterExt for Parameter {
fn analyze(&self) -> Result<ParameterAnalysis> {
let tensor = self.tensor();
let data_guard = tensor.read();
let data = data_guard.to_vec()?;
let numel = data.len();
if numel == 0 {
return Err(TorshError::InvalidArgument(
"Cannot analyze empty parameter".to_string(),
));
}
let sum: f32 = data.iter().sum();
let mean = sum / numel as f32;
let variance: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / numel as f32;
let std = variance.sqrt();
let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let zero_count = data.iter().filter(|&&x| x == 0.0).count();
let sparsity = (zero_count as f32 / numel as f32) * 100.0;
let has_nan = data.iter().any(|&x| x.is_nan());
let has_inf = data.iter().any(|&x| x.is_infinite());
Ok(ParameterAnalysis {
mean,
std,
min,
max,
numel,
sparsity,
has_nan,
has_inf,
})
}
fn is_finite(&self) -> Result<bool> {
let tensor = self.tensor();
let data = tensor.read().to_vec()?;
Ok(data.iter().all(|&x| x.is_finite()))
}
fn norm(&self) -> Result<f32> {
let tensor = self.tensor();
let data = tensor.read().to_vec()?;
let sum_sq: f32 = data.iter().map(|&x| x * x).sum();
Ok(sum_sq.sqrt())
}
fn l1_norm(&self) -> Result<f32> {
let tensor = self.tensor();
let data = tensor.read().to_vec()?;
Ok(data.iter().map(|&x| x.abs()).sum())
}
fn grad_norm(&self) -> Result<f32> {
Ok(0.0)
}
fn has_grad(&self) -> bool {
false
}
fn to_vec(&self) -> Result<Vec<f32>> {
let tensor = self.tensor();
let data_guard = tensor.read();
data_guard.to_vec()
}
fn dtype_name(&self) -> &str {
"f32" }
fn memory_bytes(&self) -> usize {
self.numel().unwrap_or(0) * 4 }
fn clone_with_grad(&self, requires_grad: bool) -> Parameter {
let tensor = self.clone_data();
if requires_grad {
Parameter::new(tensor)
} else {
Parameter::new_no_grad(tensor)
}
}
}
pub trait ParameterCollectionExt {
fn total_numel(&self) -> usize;
fn group_by_patterns(
&self,
groups: &HashMap<String, Vec<String>>,
) -> HashMap<String, ParameterGroup>;
fn filter<F>(&self, predicate: F) -> HashMap<String, Parameter>
where
F: Fn(&str, &Parameter) -> bool;
fn trainable(&self) -> HashMap<String, Parameter>;
fn frozen(&self) -> HashMap<String, Parameter>;
}
impl ParameterCollectionExt for super::ParameterCollection {
fn total_numel(&self) -> usize {
self.names()
.iter()
.filter_map(|name| self.get(name))
.map(|p| p.numel().unwrap_or(0))
.sum()
}
fn group_by_patterns(
&self,
groups: &HashMap<String, Vec<String>>,
) -> HashMap<String, ParameterGroup> {
let mut result = HashMap::new();
for (group_name, patterns) in groups {
let mut group_params = Vec::new();
for param_name in self.names() {
if patterns.iter().any(|pattern| param_name.contains(pattern)) {
if let Some(param) = self.get(param_name) {
group_params.push(param.clone());
}
}
}
if !group_params.is_empty() {
result.insert(
group_name.clone(),
ParameterGroup::new(group_name.clone(), group_params),
);
}
}
result
}
fn filter<F>(&self, predicate: F) -> HashMap<String, Parameter>
where
F: Fn(&str, &Parameter) -> bool,
{
let mut result = HashMap::new();
for name in self.names() {
if let Some(param) = self.get(name) {
if predicate(name, param) {
result.insert(name.clone(), param.clone());
}
}
}
result
}
fn trainable(&self) -> HashMap<String, Parameter> {
self.filter(|_, param| param.requires_grad())
}
fn frozen(&self) -> HashMap<String, Parameter> {
self.filter(|_, param| !param.requires_grad())
}
}
#[cfg(test)]
mod tests {
}