#![allow(clippy::cast_precision_loss)]
use candle_core::{DType, Device, Tensor};
use rust_ai_core::{
estimate_tensor_bytes,
memory::{estimate_attention_memory, DEFAULT_OVERHEAD_FACTOR},
MemoryTracker, Result,
};
fn main() -> Result<()> {
println!("=== Memory Tracking Example ===\n");
let one_gb = 1024 * 1024 * 1024;
let tracker = MemoryTracker::with_limit(one_gb).with_overhead_factor(DEFAULT_OVERHEAD_FACTOR);
println!(
"Memory budget: {} MB\n",
tracker.limit_bytes() / (1024 * 1024)
);
println!("--- Tensor Memory Estimation ---");
let shapes = [
(
&[1, 512, 4096][..],
"Embedding (batch=1, seq=512, dim=4096)",
),
(
&[1, 32, 512, 128][..],
"Attention QKV (batch=1, heads=32, seq=512, head_dim=128)",
),
(
&[1, 4096, 11008][..],
"MLP hidden (batch=1, seq=4096, hidden=11008)",
),
];
for (shape, desc) in shapes {
let bytes_f32 = estimate_tensor_bytes(shape, DType::F32);
let bytes_bf16 = estimate_tensor_bytes(shape, DType::BF16);
let overhead = tracker.estimate_with_overhead(shape, DType::F32);
println!("{desc}:");
println!(" F32: {:>8.2} MB", bytes_f32 as f64 / (1024.0 * 1024.0));
println!(" BF16: {:>8.2} MB", bytes_bf16 as f64 / (1024.0 * 1024.0));
println!(
" With overhead: {:>8.2} MB\n",
overhead as f64 / (1024.0 * 1024.0)
);
}
println!("--- Attention Memory Scaling ---");
for seq_len in [512, 1024, 2048, 4096] {
let attn_bytes = estimate_attention_memory(
1, 32, seq_len,
128, DType::BF16,
);
println!(
"seq_len={seq_len}: {:>8.2} MB",
attn_bytes as f64 / (1024.0 * 1024.0)
);
}
println!();
println!("--- Simulated Training Step ---");
let device = Device::Cpu;
let embed_shape = [1, 512, 4096];
let embed_bytes = tracker.estimate_with_overhead(&embed_shape, DType::F32);
if tracker.would_fit(embed_bytes) {
tracker.allocate(embed_bytes)?;
println!(
"Allocated embedding: {:.2} MB",
embed_bytes as f64 / (1024.0 * 1024.0)
);
}
let attn_shape = [1, 32, 512, 128];
let attn_bytes = tracker.estimate_with_overhead(&attn_shape, DType::F32);
for name in ["Q", "K", "V"] {
if tracker.would_fit(attn_bytes) {
tracker.allocate(attn_bytes)?;
println!(
"Allocated {name}: {:.2} MB",
attn_bytes as f64 / (1024.0 * 1024.0)
);
}
}
println!(
"\nCurrent allocation: {:.2} MB",
tracker.allocated_bytes() as f64 / (1024.0 * 1024.0)
);
println!(
"Peak allocation: {:.2} MB",
tracker.peak_bytes() as f64 / (1024.0 * 1024.0)
);
tracker.deallocate(attn_bytes * 2); println!(
"After freeing K, V: {:.2} MB\n",
tracker.allocated_bytes() as f64 / (1024.0 * 1024.0)
);
println!("--- Actual Tensor Creation ---");
let tensor = Tensor::zeros(&[1, 512, 4096], DType::F32, &device)?;
println!(
"Created tensor shape: {:?}, dtype: {:?}",
tensor.dims(),
tensor.dtype()
);
println!("\n=== Example Complete ===");
Ok(())
}