use super::{TQModule, TQParameter};
use crate::error::{MLError, Result};
use scirs2_core::ndarray::{Array1, ArrayD, Axis, IxDyn};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct GradientAccumulator {
pub accumulation_steps: usize,
current_step: usize,
accumulated_grads: HashMap<String, ArrayD<f64>>,
average: bool,
}
impl GradientAccumulator {
pub fn new(accumulation_steps: usize) -> Self {
Self {
accumulation_steps,
current_step: 0,
accumulated_grads: HashMap::new(),
average: true,
}
}
pub fn with_sum(accumulation_steps: usize) -> Self {
Self {
accumulation_steps,
current_step: 0,
accumulated_grads: HashMap::new(),
average: false,
}
}
pub fn accumulate(&mut self, params: &[TQParameter]) -> Result<()> {
for param in params {
if !param.requires_grad {
continue;
}
if let Some(grad) = ¶m.grad {
let entry = self
.accumulated_grads
.entry(param.name.clone())
.or_insert_with(|| ArrayD::zeros(grad.raw_dim()));
*entry = &*entry + grad;
}
}
self.current_step += 1;
Ok(())
}
pub fn is_ready(&self) -> bool {
self.current_step >= self.accumulation_steps
}
pub fn get_and_reset(&mut self) -> HashMap<String, ArrayD<f64>> {
let mut result = std::mem::take(&mut self.accumulated_grads);
if self.average && self.accumulation_steps > 1 {
let scale = 1.0 / self.accumulation_steps as f64;
for grad in result.values_mut() {
*grad = &*grad * scale;
}
}
self.current_step = 0;
result
}
pub fn reset(&mut self) {
self.accumulated_grads.clear();
self.current_step = 0;
}
pub fn step_count(&self) -> usize {
self.current_step
}
}
#[derive(Debug)]
pub struct ParameterRegistry {
parameters: HashMap<String, TQParameter>,
frozen: Vec<String>,
}
impl ParameterRegistry {
pub fn new() -> Self {
Self {
parameters: HashMap::new(),
frozen: Vec::new(),
}
}
pub fn register_module(&mut self, module: &dyn TQModule) -> Result<()> {
let params = module.parameters();
for param in params {
self.parameters.insert(param.name.clone(), param);
}
Ok(())
}
pub fn register(&mut self, param: TQParameter) {
self.parameters.insert(param.name.clone(), param);
}
pub fn get(&self, name: &str) -> Option<&TQParameter> {
self.parameters.get(name)
}
pub fn get_mut(&mut self, name: &str) -> Option<&mut TQParameter> {
self.parameters.get_mut(name)
}
pub fn trainable_parameters(&self) -> Vec<&TQParameter> {
self.parameters
.values()
.filter(|p| p.requires_grad && !self.frozen.contains(&p.name))
.collect()
}
pub fn parameter_names(&self) -> Vec<&str> {
self.parameters.keys().map(|s| s.as_str()).collect()
}
pub fn count(&self) -> usize {
self.parameters.values().map(|p| p.numel()).sum()
}
pub fn trainable_count(&self) -> usize {
self.trainable_parameters().iter().map(|p| p.numel()).sum()
}
pub fn freeze(&mut self, name: &str) -> Result<()> {
if !self.parameters.contains_key(name) {
return Err(MLError::InvalidConfiguration(format!(
"Parameter '{}' not found",
name
)));
}
if !self.frozen.contains(&name.to_string()) {
self.frozen.push(name.to_string());
}
Ok(())
}
pub fn unfreeze(&mut self, name: &str) -> Result<()> {
self.frozen.retain(|n| n != name);
Ok(())
}
pub fn freeze_all(&mut self) {
self.frozen = self.parameters.keys().cloned().collect();
}
pub fn unfreeze_all(&mut self) {
self.frozen.clear();
}
pub fn zero_grad(&mut self) {
for param in self.parameters.values_mut() {
param.zero_grad();
}
}
pub fn memory_bytes(&self) -> usize {
self.parameters.values().map(|p| p.numel() * 8).sum() }
pub fn statistics(&self) -> ParameterStatistics {
let total_params = self.count();
let trainable_params = self.trainable_count();
let memory_mb = self.memory_bytes() as f64 / (1024.0 * 1024.0);
ParameterStatistics {
total_params,
trainable_params,
frozen_params: total_params - trainable_params,
memory_mb,
}
}
}
impl Default for ParameterRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ParameterStatistics {
pub total_params: usize,
pub trainable_params: usize,
pub frozen_params: usize,
pub memory_mb: f64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ClippingStrategy {
Norm { max_norm: f64 },
Value { clip_value: f64 },
Adaptive { clip_factor: f64 },
}
pub struct GradientClipper {
strategy: ClippingStrategy,
pub last_norm: Option<f64>,
pub was_clipped: bool,
}
impl GradientClipper {
pub fn by_norm(max_norm: f64) -> Self {
Self {
strategy: ClippingStrategy::Norm { max_norm },
last_norm: None,
was_clipped: false,
}
}
pub fn by_value(clip_value: f64) -> Self {
Self {
strategy: ClippingStrategy::Value { clip_value },
last_norm: None,
was_clipped: false,
}
}
pub fn adaptive(clip_factor: f64) -> Self {
Self {
strategy: ClippingStrategy::Adaptive { clip_factor },
last_norm: None,
was_clipped: false,
}
}
pub fn clip(&mut self, params: &mut [TQParameter]) -> Result<()> {
match self.strategy {
ClippingStrategy::Norm { max_norm } => self.clip_by_norm(params, max_norm),
ClippingStrategy::Value { clip_value } => self.clip_by_value(params, clip_value),
ClippingStrategy::Adaptive { clip_factor } => self.clip_adaptive(params, clip_factor),
}
}
fn clip_by_norm(&mut self, params: &mut [TQParameter], max_norm: f64) -> Result<()> {
let mut total_norm_sq = 0.0;
for param in params.iter() {
if let Some(grad) = ¶m.grad {
for &val in grad.iter() {
total_norm_sq += val * val;
}
}
}
let total_norm = total_norm_sq.sqrt();
self.last_norm = Some(total_norm);
if total_norm > max_norm {
let scale = max_norm / (total_norm + 1e-10);
for param in params {
if let Some(grad) = &mut param.grad {
*grad = &*grad * scale;
}
}
self.was_clipped = true;
} else {
self.was_clipped = false;
}
Ok(())
}
fn clip_by_value(&mut self, params: &mut [TQParameter], clip_value: f64) -> Result<()> {
self.was_clipped = false;
for param in params {
if let Some(grad) = &mut param.grad {
for val in grad.iter_mut() {
if val.abs() > clip_value {
*val = val.signum() * clip_value;
self.was_clipped = true;
}
}
}
}
Ok(())
}
fn clip_adaptive(&mut self, params: &mut [TQParameter], clip_factor: f64) -> Result<()> {
self.was_clipped = false;
for param in params {
if let Some(grad) = &mut param.grad {
let param_norm: f64 = param.data.iter().map(|&v| v * v).sum::<f64>().sqrt();
let max_grad = param_norm * clip_factor;
let grad_norm: f64 = grad.iter().map(|&v| v * v).sum::<f64>().sqrt();
if grad_norm > max_grad {
let scale = max_grad / (grad_norm + 1e-10);
*grad = &*grad * scale;
self.was_clipped = true;
}
}
}
Ok(())
}
pub fn statistics(&self) -> ClippingStatistics {
ClippingStatistics {
was_clipped: self.was_clipped,
last_norm: self.last_norm,
strategy: self.strategy,
}
}
}
#[derive(Debug, Clone)]
pub struct ClippingStatistics {
pub was_clipped: bool,
pub last_norm: Option<f64>,
pub strategy: ClippingStrategy,
}
pub struct GradientChecker {
pub epsilon: f64,
pub rtol: f64,
pub atol: f64,
}
impl GradientChecker {
pub fn new() -> Self {
Self {
epsilon: 1e-5,
rtol: 1e-3,
atol: 1e-5,
}
}
pub fn with_epsilon(epsilon: f64) -> Self {
Self {
epsilon,
rtol: 1e-3,
atol: 1e-5,
}
}
pub fn with_tolerances(epsilon: f64, rtol: f64, atol: f64) -> Self {
Self {
epsilon,
rtol,
atol,
}
}
pub fn numerical_gradient<F>(
&self,
param: &mut TQParameter,
param_idx: usize,
loss_fn: &mut F,
) -> Result<f64>
where
F: FnMut() -> Result<f64>,
{
let flat_idx = self.flat_index(param_idx, param.shape());
let original =
param.data.as_slice_mut().ok_or_else(|| {
MLError::InvalidConfiguration("Cannot get mutable slice".to_string())
})?[flat_idx];
param.data.as_slice_mut().ok_or_else(|| {
MLError::InvalidConfiguration("Cannot get mutable slice".to_string())
})?[flat_idx] = original + self.epsilon;
let loss_plus = loss_fn()?;
param.data.as_slice_mut().ok_or_else(|| {
MLError::InvalidConfiguration("Cannot get mutable slice".to_string())
})?[flat_idx] = original - self.epsilon;
let loss_minus = loss_fn()?;
param.data.as_slice_mut().ok_or_else(|| {
MLError::InvalidConfiguration("Cannot get mutable slice".to_string())
})?[flat_idx] = original;
Ok((loss_plus - loss_minus) / (2.0 * self.epsilon))
}
pub fn check_gradient(&self, analytical: f64, numerical: f64) -> GradientCheckResult {
let abs_diff = (analytical - numerical).abs();
let rel_diff = if numerical.abs() > 1e-10 {
abs_diff / numerical.abs()
} else {
abs_diff
};
let matches = abs_diff <= self.atol || rel_diff <= self.rtol;
GradientCheckResult {
analytical,
numerical,
abs_diff,
rel_diff,
matches,
}
}
fn flat_index(&self, idx: usize, shape: &[usize]) -> usize {
idx
}
}
impl Default for GradientChecker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct GradientCheckResult {
pub analytical: f64,
pub numerical: f64,
pub abs_diff: f64,
pub rel_diff: f64,
pub matches: bool,
}
#[derive(Debug, Clone)]
pub struct ParameterGroup {
pub name: String,
pub param_names: Vec<String>,
pub lr_multiplier: f64,
pub weight_decay: f64,
pub requires_grad: bool,
}
impl ParameterGroup {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
param_names: Vec::new(),
lr_multiplier: 1.0,
weight_decay: 0.0,
requires_grad: true,
}
}
pub fn add_param(&mut self, param_name: impl Into<String>) {
self.param_names.push(param_name.into());
}
pub fn with_lr_multiplier(mut self, multiplier: f64) -> Self {
self.lr_multiplier = multiplier;
self
}
pub fn with_weight_decay(mut self, decay: f64) -> Self {
self.weight_decay = decay;
self
}
pub fn with_requires_grad(mut self, requires_grad: bool) -> Self {
self.requires_grad = requires_grad;
self
}
pub fn contains(&self, param_name: &str) -> bool {
self.param_names.iter().any(|n| n == param_name)
}
}
#[derive(Debug)]
pub struct ParameterGroupManager {
groups: Vec<ParameterGroup>,
}
impl ParameterGroupManager {
pub fn new() -> Self {
Self { groups: Vec::new() }
}
pub fn add_group(&mut self, group: ParameterGroup) {
self.groups.push(group);
}
pub fn get_group(&self, param_name: &str) -> Option<&ParameterGroup> {
self.groups.iter().find(|g| g.contains(param_name))
}
pub fn groups(&self) -> &[ParameterGroup] {
&self.groups
}
pub fn lr_multiplier(&self, param_name: &str) -> f64 {
self.get_group(param_name)
.map(|g| g.lr_multiplier)
.unwrap_or(1.0)
}
pub fn weight_decay(&self, param_name: &str) -> f64 {
self.get_group(param_name)
.map(|g| g.weight_decay)
.unwrap_or(0.0)
}
pub fn requires_grad(&self, param_name: &str) -> bool {
self.get_group(param_name)
.map(|g| g.requires_grad)
.unwrap_or(true)
}
}
impl Default for ParameterGroupManager {
fn default() -> Self {
Self::new()
}
}
pub fn gradient_norm(params: &[TQParameter]) -> f64 {
let mut norm_sq = 0.0;
for param in params {
if let Some(grad) = ¶m.grad {
for &val in grad.iter() {
norm_sq += val * val;
}
}
}
norm_sq.sqrt()
}
pub fn gradient_statistics(params: &[TQParameter]) -> GradientStatistics {
let mut all_grads = Vec::new();
for param in params {
if let Some(grad) = ¶m.grad {
all_grads.extend(grad.iter().copied());
}
}
if all_grads.is_empty() {
return GradientStatistics::default();
}
let n = all_grads.len() as f64;
let mean = all_grads.iter().sum::<f64>() / n;
let variance = all_grads.iter().map(|&g| (g - mean).powi(2)).sum::<f64>() / n;
let std = variance.sqrt();
let min = all_grads
.iter()
.copied()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.0);
let max = all_grads
.iter()
.copied()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(0.0);
let norm = gradient_norm(params);
GradientStatistics {
mean,
std,
min,
max,
norm,
}
}
#[derive(Debug, Clone, Default)]
pub struct GradientStatistics {
pub mean: f64,
pub std: f64,
pub min: f64,
pub max: f64,
pub norm: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::ArrayD;
#[test]
fn test_gradient_accumulator() {
let mut acc = GradientAccumulator::new(3);
let mut param = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "test");
param.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0]).unwrap());
for _ in 0..3 {
acc.accumulate(&[param.clone()]).unwrap();
}
assert!(acc.is_ready());
let grads = acc.get_and_reset();
let test_grad = &grads["test"];
assert!((test_grad[[0]] - 1.0).abs() < 1e-10);
assert!((test_grad[[1]] - 2.0).abs() < 1e-10);
}
#[test]
fn test_parameter_registry() {
let mut registry = ParameterRegistry::new();
let param1 = TQParameter::new(ArrayD::zeros(IxDyn(&[5])), "layer1");
let param2 = TQParameter::new(ArrayD::zeros(IxDyn(&[10])), "layer2");
registry.register(param1);
registry.register(param2);
assert_eq!(registry.count(), 15);
assert_eq!(registry.trainable_count(), 15);
registry.freeze("layer1").unwrap();
assert_eq!(registry.trainable_count(), 10);
let stats = registry.statistics();
assert_eq!(stats.total_params, 15);
assert_eq!(stats.trainable_params, 10);
assert_eq!(stats.frozen_params, 5);
}
#[test]
fn test_gradient_clipper_by_norm() {
let mut clipper = GradientClipper::by_norm(1.0);
let mut param = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "test");
param.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![3.0, 4.0]).unwrap());
clipper.clip(&mut [param]).unwrap();
assert!(clipper.was_clipped);
assert!((clipper.last_norm.unwrap() - 5.0).abs() < 1e-10);
}
#[test]
fn test_gradient_clipper_by_value() {
let mut clipper = GradientClipper::by_value(2.0);
let mut param = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "test");
param.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![3.0, -4.0]).unwrap());
clipper.clip(&mut [param]).unwrap();
assert!(clipper.was_clipped);
}
#[test]
fn test_parameter_group() {
let mut manager = ParameterGroupManager::new();
let mut group1 = ParameterGroup::new("backbone")
.with_lr_multiplier(0.1)
.with_weight_decay(0.01);
group1.add_param("layer1");
group1.add_param("layer2");
let mut group2 = ParameterGroup::new("head")
.with_lr_multiplier(1.0)
.with_weight_decay(0.0);
group2.add_param("output");
manager.add_group(group1);
manager.add_group(group2);
assert_eq!(manager.lr_multiplier("layer1"), 0.1);
assert_eq!(manager.lr_multiplier("output"), 1.0);
assert_eq!(manager.weight_decay("layer1"), 0.01);
assert_eq!(manager.weight_decay("output"), 0.0);
}
#[test]
fn test_gradient_statistics() {
let mut param1 = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "p1");
param1.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![1.0, 2.0]).unwrap());
let mut param2 = TQParameter::new(ArrayD::zeros(IxDyn(&[2])), "p2");
param2.grad = Some(ArrayD::from_shape_vec(IxDyn(&[2]), vec![3.0, 4.0]).unwrap());
let stats = gradient_statistics(&[param1, param2]);
assert!((stats.mean - 2.5).abs() < 1e-10);
assert_eq!(stats.min, 1.0);
assert_eq!(stats.max, 4.0);
}
}