use crate::error::{NeuralError, Result};
use crate::layers::Layer;
use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
use std::sync::{Arc, Mutex, RwLock};
#[derive(Debug, Default, Clone)]
pub struct MemoryUsageTracker {
stored_activation_bytes: Arc<Mutex<u64>>,
recomputed_bytes: Arc<Mutex<u64>>,
gradient_bytes: Arc<Mutex<u64>>,
peak_bytes: Arc<Mutex<u64>>,
}
impl MemoryUsageTracker {
pub fn new() -> Self {
MemoryUsageTracker {
stored_activation_bytes: Arc::new(Mutex::new(0)),
recomputed_bytes: Arc::new(Mutex::new(0)),
gradient_bytes: Arc::new(Mutex::new(0)),
peak_bytes: Arc::new(Mutex::new(0)),
}
}
pub fn record_stored_activation(&self, bytes: u64) {
let new_stored = {
if let Ok(mut v) = self.stored_activation_bytes.lock() {
*v += bytes;
*v
} else {
return;
}
};
self.update_peak_external(new_stored);
}
pub fn release_stored_activation(&self, bytes: u64) {
if let Ok(mut v) = self.stored_activation_bytes.lock() {
*v = v.saturating_sub(bytes);
}
}
pub fn record_recomputed(&self, bytes: u64) {
if let Ok(mut v) = self.recomputed_bytes.lock() {
*v += bytes;
}
}
pub fn record_gradient(&self, bytes: u64) {
let new_grad = {
if let Ok(mut v) = self.gradient_bytes.lock() {
*v += bytes;
*v
} else {
return;
}
};
self.update_peak_external(new_grad);
}
fn update_peak_external(&self, current_component: u64) {
let stored = self.stored_activation_bytes.lock().map(|v| *v).unwrap_or(0);
let grad = self.gradient_bytes.lock().map(|v| *v).unwrap_or(0);
let combined = stored + grad;
let candidate = combined.max(current_component);
if let Ok(mut pk) = self.peak_bytes.lock() {
if candidate > *pk {
*pk = candidate;
}
}
}
pub fn stored_activation_bytes(&self) -> u64 {
self.stored_activation_bytes
.lock()
.map(|v| *v)
.unwrap_or(0)
}
pub fn recomputed_bytes(&self) -> u64 {
self.recomputed_bytes.lock().map(|v| *v).unwrap_or(0)
}
pub fn gradient_bytes(&self) -> u64 {
self.gradient_bytes.lock().map(|v| *v).unwrap_or(0)
}
pub fn peak_bytes(&self) -> u64 {
self.peak_bytes.lock().map(|v| *v).unwrap_or(0)
}
pub fn reset(&self) {
for m in &[
&self.stored_activation_bytes,
&self.recomputed_bytes,
&self.gradient_bytes,
&self.peak_bytes,
] {
if let Ok(mut v) = m.lock() {
*v = 0;
}
}
}
pub fn summary(&self) -> String {
format!(
"MemoryUsage {{ stored_activations: {} KiB, recomputed: {} KiB, \
gradients: {} KiB, peak: {} KiB }}",
self.stored_activation_bytes() / 1024,
self.recomputed_bytes() / 1024,
self.gradient_bytes() / 1024,
self.peak_bytes() / 1024,
)
}
}
pub struct CheckpointedLayer<F, L>
where
F: Float + Debug + ScalarOperand + NumAssign + 'static,
L: Layer<F> + 'static,
{
inner: L,
saved_input: RwLock<Option<Array<F, IxDyn>>>,
memory_tracker: Option<Arc<MemoryUsageTracker>>,
elem_bytes: usize,
_phantom: std::marker::PhantomData<F>,
}
impl<F, L> CheckpointedLayer<F, L>
where
F: Float + Debug + ScalarOperand + NumAssign + 'static,
L: Layer<F> + 'static,
{
pub fn new(layer: L) -> Self {
CheckpointedLayer {
inner: layer,
saved_input: RwLock::new(None),
memory_tracker: None,
elem_bytes: std::mem::size_of::<F>(),
_phantom: std::marker::PhantomData,
}
}
pub fn with_tracker(mut self, tracker: Arc<MemoryUsageTracker>) -> Self {
self.memory_tracker = Some(tracker);
self
}
pub fn inner(&self) -> &L {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut L {
&mut self.inner
}
pub fn into_inner(self) -> L {
self.inner
}
fn tensor_bytes(arr: &Array<F, IxDyn>) -> u64 {
arr.len() as u64 * std::mem::size_of::<F>() as u64
}
}
impl<F, L> Layer<F> for CheckpointedLayer<F, L>
where
F: Float + Debug + ScalarOperand + NumAssign + Send + Sync + 'static,
L: Layer<F> + Send + Sync + 'static,
{
fn forward(
&self,
input: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let output = self.inner.forward(input)?;
let bytes = Self::tensor_bytes(input);
{
let mut saved = self
.saved_input
.write()
.map_err(|_| NeuralError::ComputationError("lock poisoned in checkpoint forward".to_string()))?;
if let (Some(old), Some(tracker)) = (saved.as_ref(), &self.memory_tracker) {
tracker.release_stored_activation(Self::tensor_bytes(old));
}
*saved = Some(input.clone());
}
if let Some(tracker) = &self.memory_tracker {
tracker.record_stored_activation(bytes);
}
Ok(output)
}
fn backward(
&self,
_input: &Array<F, IxDyn>,
grad_output: &Array<F, IxDyn>,
) -> Result<Array<F, IxDyn>> {
let saved = self
.saved_input
.read()
.map_err(|_| NeuralError::ComputationError("lock poisoned in checkpoint backward".to_string()))?;
let input_ref = saved.as_ref().ok_or_else(|| {
NeuralError::ComputationError(
"CheckpointedLayer: backward called before forward; no saved input".to_string(),
)
})?;
let recomputed = self.inner.forward(input_ref)?;
if let Some(tracker) = &self.memory_tracker {
tracker.record_recomputed(Self::tensor_bytes(&recomputed));
}
let grad_input = self.inner.backward(&recomputed, grad_output)?;
if let Some(tracker) = &self.memory_tracker {
tracker.record_gradient(Self::tensor_bytes(&grad_input));
}
Ok(grad_input)
}
fn update(&mut self, learning_rate: F) -> Result<()> {
self.inner.update(learning_rate)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn params(&self) -> Vec<Array<F, IxDyn>> {
self.inner.params()
}
fn gradients(&self) -> Vec<Array<F, IxDyn>> {
self.inner.gradients()
}
fn set_gradients(&mut self, gradients: &[Array<F, IxDyn>]) -> Result<()> {
self.inner.set_gradients(gradients)
}
fn set_params(&mut self, params: &[Array<F, IxDyn>]) -> Result<()> {
self.inner.set_params(params)
}
fn set_training(&mut self, training: bool) {
self.inner.set_training(training);
}
fn is_training(&self) -> bool {
self.inner.is_training()
}
fn layer_type(&self) -> &str {
"CheckpointedLayer"
}
fn parameter_count(&self) -> usize {
self.inner.parameter_count()
}
fn layer_description(&self) -> String {
format!("CheckpointedLayer({})", self.inner.layer_description())
}
}
pub struct SegmentCheckpointer<F>
where
F: Float + Debug + ScalarOperand + NumAssign + Send + Sync + 'static,
{
layers: Vec<Box<dyn Layer<F>>>,
checkpoint_every: usize,
memory_tracker: Arc<MemoryUsageTracker>,
segment_starts: RwLock<Vec<Array<F, IxDyn>>>,
}
impl<F> SegmentCheckpointer<F>
where
F: Float + Debug + ScalarOperand + NumAssign + Send + Sync + 'static,
{
pub fn new(layers: Vec<Box<dyn Layer<F>>>, checkpoint_every: usize) -> Self {
SegmentCheckpointer {
layers,
checkpoint_every,
memory_tracker: Arc::new(MemoryUsageTracker::new()),
segment_starts: RwLock::new(Vec::new()),
}
}
pub fn with_tracker(mut self, tracker: Arc<MemoryUsageTracker>) -> Self {
self.memory_tracker = tracker;
self
}
pub fn memory_tracker(&self) -> &Arc<MemoryUsageTracker> {
&self.memory_tracker
}
pub fn len(&self) -> usize {
self.layers.len()
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
pub fn checkpoint_every(&self) -> usize {
self.checkpoint_every
}
pub fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
if self.layers.is_empty() {
return Ok(input.clone());
}
let mut segment_starts = self
.segment_starts
.write()
.map_err(|_| NeuralError::ComputationError("lock poisoned".to_string()))?;
segment_starts.clear();
let ckpt = if self.checkpoint_every == 0 {
usize::MAX
} else {
self.checkpoint_every
};
let mut current = input.clone();
for (idx, layer) in self.layers.iter().enumerate() {
if idx % ckpt == 0 {
let bytes = current.len() as u64 * std::mem::size_of::<F>() as u64;
self.memory_tracker.record_stored_activation(bytes);
segment_starts.push(current.clone());
}
current = layer.forward(¤t)?;
}
Ok(current)
}
pub fn backward(&self, grad_output: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
if self.layers.is_empty() {
return Ok(grad_output.clone());
}
let segment_starts = self
.segment_starts
.read()
.map_err(|_| NeuralError::ComputationError("lock poisoned in backward".to_string()))?;
if segment_starts.is_empty() {
return Err(NeuralError::ComputationError(
"SegmentCheckpointer::backward called before forward; no saved segment starts"
.to_string(),
));
}
let ckpt = if self.checkpoint_every == 0 {
usize::MAX
} else {
self.checkpoint_every
};
let n = self.layers.len();
let mut grad = grad_output.clone();
for idx in (0..n).rev() {
let seg_idx = idx / ckpt;
let seg_idx_clamped = seg_idx.min(segment_starts.len().saturating_sub(1));
let seg_start = &segment_starts[seg_idx_clamped];
let seg_offset = seg_idx_clamped * ckpt;
let mut activation = seg_start.clone();
for recomp_idx in seg_offset..idx {
activation = self.layers[recomp_idx].forward(&activation)?;
let bytes = activation.len() as u64 * std::mem::size_of::<F>() as u64;
self.memory_tracker.record_recomputed(bytes);
}
grad = self.layers[idx].backward(&activation, &grad)?;
let bytes = grad.len() as u64 * std::mem::size_of::<F>() as u64;
self.memory_tracker.record_gradient(bytes);
}
Ok(grad)
}
pub fn push(&mut self, layer: Box<dyn Layer<F>>) {
self.layers.push(layer);
}
pub fn layers(&self) -> &[Box<dyn Layer<F>>] {
&self.layers
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::layers::Dense;
use scirs2_core::ndarray::{Array, IxDyn};
use scirs2_core::random::rng;
fn make_dense(in_dim: usize, out_dim: usize) -> Dense<f32> {
let mut rng_state = rng();
Dense::<f32>::new(in_dim, out_dim, None, &mut rng_state)
.expect("dense layer creation failed")
}
fn ones_input(shape: &[usize]) -> Array<f32, IxDyn> {
Array::<f32, IxDyn>::ones(IxDyn(shape))
}
#[test]
fn test_memory_tracker_basic() {
let tracker = MemoryUsageTracker::new();
tracker.record_stored_activation(4096);
assert_eq!(tracker.stored_activation_bytes(), 4096);
tracker.release_stored_activation(1024);
assert_eq!(tracker.stored_activation_bytes(), 3072);
tracker.record_recomputed(8192);
assert_eq!(tracker.recomputed_bytes(), 8192);
tracker.record_gradient(2048);
assert_eq!(tracker.gradient_bytes(), 2048);
}
#[test]
fn test_memory_tracker_peak() {
let tracker = MemoryUsageTracker::new();
tracker.record_stored_activation(1000);
tracker.record_gradient(500);
assert!(tracker.peak_bytes() >= 1000);
}
#[test]
fn test_memory_tracker_reset() {
let tracker = MemoryUsageTracker::new();
tracker.record_stored_activation(4096);
tracker.record_gradient(1024);
tracker.reset();
assert_eq!(tracker.stored_activation_bytes(), 0);
assert_eq!(tracker.gradient_bytes(), 0);
}
#[test]
fn test_memory_tracker_summary_format() {
let tracker = MemoryUsageTracker::new();
let summary = tracker.summary();
assert!(summary.contains("MemoryUsage"));
assert!(summary.contains("KiB"));
}
#[test]
fn test_checkpointed_layer_forward() {
let dense = make_dense(8, 4);
let ckpt = CheckpointedLayer::new(dense);
let input = ones_input(&[2, 8]);
let output = ckpt.forward(&input).expect("forward failed");
assert_eq!(output.shape(), &[2, 4]);
}
#[test]
fn test_checkpointed_layer_backward() {
let dense = make_dense(8, 4);
let ckpt = CheckpointedLayer::new(dense);
let input = ones_input(&[2, 8]);
let _output = ckpt.forward(&input).expect("forward");
let grad_out = ones_input(&[2, 4]);
let grad_in = ckpt.backward(&input, &grad_out).expect("backward");
assert_eq!(grad_in.shape(), &[2, 8]);
}
#[test]
fn test_checkpointed_layer_backward_without_forward_errors() {
let dense = make_dense(4, 2);
let ckpt = CheckpointedLayer::new(dense);
let grad = ones_input(&[1, 2]);
let input_dummy = ones_input(&[1, 4]);
let result = ckpt.backward(&input_dummy, &grad);
assert!(result.is_err(), "backward without forward should error");
}
#[test]
fn test_checkpointed_layer_memory_tracker() {
let tracker = Arc::new(MemoryUsageTracker::new());
let dense = make_dense(8, 4);
let ckpt = CheckpointedLayer::new(dense).with_tracker(tracker.clone());
let input = ones_input(&[3, 8]);
let _out = ckpt.forward(&input).expect("forward");
assert!(tracker.stored_activation_bytes() > 0);
}
#[test]
fn test_checkpointed_layer_trait_delegation() {
let dense = make_dense(4, 2);
let ckpt = CheckpointedLayer::new(dense);
assert_eq!(ckpt.layer_type(), "CheckpointedLayer");
assert!(ckpt.layer_description().contains("CheckpointedLayer"));
assert_eq!(ckpt.parameter_count(), 10);
}
#[test]
fn test_segment_checkpointer_forward() {
let layers: Vec<Box<dyn Layer<f32>>> = vec![
Box::new(make_dense(8, 6)),
Box::new(make_dense(6, 4)),
Box::new(make_dense(4, 2)),
];
let ckpt = SegmentCheckpointer::new(layers, 2);
let input = ones_input(&[2, 8]);
let output = ckpt.forward(&input).expect("forward");
assert_eq!(output.shape(), &[2, 2]);
}
#[test]
fn test_segment_checkpointer_backward() {
let layers: Vec<Box<dyn Layer<f32>>> = vec![
Box::new(make_dense(8, 6)),
Box::new(make_dense(6, 4)),
Box::new(make_dense(4, 2)),
];
let ckpt = SegmentCheckpointer::new(layers, 2);
let input = ones_input(&[2, 8]);
let _output = ckpt.forward(&input).expect("forward");
let grad = ones_input(&[2, 2]);
let grad_in = ckpt.backward(&grad).expect("backward");
assert_eq!(grad_in.shape(), &[2, 8]);
}
#[test]
fn test_segment_checkpointer_backward_without_forward_errors() {
let layers: Vec<Box<dyn Layer<f32>>> = vec![Box::new(make_dense(4, 2))];
let ckpt = SegmentCheckpointer::new(layers, 1);
let grad = ones_input(&[1, 2]);
let result = ckpt.backward(&grad);
assert!(result.is_err());
}
#[test]
fn test_segment_checkpointer_tracks_memory() {
let tracker = Arc::new(MemoryUsageTracker::new());
let layers: Vec<Box<dyn Layer<f32>>> = vec![
Box::new(make_dense(8, 4)),
Box::new(make_dense(4, 2)),
];
let ckpt = SegmentCheckpointer::new(layers, 1).with_tracker(tracker.clone());
let input = ones_input(&[4, 8]);
let _out = ckpt.forward(&input).expect("forward");
assert!(tracker.stored_activation_bytes() > 0);
}
#[test]
fn test_segment_checkpointer_empty() {
let ckpt: SegmentCheckpointer<f32> = SegmentCheckpointer::new(vec![], 2);
let input = ones_input(&[2, 4]);
let out = ckpt.forward(&input).expect("empty forward");
assert_eq!(out.shape(), input.shape());
assert!(ckpt.is_empty());
assert_eq!(ckpt.len(), 0);
}
#[test]
fn test_segment_checkpointer_checkpoint_every() {
let ckpt: SegmentCheckpointer<f32> = SegmentCheckpointer::new(vec![], 3);
assert_eq!(ckpt.checkpoint_every(), 3);
}
}