use crate::error::{OptimError, Result};
use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use std::fmt::Debug;
use std::ops::{AddAssign, MulAssign, SubAssign};
pub trait InPlaceOptimizer<A: Float + ScalarOperand + Debug, D: Dimension> {
fn step_inplace(&mut self, params: &mut Array<A, D>, gradients: &Array<A, D>) -> Result<()>;
fn step_list_inplace(
&mut self,
params_list: &mut [&mut Array<A, D>],
gradients_list: &[&Array<A, D>],
) -> Result<()> {
if params_list.len() != gradients_list.len() {
return Err(OptimError::InvalidConfig(format!(
"Number of parameter arrays ({}) does not match number of gradient arrays ({})",
params_list.len(),
gradients_list.len()
)));
}
for (params, grads) in params_list.iter_mut().zip(gradients_list.iter()) {
self.step_inplace(params, grads)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct InPlaceSGD<A: Float> {
_learningrate: A,
momentum: A,
weight_decay: A,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> InPlaceSGD<A> {
pub fn new(_learningrate: A) -> Self {
Self {
_learningrate,
momentum: A::zero(),
weight_decay: A::zero(),
}
}
pub fn with_momentum(mut self, momentum: A) -> Self {
self.momentum = momentum;
self
}
pub fn with_weight_decay(mut self, weightdecay: A) -> Self {
self.weight_decay = weightdecay;
self
}
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> InPlaceOptimizer<A, D>
for InPlaceSGD<A>
{
fn step_inplace(&mut self, params: &mut Array<A, D>, gradients: &Array<A, D>) -> Result<()> {
if self.weight_decay > A::zero() {
params.zip_mut_with(gradients, |p, &g| {
*p = *p - self._learningrate * (g + *p * self.weight_decay);
});
} else {
params.zip_mut_with(gradients, |p, &g| {
*p = *p - self._learningrate * g;
});
}
Ok(())
}
}
#[derive(Debug)]
pub struct InPlaceAdam<A: Float, D: Dimension> {
_learningrate: A,
beta1: A,
beta2: A,
epsilon: A,
weight_decay: A,
t: i32,
m: Option<Array<A, D>>,
v: Option<Array<A, D>>,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> InPlaceAdam<A, D> {
pub fn new(_learningrate: A) -> Self {
Self {
_learningrate,
beta1: A::from(0.9).expect("unwrap failed"),
beta2: A::from(0.999).expect("unwrap failed"),
epsilon: A::from(1e-8).expect("unwrap failed"),
weight_decay: A::zero(),
t: 0,
m: None,
v: None,
}
}
pub fn with_beta1(mut self, beta1: A) -> Self {
self.beta1 = beta1;
self
}
pub fn with_beta2(mut self, beta2: A) -> Self {
self.beta2 = beta2;
self
}
pub fn with_weight_decay(mut self, weightdecay: A) -> Self {
self.weight_decay = weightdecay;
self
}
pub fn with_epsilon(mut self, epsilon: A) -> Self {
self.epsilon = epsilon;
self
}
pub fn reset(&mut self) {
self.t = 0;
self.m = None;
self.v = None;
}
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> InPlaceOptimizer<A, D>
for InPlaceAdam<A, D>
{
fn step_inplace(&mut self, params: &mut Array<A, D>, gradients: &Array<A, D>) -> Result<()> {
self.t += 1;
let _t = A::from(self.t).expect("unwrap failed");
if self.m.is_none() {
self.m = Some(Array::zeros(params.raw_dim()));
}
if self.v.is_none() {
self.v = Some(Array::zeros(params.raw_dim()));
}
let m = self.m.as_mut().expect("unwrap failed");
let v = self.v.as_mut().expect("unwrap failed");
let grad_with_decay = if self.weight_decay > A::zero() {
let mut temp = gradients.clone();
temp.zip_mut_with(params, |g, &p| {
*g = *g + p * self.weight_decay;
});
temp
} else {
gradients.clone()
};
m.zip_mut_with(&grad_with_decay, |m_i, &g| {
*m_i = self.beta1 * *m_i + (A::one() - self.beta1) * g;
});
v.zip_mut_with(&grad_with_decay, |v_i, &g| {
*v_i = self.beta2 * *v_i + (A::one() - self.beta2) * g * g;
});
let bias1 = A::one() - self.beta1.powi(self.t);
let bias2 = A::one() - self.beta2.powi(self.t);
let m_iter = m.iter();
let v_iter = v.iter();
let params_iter = params.iter_mut();
for ((p, &m_i), &v_i) in params_iter.zip(m_iter).zip(v_iter) {
let m_hat = m_i / bias1;
let v_hat = v_i / bias2;
*p = *p - self._learningrate * m_hat / (v_hat.sqrt() + self.epsilon);
}
Ok(())
}
}
pub mod utils {
use super::*;
pub fn scale_inplace<A, D>(array: &mut Array<A, D>, scalar: A)
where
A: Float + ScalarOperand + MulAssign,
D: Dimension,
{
array.map_inplace(|x| *x *= scalar);
}
pub fn add_inplace<A, D>(a: &mut Array<A, D>, b: &Array<A, D>)
where
A: Float + ScalarOperand + AddAssign,
D: Dimension,
{
a.zip_mut_with(b, |x, &y| *x += y);
}
pub fn subtract_inplace<A, D>(a: &mut Array<A, D>, b: &Array<A, D>)
where
A: Float + ScalarOperand + SubAssign,
D: Dimension,
{
a.zip_mut_with(b, |x, &y| *x -= y);
}
pub fn apply_inplace<A, D, F>(array: &mut Array<A, D>, f: F)
where
A: Float + ScalarOperand,
D: Dimension,
F: Fn(&mut A),
{
array.map_inplace(f);
}
pub fn clip_inplace<A, D>(array: &mut Array<A, D>, min: A, max: A)
where
A: Float + ScalarOperand,
D: Dimension,
{
array.map_inplace(|x| {
if *x < min {
*x = min;
} else if *x > max {
*x = max;
}
});
}
pub fn normalize_inplace<A, D>(array: &mut Array<A, D>)
where
A: Float + ScalarOperand + MulAssign,
D: Dimension,
{
let norm = array.mapv(|x| x * x).sum().sqrt();
if norm > A::zero() {
array.map_inplace(|x| *x *= A::one() / norm);
}
}
}
pub mod fused {
use super::*;
#[derive(Debug, Clone, Copy)]
pub struct AdamConfig<A> {
pub lr: A,
pub beta1: A,
pub beta2: A,
pub epsilon: A,
pub bias1: A,
pub bias2: A,
pub weight_decay: Option<A>,
}
pub fn fused_adam_update<A, D>(
params: &mut Array<A, D>,
gradients: &Array<A, D>,
m: &mut Array<A, D>,
v: &mut Array<A, D>,
config: AdamConfig<A>,
) where
A: Float + ScalarOperand,
D: Dimension,
{
let one = A::one();
let one_minus_beta1 = one - config.beta1;
let one_minus_beta2 = one - config.beta2;
if let Some(wd) = config.weight_decay {
for ((((p, &g), m_val), v_val), bias_corrected) in params
.iter_mut()
.zip(gradients.iter())
.zip(m.iter_mut())
.zip(v.iter_mut())
.zip(std::iter::repeat((config.bias1, config.bias2)))
{
let g_with_decay = g + *p * wd;
*m_val = config.beta1 * *m_val + one_minus_beta1 * g_with_decay;
*v_val = config.beta2 * *v_val + one_minus_beta2 * g_with_decay * g_with_decay;
let m_hat = *m_val / bias_corrected.0;
let v_hat = *v_val / bias_corrected.1;
*p = *p - config.lr * m_hat / (v_hat.sqrt() + config.epsilon);
}
} else {
for ((((p, &g), m_val), v_val), bias_corrected) in params
.iter_mut()
.zip(gradients.iter())
.zip(m.iter_mut())
.zip(v.iter_mut())
.zip(std::iter::repeat((config.bias1, config.bias2)))
{
*m_val = config.beta1 * *m_val + one_minus_beta1 * g;
*v_val = config.beta2 * *v_val + one_minus_beta2 * g * g;
let m_hat = *m_val / bias_corrected.0;
let v_hat = *v_val / bias_corrected.1;
*p = *p - config.lr * m_hat / (v_hat.sqrt() + config.epsilon);
}
}
}
pub fn fused_sgd_update<A, D>(
params: &mut Array<A, D>,
gradients: &Array<A, D>,
momentum_buf: Option<&mut Array<A, D>>,
lr: A,
momentum: A,
weight_decay: Option<A>,
dampening: A,
) where
A: Float + ScalarOperand,
D: Dimension,
{
if let Some(_buf) = momentum_buf {
if let Some(wd) = weight_decay {
for ((p, g), buf_val) in
params.iter_mut().zip(gradients.iter()).zip(_buf.iter_mut())
{
let g_with_decay = *g + *p * wd;
*buf_val = momentum * *buf_val + (A::one() - dampening) * g_with_decay;
*p = *p - lr * *buf_val;
}
} else {
for ((p, g), buf_val) in
params.iter_mut().zip(gradients.iter()).zip(_buf.iter_mut())
{
*buf_val = momentum * *buf_val + (A::one() - dampening) * *g;
*p = *p - lr * *buf_val;
}
}
} else if let Some(wd) = weight_decay {
for (p, g) in params.iter_mut().zip(gradients.iter()) {
*p = *p - lr * (*g + *p * wd);
}
} else {
for (p, g) in params.iter_mut().zip(gradients.iter()) {
*p = *p - lr * *g;
}
}
}
pub fn fused_gradient_clip_normalize<A, D>(
gradients: &mut Array<A, D>,
max_norm: Option<A>,
clip_value: Option<A>,
) where
A: Float + ScalarOperand,
D: Dimension,
{
if let Some(clip_val) = clip_value {
for g in gradients.iter_mut() {
if *g > clip_val {
*g = clip_val;
} else if *g < -clip_val {
*g = -clip_val;
}
}
}
if let Some(max_norm_val) = max_norm {
let norm_sq = gradients
.iter()
.map(|&x| x * x)
.fold(A::zero(), |acc, x| acc + x);
let _norm = norm_sq.sqrt();
if _norm > max_norm_val {
let scale = max_norm_val / _norm;
for g in gradients.iter_mut() {
*g = *g * scale;
}
}
}
}
pub fn fused_apply_constraints<A, D>(
params: &mut Array<A, D>,
l2_constraint: Option<A>,
value_bounds: Option<(A, A)>,
) where
A: Float + ScalarOperand,
D: Dimension,
{
if let Some((min_val, max_val)) = value_bounds {
for p in params.iter_mut() {
if *p < min_val {
*p = min_val;
} else if *p > max_val {
*p = max_val;
}
}
}
if let Some(max_norm) = l2_constraint {
let norm_sq = params
.iter()
.map(|&x| x * x)
.fold(A::zero(), |acc, x| acc + x);
let norm = norm_sq.sqrt();
if norm > max_norm {
let scale = max_norm / norm;
for p in params.iter_mut() {
*p = *p * scale;
}
}
}
}
}
pub mod mixed_precision {
use super::*;
#[derive(Debug, Clone)]
pub struct LossScaler {
scale: f32,
growth_factor: f32,
backoff_factor: f32,
growth_interval: usize,
steps_since_update: usize,
}
impl LossScaler {
pub fn new(_initialscale: f32) -> Self {
Self {
scale: _initialscale,
growth_factor: 2.0,
backoff_factor: 0.5,
growth_interval: 2000,
steps_since_update: 0,
}
}
pub fn get_scale(&self) -> f32 {
self.scale
}
pub fn scale_loss(&self, loss: f32) -> f32 {
loss * self.scale
}
pub fn unscale_gradients<A, D>(&self, gradients: &mut Array<A, D>)
where
A: Float + ScalarOperand,
D: Dimension,
{
let inv_scale = A::one() / A::from(self.scale).expect("unwrap failed");
for g in gradients.iter_mut() {
*g = *g * inv_scale;
}
}
pub fn update(&mut self, foundinf: bool) {
self.steps_since_update += 1;
if foundinf {
self.scale *= self.backoff_factor;
self.steps_since_update = 0;
} else if self.steps_since_update >= self.growth_interval {
self.scale *= self.growth_factor;
self.steps_since_update = 0;
}
}
pub fn check_gradients<A, D>(&self, gradients: &Array<A, D>) -> bool
where
A: Float + ScalarOperand,
D: Dimension,
{
gradients.iter().any(|&x| !x.is_finite())
}
}
}
pub mod gradient_checkpointing {
use super::*;
use std::collections::VecDeque;
#[derive(Debug, Clone, PartialEq)]
pub enum CheckpointStrategy {
None,
Uniform {
interval: usize,
},
Logarithmic {
base: f64,
},
MemoryAware {
memory_threshold: f64,
},
Custom {
pattern: Vec<bool>,
},
}
#[derive(Debug)]
pub struct GradientCheckpointer<A: Float, D: Dimension> {
strategy: CheckpointStrategy,
checkpoints: std::collections::HashMap<usize, Array<A, D>>,
memory_tracker: MemoryTracker,
current_depth: usize,
max_depth: usize,
enabled: bool,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientCheckpointer<A, D> {
pub fn new(strategy: CheckpointStrategy) -> Self {
Self {
strategy,
checkpoints: std::collections::HashMap::new(),
memory_tracker: MemoryTracker::new(),
current_depth: 0,
max_depth: 0,
enabled: true,
}
}
pub fn set_max_depth(&mut self, depth: usize) {
self.max_depth = depth;
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn should_checkpoint(&self, depth: usize) -> bool {
if !self.enabled || self.max_depth == 0 {
return false;
}
match self.strategy {
CheckpointStrategy::None => false,
CheckpointStrategy::Uniform { interval } => depth.is_multiple_of(interval),
CheckpointStrategy::Logarithmic { base } => {
let log_depth = (depth as f64).log(base).floor() as usize;
depth == base.powi(log_depth as i32) as usize
}
CheckpointStrategy::MemoryAware { memory_threshold } => {
self.memory_tracker.usage_ratio() > memory_threshold
}
CheckpointStrategy::Custom { ref pattern } => {
if depth < pattern.len() {
pattern[depth]
} else {
false
}
}
}
}
pub fn store_checkpoint(&mut self, depth: usize, activation: Array<A, D>) {
if self.should_checkpoint(depth) {
let memory_size = activation.len() * std::mem::size_of::<A>();
self.memory_tracker.add_allocation(memory_size);
self.checkpoints.insert(depth, activation);
}
}
pub fn get_checkpoint(&self, depth: usize) -> Option<&Array<A, D>> {
self.checkpoints.get(&depth)
}
pub fn remove_checkpoint(&mut self, depth: usize) -> Option<Array<A, D>> {
if let Some(checkpoint) = self.checkpoints.remove(&depth) {
let memory_size = checkpoint.len() * std::mem::size_of::<A>();
self.memory_tracker.remove_allocation(memory_size);
Some(checkpoint)
} else {
None
}
}
pub fn clear_checkpoints(&mut self) {
self.checkpoints.clear();
self.memory_tracker.reset();
}
pub fn memory_usage(&self) -> MemoryUsage {
self.memory_tracker.usage()
}
pub fn optimize_strategy(&mut self, target_memoryusage: f64) {
let current_usage = self.memory_tracker.usage_ratio();
if current_usage > target_memoryusage {
self.strategy = match &self.strategy {
CheckpointStrategy::Uniform { interval } => CheckpointStrategy::Uniform {
interval: (interval / 2).max(1),
},
CheckpointStrategy::MemoryAware { .. } => CheckpointStrategy::MemoryAware {
memory_threshold: target_memoryusage * 0.8,
},
other => other.clone(),
};
} else if current_usage < target_memoryusage * 0.5 {
self.strategy = match &self.strategy {
CheckpointStrategy::Uniform { interval } => CheckpointStrategy::Uniform {
interval: interval * 2,
},
CheckpointStrategy::MemoryAware { .. } => CheckpointStrategy::MemoryAware {
memory_threshold: target_memoryusage * 1.2,
},
other => other.clone(),
};
}
}
pub fn checkpointed_forward<F, Output>(
&mut self,
depth: usize,
input: &Array<A, D>,
forward_fn: F,
) -> Result<(Output, Option<Array<A, D>>)>
where
F: FnOnce(&Array<A, D>) -> Result<(Output, Array<A, D>)>,
{
self.current_depth = depth;
let (output, activation) = forward_fn(input)?;
let checkpoint = if self.should_checkpoint(depth) {
self.store_checkpoint(depth, activation.clone());
Some(activation)
} else {
None
};
Ok((output, checkpoint))
}
pub fn recompute_from_checkpoint<F>(
&self,
start_depth: usize,
target_depth: usize,
recompute_fn: F,
) -> Result<Array<A, D>>
where
F: Fn(usize, &Array<A, D>) -> Result<Array<A, D>>,
{
let checkpoint_depth = (0..=start_depth)
.rev()
.find(|&d| self.checkpoints.contains_key(&d))
.ok_or_else(|| {
OptimError::InvalidConfig("No checkpoint found for recomputation".to_string())
})?;
let mut current_activation = self.checkpoints[&checkpoint_depth].clone();
for _depth in (checkpoint_depth + 1)..=target_depth {
current_activation = recompute_fn(_depth, ¤t_activation)?;
}
Ok(current_activation)
}
}
#[derive(Debug, Clone)]
pub struct MemoryTracker {
allocated_bytes: usize,
peak_bytes: usize,
total_system_memory: usize,
}
impl Default for MemoryTracker {
fn default() -> Self {
Self::new()
}
}
impl MemoryTracker {
pub fn new() -> Self {
Self {
allocated_bytes: 0,
peak_bytes: 0,
total_system_memory: Self::estimate_system_memory(),
}
}
pub fn add_allocation(&mut self, bytes: usize) {
self.allocated_bytes += bytes;
self.peak_bytes = self.peak_bytes.max(self.allocated_bytes);
}
pub fn remove_allocation(&mut self, bytes: usize) {
self.allocated_bytes = self.allocated_bytes.saturating_sub(bytes);
}
pub fn usage(&self) -> MemoryUsage {
MemoryUsage {
current_bytes: self.allocated_bytes,
peak_bytes: self.peak_bytes,
total_system_bytes: self.total_system_memory,
}
}
pub fn usage_ratio(&self) -> f64 {
if self.total_system_memory == 0 {
0.0
} else {
self.allocated_bytes as f64 / self.total_system_memory as f64
}
}
pub fn reset(&mut self) {
self.allocated_bytes = 0;
self.peak_bytes = 0;
}
fn estimate_system_memory() -> usize {
8 * 1024 * 1024 * 1024 }
}
#[derive(Debug, Clone, Copy)]
pub struct MemoryUsage {
pub current_bytes: usize,
pub peak_bytes: usize,
pub total_system_bytes: usize,
}
impl MemoryUsage {
pub fn current_ratio(&self) -> f64 {
if self.total_system_bytes == 0 {
0.0
} else {
self.current_bytes as f64 / self.total_system_bytes as f64
}
}
pub fn peak_ratio(&self) -> f64 {
if self.total_system_bytes == 0 {
0.0
} else {
self.peak_bytes as f64 / self.total_system_bytes as f64
}
}
pub fn format(&self) -> String {
format!(
"Current: {:.1} MB ({:.1}%), Peak: {:.1} MB ({:.1}%), Total: {:.1} MB",
self.current_bytes as f64 / (1024.0 * 1024.0),
self.current_ratio() * 100.0,
self.peak_bytes as f64 / (1024.0 * 1024.0),
self.peak_ratio() * 100.0,
self.total_system_bytes as f64 / (1024.0 * 1024.0)
)
}
}
#[derive(Debug)]
pub struct AutoCheckpointer<A: Float, D: Dimension> {
checkpointer: GradientCheckpointer<A, D>,
memory_history: VecDeque<f64>,
target_memoryratio: f64,
adaptation_frequency: usize,
step_count: usize,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> AutoCheckpointer<A, D> {
pub fn new(_initial_strategy: CheckpointStrategy, target_memoryratio: f64) -> Self {
Self {
checkpointer: GradientCheckpointer::new(_initial_strategy),
memory_history: VecDeque::with_capacity(100),
target_memoryratio: target_memoryratio.clamp(0.1, 0.9),
adaptation_frequency: 10,
step_count: 0,
}
}
pub fn with_adaptation_frequency(mut self, frequency: usize) -> Self {
self.adaptation_frequency = frequency.max(1);
self
}
pub fn auto_step<F, Output>(
&mut self,
depth: usize,
input: &Array<A, D>,
forward_fn: F,
) -> Result<(Output, Option<Array<A, D>>)>
where
F: FnOnce(&Array<A, D>) -> Result<(Output, Array<A, D>)>,
{
self.step_count += 1;
let result = self
.checkpointer
.checkpointed_forward(depth, input, forward_fn)?;
let current_usage = self.checkpointer.memory_usage().current_ratio();
self.memory_history.push_back(current_usage);
if self.memory_history.len() > 100 {
self.memory_history.pop_front();
}
if self.step_count.is_multiple_of(self.adaptation_frequency) {
self.adapt_strategy();
}
Ok(result)
}
fn adapt_strategy(&mut self) {
if self.memory_history.len() < 5 {
return;
}
let recent_avg = self.memory_history.iter().rev().take(10).sum::<f64>()
/ 10.0.min(self.memory_history.len() as f64);
let deviation = (recent_avg - self.target_memoryratio).abs();
if deviation > 0.1 {
self.checkpointer.optimize_strategy(self.target_memoryratio);
}
}
pub fn checkpointer(&self) -> &GradientCheckpointer<A, D> {
&self.checkpointer
}
pub fn checkpointer_mut(&mut self) -> &mut GradientCheckpointer<A, D> {
&mut self.checkpointer
}
pub fn get_memory_stats(&self) -> MemoryStats {
let usage = self.checkpointer.memory_usage();
let avg_usage = if self.memory_history.is_empty() {
0.0
} else {
self.memory_history.iter().sum::<f64>() / self.memory_history.len() as f64
};
MemoryStats {
current_usage: usage.current_ratio(),
peak_usage: usage.peak_ratio(),
average_usage: avg_usage,
target_usage: self.target_memoryratio,
checkpoints_stored: self.checkpointer.checkpoints.len(),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct MemoryStats {
pub current_usage: f64,
pub peak_usage: f64,
pub average_usage: f64,
pub target_usage: f64,
pub checkpoints_stored: usize,
}
impl MemoryStats {
pub fn is_within_target(&self, tolerance: f64) -> bool {
(self.current_usage - self.target_usage).abs() <= tolerance
}
pub fn efficiency_score(&self) -> f64 {
if self.current_usage <= self.target_usage {
self.current_usage / self.target_usage
} else {
self.target_usage / self.current_usage
}
}
}
}
pub mod adaptive {
use super::*;
#[derive(Debug, Clone)]
pub struct MemoryAwareBatchSizer {
_initial_batchsize: usize,
max_batch_size: usize,
min_batch_size: usize,
current_batch_size: usize,
memory_threshold: f64, adaptation_factor: f64,
}
impl MemoryAwareBatchSizer {
pub fn new(_initial_batchsize: usize) -> Self {
Self {
_initial_batchsize,
max_batch_size: _initial_batchsize * 4,
min_batch_size: _initial_batchsize.max(1) / 4,
current_batch_size: _initial_batchsize,
memory_threshold: 0.8,
adaptation_factor: 1.2,
}
}
pub fn with_memory_threshold(mut self, threshold: f64) -> Self {
self.memory_threshold = threshold.clamp(0.1, 0.95);
self
}
pub fn with_adaptation_factor(mut self, factor: f64) -> Self {
self.adaptation_factor = factor.max(1.0);
self
}
pub fn current_batch_size(&self) -> usize {
self.current_batch_size
}
pub fn adapt(&mut self, memory_usageratio: f64) {
if memory_usageratio > self.memory_threshold {
let new_size = (self.current_batch_size as f64 / self.adaptation_factor) as usize;
self.current_batch_size = new_size.max(self.min_batch_size);
} else if memory_usageratio < self.memory_threshold * 0.7 {
let new_size = (self.current_batch_size as f64 * self.adaptation_factor) as usize;
self.current_batch_size = new_size.min(self.max_batch_size);
}
}
pub fn reset(&mut self) {
self.current_batch_size = self._initial_batchsize;
}
}
pub fn estimate_memory_usage<A, D>(arrays: &[&Array<A, D>]) -> usize
where
A: Sized,
D: Dimension,
{
arrays
.iter()
.map(|arr| arr.len() * std::mem::size_of::<A>())
.sum()
}
pub fn get_memory_usage_ratio() -> f64 {
0.5 }
}
pub use utils::{
add_inplace, apply_inplace, clip_inplace, normalize_inplace, scale_inplace, subtract_inplace,
};
pub use adaptive::*;
pub use fused::*;
pub use gradient_checkpointing::*;
pub use mixed_precision::*;
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_inplace_sgd() {
let mut optimizer = InPlaceSGD::new(0.1);
let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
optimizer
.step_inplace(&mut params, &gradients)
.expect("unwrap failed");
assert_relative_eq!(params[0], 0.99, epsilon = 1e-6);
assert_relative_eq!(params[1], 1.98, epsilon = 1e-6);
assert_relative_eq!(params[2], 2.97, epsilon = 1e-6);
}
#[test]
fn test_inplace_adam() {
let mut optimizer = InPlaceAdam::new(0.001);
let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
for _ in 0..5 {
optimizer
.step_inplace(&mut params, &gradients)
.expect("unwrap failed");
}
assert!(params[0] < 1.0);
assert!(params[1] < 2.0);
assert!(params[2] < 3.0);
}
#[test]
fn test_utils_scale_inplace() {
let mut array = Array1::from_vec(vec![1.0, 2.0, 3.0]);
utils::scale_inplace(&mut array, 2.0);
assert_eq!(array.as_slice().expect("unwrap failed"), &[2.0, 4.0, 6.0]);
}
#[test]
fn test_utils_clip_inplace() {
let mut array = Array1::from_vec(vec![0.5, 1.5, 2.5]);
utils::clip_inplace(&mut array, 1.0, 2.0);
assert_eq!(array.as_slice().expect("unwrap failed"), &[1.0, 1.5, 2.0]);
}
#[test]
fn test_memory_efficiency() {
let mut params = Array1::from_vec(vec![1.0; 1000]);
let gradients = Array1::from_vec(vec![0.01; 1000]);
let params_ptr = params.as_ptr();
let mut optimizer = InPlaceSGD::new(0.1);
optimizer
.step_inplace(&mut params, &gradients)
.expect("unwrap failed");
assert_eq!(params_ptr, params.as_ptr());
}
#[test]
fn test_fused_adam_update() {
let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let mut m = Array1::zeros(3);
let mut v = Array1::zeros(3);
let config = fused::AdamConfig {
lr: 0.01,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
bias1: 0.1,
bias2: 0.001,
weight_decay: None,
};
fused::fused_adam_update(&mut params, &gradients, &mut m, &mut v, config);
assert!(params[0] < 1.0);
assert!(params[1] < 2.0);
assert!(params[2] < 3.0);
assert!(m[0] > 0.0);
assert!(v[0] > 0.0);
}
#[test]
fn test_fused_sgd_update() {
let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
let mut momentum_buf = Array1::zeros(3);
fused::fused_sgd_update(
&mut params,
&gradients,
Some(&mut momentum_buf),
0.1, 0.9, Some(0.01), 0.0, );
assert!(params[0] < 1.0);
assert!(params[1] < 2.0);
assert!(params[2] < 3.0);
}
#[test]
fn test_fused_gradient_clip_normalize() {
let mut gradients = Array1::from_vec(vec![5.0, -3.0, 2.0]);
fused::fused_gradient_clip_normalize(
&mut gradients,
Some(2.0), Some(1.0), );
assert!(gradients.iter().all(|&x| x.abs() <= 1.0));
let norm = gradients.iter().map(|&x| x * x).sum::<f64>().sqrt();
assert!(norm <= 2.0 + 1e-6);
}
#[test]
fn test_mixed_precision_loss_scaler() {
let scaler = mixed_precision::LossScaler::new(65536.0);
let loss = 0.5;
let scaled_loss = scaler.scale_loss(loss);
assert_eq!(scaled_loss, 0.5 * 65536.0);
let mut gradients = Array1::from_vec(vec![65536.0, 131072.0]);
scaler.unscale_gradients(&mut gradients);
assert_relative_eq!(gradients[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(gradients[1], 2.0, epsilon = 1e-6);
let inf_gradients = Array1::from_vec(vec![f64::INFINITY, 1.0]);
assert!(scaler.check_gradients(&inf_gradients));
let finite_gradients = Array1::from_vec(vec![1.0, 2.0]);
assert!(!scaler.check_gradients(&finite_gradients));
}
#[test]
fn test_memory_aware_batch_sizer() {
let mut sizer = adaptive::MemoryAwareBatchSizer::new(32)
.with_memory_threshold(0.8)
.with_adaptation_factor(1.3);
assert_eq!(sizer.current_batch_size(), 32);
sizer.adapt(0.9);
let reduced_size = sizer.current_batch_size();
assert!(reduced_size < 32);
sizer.adapt(0.3);
sizer.adapt(0.3); assert!(sizer.current_batch_size() >= 32);
sizer.reset();
assert_eq!(sizer.current_batch_size(), 32);
}
#[test]
fn test_memory_estimation() {
let array1 = Array1::from_vec(vec![1.0; 100]);
let array2 = Array1::from_vec(vec![2.0; 200]);
let arrays = vec![&array1, &array2];
let estimated_size = adaptive::estimate_memory_usage(&arrays);
let expected_size = 300 * std::mem::size_of::<f64>();
assert_eq!(estimated_size, expected_size);
}
#[test]
fn test_gradient_checkpointing_uniform() {
let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
f64,
scirs2_core::ndarray::Ix1,
> = gradient_checkpointing::GradientCheckpointer::new(
gradient_checkpointing::CheckpointStrategy::Uniform { interval: 2 },
);
checkpointer.set_max_depth(10);
assert!(checkpointer.should_checkpoint(0));
assert!(!checkpointer.should_checkpoint(1));
assert!(checkpointer.should_checkpoint(2));
assert!(!checkpointer.should_checkpoint(3));
assert!(checkpointer.should_checkpoint(4));
let activation = Array1::from_vec(vec![1.0, 2.0, 3.0]);
checkpointer.store_checkpoint(2, activation.clone());
let retrieved = checkpointer.get_checkpoint(2).expect("unwrap failed");
assert_eq!(
retrieved.as_slice().expect("unwrap failed"),
activation.as_slice().expect("unwrap failed")
);
assert!(checkpointer.get_checkpoint(1).is_none());
}
#[test]
fn test_gradient_checkpointing_logarithmic() {
let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
f64,
scirs2_core::ndarray::Ix1,
> = gradient_checkpointing::GradientCheckpointer::new(
gradient_checkpointing::CheckpointStrategy::Logarithmic { base: 2.0 },
);
checkpointer.set_max_depth(10);
assert!(checkpointer.should_checkpoint(1));
assert!(checkpointer.should_checkpoint(2));
assert!(!checkpointer.should_checkpoint(3));
assert!(checkpointer.should_checkpoint(4));
assert!(!checkpointer.should_checkpoint(5));
assert!(!checkpointer.should_checkpoint(6));
assert!(!checkpointer.should_checkpoint(7));
assert!(checkpointer.should_checkpoint(8));
}
#[test]
fn test_gradient_checkpointing_custom() {
let pattern = vec![true, false, false, true, false];
let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
f64,
scirs2_core::ndarray::Ix1,
> = gradient_checkpointing::GradientCheckpointer::new(
gradient_checkpointing::CheckpointStrategy::Custom { pattern },
);
checkpointer.set_max_depth(10);
assert!(checkpointer.should_checkpoint(0));
assert!(!checkpointer.should_checkpoint(1));
assert!(!checkpointer.should_checkpoint(2));
assert!(checkpointer.should_checkpoint(3));
assert!(!checkpointer.should_checkpoint(4));
assert!(!checkpointer.should_checkpoint(5)); }
#[test]
fn test_gradient_checkpointing_memory_tracking() {
let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
f64,
scirs2_core::ndarray::Ix1,
> = gradient_checkpointing::GradientCheckpointer::new(
gradient_checkpointing::CheckpointStrategy::Uniform { interval: 1 },
);
checkpointer.set_max_depth(5);
let activation1 = Array1::from_vec(vec![1.0; 100]);
let activation2 = Array1::from_vec(vec![2.0; 200]);
checkpointer.store_checkpoint(0, activation1);
let usage_after_first = checkpointer.memory_usage();
assert!(usage_after_first.current_bytes > 0);
checkpointer.store_checkpoint(1, activation2);
let usage_after_second = checkpointer.memory_usage();
assert!(usage_after_second.current_bytes > usage_after_first.current_bytes);
checkpointer.remove_checkpoint(0);
let usage_after_removal = checkpointer.memory_usage();
assert!(usage_after_removal.current_bytes < usage_after_second.current_bytes);
}
#[test]
fn test_checkpointed_forward() {
let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
f64,
scirs2_core::ndarray::Ix1,
> = gradient_checkpointing::GradientCheckpointer::new(
gradient_checkpointing::CheckpointStrategy::Uniform { interval: 1 },
);
checkpointer.set_max_depth(5);
let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let forward_fn = |x: &Array1<f64>| -> Result<(f64, Array1<f64>)> {
let output = x.sum();
let activation = x.mapv(|val| val * 2.0);
Ok((output, activation))
};
let (output, checkpoint) = checkpointer
.checkpointed_forward(0, &input, forward_fn)
.expect("unwrap failed");
assert_eq!(output, 6.0); assert!(checkpoint.is_some());
let checkpoint = checkpoint.expect("unwrap failed");
assert_eq!(
checkpoint.as_slice().expect("unwrap failed"),
&[2.0, 4.0, 6.0]
);
}
#[test]
fn test_recompute_from_checkpoint() {
let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
f64,
scirs2_core::ndarray::Ix1,
> = gradient_checkpointing::GradientCheckpointer::new(
gradient_checkpointing::CheckpointStrategy::Uniform { interval: 2 },
);
checkpointer.set_max_depth(10);
let checkpoint0 = Array1::from_vec(vec![1.0, 2.0]);
let checkpoint2 = Array1::from_vec(vec![3.0, 4.0]);
checkpointer.store_checkpoint(0, checkpoint0);
checkpointer.store_checkpoint(2, checkpoint2);
let recompute_fn =
|_depth: usize, x: &Array1<f64>| -> Result<Array1<f64>> { Ok(x.mapv(|val| val + 1.0)) };
let result = checkpointer
.recompute_from_checkpoint(2, 4, recompute_fn)
.expect("unwrap failed");
assert_eq!(result.as_slice().expect("unwrap failed"), &[5.0, 6.0]);
}
#[test]
fn test_auto_checkpointer() {
let mut auto_checkpointer: AutoCheckpointer<f64, scirs2_core::ndarray::Ix1> =
gradient_checkpointing::AutoCheckpointer::new(
gradient_checkpointing::CheckpointStrategy::Uniform { interval: 2 },
0.6, );
let input = Array1::from_vec(vec![1.0, 2.0]);
let forward_fn = |x: &Array1<f64>| -> Result<(f64, Array1<f64>)> {
let output = x.sum();
let activation = x.clone();
Ok((output, activation))
};
for depth in 0..5 {
let (output_checkpoint, _) = auto_checkpointer
.auto_step(depth, &input, forward_fn)
.expect("unwrap failed");
assert_eq!(output_checkpoint, 3.0); }
let stats = auto_checkpointer.get_memory_stats();
assert!(stats.target_usage > 0.0);
}
#[test]
fn test_memory_stats() {
let stats = gradient_checkpointing::MemoryStats {
current_usage: 0.5,
peak_usage: 0.7,
average_usage: 0.6,
target_usage: 0.6,
checkpoints_stored: 3,
};
assert!(stats.is_within_target(0.1));
assert!(!stats.is_within_target(0.01));
let efficiency = stats.efficiency_score();
assert!(efficiency > 0.8 && efficiency <= 1.0);
}
#[test]
fn test_memory_usage_formatting() {
let usage = gradient_checkpointing::MemoryUsage {
current_bytes: 1024 * 1024, peak_bytes: 2 * 1024 * 1024, total_system_bytes: 8 * 1024 * 1024 * 1024, };
let formatted = usage.format();
assert!(formatted.contains("1.0 MB"));
assert!(formatted.contains("2.0 MB"));
assert!(formatted.contains("8192.0 MB"));
assert_relative_eq!(usage.current_ratio(), 1.0 / 8192.0, epsilon = 1e-6);
assert_relative_eq!(usage.peak_ratio(), 2.0 / 8192.0, epsilon = 1e-6);
}
#[test]
fn test_checkpointing_strategy_optimization() {
let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
f64,
scirs2_core::ndarray::Ix1,
> = gradient_checkpointing::GradientCheckpointer::new(
gradient_checkpointing::CheckpointStrategy::Uniform { interval: 4 },
);
checkpointer.set_max_depth(10);
let checkpoint = Array1::from_vec(vec![1.0, 2.0, 3.0]);
checkpointer.store_checkpoint(0, checkpoint);
checkpointer.optimize_strategy(0.3);
assert!(
checkpointer.should_checkpoint(0)
|| checkpointer.should_checkpoint(1)
|| checkpointer.should_checkpoint(2)
);
}
#[test]
fn test_checkpointing_disabled() {
let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
f64,
scirs2_core::ndarray::Ix1,
> = gradient_checkpointing::GradientCheckpointer::new(
gradient_checkpointing::CheckpointStrategy::Uniform { interval: 1 },
);
checkpointer.set_enabled(false);
assert!(!checkpointer.should_checkpoint(0));
assert!(!checkpointer.should_checkpoint(1));
assert!(!checkpointer.should_checkpoint(2));
}
}